diff --git a/keras/src/backend/common/remat_test.py b/keras/src/backend/common/remat_test.py index 2732f5da964a..77983f109f15 100644 --- a/keras/src/backend/common/remat_test.py +++ b/keras/src/backend/common/remat_test.py @@ -116,3 +116,31 @@ def test_remat_basic_call(self): batch_size=batch_size, verbose=0, ) + + def test_remat_with_kwargs(self): + if backend.backend() in ("openvino", "numpy"): + self.skipTest( + "remat is not supported in openvino and numpy backends." + ) + + # Define a function that uses keyword arguments + def fn_with_kwargs(x, scale=1.0, offset=0.0): + return x * scale + offset + + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # Test with keyword arguments + remat_fn = backend.core.remat(fn_with_kwargs) + result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0) + expected = fn_with_kwargs(x, scale=2.0, offset=1.0) + self.assertAllClose(result_with_kwargs, expected) + + # Test with default keyword arguments + result_with_defaults = remat_fn(x) + expected_defaults = fn_with_kwargs(x) + self.assertAllClose(result_with_defaults, expected_defaults) + + # Test with partial keyword arguments + result_partial = remat_fn(x, scale=3.0) + expected_partial = fn_with_kwargs(x, scale=3.0) + self.assertAllClose(result_partial, expected_partial) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..6d4f22095ff2 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -673,7 +673,9 @@ def remat(f): """ def wrapped(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + f, *args, use_reentrant=False, **kwargs + ) return wrapped