@@ -171,8 +171,9 @@ class PointSetBC:
171171 Args:
172172 points: An array of points where the corresponding target values are known and
173173 used for training.
174- values: A 2D-array of values that gives the exact solution of the problem.
174+ values: A scalar or a 2D-array of values that gives the exact solution of the problem.
175175 component: Integer or a list of integers. The output components satisfying this BC.
176+ List of integers only supported for the backend PyTorch.
176177 batch_size: The number of points per minibatch, or `None` to return all points.
177178 This is only supported for the backend PyTorch.
178179 shuffle: Randomize the order on each pass through the data when batching.
@@ -183,10 +184,12 @@ def __init__(
183184 ):
184185 self .points = np .array (points , dtype = config .real (np ))
185186 self .values = bkd .as_tensor (values , dtype = config .real (bkd .lib ))
186- if isinstance (component , numbers .Number ):
187- self .component = [component ]
188- else :
189- self .component = component
187+ self .component = component
188+ if isinstance (component , list ) and backend_name != "pytorch" :
189+ # TODO: Add support for multiple components in other backends
190+ raise RuntimeError (
191+ "multiple components only implemented for pytorch backend"
192+ )
190193 self .batch_size = batch_size
191194
192195 if batch_size is not None : # batch iterator and state
@@ -210,10 +213,31 @@ def collocation_points(self, X):
210213
211214 def error (self , X , inputs , outputs , beg , end , aux_var = None ):
212215 if self .batch_size is not None :
216+ if isinstance (self .component , numbers .Number ):
217+ return (
218+ outputs [beg :end , self .component : self .component + 1 ]
219+ - self .values [self .batch_indices ]
220+ )
213221 return (
214222 outputs [beg :end , self .component ]
215223 - self .values [self .batch_indices ]
216224 )
225+ if isinstance (self .component , numbers .Number ):
226+ return (
227+ outputs [beg :end , self .component : self .component + 1 ]
228+ - self .values
229+ )
230+ # When a concat is provided, the following code works 'fast' in paddle cpu,
231+ # and slow in both tensorflow backends, jax untested.
232+ # tf.gather can be used instead of for loop but is also slow
233+ # if len(self.component) > 1:
234+ # calculated_error = outputs[beg:end, self.component[0]] - self.values[:,0]
235+ # for i in range(1,len(self.component)):
236+ # tmp = outputs[beg:end, self.component[i]] - self.values[:,i]
237+ # calculated_error = bkd.lib.concat([calculated_error,tmp],axis=0)
238+ # else:
239+ # calculated_error = outputs[beg:end, self.component[0]] - self.values
240+ # return calculated_error
217241 return outputs [beg :end , self .component ] - self .values
218242
219243
0 commit comments