Skip to content

Commit 1cf93c3

Browse files
CopilotricardoV94
authored andcommitted
Implement axis=None raveling behavior symbolically in CumOp
1 parent 79444a3 commit 1cf93c3

File tree

5 files changed

+70
-142
lines changed

5 files changed

+70
-142
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
3737
mode = op.mode
3838
ndim = cast(TensorVariable, node.outputs[0]).ndim
3939

40-
if axis is not None:
41-
if axis < 0:
42-
axis = ndim + axis
43-
if axis < 0 or axis >= ndim:
44-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
45-
46-
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
47-
reaxis_first_inv = tuple(np.argsort(reaxis_first))
40+
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
41+
reaxis_first_inv = tuple(np.argsort(reaxis_first))
4842

4943
if mode == "add":
50-
if axis is None or ndim == 1:
44+
if ndim == 1:
5145

5246
@numba_basic.numba_njit
5347
def cumop(x):
54-
return np.cumsum(x)
48+
return np.cumsum(x, axis=axis)
5549

5650
else:
5751

@@ -71,11 +65,11 @@ def cumop(x):
7165
return res.transpose(reaxis_first_inv)
7266

7367
else:
74-
if axis is None or ndim == 1:
68+
if ndim == 1:
7569

7670
@numba_basic.numba_njit
7771
def cumop(x):
78-
return np.cumprod(x)
72+
return np.cumprod(x, axis=axis)
7973

8074
else:
8175

@@ -92,7 +86,7 @@ def cumop(x):
9286
for m in range(1, x.shape[axis]):
9387
res[m] = res[m - 1] * x_axis_first[m]
9488

95-
return res.transpose(reaxis_first)
89+
return res.transpose(reaxis_first_inv)
9690

9791
return cumop
9892

pytensor/link/pytorch/dispatch/extra_ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
1010
mode = op.mode
1111

1212
def cumop(x):
13-
if axis is None:
14-
x = x.reshape(-1)
15-
dim = 0
16-
else:
17-
dim = axis
1813
if mode == "add":
19-
return torch.cumsum(x, dim=dim)
14+
return torch.cumsum(x, dim=axis)
2015
else:
21-
return torch.cumprod(x, dim=dim)
16+
return torch.cumprod(x, dim=axis)
2217

2318
return cumop
2419

pytensor/tensor/extra_ops.py

Lines changed: 51 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from collections.abc import Collection, Iterable
3+
from textwrap import dedent
34

45
import numpy as np
56

@@ -20,7 +21,6 @@
2021
from pytensor.npy_2_compat import (
2122
normalize_axis_index,
2223
npy_2_compat_header,
23-
numpy_axis_is_none_flag,
2424
old_np_unique,
2525
)
2626
from pytensor.raise_op import Assert
@@ -48,7 +48,7 @@
4848
from pytensor.tensor.math import sum as pt_sum
4949
from pytensor.tensor.shape import Shape_i
5050
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
51-
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
51+
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
5252
from pytensor.tensor.utils import normalize_reduce_axis
5353
from pytensor.tensor.variable import TensorVariable
5454
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
@@ -294,30 +294,24 @@ class CumOp(COp):
294294
__props__ = ("axis", "mode")
295295
check_input = False
296296
params_type = ParamsType(
297-
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
297+
axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
298298
)
299299

300-
def __init__(self, axis: int | None = None, mode="add"):
300+
def __init__(self, axis: int, mode="add"):
301301
if mode not in ("add", "mul"):
302302
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
303-
if not (isinstance(axis, int) or axis is None):
304-
raise TypeError("axis must be an integer or None.")
303+
if not isinstance(axis, int):
304+
raise TypeError("axis must be an integer.")
305+
if axis < 0:
306+
raise ValueError("axis must be non-negative.")
305307
self.axis = axis
306308
self.mode = mode
307309

308-
@property
309-
def c_axis(self) -> int:
310-
if self.axis is None:
311-
return numpy_axis_is_none_flag
312-
return self.axis
313-
314310
def make_node(self, x):
315311
x = ptb.as_tensor_variable(x)
316312
out_type = x.type()
317313

