@@ -93,16 +93,17 @@ def __init__(
9393 self .num_boundary = num_boundary
9494 self .train_distribution = train_distribution
9595 if config .hvd is not None :
96- print (
97- "When parallel training via Horovod, num_domain and num_boundary are the numbers of points over each rank, not the total number of points."
98- )
9996 if self .train_distribution != "pseudo" :
10097 raise ValueError (
10198 "Parallel training via Horovod only supports pseudo train distribution."
10299 )
103- if config .parallel_scaling == "strong" :
104- raise ValueError (
105- "Strong scaling is not supported with tensorflow.compat.v1. Please use weak scaling."
100+ if config .parallel_scaling == "weak" :
101+ print (
102+ "For weak scaling, num_domain and num_boundary are the numbers of points over each rank, not the total number of points."
103+ )
104+ elif config .parallel_scaling == "strong" :
105+ print (
106+ "For strong scaling, num_domain and num_boundary are the total number of points."
106107 )
107108 self .anchors = None if anchors is None else anchors .astype (config .real (np ))
108109 self .exclusions = exclusions
@@ -171,6 +172,11 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
171172
172173 @run_if_all_none ("train_x" , "train_y" , "train_aux_vars" )
173174 def train_next_batch (self , batch_size = None ):
175+ if config .parallel_scaling == "strong" :
176+ # Todo: Split the domain training points over rank for strong scaling.
177+ raise ValueError (
178+ "Strong scaling is not supported yet with tensorflow.compat.v1. Please use weak scaling."
179+ )
174180 self .train_x_all = self .train_points ()
175181 self .bc_points () # Generate self.num_bcs and self.train_x_bc
176182 if self .bcs and config .hvd is not None :
0 commit comments