Skip to content

Commit 1768af0

Browse files
committed
Fix dtype of Geometry.boundary_normal
1 parent fcbcb41 commit 1768af0

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

deepxde/geometry/geometry_2d.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from .geometry import Geometry
99
from .geometry_nd import Hypercube
1010
from .sampler import sample
11+
from .. import config
1112
from ..utils import vectorize
1213

1314

1415
class Disk(Geometry):
1516
def __init__(self, center, radius):
16-
self.center, self.radius = np.array(center), radius
17+
self.center = np.array(center, dtype=config.real(np))
18+
self.radius = radius
1719
super(Disk, self).__init__(
1820
2, (self.center - radius, self.center + radius), 2 * radius
1921
)
@@ -181,9 +183,9 @@ def __init__(self, x1, x2, x3):
181183
self.area = -self.area
182184
x2, x3 = x3, x2
183185

184-
self.x1 = np.array(x1)
185-
self.x2 = np.array(x2)
186-
self.x3 = np.array(x3)
186+
self.x1 = np.array(x1, dtype=config.real(np))
187+
self.x2 = np.array(x2, dtype=config.real(np))
188+
self.x3 = np.array(x3, dtype=config.real(np))
187189

188190
self.v12 = self.x2 - self.x1
189191
self.v23 = self.x3 - self.x2
@@ -338,7 +340,7 @@ class Polygon(Geometry):
338340
"""
339341

340342
def __init__(self, vertices):
341-
self.vertices = np.array(vertices)
343+
self.vertices = np.array(vertices, dtype=config.real(np))
342344
if len(vertices) == 3:
343345
raise ValueError("The polygon is a triangle. Use Triangle instead.")
344346
if Rectangle.is_valid(self.vertices):

deepxde/geometry/geometry_nd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def on_boundary(self, x):
4141
return np.logical_and(self.inside(x), _on_boundary)
4242

4343
def boundary_normal(self, x):
44-
_n = np.isclose(x, self.xmin) * -1.0 + np.isclose(x, self.xmax) * 1.0
44+
_n = -np.isclose(x, self.xmin).astype(config.real(np)) + np.isclose(
45+
x, self.xmax
46+
)
4547
# For vertices, the normal is averaged for all directions
4648
idx = np.count_nonzero(_n, axis=-1) > 1
4749
if np.any(idx):
@@ -97,7 +99,8 @@ def periodic_point(self, x, component):
9799

98100
class Hypersphere(Geometry):
99101
def __init__(self, center, radius):
100-
self.center, self.radius = np.array(center), radius
102+
self.center = np.array(center, dtype=config.real(np))
103+
self.radius = radius
101104
super(Hypersphere, self).__init__(
102105
len(center), (self.center - radius, self.center + radius), 2 * radius
103106
)
@@ -144,7 +147,7 @@ def random_points(self, n, random="pseudo"):
144147
def random_boundary_points(self, n, random="pseudo"):
145148
"""http://mathworld.wolfram.com/HyperspherePointPicking.html"""
146149
if random == "pseudo":
147-
X = np.random.normal(size=(n, self.dim))
150+
X = np.random.normal(size=(n, self.dim)).astype(config.real(np))
148151
else:
149152
U = sample(n, self.dim, random)
150153
X = stats.norm.ppf(U)

0 commit comments

Comments
 (0)