Skip to content

Commit a43079c

Browse files
Bug fix: PointSetBC supports multi-component outputs (#1058)
1 parent 6035ad5 commit a43079c

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

deepxde/icbc/boundary_conditions.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ class PointSetBC:
171171
Args:
172172
points: An array of points where the corresponding target values are known and
173173
used for training.
174-
values: A 2D-array of values that gives the exact solution of the problem.
174+
values: A scalar or a 2D-array of values that gives the exact solution of the problem.
175175
component: Integer or a list of integers. The output components satisfying this BC.
176+
List of integers only supported for the backend PyTorch.
176177
batch_size: The number of points per minibatch, or `None` to return all points.
177178
This is only supported for the backend PyTorch.
178179
shuffle: Randomize the order on each pass through the data when batching.
@@ -183,10 +184,12 @@ def __init__(
183184
):
184185
self.points = np.array(points, dtype=config.real(np))
185186
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
186-
if isinstance(component, numbers.Number):
187-
self.component = [component]
188-
else:
189-
self.component = component
187+
self.component = component
188+
if isinstance(component, list) and backend_name != "pytorch":
189+
# TODO: Add support for multiple components in other backends
190+
raise RuntimeError(
191+
"multiple components only implemented for pytorch backend"
192+
)
190193
self.batch_size = batch_size
191194

192195
if batch_size is not None: # batch iterator and state
@@ -210,10 +213,31 @@ def collocation_points(self, X):
210213

211214
def error(self, X, inputs, outputs, beg, end, aux_var=None):
212215
if self.batch_size is not None:
216+
if isinstance(self.component, numbers.Number):
217+
return (
218+
outputs[beg:end, self.component : self.component + 1]
219+
- self.values[self.batch_indices]
220+
)
213221
return (
214222
outputs[beg:end, self.component]
215223
- self.values[self.batch_indices]
216224
)
225+
if isinstance(self.component, numbers.Number):
226+
return (
227+
outputs[beg:end, self.component : self.component + 1]
228+
- self.values
229+
)
230+
# When a concat is provided, the following code works 'fast' in paddle cpu,
231+
# and slow in both tensorflow backends, jax untested.
232+
# tf.gather can be used instead of for loop but is also slow
233+
# if len(self.component) > 1:
234+
# calculated_error = outputs[beg:end, self.component[0]] - self.values[:,0]
235+
# for i in range(1,len(self.component)):
236+
# tmp = outputs[beg:end, self.component[i]] - self.values[:,i]
237+
# calculated_error = bkd.lib.concat([calculated_error,tmp],axis=0)
238+
# else:
239+
# calculated_error = outputs[beg:end, self.component[0]] - self.values
240+
# return calculated_error
217241
return outputs[beg:end, self.component] - self.values
218242

219243

0 commit comments

Comments
 (0)