Skip to content

Commit 3121d71

Browse files
committed
fix implementation and add tests
1 parent 5f97b1b commit 3121d71

File tree

3 files changed

+187
-34
lines changed

3 files changed

+187
-34
lines changed

discretize/_extensions/tree.cpp

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -418,46 +418,42 @@ void Cell::refine_func(node_map_t& nodes, function test_func, double *xs, double
418418
}
419419
}
420420

421-
void Cell::refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diagonal_balance){
421+
void Cell::refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diag_balance){
422422
// early exit if my level is higher than or equal to target
423423
if (level == max_level){
424424
return;
425425
}
426-
int_t start_ix = points[0].location_ind[0]/2;
427-
int_t start_iy = points[0].location_ind[1]/2;
428-
int_t start_iz = n_dim == 2 ? 0 : points[0].location_ind[2]/2;
429-
int_t nx = shape_cells[0]
430-
int_t ny = shape_cells[1]
431-
int_t nz = shape_cells[2]
432-
433-
// same as 2**(max_level - level), but quicker since power of 2 and integers
434-
int_t span = 1<<(max_level - level);
435-
int_t span_z = n_dim == 2 ? 1 : span;
436-
int_t i_image = (nz * ny) * start_ix + nz * start_iy + start_iz;
437-
double val_start = image[0];
438-
bool subdivide = false;
426+
int_t start_ix = points[0]->location_ind[0]/2;
427+
int_t start_iy = points[0]->location_ind[1]/2;
428+
int_t start_iz = n_dim == 2 ? 0 : points[0]->location_ind[2]/2;
429+
int_t end_ix = points[3]->location_ind[0]/2;
430+
int_t end_iy = points[3]->location_ind[1]/2;
431+
int_t end_iz = n_dim == 2? 1 : points[7]->location_ind[2]/2;
432+
int_t nx = shape_cells[0];
433+
int_t ny = shape_cells[1];
434+
int_t nz = shape_cells[2];
435+
436+
int_t i_image = (nx * ny) * start_iz + nx * start_iy + start_ix;
437+
double val_start = image[i_image];
438+
bool all_unique = true;
439439

440440
// if any of the image data contained in the cell are different, subdivide myself
441-
for(int_t ix=0; ix<span && !subdivide; ++ix){
442-
for(int_t iy=0; iy<span && !subdivide; ++iy){
443-
for(int_t iz=0; iz<span_z && !subdivide; ++iz){
444-
subdivide = image[i_image] != val_start;
445-
i_image += 1;
441+
for(int_t iz=start_iz; iz<end_iz && all_unique; ++iz)
442+
for(int_t iy=start_iy; iy<end_iy && all_unique; ++iy)
443+
for(int_t ix=start_ix; ix<end_ix && all_unique; ++ix){
444+
i_image = (nx * ny) * iz + nx * iy + ix;
445+
all_unique = image[i_image] == val_start;
446446
}
447-
i_image += nz;
448-
}
449-
i_image += nz * ny;
450-
}
451-
if(subdivide){
447+
448+
if(!all_unique){
452449
if(is_leaf()){
453450
divide(nodes, xs, ys, zs, true, diag_balance);
454451
}
455452
// recurse into children
456453
for(int_t i = 0; i < (1<<n_dim); ++i){
457-
children[i]->refine_image(nodes, geom, p_level, xs, ys, zs, diag_balance);
454+
children[i]->refine_image(nodes, image, shape_cells, xs, ys, zs, diag_balance);
458455
}
459456
}
460-
461457
}
462458

463459
void Cell::divide(node_map_t& nodes, double* xs, double* ys, double* zs, bool balance, bool diag_balance){
@@ -938,15 +934,15 @@ void Tree::refine_function(function test_func, bool diagonal_balance){
938934
roots[iz][iy][ix]->refine_func(nodes, test_func, xs, ys, zs, diagonal_balance);
939935
};
940936

941-
void Tree:refine_image(double *image, bool diagonal_balance){
942-
int_t[3] shape_cells;
937+
void Tree::refine_image(double *image, bool diagonal_balance){
938+
int_t shape_cells[3];
943939
shape_cells[0] = nx/2;
944940
shape_cells[1] = ny/2;
945941
shape_cells[2] = nz/2;
946942
for(int_t iz=0; iz<nz_roots; ++iz)
947943
for(int_t iy=0; iy<ny_roots; ++iy)
948944
for(int_t ix=0; ix<nx_roots; ++ix)
949-
roots[iz][iy][ix]->refine_image(nodes, image, xs, ys, zs, diagonal_balance);
945+
roots[iz][iy][ix]->refine_image(nodes, image, shape_cells, xs, ys, zs, diagonal_balance);
950946
}
951947

952948

discretize/_extensions/tree_ext.pyx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ cimport numpy as np
77
from libc.stdlib cimport malloc, free
88
from libcpp.vector cimport vector
99
from libcpp cimport bool
10-
from libc.math cimport INFINITYs
10+
from libc.math cimport INFINITY
1111

1212
from .tree cimport int_t, Tree as c_Tree, PyWrapper, Node, Edge, Face, Cell as c_Cell
1313
from . cimport geom
@@ -1229,21 +1229,30 @@ cdef class _TreeMesh:
12291229
refined to it's maximum level.
12301230
12311231
"""
1232-
cdef int max_level = self.max_level
12331232
if diagonal_balance is None:
12341233
diagonal_balance = self._diagonal_balance
12351234
cdef bool diag_balance = diagonal_balance
12361235

1237-
image = self._require_ndarray_with_dim('image', image, ndim=self.dim, dtype=np.float64)
1236+
image = np.require(image, dtype=np.float64, requirements="F")
1237+
cdef size_t n_expected = np.prod(self.shape_cells)
1238+
if image.size != n_expected:
1239+
raise ValueError(
1240+
f"image array size: {image.size} must match the total number of cells in the base tensor mesh: {n_expected}"
1241+
)
1242+
if image.ndim == 1:
1243+
image = image.reshape(self.shape_cells, order="F")
1244+
12381245
if image.shape != self.shape_cells:
12391246
raise ValueError(
12401247
f"image array shape: {image.shape} must match the base cell shapes: {self.shape_cells}"
12411248
)
12421249
if self.dim == 2:
12431250
image = image[..., None]
12441251

1245-
cdef double[:,:,::1] image_dat = image
1246-
self.tree.refine_image(&image[0, 0, 0], diagonal_balance)
1252+
cdef double[::1,:,:] image_dat = image
1253+
1254+
with self._tree_modify_lock:
1255+
self.tree.refine_image(&image_dat[0, 0, 0], diag_balance)
12471256
if finalize:
12481257
self.finalize()
12491258

tests/tree/test_refine.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import re
12
import discretize
23
import numpy as np
4+
import numpy.testing as npt
35
import pytest
46
from discretize.tests import assert_cell_intersects_geometric
57

@@ -490,3 +492,149 @@ def test_refine_plane3D():
490492
mesh2.refine_triangle(tris, -1)
491493

492494
assert mesh1.equals(mesh2)
495+
496+
497+
def _make_quadrant_model(mesh, order):
498+
shape_cells = mesh.shape_cells
499+
model = np.zeros(shape_cells, order="F" if order == "flat" else order)
500+
if mesh.dim == 2:
501+
model[: shape_cells[0] // 2, : shape_cells[1] // 2] = 1.0
502+
model[: shape_cells[0] // 4, : shape_cells[1] // 4] = 0.5
503+
else:
504+
model[: shape_cells[0] // 2, : shape_cells[1] // 2, : shape_cells[2] // 2] = 1.0
505+
model[: shape_cells[0] // 4, : shape_cells[1] // 4, : shape_cells[2] // 4] = 0.5
506+
if order == "flat":
507+
model = model.reshape(-1, order="F")
508+
return model
509+
510+
511+
@pytest.mark.parametrize(
512+
"tens_inp",
513+
[
514+
dict(h=[16, 16]),
515+
dict(h=[16, 32]),
516+
dict(h=[32, 16]),
517+
dict(h=[16, 16, 16]),
518+
dict(h=[16, 16, 8]),
519+
dict(h=[16, 8, 16]),
520+
dict(h=[8, 16, 16]),
521+
dict(h=[8, 8, 16]),
522+
dict(h=[8, 16, 8]),
523+
dict(h=[16, 8, 8]),
524+
],
525+
ids=[
526+
"16x16",
527+
"16x32",
528+
"32x16",
529+
"16x16x16",
530+
"16x16x8",
531+
"16x8x16",
532+
"8x16x16",
533+
"8x8x16",
534+
"8x16x8",
535+
"16x8x8",
536+
],
537+
)
538+
def test_refine_image_input_ordering(tens_inp):
539+
base_mesh = discretize.TensorMesh(**tens_inp)
540+
model_0 = _make_quadrant_model(base_mesh, order="flat")
541+
model_1 = _make_quadrant_model(base_mesh, order="C")
542+
model_2 = _make_quadrant_model(base_mesh, order="F")
543+
544+
tree0 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
545+
tree0.refine_image(model_0)
546+
547+
tree1 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
548+
tree1.refine_image(model_1)
549+
550+
tree2 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
551+
tree2.refine_image(model_2)
552+
553+
assert tree0.n_cells == tree1.n_cells == tree2.n_cells
554+
555+
for cell0, cell1, cell2 in zip(tree0, tree1, tree2):
556+
assert cell0.nodes == cell1.nodes == cell2.nodes
557+
558+
559+
@pytest.mark.parametrize(
560+
"tens_inp",
561+
[
562+
dict(h=[16, 16]),
563+
dict(h=[16, 32]),
564+
dict(h=[32, 16]),
565+
dict(h=[16, 16, 16]),
566+
dict(h=[16, 16, 8]),
567+
dict(h=[16, 8, 16]),
568+
dict(h=[8, 16, 16]),
569+
dict(h=[8, 8, 16]),
570+
dict(h=[8, 16, 8]),
571+
dict(h=[16, 8, 8]),
572+
],
573+
ids=[
574+
"16x16",
575+
"16x32",
576+
"32x16",
577+
"16x16x16",
578+
"16x16x8",
579+
"16x8x16",
580+
"8x16x16",
581+
"8x8x16",
582+
"8x16x8",
583+
"16x8x8",
584+
],
585+
)
586+
@pytest.mark.parametrize(
587+
"model_func",
588+
[
589+
lambda mesh: np.zeros(mesh.n_cells),
590+
lambda mesh: np.arange(mesh.n_cells, dtype=float),
591+
lambda mesh: _make_quadrant_model(mesh, order="flat"),
592+
],
593+
ids=["constant", "full", "quadrant"],
594+
)
595+
def test_refine_image(tens_inp, model_func):
596+
base_mesh = discretize.TensorMesh(**tens_inp)
597+
model = model_func(base_mesh)
598+
mesh = discretize.TreeMesh(base_mesh.h, base_mesh.origin, diagonal_balance=False)
599+
mesh.refine_image(model)
600+
601+
# for every cell in the tree mesh, all aligned cells in the tensor mesh
602+
# should have a single unique value.
603+
# quickest way is to generate a volume interp operator and look at indices in the
604+
# csr matrix
605+
interp_mat = discretize.utils.volume_average(base_mesh, mesh)
606+
607+
# ensure in canonical form:
608+
interp_mat.sum_duplicates()
609+
interp_mat.sort_indices()
610+
assert interp_mat.has_canonical_format
611+
612+
model = model.reshape(-1, order="F")
613+
for row in interp_mat:
614+
vals = model[row.indices]
615+
npt.assert_equal(vals, vals[0])
616+
617+
618+
def test_refine_image_bad_size():
619+
mesh = discretize.TreeMesh([32, 32])
620+
model = np.zeros(32 * 32 + 1)
621+
base_cells = np.prod(mesh.shape_cells)
622+
with pytest.raises(
623+
ValueError,
624+
match=re.escape(
625+
f"image array size: {len(model)} must match the total number of cells in the base tensor mesh: {base_cells}"
626+
),
627+
):
628+
mesh.refine_image(model)
629+
630+
631+
def test_refine_image_bad_shape():
632+
mesh = discretize.TreeMesh([32, 32])
633+
model = np.zeros((16, 64))
634+
with pytest.raises(
635+
ValueError,
636+
match=re.escape(
637+
f"image array shape: {model.shape} must match the base cell shapes: {mesh.shape_cells}"
638+
),
639+
):
640+
mesh.refine_image(model)

0 commit comments

Comments
 (0)