We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 52e01be commit 0fed389Copy full SHA for 0fed389
array_api_compat/torch/_aliases.py
@@ -548,8 +548,12 @@ def count_nonzero(
548
) -> Array:
549
result = torch.count_nonzero(x, dim=axis)
550
if keepdims:
551
- if axis is not None:
+ if isinstance(axis, int):
552
return result.unsqueeze(axis)
553
+ elif isinstance(axis, tuple):
554
+ n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
555
+ sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
556
+ return torch.reshape(result, sh)
557
return _axis_none_keepdims(result, x.ndim, keepdims)
558
else:
559
return result
0 commit comments