Skip to content

Commit 21b33c7

Browse files
committed
PointSet supports default value
1 parent 3c2daed commit 21b33c7

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

deepxde/boundary_conditions.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def error(self, X, inputs, outputs, beg, end):
4343

4444

4545
class DirichletBC(BC):
46-
"""Dirichlet boundary conditions: y(x) = func(x).
47-
"""
46+
"""Dirichlet boundary conditions: y(x) = func(x)."""
4847

4948
def __init__(self, geom, func, on_boundary, component=0):
5049
super(DirichletBC, self).__init__(geom, on_boundary, component)
@@ -60,8 +59,7 @@ def error(self, X, inputs, outputs, beg, end):
6059

6160

6261
class NeumannBC(BC):
63-
"""Neumann boundary conditions: dy/dn(x) = func(x).
64-
"""
62+
"""Neumann boundary conditions: dy/dn(x) = func(x)."""
6563

6664
def __init__(self, geom, func, on_boundary, component=0):
6765
super(NeumannBC, self).__init__(geom, on_boundary, component)
@@ -74,8 +72,7 @@ def error(self, X, inputs, outputs, beg, end):
7472

7573

7674
class RobinBC(BC):
77-
"""Robin boundary conditions: dy/dn(x) = func(x, y).
78-
"""
75+
"""Robin boundary conditions: dy/dn(x) = func(x, y)."""
7976

8077
def __init__(self, geom, func, on_boundary, component=0):
8178
super(RobinBC, self).__init__(geom, on_boundary, component)
@@ -88,8 +85,7 @@ def error(self, X, inputs, outputs, beg, end):
8885

8986

9087
class PeriodicBC(BC):
91-
"""Periodic boundary conditions on component_x.
92-
"""
88+
"""Periodic boundary conditions on component_x."""
9389

9490
def __init__(self, geom, component_x, on_boundary, derivative_order=0, component=0):
9591
super(PeriodicBC, self).__init__(geom, on_boundary, component)
@@ -138,8 +134,7 @@ def error(self, X, inputs, outputs, beg, end):
138134

139135

140136
class PointSet(object):
141-
"""A set of points.
142-
"""
137+
"""A set of points."""
143138

144139
def __init__(self, points):
145140
self.points = np.array(points)
@@ -150,12 +145,11 @@ def inside(self, x):
150145
axis=-1,
151146
)
152147

153-
def values_to_func(self, values):
148+
def values_to_func(self, values, default_value=0):
154149
def func(x):
155-
return np.matmul(
156-
np.all(np.isclose(x[:, np.newaxis, :], self.points), axis=-1),
157-
values,
158-
)
150+
pt_equal = np.all(np.isclose(x[:, np.newaxis, :], self.points), axis=-1)
151+
not_inside = np.logical_not(np.any(pt_equal, axis=-1, keepdims=True))
152+
return np.matmul(pt_equal, values) + default_value * not_inside
159153

160154
return func
161155

0 commit comments

Comments
 (0)