Skip to content

Commit a4d63a2

Browse files
kevin-j-millercopybara-github
authored andcommitted
Modify train_network to automatically detect when running on colab and print accordingly.
PiperOrigin-RevId: 755485284
1 parent 972e96c commit a4d63a2

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

disentangled_rnns/library/rnn_utils.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
from collections.abc import Callable
1818
import json
19+
import sys
1920
from typing import Any, Optional
2021
import warnings
2122

2223
from absl import logging
2324
import chex
2425
import haiku as hk
26+
from IPython import display
2527
import jax
2628
from jax.example_libraries import optimizers
2729
import jax.numpy as jnp
@@ -366,7 +368,7 @@ def train_network(
366368
loss: str = 'mse',
367369
log_losses_every: int = 10,
368370
do_plot: bool = False,
369-
print_or_log: str = 'print',
371+
report_progress_by: str = 'print',
370372
) -> tuple[hk.Params, optax.OptState, dict[str, np.ndarray]]:
371373
"""Trains a network.
372374
@@ -388,7 +390,9 @@ def train_network(
388390
log_losses_every: How many training steps between each time we check for
389391
errors and log the loss
390392
do_plot: Boolean that controls whether a learning curve is plotted
391-
print_or_log: Whether to print progress to screen or log it to absl.logging
393+
report_progress_by: Mode for reporting real-time progress. Options are
394+
"display" for displaying to an ipython notebook, "print" for printing to
395+
the console, "log" for using absl logging, and "none" for no output.
392396
393397
Returns:
394398
params: Trained parameters
@@ -560,7 +564,7 @@ def train_step(
560564
)
561565

562566
# Check for errors and report progress
563-
if step % log_losses_every == 0:
567+
if step % log_losses_every == 0 or step == n_steps - 1:
564568
if nan_in_dict(params):
565569
raise ValueError('NaN in params')
566570
if np.isnan(loss):
@@ -580,19 +584,26 @@ def train_step(
580584
f'Training Loss: {loss:.2e}. '
581585
f'Validation Loss: {l_validation:.2e}'
582586
)
583-
if print_or_log == 'print':
584-
print(log_str, end='\r')
585-
else:
587+
588+
if report_progress_by == 'print':
589+
# If we're running on colab, use display, otherwise use print
590+
on_colab = 'google.colab' in sys.modules
591+
if on_colab:
592+
display.clear_output(wait=True)
593+
display.display(log_str)
594+
else:
595+
print(log_str)
596+
elif report_progress_by == 'log':
586597
logging.info(log_str)
598+
elif report_progress_by == 'none':
599+
pass
600+
else:
601+
warnings.warn(
602+
f'Unknown report_progress_by mode: {report_progress_by}'
603+
)
587604

588605
# If we actually did any training, print final loss and make a nice plot
589606
if n_steps > 1 and do_plot:
590-
if print_or_log == 'print':
591-
print(
592-
f'Step {n_steps} of {n_steps}. '
593-
f'Training Loss: {loss:.2e}. '
594-
f'Validation Loss: {l_validation:.2e}'
595-
)
596607

597608
plt.figure()
598609
plt.semilogy(training_loss, color='black')

0 commit comments

Comments
 (0)