Skip to content

Commit 503af5e

Browse files
committed
refactor interpolation tests
1 parent 3afee98 commit 503af5e

File tree

2 files changed

+192
-282
lines changed

2 files changed

+192
-282
lines changed

tests/tree/test_tree.py

Lines changed: 11 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -438,99 +438,6 @@ def test_cell_bounds(self, mesh):
438438
np.testing.assert_equal(cell_bounds, cell_bounds_slow)
439439

440440

441-
class Test2DInterpolation(unittest.TestCase):
442-
def setUp(self):
443-
def topo(x):
444-
return np.sin(x * (2.0 * np.pi)) * 0.3 + 0.5
445-
446-
def function(cell):
447-
r = cell.center - np.array([0.5] * len(cell.center))
448-
dist1 = np.sqrt(r.dot(r)) - 0.08
449-
dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
450-
451-
dist = min([dist1, dist2])
452-
# if dist < 0.05:
453-
# return 5
454-
if dist < 0.05:
455-
return 6
456-
if dist < 0.2:
457-
return 5
458-
if dist < 0.3:
459-
return 4
460-
if dist < 1.0:
461-
return 3
462-
else:
463-
return 0
464-
465-
M = discretize.TreeMesh([64, 64], levels=6)
466-
M.refine(function)
467-
self.M = M
468-
469-
def test_fx(self):
470-
r = rng.random(self.M.nFx)
471-
P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx")
472-
self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL)
473-
474-
def test_fy(self):
475-
r = rng.random(self.M.nFy)
476-
P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy")
477-
self.assertLess(np.abs(P[:, self.M.nFx :] * r - r).max(), TOL)
478-
479-
480-
class Test3DInterpolation(unittest.TestCase):
481-
def setUp(self):
482-
def function(cell):
483-
r = cell.center - np.array([0.5] * len(cell.center))
484-
dist = np.sqrt(r.dot(r))
485-
if dist < 0.2:
486-
return 4
487-
if dist < 0.3:
488-
return 3
489-
if dist < 1.0:
490-
return 2
491-
else:
492-
return 0
493-
494-
M = discretize.TreeMesh([16, 16, 16], levels=4)
495-
M.refine(function)
496-
# M.plot_grid(show_it=True)
497-
self.M = M
498-
499-
def test_Fx(self):
500-
r = rng.random(self.M.nFx)
501-
P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx")
502-
self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL)
503-
504-
def test_Fy(self):
505-
r = rng.random(self.M.nFy)
506-
P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy")
507-
self.assertLess(
508-
np.abs(P[:, self.M.nFx : (self.M.nFx + self.M.nFy)] * r - r).max(), TOL
509-
)
510-
511-
def test_Fz(self):
512-
r = rng.random(self.M.nFz)
513-
P = self.M.get_interpolation_matrix(self.M.gridFz, "Fz")
514-
self.assertLess(np.abs(P[:, (self.M.nFx + self.M.nFy) :] * r - r).max(), TOL)
515-
516-
def test_Ex(self):
517-
r = rng.random(self.M.nEx)
518-
P = self.M.get_interpolation_matrix(self.M.gridEx, "Ex")
519-
self.assertLess(np.abs(P[:, : self.M.nEx] * r - r).max(), TOL)
520-
521-
def test_Ey(self):
522-
r = rng.random(self.M.nEy)
523-
P = self.M.get_interpolation_matrix(self.M.gridEy, "Ey")
524-
self.assertLess(
525-
np.abs(P[:, self.M.nEx : (self.M.nEx + self.M.nEy)] * r - r).max(), TOL
526-
)
527-
528-
def test_Ez(self):
529-
r = rng.random(self.M.nEz)
530-
P = self.M.get_interpolation_matrix(self.M.gridEz, "Ez")
531-
self.assertLess(np.abs(P[:, (self.M.nEx + self.M.nEy) :] * r - r).max(), TOL)
532-
533-
534441
class TestWrapAroundLevels(unittest.TestCase):
535442
def test_refine_func(self):
536443
mesh1 = discretize.TreeMesh((16, 16, 16))
@@ -649,5 +556,16 @@ def test_repr_html(self, mesh, finalize):
649556
assert len(output) != 0
650557

651558

559+
@pytest.mark.parametrize("attr", ["average_edge_to_face"])
560+
def test_caching(attr):
561+
mesh = discretize.TreeMesh([4, 4, 4])
562+
mesh.refine(-1)
563+
564+
attr1 = getattr(mesh, attr)
565+
attr2 = getattr(mesh, attr)
566+
567+
assert attr1 is attr2
568+
569+
652570
if __name__ == "__main__":
653571
unittest.main()

0 commit comments

Comments
 (0)