Skip to content

Commit 529e162

Browse files
authored
Fix assigning a value to a variable within an autocast scope. (#21864)
Previously `assign` would incorrectly cast the value to assign to the autocast dtype instead of the true dtype of the variable. Because on JAX and OpenVino variables are just a reference to an array, this would cause the variable value to change dtypes.
1 parent e048ae4 commit 529e162

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

keras/src/backend/common/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def value(self):
276276
return self._maybe_autocast(self._value)
277277

278278
def assign(self, value):
279-
value = self._convert_to_tensor(value, dtype=self.dtype)
279+
value = self._convert_to_tensor(value, dtype=self._dtype)
280280
if not shape_equal(value.shape, self.shape):
281281
raise ValueError(
282282
"The shape of the target variable and "

keras/src/backend/common/variables_test.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def test_trainable_setter(self):
176176
v.trainable = False
177177
self.assertFalse(v._value.requires_grad)
178178

179-
def test_autocasting(self):
180-
"""Tests autocasting of float variables."""
179+
def test_autocasting_float(self):
180+
# Tests autocasting of float variables
181181
v = backend.Variable(
182182
initializer=initializers.RandomNormal(),
183183
shape=(2, 2),
@@ -191,6 +191,33 @@ def test_autocasting(self):
191191
)
192192
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")
193193

194+
def test_autocasting_float_assign(self):
195+
# Tests assigning value to variable within an autocast scope
196+
v = backend.Variable(
197+
initializer=initializers.RandomNormal(),
198+
shape=(2, 2),
199+
dtype="float32",
200+
)
201+
self.assertEqual(v.dtype, "float32")
202+
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")
203+
204+
# Assign float16 value within float16 scope
205+
with AutocastScope("float16"):
206+
self.assertEqual(
207+
backend.standardize_dtype(v.value.dtype), "float16"
208+
)
209+
v.assign(ops.ones((2, 2), "float16"))
210+
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")
211+
212+
# Assign float32 value within float16 scope
213+
with AutocastScope("float16"):
214+
self.assertEqual(
215+
backend.standardize_dtype(v.value.dtype), "float16"
216+
)
217+
v.assign(ops.zeros((2, 2), "float32"))
218+
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")
219+
220+
def test_autocasting_int(self):
194221
# Test non-float variables are not affected
195222
v = backend.Variable(
196223
initializer=initializers.Ones(),
@@ -204,6 +231,7 @@ def test_autocasting(self):
204231
with AutocastScope("float16"):
205232
self.assertEqual(backend.standardize_dtype(v.value.dtype), "int32")
206233

234+
def test_autocasting_float_with_autocast_off(self):
207235
# Test autocast argument
208236
v = backend.Variable(
209237
initializer=initializers.RandomNormal(),

0 commit comments

Comments
 (0)