From ff85451fa500d55554a73e1dcc3d4da332594d8b Mon Sep 17 00:00:00 2001 From: Abhinavexists Date: Fri, 21 Nov 2025 23:26:44 +0530 Subject: [PATCH 1/6] fixed torch layer losses -> rematscope --- keras/src/backend/common/remat_test.py | 28 ++++++++++++++++++++++++++ keras/src/backend/torch/core.py | 9 ++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/common/remat_test.py b/keras/src/backend/common/remat_test.py index 2732f5da964a..77983f109f15 100644 --- a/keras/src/backend/common/remat_test.py +++ b/keras/src/backend/common/remat_test.py @@ -116,3 +116,31 @@ def test_remat_basic_call(self): batch_size=batch_size, verbose=0, ) + + def test_remat_with_kwargs(self): + if backend.backend() in ("openvino", "numpy"): + self.skipTest( + "remat is not supported in openvino and numpy backends." + ) + + # Define a function that uses keyword arguments + def fn_with_kwargs(x, scale=1.0, offset=0.0): + return x * scale + offset + + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # Test with keyword arguments + remat_fn = backend.core.remat(fn_with_kwargs) + result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0) + expected = fn_with_kwargs(x, scale=2.0, offset=1.0) + self.assertAllClose(result_with_kwargs, expected) + + # Test with default keyword arguments + result_with_defaults = remat_fn(x) + expected_defaults = fn_with_kwargs(x) + self.assertAllClose(result_with_defaults, expected_defaults) + + # Test with partial keyword arguments + result_partial = remat_fn(x, scale=3.0) + expected_partial = fn_with_kwargs(x, scale=3.0) + self.assertAllClose(result_partial, expected_partial) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..98c0d3992f51 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -5,6 +5,7 @@ import ml_dtypes import numpy as np import torch +from torch.utils.checkpoint import checkpoint from keras.src import tree from keras.src.backend.common import KerasVariable @@ -673,7 +674,13 @@ def remat(f): """ def wrapped(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False) + if not kwargs: + return checkpoint(f, *args, use_reentrant=False) + + def positional_wrapper(*pos_arg): + return f(*pos_arg, **kwargs) + + return checkpoint(positional_wrapper, *args, use_reentrant=False) return wrapped From 869f2461e53b4d03aae92accfd90b1c419977c33 Mon Sep 17 00:00:00 2001 From: Abhinavexists Date: Fri, 21 Nov 2025 23:46:39 +0530 Subject: [PATCH 2/6] pre-commit fix --- keras/src/backend/torch/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 98c0d3992f51..54b7af192440 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -676,10 +676,10 @@ def remat(f): def wrapped(*args, **kwargs): if not kwargs: return checkpoint(f, *args, use_reentrant=False) - + def positional_wrapper(*pos_arg): return f(*pos_arg, **kwargs) - + return checkpoint(positional_wrapper, *args, use_reentrant=False) return wrapped From bcef4323d22f30e11d34d1071055032e9f3f13c7 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Fri, 21 Nov 2025 23:48:06 +0530 Subject: [PATCH 3/6] Update keras/src/backend/torch/core.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/backend/torch/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 54b7af192440..4a3814d29626 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -677,8 +677,8 @@ def wrapped(*args, **kwargs): if not kwargs: return checkpoint(f, *args, use_reentrant=False) - def positional_wrapper(*pos_arg): - return f(*pos_arg, **kwargs) + def positional_wrapper(*pos_args): + return f(*pos_args, **kwargs) return checkpoint(positional_wrapper, *args, use_reentrant=False) From 5afc0db7142f40c321c5890950da61b5d2081c22 Mon Sep 17 00:00:00 2001 From: Abhinavexists Date: Thu, 4 Dec 2025 22:01:24 +0530 Subject: [PATCH 4/6] switched to kwargs to rather than pos_args --- keras/src/backend/torch/core.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 4a3814d29626..7eaf6584003b 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -674,13 +674,7 @@ def remat(f): """ def wrapped(*args, **kwargs): - if not kwargs: - return checkpoint(f, *args, use_reentrant=False) - - def positional_wrapper(*pos_args): - return f(*pos_args, **kwargs) - - return checkpoint(positional_wrapper, *args, use_reentrant=False) + return checkpoint(f, *args, use_reentrant=False, **kwargs) return wrapped From 3b07dc5287c214fb3861e093f6c80028dcef0a1b Mon Sep 17 00:00:00 2001 From: Abhinavexists Date: Thu, 4 Dec 2025 22:11:16 +0530 Subject: [PATCH 5/6] import fixes -> type check issues --- keras/src/backend/torch/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 7eaf6584003b..44ef0d3937fa 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -5,7 +5,6 @@ import ml_dtypes import numpy as np import torch -from torch.utils.checkpoint import checkpoint from keras.src import tree from keras.src.backend.common import KerasVariable @@ -674,7 +673,9 @@ def remat(f): """ def wrapped(*args, **kwargs): - return checkpoint(f, *args, use_reentrant=False, **kwargs) + return torch.utils.checkpoint.checkpoint( + f, *args, use_reentrant=False, **kwargs + ) return wrapped From 272c71b72b1d4f13434d91a332850cc90dca51c5 Mon Sep 17 00:00:00 2001 From: Abhinavexists Date: Thu, 4 Dec 2025 22:21:17 +0530 Subject: [PATCH 6/6] code_format_fix --- keras/src/backend/torch/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 44ef0d3937fa..6d4f22095ff2 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -675,7 +675,7 @@ def remat(f): def wrapped(*args, **kwargs): return torch.utils.checkpoint.checkpoint( f, *args, use_reentrant=False, **kwargs - ) + ) return wrapped