1
1
import warnings
2
2
from collections .abc import Collection , Iterable
3
+ from textwrap import dedent
3
4
4
5
import numpy as np
5
6
20
21
from pytensor .npy_2_compat import (
21
22
normalize_axis_index ,
22
23
npy_2_compat_header ,
23
- numpy_axis_is_none_flag ,
24
24
old_np_unique ,
25
25
)
26
26
from pytensor .raise_op import Assert
48
48
from pytensor .tensor .math import sum as pt_sum
49
49
from pytensor .tensor .shape import Shape_i
50
50
from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
51
- from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes , vector
51
+ from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes
52
52
from pytensor .tensor .utils import normalize_reduce_axis
53
53
from pytensor .tensor .variable import TensorVariable
54
54
from pytensor .utils import LOCAL_BITWIDTH , PYTHON_INT_BITWIDTH
@@ -294,30 +294,24 @@ class CumOp(COp):
294
294
__props__ = ("axis" , "mode" )
295
295
check_input = False
296
296
params_type = ParamsType (
297
- c_axis = int_t , mode = EnumList (("MODE_ADD" , "add" ), ("MODE_MUL" , "mul" ))
297
+ axis = int_t , mode = EnumList (("MODE_ADD" , "add" ), ("MODE_MUL" , "mul" ))
298
298
)
299
299
300
- def __init__ (self , axis : int | None = None , mode = "add" ):
300
+ def __init__ (self , axis : int , mode = "add" ):
301
301
if mode not in ("add" , "mul" ):
302
302
raise ValueError (f'{ type (self ).__name__ } : Unknown mode "{ mode } "' )
303
- if not (isinstance (axis , int ) or axis is None ):
304
- raise TypeError ("axis must be an integer or None." )
303
+ if not isinstance (axis , int ):
304
+ raise TypeError ("axis must be an integer." )
305
+ if axis < 0 :
306
+ raise ValueError ("axis must be non-negative." )
305
307
self .axis = axis
306
308
self .mode = mode
307
309
308
- @property
309
- def c_axis (self ) -> int :
310
- if self .axis is None :
311
- return numpy_axis_is_none_flag
312
- return self .axis
313
-
314
310
def make_node (self , x ):
315
311
x = ptb .as_tensor_variable (x )
316
312
out_type = x .type ()
317
313
318
- if self .axis is None :
319
- out_type = vector (dtype = x .dtype ) # Flatten
320
- elif self .axis >= x .ndim or self .axis < - x .ndim :
314
+ if self .axis >= x .ndim :
321
315
raise ValueError (f"axis(={ self .axis } ) out of bounds" )
322
316
323
317
return Apply (self , [x ], [out_type ])
@@ -330,21 +324,10 @@ def perform(self, node, inputs, output_storage):
330
324
else :
331
325
z [0 ] = np .cumprod (x , axis = self .axis )
332
326
333
- def grad (self , inputs , output_gradients ):
327
+ def L_op (self , inputs , outputs , output_gradients ):
334
328
(x ,) = inputs
335
329
(gi ,) = output_gradients
336
330
337
- if self .axis is None :
338
- if self .mode == "add" :
339
- return [cumsum (gi [::- 1 ])[::- 1 ].reshape (x .shape )]
340
- elif self .mode == "mul" :
341
- fx = cumprod (x , axis = self .axis )
342
- return [cumsum ((fx * gi )[::- 1 ])[::- 1 ].reshape (x .shape ) / x ]
343
- else :
344
- raise NotImplementedError (
345
- f'{ type (self ).__name__ } : unknown gradient for mode "{ self .mode } "'
346
- )
347
-
348
331
reverse_slicing = [slice (None , None , None )] * gi .ndim
349
332
reverse_slicing [self .axis ] = slice (None , None , - 1 )
350
333
reverse_slicing = tuple (reverse_slicing )
@@ -361,9 +344,6 @@ def grad(self, inputs, output_gradients):
361
344
)
362
345
363
346
def infer_shape (self , fgraph , node , shapes ):
364
- if self .axis is None and len (shapes [0 ]) > 1 :
365
- return [(prod (shapes [0 ]),)] # Flatten
366
-
367
347
return shapes
368
348
369
349
def c_support_code_apply (self , node : Apply , name : str ) -> str :
@@ -376,61 +356,43 @@ def c_code(self, node, name, inames, onames, sub):
376
356
fail = sub ["fail" ]
377
357
params = sub ["params" ]
378
358
379
- if self .axis is None :
380
- axis_code = "int axis = NPY_RAVEL_AXIS;\n "
381
- else :
382
- axis_code = f"int axis = { params } ->c_axis;\n "
383
-
384
- code = (
385
- axis_code
386
- + f"""
387
- #undef NPY_UF_DBG_TRACING
388
- #define NPY_UF_DBG_TRACING 1
389
-
390
- if (axis == 0 && PyArray_NDIM({ x } ) == 1)
391
- axis = NPY_RAVEL_AXIS;
392
- npy_intp shape[1] = {{ PyArray_SIZE({ x } ) }};
393
- if(axis == NPY_RAVEL_AXIS && !({ z } && PyArray_DIMS({ z } )[0] == shape[0]))
394
- {{
395
- Py_XDECREF({ z } );
396
- { z } = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({ x } ));
397
- }}
359
+ return dedent (
360
+ f"""
361
+ int axis = { params } ->axis;
398
362
399
- else if(axis != NPY_RAVEL_AXIS && !({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
400
- {{
401
- Py_XDECREF({ z } );
402
- { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
403
- }}
363
+ if (!({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
364
+ {{
365
+ Py_XDECREF({ z } );
366
+ { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
367
+ if (!{ z } ){{ { fail } }};
368
+ }}
369
+
370
+ {{
371
+
372
+ PyObject * t = NULL;
373
+ if({ params } ->mode == MODE_ADD)
374
+ t = PyArray_CumSum({ x } , axis, PyArray_TYPE({ x } ), { z } );
375
+ else if({ params } ->mode == MODE_MUL)
376
+ t = PyArray_CumProd({ x } , axis, PyArray_TYPE({ x } ), { z } );
404
377
405
- if (!{ z } )
378
+ if (!t){{
406
379
{ fail } ;
407
- {{
408
-
409
- PyObject * t = NULL;
410
- if({ params } ->mode == MODE_ADD)
411
- t = PyArray_CumSum(
412
- { x } , axis,
413
- PyArray_TYPE({ x } ), { z } );
414
- else if({ params } ->mode == MODE_MUL)
415
- t = PyArray_CumProd(
416
- { x } , axis,
417
- PyArray_TYPE({ x } ), { z } );
418
-
419
- if (!t){{
420
- { fail } ;
421
- }}
422
- // Because PyArray_CumSum/CumProd returns a newly created reference on t.
423
- Py_XDECREF(t);
424
380
}}
381
+
382
+ // Because PyArray_CumSum/CumProd returns a newly created reference on t.
383
+ Py_XDECREF(t);
384
+ }}
425
385
"""
426
386
)
427
387
428
- return code
429
-
430
388
def c_code_cache_version (self ):
431
- return (9 ,)
389
+ return (10 ,)
432
390
433
391
def __str__ (self ):
392
+ if self .mode == "add" :
393
+ return f"Cumsum{{axis={ self .axis } }}"
394
+ elif self .mode == "mul" :
395
+ return f"Cumprod{{axis={ self .axis } }}"
434
396
return f"{ self .__class__ .__name__ } {{{ self .axis } , { self .mode } }}"
435
397
436
398
@@ -451,6 +413,12 @@ def cumsum(x, axis=None):
451
413
.. versionadded:: 0.7
452
414
453
415
"""
416
+ x = ptb .as_tensor_variable (x )
417
+ if axis is None :
418
+ x = x .ravel ()
419
+ axis = 0
420
+ else :
421
+ axis = normalize_axis_index (axis , x .ndim )
454
422
return CumOp (axis = axis , mode = "add" )(x )
455
423
456
424
@@ -471,6 +439,12 @@ def cumprod(x, axis=None):
471
439
.. versionadded:: 0.7
472
440
473
441
"""
442
+ x = ptb .as_tensor_variable (x )
443
+ if axis is None :
444
+ x = x .ravel ()
445
+ axis = 0
446
+ else :
447
+ axis = normalize_axis_index (axis , x .ndim )
474
448
return CumOp (axis = axis , mode = "mul" )(x )
475
449
476
450
@@ -479,18 +453,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
479
453
"""Vectorize the CumOp to work on a batch of inputs."""
480
454
[original_x ] = node .inputs
481
455
batch_ndim = batch_x .ndim - original_x .ndim
482
- axis = op .axis
483
- if axis is None and original_x .ndim == 1 :
484
- axis = 0
485
- elif axis is not None :
486
- axis = normalize_axis_index (op .axis , original_x .ndim )
487
-
488
- if axis is None :
489
- # Ravel all unbatched dimensions and perform CumOp on the last axis
490
- batch_x_raveled = [batch_x .flatten (ndim = batch_ndim + 1 ) for x in batch_x ]
491
- return type (op )(axis = - 1 , mode = op .mode ).make_node (batch_x_raveled )
492
- else :
493
- return type (op )(axis = axis + batch_ndim , mode = op .mode ).make_node (batch_x )
456
+ # op.axis is already normalized and non-negative
457
+ return type (op )(axis = op .axis + batch_ndim , mode = op .mode ).make_node (batch_x )
494
458
495
459
496
460
def diff (x , n = 1 , axis = - 1 ):
0 commit comments