@@ -48,6 +48,8 @@ def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
4848 # self.X_f_in = np.asarray(tmp)
4949 self .X_f_in = [tf .cast (np .reshape (vec , (- 1 , 1 )), tf .float32 ) for i , vec in enumerate (self .domain .X_f .T )]
5050 self .u_model = neural_net (self .layer_sizes )
51+ self .batch = None
52+ self .batch_indx_map = None
5153 self .lambdas = self .dict_adaptive = self .lambdas_map = None
5254 self .isAdaptive = isAdaptive
5355
@@ -83,11 +85,12 @@ def update_loss(self):
8385 #####################################
8486 # Check if adaptive is allowed
8587 if self .isAdaptive :
86- idx_lambda_bcs = self .lambdas_map ['bcs' ][0 ]
88+ if len (self .lambdas_map ['bcs' ]) > 0 :
89+ idx_lambda_bcs = self .lambdas_map ['bcs' ][0 ]
8790
8891 for counter_bc , bc in enumerate (self .bcs ):
8992 loss_bc = 0.
90- # Check if the current BS is adaptive
93+ # Check if the current BC is adaptive
9194 if self .isAdaptive :
9295 isBC_adaptive = self .dict_adaptive ["BCs" ][counter_bc ]
9396 else :
@@ -142,7 +145,16 @@ def update_loss(self):
142145 # Residual Equations
143146 #####################################
144147 # pass thorough the forward method
145- f_u_preds = self .f_model (self .u_model , * self .X_f_in )
148+ if self .n_batches > 1 :
149+ # The collocation points will be split based on the batch_indx_map
150+ # generated on the beginning of this epoch on models.train_op_inner.apply_grads
151+ X_batch = []
152+ for x_in in self .X_f_in :
153+ indx_on_batch = self .batch_indx_map [self .batch * self .batch_sz :(self .batch + 1 ) * self .batch_sz ]
154+ X_batch .append (tf .gather (x_in ,indx_on_batch ))
155+ f_u_preds = self .f_model (self .u_model , * X_batch )
156+ else :
157+ f_u_preds = self .f_model (self .u_model , * self .X_f_in )
146158
147159 # If it is only one residual, just convert it to a tuple of one element
148160 if not isinstance (f_u_preds , tuple ):
@@ -153,13 +165,23 @@ def update_loss(self):
153165 # Check if the current Residual is adaptive
154166 if self .isAdaptive :
155167 isRes_adaptive = self .dict_adaptive ["residual" ][counter_res ]
156- idx_lambda_res = self .lambdas_map ['residual' ][0 ]
157168 if isRes_adaptive :
169+ idx_lambda_res = self .lambdas_map ['residual' ][0 ]
170+ lambdas2loss = self .lambdas [idx_lambda_res ]
171+
172+ if self .n_batches > 1 :
173+ # select lambdas on minebatch
174+ lambdas2loss = tf .gather (lambdas2loss ,indx_on_batch )
175+
158176 if self .g is not None :
159- loss_r = g_MSE (f_u_pred , constant (0.0 ), self .g (self . lambdas [ idx_lambda_res ] ))
177+ loss_r = g_MSE (f_u_pred , constant (0.0 ), self .g (lambdas2loss ))
160178 else :
161- loss_r = MSE (f_u_pred , constant (0.0 ), self . lambdas [ idx_lambda_res ] )
179+ loss_r = MSE (f_u_pred , constant (0.0 ), lambdas2loss )
162180 idx_lambda_res += 1
181+ else :
182+ # In the case where the model is Adaptive but the residual
183+ # is not adaptive, the residual loss should be computed.
184+ loss_r = MSE (f_u_pred , constant (0.0 ))
163185 else :
164186 loss_r = MSE (f_u_pred , constant (0.0 ))
165187
@@ -177,8 +199,18 @@ def grad(self):
177199 return loss_value , grads
178200
179201 def fit (self , tf_iter = 0 , newton_iter = 0 , batch_sz = None , newton_eager = True ):
180- if self .isAdaptive and (batch_sz is not None ):
181- raise Exception ("Currently we dont support minibatching for adaptive PINNs" )
202+
203+ # Can adjust batch size for collocation points, here we set it to N_f
204+ N_f = self .X_f_len [0 ]
205+ self .batch_sz = batch_sz if batch_sz is not None else N_f
206+ self .n_batches = N_f // self .batch_sz
207+
208+ if self .isAdaptive and self .dist :
209+ raise Exception ("Currently we dont support distributed training for adaptive PINNs" )
210+
211+ if self .n_batches > 1 and self .dist :
212+ raise Exception ("Currently we dont support distributed minibatching training" )
213+
182214 if self .dist :
183215 BUFFER_SIZE = len (self .X_f_in [0 ])
184216 EPOCHS = tf_iter
@@ -194,13 +226,6 @@ def fit(self, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
194226
195227 print ("Number of GPU devices: {}" .format (self .strategy .num_replicas_in_sync ))
196228
197- self .batch_sz = batch_sz if batch_sz is not None else len (self .X_f_in [0 ])
198- # weights_idx = tensor(list(range(len(self.x_f))), dtype=tf.int32)
199- # print(weights_idx)
200- # print(tf.gather(self.col_weights, weights_idx))
201- N_f = len (self .X_f_in [0 ])
202- self .n_batches = N_f // self .batch_sz
203-
204229 BATCH_SIZE_PER_REPLICA = self .batch_sz
205230 GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * self .strategy .num_replicas_in_sync
206231
@@ -229,7 +254,7 @@ def fit(self, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
229254 fit_dist (self , tf_iter = tf_iter , newton_iter = newton_iter , batch_sz = batch_sz , newton_eager = newton_eager )
230255
231256 else :
232- fit (self , tf_iter = tf_iter , newton_iter = newton_iter , batch_sz = batch_sz , newton_eager = newton_eager )
257+ fit (self , tf_iter = tf_iter , newton_iter = newton_iter , newton_eager = newton_eager )
233258
234259 # L-BFGS implementation from https://github.com/pierremtb/PINNs-TF2.0
235260 def get_loss_and_flat_grad (self ):
0 commit comments