import warnings
from typing import Any, Literal
import array_api_extra as xpx
from array_api._2024_12 import Array
from array_api_compat import array_namespace
from typing_extensions import deprecated
from ._torch_like import create_slice, select, take_slice
@deprecated(
"This function is too slow thus no longer supported."
"For debugging purposes, please use "
"shift_nth_row_n_steps_advanced_indexing instead.",
category=DeprecationWarning,
)
def shift_nth_row_n_steps_for_loop_assign(
a: Array,
*,
axis_row: int = -2,
axis_shift: int = -1,
cut_padding: bool = False,
mode: Literal["fill"] = "fill",
fill_values: Literal[0] = 0,
) -> Array:
"""
Shifts the nth row n steps to the right.
Parameters
----------
a : Array
The source array.
axis_row : int, optional
The axis of the row to shift, by default -2
axis_shift : int, optional
The axis of the shift, by default -1
cut_padding : bool, optional
Whether to cut additional columns, by default False
mode : Literal["fill", "roll", "abs"], optional
The padding mode, by default "constant"
- fill(padding_mode=constant) -> shift + fill
(result[i,j] = a[i,j+n_shift*i] if j >= i else fill_values)
- roll(padding_mode=wrap) -> shift + roll
(a[i,j] = b[i] then result[i,j] = b[(j+n_shift*i)%len(b)])
- abs(padding_mode=reflect) -> shift + symmetric
(a[i,j] = b[i] then result[i,j] = b[abs(j+n_shift*i)]
not implemented,
do `result + result.T - result * xp.eye(result.shape[-1])` instead
(current behavior aims to support cut_padding = False)
fill_values : Literal[0], optional
The constant value to fill, by default 0
Only used when padding_mode = "constant"
Returns
-------
Array
The shifted array. If the input is (..., row, ..., shift, ...),
the output will be (..., row, ..., shift + row - 1, ...).
[...,i,...,j,...] -> [...,i,...,j+i,...]
"""
xp = array_namespace(a)
input_shape = list(a.shape)
ndim = len(input_shape)
axis_row = axis_row % ndim
axis_shift = axis_shift % ndim
if axis_row == axis_shift:
raise ValueError("axis_row and axis_shift should not be the same.")
row_len = input_shape[axis_row]
shift_len = input_shape[axis_shift]
if cut_padding:
output = xp.zeros_like(a)
else:
output_shape = list(input_shape)
output_shape[axis_shift] = row_len + shift_len - 1
output = xp.zeros(output_shape, dtype=a.dtype, device=a.device)
for i in range(row_len):
row = take_slice(a, i, i + 1, axis=axis_row)
if cut_padding:
if i >= shift_len:
break
output[
create_slice(
ndim, [(axis_row, slice(i, i + 1)), (axis_shift, slice(i, None))]
)
] = take_slice(row, 0, shift_len - i, axis=axis_shift)
else:
output[
create_slice(
ndim,
[
(axis_row, slice(i, i + 1)),
(axis_shift, slice(i, i + shift_len)),
],
)
] = row
return output
@deprecated(
"This function is too slow thus no longer supported."
"For debugging purposes, please use "
"shift_nth_row_n_steps_advanced_indexing instead.",
category=DeprecationWarning,
)
def shift_nth_row_n_steps_for_loop_concat(
a: Array,
*,
axis_row: int = -2,
axis_shift: int = -1,
cut_padding: bool = False,
mode: Literal["fill"] = "fill",
fill_values: Literal[0] = 0,
) -> Array:
"""
Shifts the nth row n steps to the right.
Parameters
----------
a : Array
The source array.
axis_row : int, optional
The axis of the row to shift, by default -2
axis_shift : int, optional
The axis of the shift, by default -1
cut_padding : bool, optional
Whether to cut additional columns, by default False
mode : Literal["fill", "roll", "abs"], optional
The padding mode, by default "constant"
- fill(padding_mode=constant) -> shift + fill
(result[i,j] = a[i,j+n_shift*i] if j >= i else fill_values)
- roll(padding_mode=wrap) -> shift + roll
(a[i,j] = b[i] then result[i,j] = b[(j+n_shift*i)%len(b)])
- abs(padding_mode=reflect) -> shift + symmetric
(a[i,j] = b[i] then result[i,j] = b[abs(j+n_shift*i)]
not implemented,
do `result + result.T - result * xp.eye(result.shape[-1])` instead
(current behavior aims to support cut_padding = False)
fill_values : Literal[0], optional
The constant value to fill, by default 0
Only used when padding_mode = "constant"
Returns
-------
Array
The shifted array. If the input is (..., row, ..., shift, ...),
the output will be (..., row, ..., shift + row - 1, ...).
[...,i,...,j,...] -> [...,i,...,j+i,...]
"""
xp = array_namespace(a)
outputs = []
input_shape = list(a.shape)
row_len = input_shape[axis_row]
shift_len = input_shape[axis_shift]
for i in range(row_len):
row = take_slice(a, i, i + 1, axis=axis_row)
row_shape = row.shape
if cut_padding:
row_cut = take_slice(row, 0, max(0, shift_len - i), axis=axis_shift)
zero_shape = list(row_shape)
zero_shape[axis_shift] = min(i, shift_len)
output = xp.concat(
[xp.zeros(zero_shape, dtype=a.dtype, device=a.device), row_cut],
axis=axis_shift,
).squeeze(axis=axis_row)
else:
zero_shape_left = list(row_shape)
zero_shape_left[axis_shift] = i
zero_shape_right = list(row_shape)
zero_shape_right[axis_shift] = row_len - 1 - i
output = xp.concat(
[
xp.zeros(zero_shape_left, dtype=a.dtype, device=a.device),
row,
xp.zeros(zero_shape_right, dtype=a.dtype, device=a.device),
],
axis=axis_shift,
).squeeze(axis=axis_row)
outputs.append(output)
output = xp.stack(outputs, axis=axis_row)
return output
def shift_nth_row_n_steps_advanced_indexing(
a: Array,
*,
axis_row: int = -2,
axis_shift: int = -1,
cut_padding: bool = False,
mode: Literal["fill", "roll", "abs"] = "fill",
fill_values: float = 0,
) -> Array:
"""
Shifts the nth row n steps to the right.
Parameters
----------
a : Array
The source array.
axis_row : int, optional
The axis of the row to shift, by default -2
axis_shift : int, optional
The axis of the shift, by default -1
cut_padding : bool, optional
Whether to cut additional columns, by default False
mode : Literal["fill", "roll", "abs"], optional
The padding mode, by default "constant"
- fill(padding_mode=constant) -> shift + fill
(result[i,j] = a[i,j+n_shift*i] if j >= i else fill_values)
- roll(padding_mode=wrap) -> shift + roll
(a[i,j] = b[i] then result[i,j] = b[(j+n_shift*i)%len(b)])
- abs(padding_mode=reflect) -> shift + symmetric
(a[i,j] = b[i] then result[i,j] = b[abs(j+n_shift*i)]
not implemented,
do `result + result.T - result * xp.eye(result.shape[-1])` instead
(current behavior aims to support cut_padding = False)
fill_values : float, optional
The constant value to fill, by default 0
Only used when padding_mode = "constant"
Returns
-------
Array
The shifted array. If the input is (..., row, ..., shift, ...),
the output will be (..., row, ..., shift + row - 1, ...).
[...,i,...,j,...] -> [...,i,...,j+i,...]
"""
xp = array_namespace(a)
axis_row_ = -2
axis_shift_ = -1
a = xp.moveaxis(a, (axis_row, axis_shift), (axis_row_, axis_shift_))
shape = a.shape
i_row = xp.arange(shape[axis_row_])[:, None]
i_shift = (
xp.arange(shape[axis_shift_] + (0 if cut_padding else shape[axis_row_] - 1))[
None, :
]
- i_row
)
i_shift = xp.clip(i_shift, -1, shape[axis_shift_])
if not cut_padding:
i_shift = xp.where(i_shift == shape[axis_shift_], -1, i_shift)
a = a[
create_slice(
len(shape),
[(axis_row_, i_row), (axis_shift_, i_shift)],
default=slice(None),
)
]
a = xpx.at(
a, create_slice(len(shape) - 1, [(-1, i_shift == -1)], default=slice(None))
).set(0)
return xp.moveaxis(a, (axis_row_, axis_shift_), (axis_row, axis_shift))
[docs]
def shift_nth_row_n_steps(
a: Array,
*,
axis_row: int = -2,
axis_shift: int = -1,
cut_padding: bool = False,
mode: Literal["fill", "roll", "abs"] = "fill",
fill_values: float = 0,
) -> Array:
"""
Shifts the nth row n steps to the right.
Parameters
----------
a : Array
The source array.
axis_row : int, optional
The axis of the row to shift, by default -2
axis_shift : int, optional
The axis of the shift, by default -1
cut_padding : bool, optional
Whether to cut additional columns, by default False
mode : Literal["fill", "roll", "abs"], optional
The padding mode, by default "constant"
- fill(padding_mode=constant) -> shift + fill
(result[i,j] = a[i,j+n_shift*i] if j >= i else fill_values)
- roll(padding_mode=wrap) -> shift + roll
(a[i,j] = b[i] then result[i,j] = b[(j+n_shift*i)%len(b)])
- abs(padding_mode=reflect) -> shift + symmetric
(a[i,j] = b[i] then result[i,j] = b[abs(j+n_shift*i)]
not implemented,
do `result + result.T - result * xp.eye(result.shape[-1])` instead
(current behavior aims to support cut_padding = False)
fill_values : float, optional
The constant value to fill, by default 0
Only used when padding_mode = "constant"
Returns
-------
Array
The shifted array. If the input is (..., row, ..., shift, ...),
the output will be (..., row, ..., shift + row - 1, ...).
[...,i,...,j,...] -> [...,i,...,j+i,...]
"""
xp = array_namespace(a)
# swap axis_row and -2, axis_shift and -1
axis_row_ = -2
axis_shift_ = -1
a = xp.moveaxis(a, (axis_row, axis_shift), (axis_row_, axis_shift_))
shape = a.shape
l_row = shape[axis_row_]
l_shift = shape[axis_shift_]
if cut_padding and l_shift < l_row:
warnings.warn(
"cut_padding is True, but s < r, which results in redundant computation.",
stacklevel=2,
)
# first pad to [s, r] -> [s+r, r]
# if cut_padding, could be [s, r] -> [s+r-1, r]
# and therefore by mode="reflect", we get symmetric output
mode_ = {
"fill": "constant",
"roll": "wrap",
"abs": "reflect",
}[mode]
if "torch" in str(xp):
if mode_ == "wrap":
mode_ = "circular"
kwargs: dict[str, Any] = {"mode": mode_}
if mode_ == "constant":
kwargs["value"] = fill_values
output = xp.nn.functional.pad(
a,
(0, l_row),
**kwargs,
)
else:
kwargs = {"mode": mode_}
if mode_ == "constant":
kwargs["constant_values"] = fill_values
output = xp.pad(
a,
[(0, 0)] * (len(shape) - 1) + [(0, l_row)],
**kwargs,
)
# flatten axis_shift_ to axis_row_
flatten_shape = list(output.shape)
flatten_shape[axis_shift_] = 1
flatten_shape[axis_row_] = -1
output = output.reshape(flatten_shape)
output = select(output, 0, axis=axis_shift_)
# remove last padding, [(s+r)*r] -> [(s+r-1)*r]
output = take_slice(output, 0, (l_shift + l_row - 1) * l_row, axis=axis_shift_)
# new shape is [s+r-1,r]
result_shape = list(shape)
result_shape[axis_shift_] = l_shift + l_row - 1
output = xp.reshape(output, result_shape)
# cut padding
if cut_padding:
output = take_slice(output, 0, l_shift, axis=axis_shift_)
# return the result
return xp.moveaxis(output, (axis_row_, axis_shift_), (axis_row, axis_shift))