diff --git a/doc/extending/inplace.rst b/doc/extending/inplace.rst index 74ffa58119..3680b99352 100644 --- a/doc/extending/inplace.rst +++ b/doc/extending/inplace.rst @@ -206,6 +206,101 @@ input(s)'s memory). From there, go to the previous section. Consider using :class:`DebugMode` when developing a new :class:`Op` that uses :attr:`Op.view_map` and/or :attr:`Op.destroy_map`. +The `inplace_on_inputs` Method +============================== + +PyTensor provides a method :meth:`Op.inplace_on_inputs` that allows an `Op` to +create a version of itself that operates inplace on as many of the requested +inputs as possible while avoiding inplace operations on non-requested inputs. + +This method takes a list of input indices where inplace operations are allowed +and returns a new `Op` instance that will perform inplace operations only on +those inputs where it is safe and beneficial to do so. + +.. testcode:: + + import numpy as np + import pytensor + import pytensor.tensor as pt + from pytensor.graph.basic import Apply + from pytensor.graph.op import Op + from pytensor.tensor.blockwise import Blockwise + + class MyOpWithInplace(Op): + __props__ = ("destroy_a",) + + def __init__(self, destroy_a): + self.destroy_a = destroy_a + if destroy_a: + self.destroy_map = {0: [0]} + + def make_node(self, a): + return Apply(self, [a], [a.type()]) + + def perform(self, node, inputs, output_storage): + [a] = inputs + if not self.destroy_a: + a = a.copy() + a[0] += 1 + output_storage[0][0] = a + + def inplace_on_inputs(self, allowed_inplace_inputs): + if 0 in allowed_inplace_inputs: + return MyOpWithInplace(destroy_a=True) + return self + + a = pt.vector("a") + # Only Blockwise trigger inplace automatically for now + # Since the Blockwise isn't needed in this case, it will be removed after the inplace optimization + op = Blockwise(MyOpWithInplace(destroy_a=False), signature="(a)->(a)") + out = op(a) + + # Give PyTensor permission to inplace on user provided inputs + fn = pytensor.function([pytensor.In(a, mutable=True)], out) + + # Confirm that we have the inplace version of the Op + fn.dprint(print_destroy_map=True) + +.. testoutput:: + + Blockwise{MyOpWithInplace{destroy_a=True}, (a)->(a)} [id A] '' 5 + └─ a [id B] + +The output shows that the function now uses the inplace version (`destroy_a=True`). + +.. testcode:: + + # Test that inplace modification works + test_a = np.zeros(5) + result = fn(test_a) + print("Function result:", result) + print("Original array after function call:", test_a) + +.. testoutput:: + + Function result: [1. 0. 0. 0. 0.] + Original array after function call: [1. 0. 0. 0. 0.] + +Currently, this method is primarily used with Blockwise operations through PyTensor's +rewriting system, but it will be extended to support core ops directly in future versions. +The rewriting system automatically calls this method to optimize memory usage by +enabling inplace operations where they do not interfere with the computation graph's +correctness. + +When implementing this method in a custom `Op`: + +- Return a new instance of your `Op` with a :attr:`destroy_map` that reflects + the inplace operations on the allowed inputs +- Ensure that inplace operations are only performed on inputs that are in the + `allowed_inplace_inputs` list +- Return `self` if no inplace optimization is possible or beneficial +- The returned `Op` should be functionally equivalent to the original but with + better memory efficiency + +.. note:: + This method is automatically used by PyTensor's optimization pipeline and typically + does not need to be called directly by user code. + Inplace Rewriting and `DebugMode` ================================= diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 3a00922c87..00490f7357 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -605,7 +605,6 @@ def make_thunk( def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`.""" - # TODO: Document this in the Create your own Op docs # By default, do nothing return self