12
12
from keras .src .utils .naming import auto_name
13
13
14
14
15
- class KerasVariable :
15
+ class Variable :
16
16
"""Represents a backend-agnostic variable in Keras.
17
17
18
18
A `Variable` acts as a container for state. It holds a tensor value and can
@@ -30,17 +30,25 @@ class KerasVariable:
30
30
dtype type (`"float32"` if never configured).
31
31
trainable: Optional. Boolean indicating if variable is trainable.
32
32
Defaults to `True`.
33
+ autocast: Optional. Boolean indicating whether the variable supports
34
+ autocasting. If `True`, the layer may first convert the variable
35
+ to the compute data type when accessed. Defaults to `True`.
36
+ aggregation: Optional. String specifying how a distributed variable will
37
+ be aggregated. This serves as a semantic annotation, to be taken
38
+ into account by downstream backends or users. Defaults to `"mean"`.
33
39
name: Optional. A unique name for the variable. Automatically generated
34
40
if not set.
35
41
36
42
Attributes:
37
- name: The name of the variable (string).
38
- path: The path of the variable within the Keras model or layer (string).
39
- dtype: The data type of the variable (string).
40
43
shape: The shape of the variable (tuple of integers).
41
44
ndim: The number of dimensions of the variable (integer).
45
+ dtype: The data type of the variable (string).
42
46
trainable: Whether the variable is trainable (boolean).
47
+ autocast: Whether the variable supports autocasting (boolean).
48
+ aggregation: How a distributed variable will be aggregated (string).
43
49
value: The current value of the variable (NumPy array or tensor).
50
+ name: The name of the variable (string).
51
+ path: The path of the variable within the Keras model or layer (string).
44
52
45
53
Examples:
46
54
@@ -101,20 +109,19 @@ def __init__(
101
109
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
102
110
f"Received: aggregation={ aggregation } "
103
111
)
104
- self .name = name
112
+ self ._name = name
105
113
parent_path = current_path ()
106
114
if parent_path :
107
- self .path = current_path () + "/" + self . name
115
+ self ._path = current_path () + "/" + name
108
116
else :
109
- self .path = self .name
110
- dtype = standardize_dtype (dtype )
111
- self ._dtype = dtype
117
+ self ._path = name
118
+ self ._dtype = standardize_dtype (dtype )
112
119
self ._shape = None
113
120
self ._initializer = None
114
121
self ._regularizer = None
115
122
self ._constraint = None
116
- self ._trainable = trainable
117
- self ._autocast = autocast
123
+ self ._trainable = bool ( trainable )
124
+ self ._autocast = bool ( autocast )
118
125
self ._aggregation = aggregation
119
126
# `self._overwrite_with_gradient` is an internal property to determine
120
127
# whether this variable should be overwritten by the computed gradient.
@@ -163,7 +170,7 @@ def __init__(
163
170
self ._initialize_with_initializer (initializer )
164
171
else :
165
172
self ._initialize (initializer )
166
- self ._shape = tuple (self ._value .shape )
173
+ self ._shape = self . _validate_shape (self ._value .shape )
167
174
self ._ndim = len (self ._shape )
168
175
169
176
def _deferred_initialize (self ):
@@ -201,10 +208,12 @@ def numpy(self):
201
208
202
209
@property
203
210
def aggregation (self ):
211
+ """The strategy for aggregating this variable."""
204
212
return self ._aggregation
205
213
206
214
@property
207
215
def value (self ):
216
+ """The current value of the variable (numpy array or backend tensor)."""
208
217
if in_stateless_scope ():
209
218
scope = get_stateless_scope ()
210
219
value = scope .get_current_value (self )
@@ -246,30 +255,46 @@ def assign_sub(self, value):
246
255
247
256
@property
248
257
def dtype (self ):
258
+ """The data type of the variable."""
249
259
autocast_scope = get_autocast_scope ()
250
260
if (
251
261
self ._autocast
252
262
and autocast_scope is not None
253
263
and is_float_dtype (self ._dtype )
254
264
):
255
- return autocast_scope .dtype
256
- return self ._dtype
265
+ dtype = autocast_scope .dtype
266
+ else :
267
+ dtype = self ._dtype
268
+ return backend .standardize_dtype (dtype )
257
269
258
270
@property
259
271
def shape (self ):
272
+ """The shape of the variable."""
260
273
return self ._shape
261
274
262
275
@property
263
276
def ndim (self ):
277
+ """The number of dimensions of the variable."""
264
278
return self ._ndim
265
279
266
280
@property
267
281
def trainable (self ):
282
+ """Whether the variable is trainable."""
268
283
return self ._trainable
269
284
270
285
@trainable .setter
271
286
def trainable (self , value ):
272
- self ._trainable = value
287
+ self ._trainable = bool (value )
288
+
289
+ @property
290
+ def name (self ):
291
+ """The name of the variable."""
292
+ return self ._name
293
+
294
+ @property
295
+ def path (self ):
296
+ """The path of the variable within the Keras model or layer."""
297
+ return self ._path
273
298
274
299
@property
275
300
def overwrite_with_gradient (self ):
@@ -326,9 +351,13 @@ def constraint(self, value):
326
351
self ._constraint = value
327
352
328
353
def __repr__ (self ):
354
+ value = None
355
+ if hasattr (self , "_value" ) and self ._value is not None :
356
+ value = backend .core .convert_to_numpy (self ._value )
357
+ value_str = f", value={ value } " if value is not None else ""
329
358
return (
330
- f"<KerasVariable shape ={ self .shape } , dtype ={ self .dtype } , "
331
- f"path ={ self .path } >"
359
+ f"<Variable path ={ self .path } , shape ={ self .shape } , "
360
+ f"dtype ={ self .dtype } { value_str } >"
332
361
)
333
362
334
363
def _initialize (self , value ):
@@ -573,7 +602,7 @@ def get_autocast_scope():
573
602
class AutocastScope :
574
603
"""Context manager that enables the autocasting of float variables.
575
604
576
- Under this context manager, float `KerasVariables `s will be cast to `dtype`
605
+ Under this context manager, float `Variables `s will be cast to `dtype`
577
606
(note that `dtype` must also be float).
578
607
"""
579
608
0 commit comments