16
16
17
17
from collections .abc import Callable
18
18
import json
19
+ import sys
19
20
from typing import Any , Optional
20
21
import warnings
21
22
22
23
from absl import logging
23
24
import chex
24
25
import haiku as hk
26
+ from IPython import display
25
27
import jax
26
28
from jax .example_libraries import optimizers
27
29
import jax .numpy as jnp
@@ -366,7 +368,7 @@ def train_network(
366
368
loss : str = 'mse' ,
367
369
log_losses_every : int = 10 ,
368
370
do_plot : bool = False ,
369
- print_or_log : str = 'print' ,
371
+ report_progress_by : str = 'print' ,
370
372
) -> tuple [hk .Params , optax .OptState , dict [str , np .ndarray ]]:
371
373
"""Trains a network.
372
374
@@ -388,7 +390,9 @@ def train_network(
388
390
log_losses_every: How many training steps between each time we check for
389
391
errors and log the loss
390
392
do_plot: Boolean that controls whether a learning curve is plotted
391
- print_or_log: Whether to print progress to screen or log it to absl.logging
393
+ report_progress_by: Mode for reporting real-time progress. Options are
394
+ "display" for displaying to an ipython notebook, "print" for printing to
395
+ the console, "log" for using absl logging, and "none" for no output.
392
396
393
397
Returns:
394
398
params: Trained parameters
@@ -560,7 +564,7 @@ def train_step(
560
564
)
561
565
562
566
# Check for errors and report progress
563
- if step % log_losses_every == 0 :
567
+ if step % log_losses_every == 0 or step == n_steps - 1 :
564
568
if nan_in_dict (params ):
565
569
raise ValueError ('NaN in params' )
566
570
if np .isnan (loss ):
@@ -580,19 +584,26 @@ def train_step(
580
584
f'Training Loss: { loss :.2e} . '
581
585
f'Validation Loss: { l_validation :.2e} '
582
586
)
583
- if print_or_log == 'print' :
584
- print (log_str , end = '\r ' )
585
- else :
587
+
588
+ if report_progress_by == 'print' :
589
+ # If we're running on colab, use display, otherwise use print
590
+ on_colab = 'google.colab' in sys .modules
591
+ if on_colab :
592
+ display .clear_output (wait = True )
593
+ display .display (log_str )
594
+ else :
595
+ print (log_str )
596
+ elif report_progress_by == 'log' :
586
597
logging .info (log_str )
598
+ elif report_progress_by == 'none' :
599
+ pass
600
+ else :
601
+ warnings .warn (
602
+ f'Unknown report_progress_by mode: { report_progress_by } '
603
+ )
587
604
588
605
# If we actually did any training, print final loss and make a nice plot
589
606
if n_steps > 1 and do_plot :
590
- if print_or_log == 'print' :
591
- print (
592
- f'Step { n_steps } of { n_steps } . '
593
- f'Training Loss: { loss :.2e} . '
594
- f'Validation Loss: { l_validation :.2e} '
595
- )
596
607
597
608
plt .figure ()
598
609
plt .semilogy (training_loss , color = 'black' )
0 commit comments