@@ -184,7 +184,7 @@ variables.
184184
185185def compute_loss_and_updates (trainable_variables , non_trainable_variables , x , y ):
186186 y_pred, non_trainable_variables = model.stateless_call(
187- trainable_variables, non_trainable_variables, x
187+ trainable_variables, non_trainable_variables, x, training = True
188188 )
189189 loss = loss_fn(y, y_pred)
190190 return loss, non_trainable_variables
@@ -225,7 +225,7 @@ def train_step(state, data):
225225 trainable_variables, non_trainable_variables, x, y
226226 )
227227 trainable_variables, optimizer_variables = optimizer.stateless_apply(
228- grads, trainable_variables, optimizer_variables
228+ optimizer_variables, grads, trainable_variables
229229 )
230230 # Return updated state
231231 return loss, (
@@ -305,37 +305,37 @@ for step, data in enumerate(train_dataset):
305305
306306<div class =" k-default-codeblock " >
307307```
308- Training loss (for 1 batch) at step 0: 156.4785
308+ Training loss (for 1 batch) at step 0: 96.2726
309309Seen so far: 32 samples
310- Training loss (for 1 batch) at step 100: 2.5526
310+ Training loss (for 1 batch) at step 100: 2.0853
311311Seen so far: 3232 samples
312- Training loss (for 1 batch) at step 200: 1.8922
312+ Training loss (for 1 batch) at step 200: 0.6535
313313Seen so far: 6432 samples
314- Training loss (for 1 batch) at step 300: 1.2381
314+ Training loss (for 1 batch) at step 300: 1.2679
315315Seen so far: 9632 samples
316- Training loss (for 1 batch) at step 400: 0.4812
316+ Training loss (for 1 batch) at step 400: 0.7563
317317Seen so far: 12832 samples
318- Training loss (for 1 batch) at step 500: 2.3339
318+ Training loss (for 1 batch) at step 500: 0.7154
319319Seen so far: 16032 samples
320- Training loss (for 1 batch) at step 600: 0.5615
320+ Training loss (for 1 batch) at step 600: 1.0267
321321Seen so far: 19232 samples
322- Training loss (for 1 batch) at step 700: 0.6471
322+ Training loss (for 1 batch) at step 700: 0.6860
323323Seen so far: 22432 samples
324- Training loss (for 1 batch) at step 800: 1.6272
324+ Training loss (for 1 batch) at step 800: 0.7306
325325Seen so far: 25632 samples
326- Training loss (for 1 batch) at step 900: 0.9416
326+ Training loss (for 1 batch) at step 900: 0.4571
327327Seen so far: 28832 samples
328- Training loss (for 1 batch) at step 1000: 0.8152
328+ Training loss (for 1 batch) at step 1000: 0.6023
329329Seen so far: 32032 samples
330- Training loss (for 1 batch) at step 1100: 0.8838
330+ Training loss (for 1 batch) at step 1100: 0.9140
331331Seen so far: 35232 samples
332- Training loss (for 1 batch) at step 1200: 0.1278
332+ Training loss (for 1 batch) at step 1200: 0.4224
333333Seen so far: 38432 samples
334- Training loss (for 1 batch) at step 1300: 1.9234
334+ Training loss (for 1 batch) at step 1300: 0.6696
335335Seen so far: 41632 samples
336- Training loss (for 1 batch) at step 1400: 0.3413
336+ Training loss (for 1 batch) at step 1400: 0.1399
337337Seen so far: 44832 samples
338- Training loss (for 1 batch) at step 1500: 0.2429
338+ Training loss (for 1 batch) at step 1500: 0.5761
339339Seen so far: 48032 samples
340340
341341```
@@ -514,65 +514,65 @@ for step, data in enumerate(val_dataset):
514514
515515<div class =" k-default-codeblock " >
516516```
517- Training loss (for 1 batch) at step 0: 96.4990
518- Training accuracy: 0.0625
517+ Training loss (for 1 batch) at step 0: 70.8851
518+ Training accuracy: 0.09375
519519Seen so far: 32 samples
520- Training loss (for 1 batch) at step 100: 2.0447
521- Training accuracy: 0.6064356565475464
520+ Training loss (for 1 batch) at step 100: 2.1930
521+ Training accuracy: 0.6596534848213196
522522Seen so far: 3232 samples
523- Training loss (for 1 batch) at step 200: 2.0184
524- Training accuracy: 0.6934079527854919
523+ Training loss (for 1 batch) at step 200: 3.0249
524+ Training accuracy: 0.7352300882339478
525525Seen so far: 6432 samples
526- Training loss (for 1 batch) at step 300: 1.9111
527- Training accuracy: 0.7303779125213623
526+ Training loss (for 1 batch) at step 300: 0.6004
527+ Training accuracy: 0.7588247656822205
528528Seen so far: 9632 samples
529- Training loss (for 1 batch) at step 400: 1.8042
530- Training accuracy: 0.7555330395698547
529+ Training loss (for 1 batch) at step 400: 1.4633
530+ Training accuracy: 0.7736907601356506
531531Seen so far: 12832 samples
532- Training loss (for 1 batch) at step 500: 1.2200
533- Training accuracy: 0.7659056782722473
532+ Training loss (for 1 batch) at step 500: 1.3367
533+ Training accuracy: 0.7826846241950989
534534Seen so far: 16032 samples
535- Training loss (for 1 batch) at step 600: 1.3437
536- Training accuracy: 0.7793781161308289
535+ Training loss (for 1 batch) at step 600: 0.8767
536+ Training accuracy: 0.7930532693862915
537537Seen so far: 19232 samples
538- Training loss (for 1 batch) at step 700: 1.2409
539- Training accuracy: 0.789318859577179
538+ Training loss (for 1 batch) at step 700: 0.3479
539+ Training accuracy: 0.8004636168479919
540540Seen so far: 22432 samples
541- Training loss (for 1 batch) at step 800: 1.6530
542- Training accuracy: 0.7977527976036072
541+ Training loss (for 1 batch) at step 800: 0.3608
542+ Training accuracy: 0.8066869378089905
543543Seen so far: 25632 samples
544- Training loss (for 1 batch) at step 900: 0.4173
545- Training accuracy: 0.8060488104820251
544+ Training loss (for 1 batch) at step 900: 0.7582
545+ Training accuracy: 0.8117369413375854
546546Seen so far: 28832 samples
547- Training loss (for 1 batch) at step 1000: 0.5543
548- Training accuracy: 0.8100025057792664
547+ Training loss (for 1 batch) at step 1000: 1.3135
548+ Training accuracy: 0.8142170310020447
549549Seen so far: 32032 samples
550- Training loss (for 1 batch) at step 1100: 1.2699
551- Training accuracy: 0.8160762786865234
550+ Training loss (for 1 batch) at step 1100: 1.0202
551+ Training accuracy: 0.8186308145523071
552552Seen so far: 35232 samples
553- Training loss (for 1 batch) at step 1200: 1.2621
554- Training accuracy: 0.8213468194007874
553+ Training loss (for 1 batch) at step 1200: 0.6766
554+ Training accuracy: 0.822023332118988
555555Seen so far: 38432 samples
556- Training loss (for 1 batch) at step 1300: 0.8028
557- Training accuracy: 0.8257350325584412
556+ Training loss (for 1 batch) at step 1300: 0.7606
557+ Training accuracy: 0.8257110118865967
558558Seen so far: 41632 samples
559- Training loss (for 1 batch) at step 1400: 1.0701
560- Training accuracy: 0.8298090696334839
559+ Training loss (for 1 batch) at step 1400: 0.7657
560+ Training accuracy: 0.8290283679962158
561561Seen so far: 44832 samples
562- Training loss (for 1 batch) at step 1500: 0.3910
563- Training accuracy: 0.8336525559425354
562+ Training loss (for 1 batch) at step 1500: 0.6563
563+ Training accuracy: 0.831653892993927
564564Seen so far: 48032 samples
565- Validation loss (for 1 batch) at step 0: 0.2482
566- Validation accuracy: 0.835365355014801
565+ Validation loss (for 1 batch) at step 0: 0.1622
566+ Validation accuracy: 0.8329269289970398
567567Seen so far: 32 samples
568- Validation loss (for 1 batch) at step 100: 1.1641
569- Validation accuracy: 0.8388938903808594
568+ Validation loss (for 1 batch) at step 100: 0.7455
569+ Validation accuracy: 0.8338780999183655
570570Seen so far: 3232 samples
571- Validation loss (for 1 batch) at step 200: 0.1201
572- Validation accuracy: 0.8428196907043457
571+ Validation loss (for 1 batch) at step 200: 0.2738
572+ Validation accuracy: 0.836174488067627
573573Seen so far: 6432 samples
574- Validation loss (for 1 batch) at step 300: 0.0755
575- Validation accuracy: 0.8471122980117798
574+ Validation loss (for 1 batch) at step 300: 0.1255
575+ Validation accuracy: 0.8390461206436157
576576Seen so far: 9632 samples
577577
578578```
0 commit comments