318-
if self.axis is None:
319-
out_type = vector(dtype=x.dtype) # Flatten
320-
elif self.axis >= x.ndim or self.axis < -x.ndim:
314+
if self.axis >= x.ndim:
321315
raise ValueError(f"axis(={self.axis}) out of bounds")
322316

323317
return Apply(self, [x], [out_type])
@@ -330,21 +324,10 @@ def perform(self, node, inputs, output_storage):
330324
else:
331325
z[0] = np.cumprod(x, axis=self.axis)
332326

333-
def grad(self, inputs, output_gradients):
327+
def L_op(self, inputs, outputs, output_gradients):
334328
(x,) = inputs
335329
(gi,) = output_gradients
336330

337-
if self.axis is None:
338-
if self.mode == "add":
339-
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
340-
elif self.mode == "mul":
341-
fx = cumprod(x, axis=self.axis)
342-
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
343-
else:
344-
raise NotImplementedError(
345-
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
346-
)
347-
348331
reverse_slicing = [slice(None, None, None)] * gi.ndim
349332
reverse_slicing[self.axis] = slice(None, None, -1)
350333
reverse_slicing = tuple(reverse_slicing)
@@ -361,9 +344,6 @@ def grad(self, inputs, output_gradients):
361344
)
362345

363346
def infer_shape(self, fgraph, node, shapes):
364-
if self.axis is None and len(shapes[0]) > 1:
365-
return [(prod(shapes[0]),)] # Flatten
366-
367347
return shapes
368348

369349
def c_support_code_apply(self, node: Apply, name: str) -> str:
@@ -376,61 +356,43 @@ def c_code(self, node, name, inames, onames, sub):
376356
fail = sub["fail"]
377357
params = sub["params"]
378358

