Skip to content

Commit 9529145

Browse files
authored
ENH: delegate create_diagonal function (#501)
1 parent ddd59e5 commit 9529145

File tree

3 files changed

+66
-48
lines changed

3 files changed

+66
-48
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
argpartition,
55
atleast_nd,
66
cov,
7+
create_diagonal,
78
expand_dims,
89
isclose,
910
isin,
@@ -18,7 +19,6 @@
1819
from ._lib._funcs import (
1920
apply_where,
2021
broadcast_shapes,
21-
create_diagonal,
2222
default_dtype,
2323
kron,
2424
nunique,

src/array_api_extra/_delegation.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
__all__ = [
2222
"atleast_nd",
2323
"cov",
24+
"create_diagonal",
2425
"expand_dims",
2526
"isclose",
2627
"nan_to_num",
@@ -174,6 +175,67 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
174175
return _funcs.cov(m, xp=xp)
175176

176177

178+
def create_diagonal(
179+
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
180+
) -> Array:
181+
"""
182+
Construct a diagonal array.
183+
184+
Parameters
185+
----------
186+
x : array
187+
An array having shape ``(*batch_dims, k)``.
188+
offset : int, optional
189+
Offset from the leading diagonal (default is ``0``).
190+
Use positive ints for diagonals above the leading diagonal,
191+
and negative ints for diagonals below the leading diagonal.
192+
xp : array_namespace, optional
193+
The standard-compatible namespace for `x`. Default: infer.
194+
195+
Returns
196+
-------
197+
array
198+
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
199+
on the diagonal (offset by `offset`).
200+
201+
Examples
202+
--------
203+
>>> import array_api_strict as xp
204+
>>> import array_api_extra as xpx
205+
>>> x = xp.asarray([2, 4, 8])
206+
207+
>>> xpx.create_diagonal(x, xp=xp)
208+
Array([[2, 0, 0],
209+
[0, 4, 0],
210+
[0, 0, 8]], dtype=array_api_strict.int64)
211+
212+
>>> xpx.create_diagonal(x, offset=-2, xp=xp)
213+
Array([[0, 0, 0, 0, 0],
214+
[0, 0, 0, 0, 0],
215+
[2, 0, 0, 0, 0],
216+
[0, 4, 0, 0, 0],
217+
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
218+
"""
219+
if xp is None:
220+
xp = array_namespace(x)
221+
222+
if x.ndim == 0:
223+
err_msg = "`x` must be at least 1-dimensional."
224+
raise ValueError(err_msg)
225+
226+
if is_torch_namespace(xp):
227+
return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1)
228+
229+
if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2:
230+
return xp.diag(x, k=offset)
231+
232+
if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3:
233+
batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset)
234+
return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n))
235+
236+
return _funcs.create_diagonal(x, offset=offset, xp=xp)
237+
238+
177239
def expand_dims(
178240
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
179241
) -> Array:

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -295,53 +295,9 @@ def one_hot(
295295

296296

297297
def create_diagonal(
298-
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
299-
) -> Array:
300-
"""
301-
Construct a diagonal array.
302-
303-
Parameters
304-
----------
305-
x : array
306-
An array having shape ``(*batch_dims, k)``.
307-
offset : int, optional
308-
Offset from the leading diagonal (default is ``0``).
309-
Use positive ints for diagonals above the leading diagonal,
310-
and negative ints for diagonals below the leading diagonal.
311-
xp : array_namespace, optional
312-
The standard-compatible namespace for `x`. Default: infer.
313-
314-
Returns
315-
-------
316-
array
317-
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
318-
on the diagonal (offset by `offset`).
319-
320-
Examples
321-
--------
322-
>>> import array_api_strict as xp
323-
>>> import array_api_extra as xpx
324-
>>> x = xp.asarray([2, 4, 8])
325-
326-
>>> xpx.create_diagonal(x, xp=xp)
327-
Array([[2, 0, 0],
328-
[0, 4, 0],
329-
[0, 0, 8]], dtype=array_api_strict.int64)
330-
331-
>>> xpx.create_diagonal(x, offset=-2, xp=xp)
332-
Array([[0, 0, 0, 0, 0],
333-
[0, 0, 0, 0, 0],
334-
[2, 0, 0, 0, 0],
335-
[0, 4, 0, 0, 0],
336-
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
337-
"""
338-
if xp is None:
339-
xp = array_namespace(x)
340-
341-
if x.ndim == 0:
342-
err_msg = "`x` must be at least 1-dimensional."
343-
raise ValueError(err_msg)
344-
298+
x: Array, /, *, offset: int = 0, xp: ModuleType
299+
) -> Array: # numpydoc ignore=PR01,RT01
300+
"""See docstring in array_api_extra._delegation."""
345301
x_shape = eager_shape(x)
346302
batch_dims = x_shape[:-1]
347303
n = x_shape[-1] + abs(offset)

0 commit comments

Comments
 (0)