Skip to content

Commit c776388

Browse files
authored
Merge pull request #410 from jcapriot/bugfix/tree_mesh_interp_zero_out
Fix logic error when zeros_outside was True
2 parents d08dd68 + 0253e0a commit c776388

File tree

3 files changed

+216
-285
lines changed

3 files changed

+216
-285
lines changed

discretize/_extensions/tree_ext.pyx

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5813,26 +5813,33 @@ cdef class _TreeMesh:
58135813
cell = self.tree.containing_cell(x, y, z)
58145814
row_inds = indices[indptr[i]:indptr[i+1]]
58155815
row_data = data[indptr[i]:indptr[i+1]]
5816+
was_outside = False
58165817
if zeros_out:
58175818
if x < cell.points[0].location[0]-eps:
58185819
row_data[:] = 0.0
58195820
row_inds[:] = 0
5821+
was_outside = True
58205822
elif x > cell.points[3].location[0]+eps:
58215823
row_data[:] = 0.0
58225824
row_inds[:] = 0
5825+
was_outside = True
58235826
elif y < cell.points[0].location[1]-eps:
58245827
row_data[:] = 0.0
58255828
row_inds[:] = 0
5829+
was_outside = True
58265830
elif y > cell.points[3].location[1]+eps:
58275831
row_data[:] = 0.0
58285832
row_inds[:] = 0
5833+
was_outside = True
58295834
elif dim == 3 and z < cell.points[0].location[2]-eps:
58305835
row_data[:] = 0.0
58315836
row_inds[:] = 0
5837+
was_outside = True
58325838
elif dim == 3 and z > cell.points[7].location[2]+eps:
58335839
row_data[:] = 0.0
58345840
row_inds[:] = 0
5835-
else:
5841+
was_outside = True
5842+
if not was_outside:
58365843
# look + dir and - dir away
58375844
if (
58385845
locations[i, dir] < cell.location[dir]
@@ -5964,26 +5971,33 @@ cdef class _TreeMesh:
59645971
cell = self.tree.containing_cell(x, y, z)
59655972
row_inds = indices[indptr[i]:indptr[i+1]]
59665973
row_data = data[indptr[i]:indptr[i+1]]
5974+
was_outside = False
59675975
if zeros_out:
59685976
if x < cell.points[0].location[0]-eps:
59695977
row_data[:] = 0.0
59705978
row_inds[:] = 0
5979+
was_outside = True
59715980
elif x > cell.points[3].location[0]+eps:
59725981
row_data[:] = 0.0
59735982
row_inds[:] = 0
5983+
was_outside = True
59745984
elif y < cell.points[0].location[1]-eps:
59755985
row_data[:] = 0.0
59765986
row_inds[:] = 0
5987+
was_outside = True
59775988
elif y > cell.points[3].location[1]+eps:
59785989
row_data[:] = 0.0
59795990
row_inds[:] = 0
5991+
was_outside = True
59805992
elif dim == 3 and z < cell.points[0].location[2]-eps:
59815993
row_data[:] = 0.0
59825994
row_inds[:] = 0
5995+
was_outside = True
59835996
elif dim == 3 and z > cell.points[7].location[2]+eps:
59845997
row_data[:] = 0.0
59855998
row_inds[:] = 0
5986-
else:
5999+
was_outside = True
6000+
if not was_outside:
59876001
# Find containing cells
59886002
# Decide order to search based on which face it is closest to
59896003
if dim == 3:
@@ -6219,26 +6233,33 @@ cdef class _TreeMesh:
62196233
cell = self.tree.containing_cell(x, y, z)
62206234
row_inds = indices[indptr[i]:indptr[i + 1]]
62216235
row_data = data[indptr[i]:indptr[i + 1]]
6236+
was_outside = False
62226237
if zeros_out:
62236238
if x < cell.points[0].location[0]-eps:
62246239
row_data[:] = 0.0
62256240
row_inds[:] = 0
6241+
was_outside = True
62266242
elif x > cell.points[3].location[0]+eps:
62276243
row_data[:] = 0.0
62286244
row_inds[:] = 0
6245+
was_outside = True
62296246
elif y < cell.points[0].location[1]-eps:
62306247
row_data[:] = 0.0
62316248
row_inds[:] = 0
6249+
was_outside = True
62326250
elif y > cell.points[3].location[1]+eps:
62336251
row_data[:] = 0.0
62346252
row_inds[:] = 0
6253+
was_outside = True
62356254
elif dim == 3 and z < cell.points[0].location[2]-eps:
62366255
row_data[:] = 0.0
62376256
row_inds[:] = 0
6257+
was_outside = True
62386258
elif dim == 3 and z > cell.points[7].location[2]+eps:
62396259
row_data[:] = 0.0
62406260
row_inds[:] = 0
6241-
else:
6261+
was_outside = True
6262+
if not was_outside:
62426263
# decide order to search based on distance to each faces
62436264
#
62446265
if (

tests/tree/test_tree.py

Lines changed: 11 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -438,99 +438,6 @@ def test_cell_bounds(self, mesh):
438438
np.testing.assert_equal(cell_bounds, cell_bounds_slow)
439439

440440

441-
class Test2DInterpolation(unittest.TestCase):
442-
def setUp(self):
443-
def topo(x):
444-
return np.sin(x * (2.0 * np.pi)) * 0.3 + 0.5
445-
446-
def function(cell):
447-
r = cell.center - np.array([0.5] * len(cell.center))
448-
dist1 = np.sqrt(r.dot(r)) - 0.08
449-
dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
450-
451-
dist = min([dist1, dist2])
452-
# if dist < 0.05:
453-
# return 5
454-
if dist < 0.05:
455-
return 6
456-
if dist < 0.2:
457-
return 5
458-
if dist < 0.3:
459-
return 4
460-
if dist < 1.0:
461-
return 3
462-
else:
463-
return 0
464-
465-
M = discretize.TreeMesh([64, 64], levels=6)
466-
M.refine(function)
467-
self.M = M
468-
469-
def test_fx(self):
470-
r = rng.random(self.M.nFx)
471-
P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx")
472-
self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL)
473-
474-
def test_fy(self):
475-
r = rng.random(self.M.nFy)
476-
P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy")
477-
self.assertLess(np.abs(P[:, self.M.nFx :] * r - r).max(), TOL)
478-
479-
480-
class Test3DInterpolation(unittest.TestCase):
481-
def setUp(self):
482-
def function(cell):
483-
r = cell.center - np.array([0.5] * len(cell.center))
484-
dist = np.sqrt(r.dot(r))
485-
if dist < 0.2:
486-
return 4
487-
if dist < 0.3:
488-
return 3
489-
if dist < 1.0:
490-
return 2
491-
else:
492-
return 0
493-
494-
M = discretize.TreeMesh([16, 16, 16], levels=4)
495-
M.refine(function)
496-
# M.plot_grid(show_it=True)
497-
self.M = M
498-
499-
def test_Fx(self):
500-
r = rng.random(self.M.nFx)
501-
P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx")
502-
self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL)
503-
504-
def test_Fy(self):
505-
r = rng.random(self.M.nFy)
506-
P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy")
507-
self.assertLess(
508-
np.abs(P[:, self.M.nFx : (self.M.nFx + self.M.nFy)] * r - r).max(), TOL
509-
)
510-
511-
def test_Fz(self):
512-
r = rng.random(self.M.nFz)
513-
P = self.M.get_interpolation_matrix(self.M.gridFz, "Fz")
514-
self.assertLess(np.abs(P[:, (self.M.nFx + self.M.nFy) :] * r - r).max(), TOL)
515-
516-
def test_Ex(self):
517-
r = rng.random(self.M.nEx)
518-
P = self.M.get_interpolation_matrix(self.M.gridEx, "Ex")
519-
self.assertLess(np.abs(P[:, : self.M.nEx] * r - r).max(), TOL)
520-
521-
def test_Ey(self):
522-
r = rng.random(self.M.nEy)
523-
P = self.M.get_interpolation_matrix(self.M.gridEy, "Ey")
524-
self.assertLess(
525-
np.abs(P[:, self.M.nEx : (self.M.nEx + self.M.nEy)] * r - r).max(), TOL
526-
)
527-
528-
def test_Ez(self):
529-
r = rng.random(self.M.nEz)
530-
P = self.M.get_interpolation_matrix(self.M.gridEz, "Ez")
531-
self.assertLess(np.abs(P[:, (self.M.nEx + self.M.nEy) :] * r - r).max(), TOL)
532-
533-
534441
class TestWrapAroundLevels(unittest.TestCase):
535442
def test_refine_func(self):
536443
mesh1 = discretize.TreeMesh((16, 16, 16))
@@ -649,5 +556,16 @@ def test_repr_html(self, mesh, finalize):
649556
assert len(output) != 0
650557

651558

559+
@pytest.mark.parametrize("attr", ["average_edge_to_face"])
560+
def test_caching(attr):
561+
mesh = discretize.TreeMesh([4, 4, 4])
562+
mesh.refine(-1)
563+
564+
attr1 = getattr(mesh, attr)
565+
attr2 = getattr(mesh, attr)
566+
567+
assert attr1 is attr2
568+
569+
652570
if __name__ == "__main__":
653571
unittest.main()

0 commit comments

Comments
 (0)