Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions keras/src/backend/common/remat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,31 @@ def test_remat_basic_call(self):
batch_size=batch_size,
verbose=0,
)

def test_remat_with_kwargs(self):
if backend.backend() in ("openvino", "numpy"):
self.skipTest(
"remat is not supported in openvino and numpy backends."
)

# Define a function that uses keyword arguments
def fn_with_kwargs(x, scale=1.0, offset=0.0):
return x * scale + offset

x = np.array([1.0, 2.0, 3.0], dtype=np.float32)

# Test with keyword arguments
remat_fn = backend.core.remat(fn_with_kwargs)
result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0)
expected = fn_with_kwargs(x, scale=2.0, offset=1.0)
self.assertAllClose(result_with_kwargs, expected)

# Test with default keyword arguments
result_with_defaults = remat_fn(x)
expected_defaults = fn_with_kwargs(x)
self.assertAllClose(result_with_defaults, expected_defaults)

# Test with partial keyword arguments
result_partial = remat_fn(x, scale=3.0)
expected_partial = fn_with_kwargs(x, scale=3.0)
self.assertAllClose(result_partial, expected_partial)
Comment on lines +132 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test cases for different keyword argument scenarios are well-covered. For better maintainability and readability, you could consider parameterizing this test using subTest. This would make it more compact and easier to add new test cases in the future.

Suggested change
# Test with keyword arguments
remat_fn = backend.core.remat(fn_with_kwargs)
result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0)
expected = fn_with_kwargs(x, scale=2.0, offset=1.0)
self.assertAllClose(result_with_kwargs, expected)
# Test with default keyword arguments
result_with_defaults = remat_fn(x)
expected_defaults = fn_with_kwargs(x)
self.assertAllClose(result_with_defaults, expected_defaults)
# Test with partial keyword arguments
result_partial = remat_fn(x, scale=3.0)
expected_partial = fn_with_kwargs(x, scale=3.0)
self.assertAllClose(result_partial, expected_partial)
remat_fn = backend.core.remat(fn_with_kwargs)
test_cases = [
("with_kwargs", {"scale": 2.0, "offset": 1.0}),
("with_defaults", {}),
("partial_kwargs", {"scale": 3.0}),
]
for name, kwargs in test_cases:
with self.subTest(msg=name):
result = remat_fn(x, **kwargs)
expected = fn_with_kwargs(x, **kwargs)
self.assertAllClose(result, expected)

4 changes: 3 additions & 1 deletion keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,9 @@ def remat(f):
"""

def wrapped(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(
f, *args, use_reentrant=False, **kwargs
)

return wrapped

Expand Down