Skip to content

Commit 35db4c1

Browse files
author
KangyuMac
committed
add verbose to dde.model.train and dde.model.compile
1 parent 40cd7e5 commit 35db4c1

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

deepxde/model.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def compile(
6565
decay=None,
6666
loss_weights=None,
6767
external_trainable_variables=None,
68+
verbose=1,
6869
):
6970
"""Configures the model for training.
7071
@@ -114,8 +115,9 @@ def compile(
114115
physics systems that need to be recovered. If the backend is
115116
tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
116117
trainable ``dde.Variable`` objects are automatically collected.
118+
verbose (Integer): Controls the verbosity of the compile process.
117119
"""
118-
if config.rank == 0:
120+
if verbose > 0 and config.rank == 0:
119121
print("Compiling model...")
120122
self.opt_name = optimizer
121123
loss_fn = losses_module.get(loss)
@@ -585,6 +587,7 @@ def train(
585587
model_restore_path=None,
586588
model_save_path=None,
587589
epochs=None,
590+
verbose=1,
588591
):
589592
"""Trains the model.
590593
@@ -610,6 +613,7 @@ def train(
610613
model_save_path (String): Prefix of filenames created for the checkpoint.
611614
epochs (Integer): Deprecated alias to `iterations`. This will be removed in
612615
a future version.
616+
verbose (Integer): Controls the verbosity of the train process.
613617
"""
614618
if iterations is None and epochs is not None:
615619
print(
@@ -635,36 +639,36 @@ def train(
635639
if model_restore_path is not None:
636640
self.restore(model_restore_path, verbose=1)
637641

638-
if config.rank == 0:
642+
if verbose > 0 and config.rank == 0:
639643
print("Training model...\n")
640644
self.stop_training = False
641645
self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
642646
self.train_state.set_data_test(*self.data.test())
643-
self._test()
647+
self._test(verbose=verbose)
644648
self.callbacks.on_train_begin()
645649
if optimizers.is_external_optimizer(self.opt_name):
646650
if backend_name == "tensorflow.compat.v1":
647-
self._train_tensorflow_compat_v1_scipy(display_every)
651+
self._train_tensorflow_compat_v1_scipy(display_every, verbose=verbose)
648652
elif backend_name == "tensorflow":
649-
self._train_tensorflow_tfp()
653+
self._train_tensorflow_tfp(verbose=verbose)
650654
elif backend_name == "pytorch":
651-
self._train_pytorch_lbfgs()
655+
self._train_pytorch_lbfgs(verbose=verbose)
652656
elif backend_name == "paddle":
653-
self._train_paddle_lbfgs()
657+
self._train_paddle_lbfgs(verbose=verbose)
654658
else:
655659
if iterations is None:
656660
raise ValueError("No iterations for {}.".format(self.opt_name))
657-
self._train_sgd(iterations, display_every)
661+
self._train_sgd(iterations, display_every, verbose=verbose)
658662
self.callbacks.on_train_end()
659663

660-
if config.rank == 0:
664+
if verbose > 0 and config.rank == 0:
661665
print("")
662666
display.training_display.summary(self.train_state)
663667
if model_save_path is not None:
664668
self.save(model_save_path, verbose=1)
665669
return self.losshistory, self.train_state
666670

667-
def _train_sgd(self, iterations, display_every):
671+
def _train_sgd(self, iterations, display_every, verbose=1):
668672
for i in range(iterations):
669673
self.callbacks.on_epoch_begin()
670674
self.callbacks.on_batch_begin()
@@ -681,15 +685,15 @@ def _train_sgd(self, iterations, display_every):
681685
self.train_state.epoch += 1
682686
self.train_state.step += 1
683687
if self.train_state.step % display_every == 0 or i + 1 == iterations:
684-
self._test()
688+
self._test(verbose=verbose)
685689

686690
self.callbacks.on_batch_end()
687691
self.callbacks.on_epoch_end()
688692

689693
if self.stop_training:
690694
break
691695

692-
def _train_tensorflow_compat_v1_scipy(self, display_every):
696+
def _train_tensorflow_compat_v1_scipy(self, display_every, verbose=1):
693697
def loss_callback(loss_train, loss_test, *args):
694698
self.train_state.epoch += 1
695699
self.train_state.step += 1
@@ -703,7 +707,8 @@ def loss_callback(loss_train, loss_test, *args):
703707
self.train_state.loss_test,
704708
None,
705709
)
706-
display.training_display(self.train_state)
710+
if verbose > 0:
711+
display.training_display(self.train_state)
707712
for cb in self.callbacks.callbacks:
708713
if type(cb).__name__ == "VariableValue":
709714
cb.epochs_since_last += 1
@@ -736,9 +741,9 @@ def loss_callback(loss_train, loss_test, *args):
736741
fetches=fetches,
737742
loss_callback=loss_callback,
738743
)
739-
self._test()
744+
self._test(verbose=verbose)
740745

