5
5
6
6
7
7
class JAXApproximator (keras .Model ):
8
+ """
9
+ Base class for approximators using JAX and Keras' stateless training interface.
10
+
11
+ This class enables stateless training and evaluation steps with JAX, supporting
12
+ JAX-compatible gradient computation and variable updates through the `StatelessScope`.
13
+
14
+ Notes
15
+ -----
16
+ Subclasses must implement:
17
+ - compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]
18
+ - _batch_size_from_data(self, data: dict[str, any]) -> int
19
+ """
20
+
8
21
# noinspection PyMethodOverriding
9
22
def compute_metrics (self , * args , ** kwargs ) -> dict [str , jax .Array ]:
10
- # implemented by each respective architecture
23
+ """
24
+ Compute and return a dictionary of metrics for the current batch.
25
+
26
+ This method is expected to be implemented by each subclass to compute
27
+ task-specific metrics using JAX arrays. It is compatible with stateless
28
+ execution and must be differentiable under JAX's `grad` system.
29
+
30
+ Parameters
31
+ ----------
32
+ *args : tuple
33
+ Positional arguments passed to the metric computation function.
34
+ **kwargs
35
+ Keyword arguments passed to the metric computation function.
36
+
37
+ Returns
38
+ -------
39
+ dict of str to jax.Array
40
+ Dictionary containing named metric values as JAX arrays.
41
+ """
11
42
raise NotImplementedError
12
43
13
44
def stateless_compute_metrics (
@@ -19,17 +50,34 @@ def stateless_compute_metrics(
19
50
stage : str = "training" ,
20
51
) -> (jax .Array , tuple ):
21
52
"""
22
- Things we do for jax:
23
- 1. Accept trainable variables as the first argument
24
- (can be at any position as indicated by the argnum parameter
25
- in autograd, but needs to be an explicit arg)
26
- 2. Accept, potentially modify, and return other state variables
27
- 3. Return just the loss tensor as the first value
28
- 4. Return all other values in a tuple as the second value
29
-
30
- This ensures:
31
- 1. The function is stateless
32
- 2. The function can be differentiated with jax autograd
53
+ Stateless computation of metrics required for JAX autograd.
54
+
55
+ This method performs a stateless forward pass using the given model
56
+ variables and returns both the loss and auxiliary information for
57
+ further updates.
58
+
59
+ Parameters
60
+ ----------
61
+ trainable_variables : Any
62
+ Current values of the trainable weights.
63
+ non_trainable_variables : Any
64
+ Current values of non-trainable variables (e.g., batch norm statistics).
65
+ metrics_variables : Any
66
+ Current values of metric tracking variables.
67
+ data : dict of str to any
68
+ Input data dictionary passed to `compute_metrics`.
69
+ stage : str, default="training"
70
+ Whether the computation is for "training" or "validation".
71
+
72
+ Returns
73
+ -------
74
+ loss : jax.Array
75
+ Scalar loss tensor for gradient computation.
76
+ aux : tuple
77
+ Tuple containing:
78
+ - metrics (dict of str to jax.Array)
79
+ - updated non-trainable variables
80
+ - updated metrics variables
33
81
"""
34
82
state_mapping = []
35
83
state_mapping .extend (zip (self .trainable_variables , trainable_variables ))
@@ -48,19 +96,55 @@ def stateless_compute_metrics(
48
96
return metrics ["loss" ], (metrics , non_trainable_variables , metrics_variables )
49
97
50
98
def stateless_test_step (self , state : tuple , data : dict [str , any ]) -> (dict [str , jax .Array ], tuple ):
99
+ """
100
+ Stateless validation step compatible with JAX.
101
+
102
+ Parameters
103
+ ----------
104
+ state : tuple
105
+ Tuple of (trainable_variables, non_trainable_variables, metrics_variables).
106
+ data : dict of str to any
107
+ Input data for validation.
108
+
109
+ Returns
110
+ -------
111
+ metrics : dict of str to jax.Array
112
+ Dictionary of computed evaluation metrics.
113
+ state : tuple
114
+ Updated state tuple after evaluation.
115
+ """
51
116
trainable_variables , non_trainable_variables , metrics_variables = state
52
117
53
118
loss , aux = self .stateless_compute_metrics (
54
119
trainable_variables , non_trainable_variables , metrics_variables , data = data , stage = "validation"
55
120
)
56
121
metrics , non_trainable_variables , metrics_variables = aux
57
122
58
- metrics_variables = self ._update_loss (loss , metrics_variables )
123
+ metrics_variables = self ._update_metrics (loss , metrics_variables , self . _batch_size_from_data ( data ) )
59
124
60
125
state = trainable_variables , non_trainable_variables , metrics_variables
61
126
return metrics , state
62
127
63
128
def stateless_train_step (self , state : tuple , data : dict [str , any ]) -> (dict [str , jax .Array ], tuple ):
129
+ """
130
+ Stateless training step compatible with JAX autograd and stateless optimization.
131
+
132
+ Computes gradients and applies optimizer updates in a purely functional style.
133
+
134
+ Parameters
135
+ ----------
136
+ state : tuple
137
+ Tuple of (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables).
138
+ data : dict of str to any
139
+ Input data for training.
140
+
141
+ Returns
142
+ -------
143
+ metrics : dict of str to jax.Array
144
+ Dictionary of computed training metrics.
145
+ state : tuple
146
+ Updated state tuple after training.
147
+ """
64
148
trainable_variables , non_trainable_variables , optimizer_variables , metrics_variables = state
65
149
66
150
grad_fn = jax .value_and_grad (self .stateless_compute_metrics , has_aux = True )
@@ -74,23 +158,92 @@ def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
74
158
optimizer_variables , grads , trainable_variables
75
159
)
76
160
77
- metrics_variables = self ._update_loss (loss , metrics_variables )
161
+ metrics_variables = self ._update_metrics (loss , metrics_variables , self . _batch_size_from_data ( data ) )
78
162
79
163
state = trainable_variables , non_trainable_variables , optimizer_variables , metrics_variables
80
164
return metrics , state
81
165
82
166
def test_step (self , * args , ** kwargs ):
167
+ """
168
+ Alias to `stateless_test_step` for compatibility with `keras.Model`.
169
+
170
+ Parameters
171
+ ----------
172
+ *args, **kwargs : Any
173
+ Passed through to `stateless_test_step`.
174
+
175
+ Returns
176
+ -------
177
+ See `stateless_test_step`.
178
+ """
83
179
return self .stateless_test_step (* args , ** kwargs )
84
180
85
181
def train_step (self , * args , ** kwargs ):
182
+ """
183
+ Alias to `stateless_train_step` for compatibility with `keras.Model`.
184
+
185
+ Parameters
186
+ ----------
187
+ *args, **kwargs : Any
188
+ Passed through to `stateless_train_step`.
189
+
190
+ Returns
191
+ -------
192
+ See `stateless_train_step`.
193
+ """
86
194
return self .stateless_train_step (* args , ** kwargs )
87
195
88
- def _update_loss (self , loss : jax .Array , metrics_variables : any ) -> any :
89
- # update the loss progress bar, and possibly metrics variables along with it
196
+ def _update_metrics (self , loss : jax .Array , metrics_variables : any , sample_weight : any = None ) -> any :
197
+ """
198
+ Updates metric tracking variables in a stateless JAX-compatible way.
199
+
200
+ This method updates the loss tracker (and any other Keras metrics)
201
+ and returns updated metric variable states for downstream use.
202
+
203
+ Parameters
204
+ ----------
205
+ loss : jax.Array
206
+ Scalar loss used for metric tracking.
207
+ metrics_variables : Any
208
+ Current metric variable states.
209
+ sample_weight : Any, optional
210
+ Sample weights to apply during update.
211
+
212
+ Returns
213
+ -------
214
+ metrics_variables : Any
215
+ Updated metrics variable states.
216
+ """
90
217
state_mapping = list (zip (self .metrics_variables , metrics_variables ))
91
218
with keras .StatelessScope (state_mapping ) as scope :
92
- self ._loss_tracker .update_state (loss )
219
+ self ._loss_tracker .update_state (loss , sample_weight = sample_weight )
93
220
221
+ # JAX is stateless, so we need to return the metrics as state in downstream functions
94
222
metrics_variables = [scope .get_current_value (v ) for v in self .metrics_variables ]
95
223
96
224
return metrics_variables
225
+
226
+ # noinspection PyMethodOverriding
227
+ def _batch_size_from_data (self , data : any ) -> int :
228
+ """Obtain the batch size from a batch of data.
229
+
230
+ To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
231
+ required. As the data structure differs between approximators, each concrete approximator has to specify
232
+ this method.
233
+
234
+ Parameters
235
+ ----------
236
+ data :
237
+ The data that are passed to `compute_metrics` as keyword arguments.
238
+
239
+ Returns
240
+ -------
241
+ batch_size : int
242
+ The batch size of the given data.
243
+ """
244
+ raise NotImplementedError (
245
+ "Correct calculation of the metrics requires obtaining the batch size from the supplied data "
246
+ "for proper weighting of metrics for batches with different sizes. Please implement the "
247
+ "_batch_size_from_data method for your approximator. For a given batch of data, it should "
248
+ "return the corresponding batch size."
249
+ )
0 commit comments