1212from keras .src .utils .naming import auto_name
1313
1414
15- class KerasVariable :
15+ class Variable :
1616 """Represents a backend-agnostic variable in Keras.
1717
1818 A `Variable` acts as a container for state. It holds a tensor value and can
@@ -30,17 +30,25 @@ class KerasVariable:
3030 dtype type (`"float32"` if never configured).
3131 trainable: Optional. Boolean indicating if variable is trainable.
3232 Defaults to `True`.
33+ autocast: Optional. Boolean indicating whether the variable supports
34+ autocasting. If `True`, the layer may first convert the variable
35+ to the compute data type when accessed. Defaults to `True`.
36+ aggregation: Optional. String specifying how a distributed variable will
37+ be aggregated. This serves as a semantic annotation, to be taken
38+ into account by downstream backends or users. Defaults to `"mean"`.
3339 name: Optional. A unique name for the variable. Automatically generated
3440 if not set.
3541
3642 Attributes:
37- name: The name of the variable (string).
38- path: The path of the variable within the Keras model or layer (string).
39- dtype: The data type of the variable (string).
4043 shape: The shape of the variable (tuple of integers).
4144 ndim: The number of dimensions of the variable (integer).
45+ dtype: The data type of the variable (string).
4246 trainable: Whether the variable is trainable (boolean).
47+ autocast: Whether the variable supports autocasting (boolean).
48+ aggregation: How a distributed variable will be aggregated (string).
4349 value: The current value of the variable (NumPy array or tensor).
50+ name: The name of the variable (string).
51+ path: The path of the variable within the Keras model or layer (string).
4452
4553 Examples:
4654
@@ -101,20 +109,19 @@ def __init__(
101109 "one of {'none', 'mean', 'sum', 'only_first_replica'}. "
102110 f"Received: aggregation={ aggregation } "
103111 )
104- self .name = name
112+ self ._name = name
105113 parent_path = current_path ()
106114 if parent_path :
107- self .path = current_path () + "/" + self . name
115+ self ._path = current_path () + "/" + name
108116 else :
109- self .path = self .name
110- dtype = standardize_dtype (dtype )
111- self ._dtype = dtype
117+ self ._path = name
118+ self ._dtype = standardize_dtype (dtype )
112119 self ._shape = None
113120 self ._initializer = None
114121 self ._regularizer = None
115122 self ._constraint = None
116- self ._trainable = trainable
117- self ._autocast = autocast
123+ self ._trainable = bool ( trainable )
124+ self ._autocast = bool ( autocast )
118125 self ._aggregation = aggregation
119126 # `self._overwrite_with_gradient` is an internal property to determine
120127 # whether this variable should be overwritten by the computed gradient.
@@ -163,7 +170,7 @@ def __init__(
163170 self ._initialize_with_initializer (initializer )
164171 else :
165172 self ._initialize (initializer )
166- self ._shape = tuple (self ._value .shape )
173+ self ._shape = self . _validate_shape (self ._value .shape )
167174 self ._ndim = len (self ._shape )
168175
169176 def _deferred_initialize (self ):
@@ -201,10 +208,12 @@ def numpy(self):
201208
202209 @property
203210 def aggregation (self ):
211+ """The strategy for aggregating this variable."""
204212 return self ._aggregation
205213
206214 @property
207215 def value (self ):
216+ """The current value of the variable (numpy array or backend tensor)."""
208217 if in_stateless_scope ():
209218 scope = get_stateless_scope ()
210219 value = scope .get_current_value (self )
@@ -246,30 +255,46 @@ def assign_sub(self, value):
246255
247256 @property
248257 def dtype (self ):
258+ """The data type of the variable."""
249259 autocast_scope = get_autocast_scope ()
250260 if (
251261 self ._autocast
252262 and autocast_scope is not None
253263 and is_float_dtype (self ._dtype )
254264 ):
255- return autocast_scope .dtype
256- return self ._dtype
265+ dtype = autocast_scope .dtype
266+ else :
267+ dtype = self ._dtype
268+ return backend .standardize_dtype (dtype )
257269
258270 @property
259271 def shape (self ):
272+ """The shape of the variable."""
260273 return self ._shape
261274
262275 @property
263276 def ndim (self ):
277+ """The number of dimensions of the variable."""
264278 return self ._ndim
265279
266280 @property
267281 def trainable (self ):
282+ """Whether the variable is trainable."""
268283 return self ._trainable
269284
270285 @trainable .setter
271286 def trainable (self , value ):
272- self ._trainable = value
287+ self ._trainable = bool (value )
288+
289+ @property
290+ def name (self ):
291+ """The name of the variable."""
292+ return self ._name
293+
294+ @property
295+ def path (self ):
296+ """The path of the variable within the Keras model or layer."""
297+ return self ._path
273298
274299 @property
275300 def overwrite_with_gradient (self ):
@@ -326,9 +351,13 @@ def constraint(self, value):
326351 self ._constraint = value
327352
328353 def __repr__ (self ):
354+ value = None
355+ if hasattr (self , "_value" ) and self ._value is not None :
356+ value = backend .core .convert_to_numpy (self ._value )
357+ value_str = f", value={ value } " if value is not None else ""
329358 return (
330- f"<KerasVariable shape ={ self .shape } , dtype ={ self .dtype } , "
331- f"path ={ self .path } >"
359+ f"<Variable path ={ self .path } , shape ={ self .shape } , "
360+ f"dtype ={ self .dtype } { value_str } >"
332361 )
333362
334363 def _initialize (self , value ):
@@ -573,7 +602,7 @@ def get_autocast_scope():
573602class AutocastScope :
574603 """Context manager that enables the autocasting of float variables.
575604
576- Under this context manager, float `KerasVariables `s will be cast to `dtype`
605+ Under this context manager, float `Variables `s will be cast to `dtype`
577606 (note that `dtype` must also be float).
578607 """
579608
0 commit comments