741-
def _train_tensorflow_tfp(self):
746+
def _train_tensorflow_tfp(self, verbose=1):
742747
# There is only one optimization step. If using multiple steps with/without
743748
# previous_optimizer_results, L-BFGS failed to reach a small error. The reason
744749
# could be that tfp.optimizer.lbfgs_minimize will start from scratch for each
@@ -756,12 +761,12 @@ def _train_tensorflow_tfp(self):
756761
n_iter += results.num_iterations.numpy()
757762
self.train_state.epoch += results.num_iterations.numpy()
758763
self.train_state.step += results.num_iterations.numpy()
759-
self._test()
764+
self._test(verbose=verbose)
760765

761766
if results.converged or results.failed:
762767
break
763768

764-
def _train_pytorch_lbfgs(self):
769+
def _train_pytorch_lbfgs(self, verbose=1):
765770
prev_n_iter = 0
766771
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
767772
self.callbacks.on_epoch_begin()
@@ -784,15 +789,15 @@ def _train_pytorch_lbfgs(self):
784789
self.train_state.epoch += n_iter - prev_n_iter
785790
self.train_state.step += n_iter - prev_n_iter
786791
prev_n_iter = n_iter
787-
self._test()
792+
self._test(verbose=verbose)
788793

789794
self.callbacks.on_batch_end()
790795
self.callbacks.on_epoch_end()
791796

792797
if self.stop_training:
793798
break
794799

795-
def _train_paddle_lbfgs(self):
800+
def _train_paddle_lbfgs(self, verbose=1):
796801
prev_n_iter = 0
797802

798803
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
@@ -816,15 +821,15 @@ def _train_paddle_lbfgs(self):
816821
self.train_state.epoch += n_iter - prev_n_iter
817822
self.train_state.step += n_iter - prev_n_iter
818823
prev_n_iter = n_iter
819-
self._test()
824+
self._test(verbose=verbose)
820825

821826
self.callbacks.on_batch_end()
822827
self.callbacks.on_epoch_end()
823828

824829
if self.stop_training:
825830
break
826831

827-
def _test(self):
832+
def _test(self, verbose=1):
828833
# TODO Now only print the training loss in rank 0. The correct way is to print the average training loss of all ranks.
829834
(
830835
self.train_state.y_pred_train,
@@ -867,7 +872,7 @@ def _test(self):
867872
or np.isnan(self.train_state.loss_test).any()
868873
):
869874
self.stop_training = True
870-
if config.rank == 0:
875+
if verbose > 0 and config.rank == 0:
871876
display.training_display(self.train_state)
872877

873878
def predict(self, x, operator=None, callbacks=None):

deepxde/utils/internal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ def timing(f):
1818

1919
@wraps(f)
2020
def wrapper(*args, **kwargs):
21+
verbose = kwargs.get('verbose', 1)
2122
ts = timeit.default_timer()
2223
result = f(*args, **kwargs)
2324
te = timeit.default_timer()
24-
if config.rank == 0:
25+
if verbose > 0 and config.rank == 0:
2526
print("%r took %f s\n" % (f.__name__, te - ts))
2627
sys.stdout.flush()
2728
return result

0 commit comments

Comments
 (0)