@@ -65,6 +65,7 @@ def compile(
6565 decay = None ,
6666 loss_weights = None ,
6767 external_trainable_variables = None ,
68+ verbose = 1 ,
6869 ):
6970 """Configures the model for training.
7071
@@ -114,8 +115,9 @@ def compile(
114115 physics systems that need to be recovered. If the backend is
115116 tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
116117 trainable ``dde.Variable`` objects are automatically collected.
118+ verbose (Integer): Controls the verbosity of the compile process.
117119 """
118- if config .rank == 0 :
120+ if verbose > 0 and config .rank == 0 :
119121 print ("Compiling model..." )
120122 self .opt_name = optimizer
121123 loss_fn = losses_module .get (loss )
@@ -585,6 +587,7 @@ def train(
585587 model_restore_path = None ,
586588 model_save_path = None ,
587589 epochs = None ,
590+ verbose = 1 ,
588591 ):
589592 """Trains the model.
590593
@@ -610,6 +613,7 @@ def train(
610613 model_save_path (String): Prefix of filenames created for the checkpoint.
611614 epochs (Integer): Deprecated alias to `iterations`. This will be removed in
612615 a future version.
616+ verbose (Integer): Controls the verbosity of the train process.
613617 """
614618 if iterations is None and epochs is not None :
615619 print (
@@ -635,36 +639,36 @@ def train(
635639 if model_restore_path is not None :
636640 self .restore (model_restore_path , verbose = 1 )
637641
638- if config .rank == 0 :
642+ if verbose > 0 and config .rank == 0 :
639643 print ("Training model...\n " )
640644 self .stop_training = False
641645 self .train_state .set_data_train (* self .data .train_next_batch (self .batch_size ))
642646 self .train_state .set_data_test (* self .data .test ())
643- self ._test ()
647+ self ._test (verbose = verbose )
644648 self .callbacks .on_train_begin ()
645649 if optimizers .is_external_optimizer (self .opt_name ):
646650 if backend_name == "tensorflow.compat.v1" :
647- self ._train_tensorflow_compat_v1_scipy (display_every )
651+ self ._train_tensorflow_compat_v1_scipy (display_every , verbose = verbose )
648652 elif backend_name == "tensorflow" :
649- self ._train_tensorflow_tfp ()
653+ self ._train_tensorflow_tfp (verbose = verbose )
650654 elif backend_name == "pytorch" :
651- self ._train_pytorch_lbfgs ()
655+ self ._train_pytorch_lbfgs (verbose = verbose )
652656 elif backend_name == "paddle" :
653- self ._train_paddle_lbfgs ()
657+ self ._train_paddle_lbfgs (verbose = verbose )
654658 else :
655659 if iterations is None :
656660 raise ValueError ("No iterations for {}." .format (self .opt_name ))
657- self ._train_sgd (iterations , display_every )
661+ self ._train_sgd (iterations , display_every , verbose = verbose )
658662 self .callbacks .on_train_end ()
659663
660- if config .rank == 0 :
664+ if verbose > 0 and config .rank == 0 :
661665 print ("" )
662666 display .training_display .summary (self .train_state )
663667 if model_save_path is not None :
664668 self .save (model_save_path , verbose = 1 )
665669 return self .losshistory , self .train_state
666670
667- def _train_sgd (self , iterations , display_every ):
671+ def _train_sgd (self , iterations , display_every , verbose = 1 ):
668672 for i in range (iterations ):
669673 self .callbacks .on_epoch_begin ()
670674 self .callbacks .on_batch_begin ()
@@ -681,15 +685,15 @@ def _train_sgd(self, iterations, display_every):
681685 self .train_state .epoch += 1
682686 self .train_state .step += 1
683687 if self .train_state .step % display_every == 0 or i + 1 == iterations :
684- self ._test ()
688+ self ._test (verbose = verbose )
685689
686690 self .callbacks .on_batch_end ()
687691 self .callbacks .on_epoch_end ()
688692
689693 if self .stop_training :
690694 break
691695
692- def _train_tensorflow_compat_v1_scipy (self , display_every ):
696+ def _train_tensorflow_compat_v1_scipy (self , display_every , verbose = 1 ):
693697 def loss_callback (loss_train , loss_test , * args ):
694698 self .train_state .epoch += 1
695699 self .train_state .step += 1
@@ -703,7 +707,8 @@ def loss_callback(loss_train, loss_test, *args):
703707 self .train_state .loss_test ,
704708 None ,
705709 )
706- display .training_display (self .train_state )
710+ if verbose > 0 :
711+ display .training_display (self .train_state )
707712 for cb in self .callbacks .callbacks :
708713 if type (cb ).__name__ == "VariableValue" :
709714 cb .epochs_since_last += 1
@@ -736,9 +741,9 @@ def loss_callback(loss_train, loss_test, *args):
736741 fetches = fetches ,
737742 loss_callback = loss_callback ,
738743 )
739- self ._test ()
744+ self ._test (verbose = verbose )
740745
741- def _train_tensorflow_tfp (self ):
746+ def _train_tensorflow_tfp (self , verbose = 1 ):
742747 # There is only one optimization step. If using multiple steps with/without
743748 # previous_optimizer_results, L-BFGS failed to reach a small error. The reason
744749 # could be that tfp.optimizer.lbfgs_minimize will start from scratch for each
@@ -756,12 +761,12 @@ def _train_tensorflow_tfp(self):
756761 n_iter += results .num_iterations .numpy ()
757762 self .train_state .epoch += results .num_iterations .numpy ()
758763 self .train_state .step += results .num_iterations .numpy ()
759- self ._test ()
764+ self ._test (verbose = verbose )
760765
761766 if results .converged or results .failed :
762767 break
763768
764- def _train_pytorch_lbfgs (self ):
769+ def _train_pytorch_lbfgs (self , verbose = 1 ):
765770 prev_n_iter = 0
766771 while prev_n_iter < optimizers .LBFGS_options ["maxiter" ]:
767772 self .callbacks .on_epoch_begin ()
@@ -784,15 +789,15 @@ def _train_pytorch_lbfgs(self):
784789 self .train_state .epoch += n_iter - prev_n_iter
785790 self .train_state .step += n_iter - prev_n_iter
786791 prev_n_iter = n_iter
787- self ._test ()
792+ self ._test (verbose = verbose )
788793
789794 self .callbacks .on_batch_end ()
790795 self .callbacks .on_epoch_end ()
791796
792797 if self .stop_training :
793798 break
794799
795- def _train_paddle_lbfgs (self ):
800+ def _train_paddle_lbfgs (self , verbose = 1 ):
796801 prev_n_iter = 0
797802
798803 while prev_n_iter < optimizers .LBFGS_options ["maxiter" ]:
@@ -816,15 +821,15 @@ def _train_paddle_lbfgs(self):
816821 self .train_state .epoch += n_iter - prev_n_iter
817822 self .train_state .step += n_iter - prev_n_iter
818823 prev_n_iter = n_iter
819- self ._test ()
824+ self ._test (verbose = verbose )
820825
821826 self .callbacks .on_batch_end ()
822827 self .callbacks .on_epoch_end ()
823828
824829 if self .stop_training :
825830 break
826831
827- def _test (self ):
832+ def _test (self , verbose = 1 ):
828833 # TODO Now only print the training loss in rank 0. The correct way is to print the average training loss of all ranks.
829834 (
830835 self .train_state .y_pred_train ,
@@ -867,7 +872,7 @@ def _test(self):
867872 or np .isnan (self .train_state .loss_test ).any ()
868873 ):
869874 self .stop_training = True
870- if config .rank == 0 :
875+ if verbose > 0 and config .rank == 0 :
871876 display .training_display (self .train_state )
872877
873878 def predict (self , x , operator = None , callbacks = None ):
0 commit comments