Skip to content

Commit e6058f1

Browse files
authored
Merge pull request #406 from jcapriot/refine_image
Add functionality to refine a `TreeMesh` using an "image"
2 parents cfc9f2e + 0e7dde7 commit e6058f1

File tree

5 files changed

+252
-0
lines changed

5 files changed

+252
-0
lines changed

discretize/_extensions/tree.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,44 @@ void Cell::refine_func(node_map_t& nodes, function test_func, double *xs, double
418418
}
419419
}
420420

421+
void Cell::refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diag_balance){
422+
// early exit if my level is higher than or equal to target
423+
if (level == max_level){
424+
return;
425+
}
426+
int_t start_ix = points[0]->location_ind[0]/2;
427+
int_t start_iy = points[0]->location_ind[1]/2;
428+
int_t start_iz = n_dim == 2 ? 0 : points[0]->location_ind[2]/2;
429+
int_t end_ix = points[3]->location_ind[0]/2;
430+
int_t end_iy = points[3]->location_ind[1]/2;
431+
int_t end_iz = n_dim == 2? 1 : points[7]->location_ind[2]/2;
432+
int_t nx = shape_cells[0];
433+
int_t ny = shape_cells[1];
434+
int_t nz = shape_cells[2];
435+
436+
int_t i_image = (nx * ny) * start_iz + nx * start_iy + start_ix;
437+
double val_start = image[i_image];
438+
bool all_unique = true;
439+
440+
// if any of the image data contained in the cell are different, subdivide myself
441+
for(int_t iz=start_iz; iz<end_iz && all_unique; ++iz)
442+
for(int_t iy=start_iy; iy<end_iy && all_unique; ++iy)
443+
for(int_t ix=start_ix; ix<end_ix && all_unique; ++ix){
444+
i_image = (nx * ny) * iz + nx * iy + ix;
445+
all_unique = image[i_image] == val_start;
446+
}
447+
448+
if(!all_unique){
449+
if(is_leaf()){
450+
divide(nodes, xs, ys, zs, true, diag_balance);
451+
}
452+
// recurse into children
453+
for(int_t i = 0; i < (1<<n_dim); ++i){
454+
children[i]->refine_image(nodes, image, shape_cells, xs, ys, zs, diag_balance);
455+
}
456+
}
457+
}
458+
421459
void Cell::divide(node_map_t& nodes, double* xs, double* ys, double* zs, bool balance, bool diag_balance){
422460
// Gaurd against dividing a cell that is already at the max level
423461
if (level == max_level){
@@ -896,6 +934,18 @@ void Tree::refine_function(function test_func, bool diagonal_balance){
896934
roots[iz][iy][ix]->refine_func(nodes, test_func, xs, ys, zs, diagonal_balance);
897935
};
898936

937+
void Tree::refine_image(double *image, bool diagonal_balance){
938+
int_t shape_cells[3];
939+
shape_cells[0] = nx/2;
940+
shape_cells[1] = ny/2;
941+
shape_cells[2] = nz/2;
942+
for(int_t iz=0; iz<nz_roots; ++iz)
943+
for(int_t iy=0; iy<ny_roots; ++iy)
944+
for(int_t ix=0; ix<nx_roots; ++ix)
945+
roots[iz][iy][ix]->refine_image(nodes, image, shape_cells, xs, ys, zs, diagonal_balance);
946+
}
947+
948+
899949
void Tree::finalize_lists(){
900950
for(int_t iz=0; iz<nz_roots; ++iz)
901951
for(int_t iy=0; iy<ny_roots; ++iy)

discretize/_extensions/tree.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class Cell{
136136
void insert_cell(node_map_t &nodes, double *new_center, int_t p_level, double* xs, double *ys, double *zs, bool diag_balance=false);
137137

138138
void refine_func(node_map_t& nodes, function test_func, double *xs, double *ys, double* zs, bool diag_balance=false);
139+
void refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diagonal_balance=false);
139140

140141
inline bool is_leaf(){ return children[0]==NULL;};
141142
void spawn(node_map_t& nodes, Cell *kids[8], double* xs, double *ys, double *zs);
@@ -217,6 +218,8 @@ class Tree{
217218

218219
void refine_function(function test_func, bool diagonal_balance=false);
219220

221+
void refine_image(double* image, bool diagonal_balance=false);
222+
220223
template <class T>
221224
void refine_geom(const T& geom, int_t p_level, bool diagonal_balance=false){
222225
for(int_t iz=0; iz<nz_roots; ++iz)

discretize/_extensions/tree.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ cdef extern from "tree.h":
9494

9595
void refine_geom[T](const T&, int_t, bool)
9696

97+
void refine_image(double*, bool)
98+
9799
void number()
98100
void initialize_roots()
99101
void insert_cell(double *new_center, int_t p_level, bool)

discretize/_extensions/tree_ext.pyx

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,55 @@ cdef class _TreeMesh:
12141214
if finalize:
12151215
self.finalize()
12161216

1217+
def refine_image(self, image, finalize=True, diagonal_balance=None):
1218+
"""Refine using an ND image, ensuring that each cell contains exactly one unique value.
1219+
1220+
This function takes an N-dimensional image, defined on the underlying fine tensor mesh,
1221+
and recursively subdivides each cell if that cell contains more than 1 unique value in the
1222+
image. This is useful when using the `TreeMesh` to represent an exact compressed form of an input
1223+
model.
1224+
1225+
Parameters
1226+
----------
1227+
image : (shape_cells) numpy.ndarray
1228+
Must have the same shape as the base tensor mesh (`TreeMesh.shape_cells`), as if every cell on this mesh was
1229+
refined to it's maximum level.
1230+
finalize : bool, optional
1231+
Whether to finalize after inserting point(s)
1232+
diagonal_balance : bool or None, optional
1233+
Whether to balance cells diagonally in the refinement, `None` implies using
1234+
the same setting used to instantiate the `TreeMesh`.
1235+
1236+
"""
1237+
if diagonal_balance is None:
1238+
diagonal_balance = self._diagonal_balance
1239+
cdef bool diag_balance = diagonal_balance
1240+
1241+
image = np.require(image, dtype=np.float64, requirements="F")
1242+
cdef size_t n_expected = np.prod(self.shape_cells)
1243+
if image.size != n_expected:
1244+
raise ValueError(
1245+
f"image array size: {image.size} must match the total number of cells in the base tensor mesh: {n_expected}"
1246+
)
1247+
if image.ndim == 1:
1248+
image = image.reshape(self.shape_cells, order="F")
1249+
1250+
if image.shape != self.shape_cells:
1251+
raise ValueError(
1252+
f"image array shape: {image.shape} must match the base cell shapes: {self.shape_cells}"
1253+
)
1254+
if self.dim == 2:
1255+
image = image[..., None]
1256+
1257+
cdef double[::1,:,:] image_dat = image
1258+
1259+
with self._tree_modify_lock:
1260+
self.tree.refine_image(&image_dat[0, 0, 0], diag_balance)
1261+
if finalize:
1262+
self.finalize()
1263+
1264+
1265+
12171266
def finalize(self):
12181267
"""Finalize the :class:`~discretize.TreeMesh`.
12191268

tests/tree/test_refine.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import re
12
import discretize
23
import numpy as np
4+
import numpy.testing as npt
35
import pytest
46
from discretize.tests import assert_cell_intersects_geometric
57

@@ -490,3 +492,149 @@ def test_refine_plane3D():
490492
mesh2.refine_triangle(tris, -1)
491493

492494
assert mesh1.equals(mesh2)
495+
496+
497+
def _make_quadrant_model(mesh, order):
498+
shape_cells = mesh.shape_cells
499+
model = np.zeros(shape_cells, order="F" if order == "flat" else order)
500+
if mesh.dim == 2:
501+
model[: shape_cells[0] // 2, : shape_cells[1] // 2] = 1.0
502+
model[: shape_cells[0] // 4, : shape_cells[1] // 4] = 0.5
503+
else:
504+
model[: shape_cells[0] // 2, : shape_cells[1] // 2, : shape_cells[2] // 2] = 1.0
505+
model[: shape_cells[0] // 4, : shape_cells[1] // 4, : shape_cells[2] // 4] = 0.5
506+
if order == "flat":
507+
model = model.reshape(-1, order="F")
508+
return model
509+
510+
511+
@pytest.mark.parametrize(
512+
"tens_inp",
513+
[
514+
dict(h=[16, 16]),
515+
dict(h=[16, 32]),
516+
dict(h=[32, 16]),
517+
dict(h=[16, 16, 16]),
518+
dict(h=[16, 16, 8]),
519+
dict(h=[16, 8, 16]),
520+
dict(h=[8, 16, 16]),
521+
dict(h=[8, 8, 16]),
522+
dict(h=[8, 16, 8]),
523+
dict(h=[16, 8, 8]),
524+
],
525+
ids=[
526+
"16x16",
527+
"16x32",
528+
"32x16",
529+
"16x16x16",
530+
"16x16x8",
531+
"16x8x16",
532+
"8x16x16",
533+
"8x8x16",
534+
"8x16x8",
535+
"16x8x8",
536+
],
537+
)
538+
def test_refine_image_input_ordering(tens_inp):
539+
base_mesh = discretize.TensorMesh(**tens_inp)
540+
model_0 = _make_quadrant_model(base_mesh, order="flat")
541+
model_1 = _make_quadrant_model(base_mesh, order="C")
542+
model_2 = _make_quadrant_model(base_mesh, order="F")
543+
544+
tree0 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
545+
tree0.refine_image(model_0)
546+
547+
tree1 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
548+
tree1.refine_image(model_1)
549+
550+
tree2 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
551+
tree2.refine_image(model_2)
552+
553+
assert tree0.n_cells == tree1.n_cells == tree2.n_cells
554+
555+
for cell0, cell1, cell2 in zip(tree0, tree1, tree2):
556+
assert cell0.nodes == cell1.nodes == cell2.nodes
557+
558+
559+
@pytest.mark.parametrize(
560+
"tens_inp",
561+
[
562+
dict(h=[16, 16]),
563+
dict(h=[16, 32]),
564+
dict(h=[32, 16]),
565+
dict(h=[16, 16, 16]),
566+
dict(h=[16, 16, 8]),
567+
dict(h=[16, 8, 16]),
568+
dict(h=[8, 16, 16]),
569+
dict(h=[8, 8, 16]),
570+
dict(h=[8, 16, 8]),
571+
dict(h=[16, 8, 8]),
572+
],
573+
ids=[
574+
"16x16",
575+
"16x32",
576+
"32x16",
577+
"16x16x16",
578+
"16x16x8",
579+
"16x8x16",
580+
"8x16x16",
581+
"8x8x16",
582+
"8x16x8",
583+
"16x8x8",
584+
],
585+
)
586+
@pytest.mark.parametrize(
587+
"model_func",
588+
[
589+
lambda mesh: np.zeros(mesh.n_cells),
590+
lambda mesh: np.arange(mesh.n_cells, dtype=float),
591+
lambda mesh: _make_quadrant_model(mesh, order="flat"),
592+
],
593+
ids=["constant", "full", "quadrant"],
594+
)
595+
def test_refine_image(tens_inp, model_func):
596+
base_mesh = discretize.TensorMesh(**tens_inp)
597+
model = model_func(base_mesh)
598+
mesh = discretize.TreeMesh(base_mesh.h, base_mesh.origin, diagonal_balance=False)
599+
mesh.refine_image(model)
600+
601+
# for every cell in the tree mesh, all aligned cells in the tensor mesh
602+
# should have a single unique value.
603+
# quickest way is to generate a volume interp operator and look at indices in the
604+
# csr matrix
605+
interp_mat = discretize.utils.volume_average(base_mesh, mesh)
606+
607+
# ensure in canonical form:
608+
interp_mat.sum_duplicates()
609+
interp_mat.sort_indices()
610+
assert interp_mat.has_canonical_format
611+
612+
model = model.reshape(-1, order="F")
613+
for row in interp_mat:
614+
vals = model[row.indices]
615+
npt.assert_equal(vals, vals[0])
616+
617+
618+
def test_refine_image_bad_size():
619+
mesh = discretize.TreeMesh([32, 32])
620+
model = np.zeros(32 * 32 + 1)
621+
base_cells = np.prod(mesh.shape_cells)
622+
with pytest.raises(
623+
ValueError,
624+
match=re.escape(
625+
f"image array size: {len(model)} must match the total number of cells in the base tensor mesh: {base_cells}"
626+
),
627+
):
628+
mesh.refine_image(model)
629+
630+
631+
def test_refine_image_bad_shape():
632+
mesh = discretize.TreeMesh([32, 32])
633+
model = np.zeros((16, 64))
634+
with pytest.raises(
635+
ValueError,
636+
match=re.escape(
637+
f"image array shape: {model.shape} must match the base cell shapes: {mesh.shape_cells}"
638+
),
639+
):
640+
mesh.refine_image(model)

0 commit comments

Comments
 (0)