Source code for shift_nth_row_n_steps._torch_like

from collections.abc import Callable, Sequence
from types import EllipsisType
from typing import Any

from array_api._2024_12 import Array
from array_api_compat import array_namespace


[docs] def create_slice( ndim: int, axis_and_key: Sequence[tuple[int, int | slice | EllipsisType | None]], *, default: int | slice | EllipsisType | None | Callable[[], int | slice | EllipsisType | None] = lambda: slice(None), ) -> tuple[int | slice | EllipsisType | None, ...]: """ Create a slice tuple with default values. Parameters ---------- ndim : int The number of dimensions. axis_and_key : Sequence[tuple[int, int | slice | EllipsisType | None]] The axis and key pair. default : int | slice | EllipsisType | None, optional The default value, by default slice(None,) Returns ------- tuple[int | slice | EllipsisType | None, ...] The slice tuple. """ if isinstance(default, Callable): # type: ignore default_ = default() # type: ignore else: default_ = default result = [default_] * ndim for axis, key in axis_and_key: result[axis] = key return tuple(result)
[docs] def take_slice(a: Array, start: int, end: int, *, axis: int) -> Array: """ numpy.take() alternative using slices. (faster) similar to torch.narrow(). Parameters ---------- a : Array The source array. start : int The index of the element to start from. end : int The index of the element to end at. axis : int The axis to take the slice from. Returns ------- Array The sliced array. """ ndim = a.ndim axis = axis % ndim return a[create_slice(ndim, [(axis, slice(start, end))])]
[docs] def narrow(a: Array, start: int, length: int, *, axis: int) -> Array: """ torch.narrow() in xp. Parameters ---------- a : Array The source array. start : int The index of the element to start from. length : int The length of the slice. axis : int The axis to narrow. Returns ------- Array The narrowed array. """ return take_slice(a, start, start + length, axis=axis)
[docs] def select(a: Array, index: int, *, axis: int) -> Array: """ torch.select() (!= numpy.select()) in xp. Parameters ---------- a : Array The source array. index : int The index of the element to select. axis : int The axis to select from. Returns ------- Array The selected array. """ ndim = a.ndim axis = axis % ndim return a[create_slice(ndim, [(axis, index)])]
def advanced_indexing_nan(a: Array, index: Any) -> Array: """ Advanced indexing with NaN. Parameters ---------- a : Array The source array. index : int The index of the element to select. Returns ------- Array The selected array. """ xp = array_namespace(a) a = a[xp.nan_to_num(index, 0)] a[xp.isnan(index)] = xp.nan return a