379-
if self.axis is None:
380-
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
381-
else:
382-
axis_code = f"int axis = {params}->c_axis;\n"
383-
384-
code = (
385-
axis_code
386-
+ f"""
387-
#undef NPY_UF_DBG_TRACING
388-
#define NPY_UF_DBG_TRACING 1
389-
390-
if (axis == 0 && PyArray_NDIM({x}) == 1)
391-
axis = NPY_RAVEL_AXIS;
392-
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
393-
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
394-
{{
395-
Py_XDECREF({z});
396-
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
397-
}}
359+
return dedent(
360+
f"""
361+
int axis = {params}->axis;
398362
399-
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
400-
{{
401-
Py_XDECREF({z});
402-
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
403-
}}
363+
if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
364+
{{
365+
Py_XDECREF({z});
366+
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
367+
if (!{z}){{ {fail} }};
368+
}}
369+
370+
{{
371+
372+
PyObject * t = NULL;
373+
if({params}->mode == MODE_ADD)
374+
t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
375+
else if({params}->mode == MODE_MUL)
376+
t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});
404377
405-
if (!{z})
378+
if (!t){{
406379
{fail};
407-
{{
408-
409-
PyObject * t = NULL;
410-
if({params}->mode == MODE_ADD)
411-
t = PyArray_CumSum(
412-
{x}, axis,
413-
PyArray_TYPE({x}), {z});
414-
else if({params}->mode == MODE_MUL)
415-
t = PyArray_CumProd(
416-
{x}, axis,
417-
PyArray_TYPE({x}), {z});
418-
419-
if (!t){{
420-
{fail};
421-
}}
422-
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
423-
Py_XDECREF(t);
424380
}}
381+
382+
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
383+
Py_XDECREF(t);
384+
}}
425385
"""
426386
)
427387

428-
return code
429-
430388
def c_code_cache_version(self):
431-
return (9,)
389+
return (10,)
432390

433391
def __str__(self):
392+
if self.mode == "add":
393+
return f"Cumsum{{axis={self.axis}}}"
394+
elif self.mode == "mul":
395+
return f"Cumprod{{axis={self.axis}}}"
434396
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
435397

436398

@@ -451,6 +413,12 @@ def cumsum(x, axis=None):
451413
.. versionadded:: 0.7
452414
453415
"""
416+
x = ptb.as_tensor_variable(x)
417+
if axis is None:
418+
x = x.ravel()
419+
axis = 0
420+
else:
421+
axis = normalize_axis_index(axis, x.ndim)
454422
return CumOp(axis=axis, mode="add")(x)
455423

456424

@@ -471,6 +439,12 @@ def cumprod(x, axis=None):
471439
.. versionadded:: 0.7
472440
473441
"""
442+
x = ptb.as_tensor_variable(x)
443+
if axis is None:
444+
x = x.ravel()
445+
axis = 0
446+
else:
447+
axis = normalize_axis_index(axis, x.ndim)
474448
return CumOp(axis=axis, mode="mul")(x)
475449

476450

@@ -479,18 +453,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
479453
"""Vectorize the CumOp to work on a batch of inputs."""
480454
[original_x] = node.inputs
481455
batch_ndim = batch_x.ndim - original_x.ndim
482-
axis = op.axis
483-
if axis is None and original_x.ndim == 1:
484-
axis = 0
485-
elif axis is not None:
486-
axis = normalize_axis_index(op.axis, original_x.ndim)
487-
488-
if axis is None:
489-
# Ravel all unbatched dimensions and perform CumOp on the last axis
490-
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
491-
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
492-
else:
493-
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
456+
# op.axis is already normalized and non-negative
457+
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)
494458

495459

496460
def diff(x, n=1, axis=-1):

tests/link/pytorch/test_extra_ops.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,13 @@
55
from tests.link.pytorch.test_basic import compare_pytorch_and_py
66

77

8-
@pytest.mark.parametrize(
9-
"dtype",
10-
["float64", "int64"],
11-
)
12-
@pytest.mark.parametrize(
13-
"axis",
14-
[None, 1, (0,)],
15-
)
8+
@pytest.mark.parametrize("dtype", ["float64", "int64"])
9+
@pytest.mark.parametrize("axis", [None, -1])
1610
def test_pytorch_CumOp(axis, dtype):
17-
"""Test PyTorch conversion of the `CumOp` `Op`."""
18-
19-
# Create a symbolic input for the first input of `CumOp`
2011
a = pt.matrix("a", dtype=dtype)
21-
22-
# Create test value
2312
test_value = np.arange(9, dtype=dtype).reshape((3, 3))
24-
25-
# Create the output variable
26-
if isinstance(axis, tuple):
27-
with pytest.raises(TypeError, match="axis must be an integer or None."):
28-
out = pt.cumsum(a, axis=axis)
29-
with pytest.raises(TypeError, match="axis must be an integer or None."):
30-
out = pt.cumprod(a, axis=axis)
31-
else:
32-
out = pt.cumsum(a, axis=axis)
33-
34-
# Pass the inputs and outputs to the testing function
35-
compare_pytorch_and_py([a], [out], [test_value])
36-
37-
# For the second mode of CumOp
38-
out = pt.cumprod(a, axis=axis)
39-
40-
compare_pytorch_and_py([a], [out], [test_value])
13+
outs = [pt.cumsum(a, axis=axis), pt.cumprod(a, axis=axis)]
14+
compare_pytorch_and_py([a], outs, [test_value])
4115

4216

4317
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])

tests/tensor/test_extra_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class TestCumOp(utt.InferShapeTester):
194194
def setup_method(self):
195195
super().setup_method()
196196
self.op_class = CumOp
197-
self.op = CumOp()
197+
self.op = CumOp(axis=0)
198198

199199
def test_cum_op(self):
200200
x = tensor3("x")
@@ -225,17 +225,18 @@ def test_infer_shape(self):
225225
x = tensor3("x")
226226
a = np.random.random((3, 5, 2)).astype(config.floatX)
227227

228-
# Test axis=None
229-
self._compile_and_check([x], [self.op(x)], [a], self.op_class)
228+
# Test default axis=None
229+
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)
230230

231231
for axis in range(-len(a.shape), len(a.shape)):
232232
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)
233233

234234
def test_grad(self):
235235
a = np.random.random((3, 5, 2)).astype(config.floatX)
236236

237-
utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None
238-
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None
237+
# Test default axis=None using cumsum/cumprod functions
238+
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
239+
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
239240

240241
for axis in range(-len(a.shape), len(a.shape)):
241242
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)

0 commit comments

Comments
 (0)