Skip to content

Commit 1d654e8

Browse files
committed
Minor fixes
1 parent 1a14b7c commit 1d654e8

8 files changed

+72
-89
lines changed

guides/distributed_training_with_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_datasets():
175175
# Keras provides a pure functional forward pass: model.stateless_call
176176
def compute_loss(trainable_variables, non_trainable_variables, x, y):
177177
y_pred, updated_non_trainable_variables = model.stateless_call(
178-
trainable_variables, non_trainable_variables, x
178+
trainable_variables, non_trainable_variables, x, training=True
179179
)
180180
loss_value = loss(y, y_pred)
181181
return loss_value, updated_non_trainable_variables

guides/ipynb/distributed_training_with_jax.ipynb

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,6 @@
4545
"most common setup for researchers and small-scale industry workflows."
4646
]
4747
},
48-
{
49-
"cell_type": "code",
50-
"execution_count": null,
51-
"metadata": {
52-
"colab_type": "code"
53-
},
54-
"outputs": [],
55-
"source": [
56-
"!pip install keras --upgrade --quiet"
57-
]
58-
},
5948
{
6049
"cell_type": "markdown",
6150
"metadata": {
@@ -225,7 +214,7 @@
225214
"# Keras provides a pure functional forward pass: model.stateless_call\n",
226215
"def compute_loss(trainable_variables, non_trainable_variables, x, y):\n",
227216
" y_pred, updated_non_trainable_variables = model.stateless_call(\n",
228-
" trainable_variables, non_trainable_variables, x\n",
217+
" trainable_variables, non_trainable_variables, x, training=True\n",
229218
" )\n",
230219
" loss_value = loss(y, y_pred)\n",
231220
" return loss_value, updated_non_trainable_variables\n",

guides/ipynb/writing_a_custom_training_loop_in_jax.ipynb

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@
1414
"**Description:** Writing low-level training & evaluation loops in JAX."
1515
]
1616
},
17-
{
18-
"cell_type": "code",
19-
"execution_count": null,
20-
"metadata": {
21-
"colab_type": "code"
22-
},
23-
"outputs": [],
24-
"source": [
25-
"!pip install keras --upgrade --quiet"
26-
]
27-
},
2817
{
2918
"cell_type": "markdown",
3019
"metadata": {
@@ -255,7 +244,7 @@
255244
"\n",
256245
"def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):\n",
257246
" y_pred, non_trainable_variables = model.stateless_call(\n",
258-
" trainable_variables, non_trainable_variables, x\n",
247+
" trainable_variables, non_trainable_variables, x, training=True\n",
259248
" )\n",
260249
" loss = loss_fn(y, y_pred)\n",
261250
" return loss, non_trainable_variables\n",
@@ -322,7 +311,7 @@
322311
" trainable_variables, non_trainable_variables, x, y\n",
323312
" )\n",
324313
" trainable_variables, optimizer_variables = optimizer.stateless_apply(\n",
325-
" grads, trainable_variables, optimizer_variables\n",
314+
" optimizer_variables, grads, trainable_variables\n",
326315
" )\n",
327316
" # Return updated state\n",
328317
" return loss, (\n",

guides/md/distributed_training_with_jax.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ optimizer.build(model.trainable_variables)
180180
# Keras provides a pure functional forward pass: model.stateless_call
181181
def compute_loss(trainable_variables, non_trainable_variables, x, y):
182182
y_pred, updated_non_trainable_variables = model.stateless_call(
183-
trainable_variables, non_trainable_variables, x
183+
trainable_variables, non_trainable_variables, x, training=True
184184
)
185185
loss_value = loss(y, y_pred)
186186
return loss_value, updated_non_trainable_variables
@@ -292,8 +292,8 @@ Data sharding
292292

293293
<div class="k-default-codeblock">
294294
```
295-
Epoch 0 loss: 0.43531758
296-
Epoch 1 loss: 0.5194763
295+
Epoch 0 loss: 0.28599858
296+
Epoch 1 loss: 0.23666474
297297
298298
```
299299
</div>

guides/md/writing_a_custom_training_loop_in_jax.md

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ variables.
184184

185185
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
186186
y_pred, non_trainable_variables = model.stateless_call(
187-
trainable_variables, non_trainable_variables, x
187+
trainable_variables, non_trainable_variables, x, training=True
188188
)
189189
loss = loss_fn(y, y_pred)
190190
return loss, non_trainable_variables
@@ -225,7 +225,7 @@ def train_step(state, data):
225225
trainable_variables, non_trainable_variables, x, y
226226
)
227227
trainable_variables, optimizer_variables = optimizer.stateless_apply(
228-
grads, trainable_variables, optimizer_variables
228+
optimizer_variables, grads, trainable_variables
229229
)
230230
# Return updated state
231231
return loss, (
@@ -305,37 +305,37 @@ for step, data in enumerate(train_dataset):
305305

306306
<div class="k-default-codeblock">
307307
```
308-
Training loss (for 1 batch) at step 0: 156.4785
308+
Training loss (for 1 batch) at step 0: 96.2726
309309
Seen so far: 32 samples
310-
Training loss (for 1 batch) at step 100: 2.5526
310+
Training loss (for 1 batch) at step 100: 2.0853
311311
Seen so far: 3232 samples
312-
Training loss (for 1 batch) at step 200: 1.8922
312+
Training loss (for 1 batch) at step 200: 0.6535
313313
Seen so far: 6432 samples
314-
Training loss (for 1 batch) at step 300: 1.2381
314+
Training loss (for 1 batch) at step 300: 1.2679
315315
Seen so far: 9632 samples
316-
Training loss (for 1 batch) at step 400: 0.4812
316+
Training loss (for 1 batch) at step 400: 0.7563
317317
Seen so far: 12832 samples
318-
Training loss (for 1 batch) at step 500: 2.3339
318+
Training loss (for 1 batch) at step 500: 0.7154
319319
Seen so far: 16032 samples
320-
Training loss (for 1 batch) at step 600: 0.5615
320+
Training loss (for 1 batch) at step 600: 1.0267
321321
Seen so far: 19232 samples
322-
Training loss (for 1 batch) at step 700: 0.6471
322+
Training loss (for 1 batch) at step 700: 0.6860
323323
Seen so far: 22432 samples
324-
Training loss (for 1 batch) at step 800: 1.6272
324+
Training loss (for 1 batch) at step 800: 0.7306
325325
Seen so far: 25632 samples
326-
Training loss (for 1 batch) at step 900: 0.9416
326+
Training loss (for 1 batch) at step 900: 0.4571
327327
Seen so far: 28832 samples
328-
Training loss (for 1 batch) at step 1000: 0.8152
328+
Training loss (for 1 batch) at step 1000: 0.6023
329329
Seen so far: 32032 samples
330-
Training loss (for 1 batch) at step 1100: 0.8838
330+
Training loss (for 1 batch) at step 1100: 0.9140
331331
Seen so far: 35232 samples
332-
Training loss (for 1 batch) at step 1200: 0.1278
332+
Training loss (for 1 batch) at step 1200: 0.4224
333333
Seen so far: 38432 samples
334-
Training loss (for 1 batch) at step 1300: 1.9234
334+
Training loss (for 1 batch) at step 1300: 0.6696
335335
Seen so far: 41632 samples
336-
Training loss (for 1 batch) at step 1400: 0.3413
336+
Training loss (for 1 batch) at step 1400: 0.1399
337337
Seen so far: 44832 samples
338-
Training loss (for 1 batch) at step 1500: 0.2429
338+
Training loss (for 1 batch) at step 1500: 0.5761
339339
Seen so far: 48032 samples
340340
341341
```
@@ -514,65 +514,65 @@ for step, data in enumerate(val_dataset):
514514

515515
<div class="k-default-codeblock">
516516
```
517-
Training loss (for 1 batch) at step 0: 96.4990
518-
Training accuracy: 0.0625
517+
Training loss (for 1 batch) at step 0: 70.8851
518+
Training accuracy: 0.09375
519519
Seen so far: 32 samples
520-
Training loss (for 1 batch) at step 100: 2.0447
521-
Training accuracy: 0.6064356565475464
520+
Training loss (for 1 batch) at step 100: 2.1930
521+
Training accuracy: 0.6596534848213196
522522
Seen so far: 3232 samples
523-
Training loss (for 1 batch) at step 200: 2.0184
524-
Training accuracy: 0.6934079527854919
523+
Training loss (for 1 batch) at step 200: 3.0249
524+
Training accuracy: 0.7352300882339478
525525
Seen so far: 6432 samples
526-
Training loss (for 1 batch) at step 300: 1.9111
527-
Training accuracy: 0.7303779125213623
526+
Training loss (for 1 batch) at step 300: 0.6004
527+
Training accuracy: 0.7588247656822205
528528
Seen so far: 9632 samples
529-
Training loss (for 1 batch) at step 400: 1.8042
530-
Training accuracy: 0.7555330395698547
529+
Training loss (for 1 batch) at step 400: 1.4633
530+
Training accuracy: 0.7736907601356506
531531
Seen so far: 12832 samples
532-
Training loss (for 1 batch) at step 500: 1.2200
533-
Training accuracy: 0.7659056782722473
532+
Training loss (for 1 batch) at step 500: 1.3367
533+
Training accuracy: 0.7826846241950989
534534
Seen so far: 16032 samples
535-
Training loss (for 1 batch) at step 600: 1.3437
536-
Training accuracy: 0.7793781161308289
535+
Training loss (for 1 batch) at step 600: 0.8767
536+
Training accuracy: 0.7930532693862915
537537
Seen so far: 19232 samples
538-
Training loss (for 1 batch) at step 700: 1.2409
539-
Training accuracy: 0.789318859577179
538+
Training loss (for 1 batch) at step 700: 0.3479
539+
Training accuracy: 0.8004636168479919
540540
Seen so far: 22432 samples
541-
Training loss (for 1 batch) at step 800: 1.6530
542-
Training accuracy: 0.7977527976036072
541+
Training loss (for 1 batch) at step 800: 0.3608
542+
Training accuracy: 0.8066869378089905
543543
Seen so far: 25632 samples
544-
Training loss (for 1 batch) at step 900: 0.4173
545-
Training accuracy: 0.8060488104820251
544+
Training loss (for 1 batch) at step 900: 0.7582
545+
Training accuracy: 0.8117369413375854
546546
Seen so far: 28832 samples
547-
Training loss (for 1 batch) at step 1000: 0.5543
548-
Training accuracy: 0.8100025057792664
547+
Training loss (for 1 batch) at step 1000: 1.3135
548+
Training accuracy: 0.8142170310020447
549549
Seen so far: 32032 samples
550-
Training loss (for 1 batch) at step 1100: 1.2699
551-
Training accuracy: 0.8160762786865234
550+
Training loss (for 1 batch) at step 1100: 1.0202
551+
Training accuracy: 0.8186308145523071
552552
Seen so far: 35232 samples
553-
Training loss (for 1 batch) at step 1200: 1.2621
554-
Training accuracy: 0.8213468194007874
553+
Training loss (for 1 batch) at step 1200: 0.6766
554+
Training accuracy: 0.822023332118988
555555
Seen so far: 38432 samples
556-
Training loss (for 1 batch) at step 1300: 0.8028
557-
Training accuracy: 0.8257350325584412
556+
Training loss (for 1 batch) at step 1300: 0.7606
557+
Training accuracy: 0.8257110118865967
558558
Seen so far: 41632 samples
559-
Training loss (for 1 batch) at step 1400: 1.0701
560-
Training accuracy: 0.8298090696334839
559+
Training loss (for 1 batch) at step 1400: 0.7657
560+
Training accuracy: 0.8290283679962158
561561
Seen so far: 44832 samples
562-
Training loss (for 1 batch) at step 1500: 0.3910
563-
Training accuracy: 0.8336525559425354
562+
Training loss (for 1 batch) at step 1500: 0.6563
563+
Training accuracy: 0.831653892993927
564564
Seen so far: 48032 samples
565-
Validation loss (for 1 batch) at step 0: 0.2482
566-
Validation accuracy: 0.835365355014801
565+
Validation loss (for 1 batch) at step 0: 0.1622
566+
Validation accuracy: 0.8329269289970398
567567
Seen so far: 32 samples
568-
Validation loss (for 1 batch) at step 100: 1.1641
569-
Validation accuracy: 0.8388938903808594
568+
Validation loss (for 1 batch) at step 100: 0.7455
569+
Validation accuracy: 0.8338780999183655
570570
Seen so far: 3232 samples
571-
Validation loss (for 1 batch) at step 200: 0.1201
572-
Validation accuracy: 0.8428196907043457
571+
Validation loss (for 1 batch) at step 200: 0.2738
572+
Validation accuracy: 0.836174488067627
573573
Seen so far: 6432 samples
574-
Validation loss (for 1 batch) at step 300: 0.0755
575-
Validation accuracy: 0.8471122980117798
574+
Validation loss (for 1 batch) at step 300: 0.1255
575+
Validation accuracy: 0.8390461206436157
576576
Seen so far: 9632 samples
577577
578578
```

guides/writing_a_custom_training_loop_in_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
177177

178178
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
179179
y_pred, non_trainable_variables = model.stateless_call(
180-
trainable_variables, non_trainable_variables, x
180+
trainable_variables, non_trainable_variables, x, training=True
181181
)
182182
loss = loss_fn(y, y_pred)
183183
return loss, non_trainable_variables

scripts/api_audit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,4 @@ def list_all_keras_ops():
9191

9292

9393
if __name__ == "__main__":
94-
list_all_keras_ops()
94+
audit_api_docs()

scripts/api_master.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,11 @@
614614
"title": "MelSpectrogram layer",
615615
"generate": ["keras.layers.MelSpectrogram"],
616616
},
617+
{
618+
"path": "stft_spectrogram",
619+
"title": "STFTSpectrogram layer",
620+
"generate": ["keras.layers.STFTSpectrogram"],
621+
},
617622
],
618623
},
619624
],

0 commit comments

Comments
 (0)