Skip to content

Commit 43de4af

Browse files
authored
Merge pull request #385 from ev-br/expand_dims_tuples
ENH: add axis tuple support to torch.expand_dims
2 parents 0da540c + 35b631f commit 43de4af

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,24 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
694694
return torch.triu(x, k)
695695

696696
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
697-
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
698-
return torch.unsqueeze(x, axis)
697+
def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array:
698+
if isinstance(axis, int):
699+
return torch.unsqueeze(x, axis)
700+
else:
701+
# follow https://github.com/numpy/numpy/blob/maintenance/2.4.x/numpy/lib/_shape_base_impl.py#L596-L602
702+
y_ndim = x.ndim + len(axis)
703+
704+
# normalize
705+
n_axis = tuple(ax + y_ndim if ax < 0 else ax for ax in axis)
706+
if (len(n_axis) != len(set(n_axis)) or
707+
_builtin_any(ax < 0 or ax >= y_ndim for ax in n_axis)
708+
):
709+
raise ValueError(f"{axis=} not allowed for {x.shape = }")
710+
711+
shape_it = iter(x.shape)
712+
shape = [1 if ax in n_axis else next(shape_it) for ax in range(y_ndim)]
699713

714+
return torch.reshape(x, shape)
700715

701716
def astype(
702717
x: Array,

0 commit comments

Comments
 (0)