diff --git a/keras/src/backend/common/remat_test.py b/keras/src/backend/common/remat_test.py index 2732f5da964a..77983f109f15 100644 --- a/keras/src/backend/common/remat_test.py +++ b/keras/src/backend/common/remat_test.py @@ -116,3 +116,31 @@ def test_remat_basic_call(self): batch_size=batch_size, verbose=0, ) + + def test_remat_with_kwargs(self): + if backend.backend() in ("openvino", "numpy"): + self.skipTest( + "remat is not supported in openvino and numpy backends." + ) + + # Define a function that uses keyword arguments + def fn_with_kwargs(x, scale=1.0, offset=0.0): + return x * scale + offset + + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # Test with keyword arguments + remat_fn = backend.core.remat(fn_with_kwargs) + result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0) + expected = fn_with_kwargs(x, scale=2.0, offset=1.0) + self.assertAllClose(result_with_kwargs, expected) + + # Test with default keyword arguments + result_with_defaults = remat_fn(x) + expected_defaults = fn_with_kwargs(x) + self.assertAllClose(result_with_defaults, expected_defaults) + + # Test with partial keyword arguments + result_partial = remat_fn(x, scale=3.0) + expected_partial = fn_with_kwargs(x, scale=3.0) + self.assertAllClose(result_partial, expected_partial) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..4a3814d29626 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -5,6 +5,7 @@ import ml_dtypes import numpy as np import torch +from torch.utils.checkpoint import checkpoint from keras.src import tree from keras.src.backend.common import KerasVariable @@ -673,7 +674,13 @@ def remat(f): """ def wrapped(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False) + if not kwargs: + return checkpoint(f, *args, use_reentrant=False) + + def positional_wrapper(*pos_args): + return f(*pos_args, **kwargs) + + return checkpoint(positional_wrapper, *args, use_reentrant=False) return wrapped