Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions discretize/_extensions/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,44 @@ void Cell::refine_func(node_map_t& nodes, function test_func, double *xs, double
}
}

void Cell::refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diag_balance){
// early exit if my level is higher than or equal to target
if (level == max_level){
return;
}
int_t start_ix = points[0]->location_ind[0]/2;
int_t start_iy = points[0]->location_ind[1]/2;
int_t start_iz = n_dim == 2 ? 0 : points[0]->location_ind[2]/2;
int_t end_ix = points[3]->location_ind[0]/2;
int_t end_iy = points[3]->location_ind[1]/2;
int_t end_iz = n_dim == 2? 1 : points[7]->location_ind[2]/2;
int_t nx = shape_cells[0];
int_t ny = shape_cells[1];
int_t nz = shape_cells[2];

int_t i_image = (nx * ny) * start_iz + nx * start_iy + start_ix;
double val_start = image[i_image];
bool all_unique = true;

// if any of the image data contained in the cell are different, subdivide myself
for(int_t iz=start_iz; iz<end_iz && all_unique; ++iz)
for(int_t iy=start_iy; iy<end_iy && all_unique; ++iy)
for(int_t ix=start_ix; ix<end_ix && all_unique; ++ix){
i_image = (nx * ny) * iz + nx * iy + ix;
all_unique = image[i_image] == val_start;
}

if(!all_unique){
if(is_leaf()){
divide(nodes, xs, ys, zs, true, diag_balance);
}
// recurse into children
for(int_t i = 0; i < (1<<n_dim); ++i){
children[i]->refine_image(nodes, image, shape_cells, xs, ys, zs, diag_balance);
}
}
}

void Cell::divide(node_map_t& nodes, double* xs, double* ys, double* zs, bool balance, bool diag_balance){
// Gaurd against dividing a cell that is already at the max level
if (level == max_level){
Expand Down Expand Up @@ -896,6 +934,18 @@ void Tree::refine_function(function test_func, bool diagonal_balance){
roots[iz][iy][ix]->refine_func(nodes, test_func, xs, ys, zs, diagonal_balance);
};

void Tree::refine_image(double *image, bool diagonal_balance){
int_t shape_cells[3];
shape_cells[0] = nx/2;
shape_cells[1] = ny/2;
shape_cells[2] = nz/2;
for(int_t iz=0; iz<nz_roots; ++iz)
for(int_t iy=0; iy<ny_roots; ++iy)
for(int_t ix=0; ix<nx_roots; ++ix)
roots[iz][iy][ix]->refine_image(nodes, image, shape_cells, xs, ys, zs, diagonal_balance);
}


void Tree::finalize_lists(){
for(int_t iz=0; iz<nz_roots; ++iz)
for(int_t iy=0; iy<ny_roots; ++iy)
Expand Down
3 changes: 3 additions & 0 deletions discretize/_extensions/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class Cell{
void insert_cell(node_map_t &nodes, double *new_center, int_t p_level, double* xs, double *ys, double *zs, bool diag_balance=false);

void refine_func(node_map_t& nodes, function test_func, double *xs, double *ys, double* zs, bool diag_balance=false);
void refine_image(node_map_t& nodes, double* image, int_t *shape_cells, double *xs, double*ys, double *zs, bool diagonal_balance=false);

inline bool is_leaf(){ return children[0]==NULL;};
void spawn(node_map_t& nodes, Cell *kids[8], double* xs, double *ys, double *zs);
Expand Down Expand Up @@ -217,6 +218,8 @@ class Tree{

void refine_function(function test_func, bool diagonal_balance=false);

void refine_image(double* image, bool diagonal_balance=false);

template <class T>
void refine_geom(const T& geom, int_t p_level, bool diagonal_balance=false){
for(int_t iz=0; iz<nz_roots; ++iz)
Expand Down
2 changes: 2 additions & 0 deletions discretize/_extensions/tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ cdef extern from "tree.h":

void refine_geom[T](const T&, int_t, bool)

void refine_image(double*, bool)

void number()
void initialize_roots()
void insert_cell(double *new_center, int_t p_level, bool)
Expand Down
49 changes: 49 additions & 0 deletions discretize/_extensions/tree_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,55 @@ cdef class _TreeMesh:
if finalize:
self.finalize()

def refine_image(self, image, finalize=True, diagonal_balance=None):
"""Refine using an ND image, ensuring that each cell contains exactly one unique value.

This function takes an N-dimensional image, defined on the underlying fine tensor mesh,
and recursively subdivides each cell if that cell contains more than 1 unique value in the
image. This is useful when using the `TreeMesh` to represent an exact compressed form of an input
model.

Parameters
----------
image : (shape_cells) numpy.ndarray
Must have the same shape as the base tensor mesh (`TreeMesh.shape_cells`), as if every cell on this mesh was
refined to it's maximum level.
finalize : bool, optional
Whether to finalize after inserting point(s)
diagonal_balance : bool or None, optional
Whether to balance cells diagonally in the refinement, `None` implies using
the same setting used to instantiate the `TreeMesh`.

"""
if diagonal_balance is None:
diagonal_balance = self._diagonal_balance
cdef bool diag_balance = diagonal_balance

image = np.require(image, dtype=np.float64, requirements="F")
cdef size_t n_expected = np.prod(self.shape_cells)
if image.size != n_expected:
raise ValueError(
f"image array size: {image.size} must match the total number of cells in the base tensor mesh: {n_expected}"
)
if image.ndim == 1:
image = image.reshape(self.shape_cells, order="F")

if image.shape != self.shape_cells:
raise ValueError(
f"image array shape: {image.shape} must match the base cell shapes: {self.shape_cells}"
)
if self.dim == 2:
image = image[..., None]

cdef double[::1,:,:] image_dat = image

with self._tree_modify_lock:
self.tree.refine_image(&image_dat[0, 0, 0], diag_balance)
if finalize:
self.finalize()



def finalize(self):
"""Finalize the :class:`~discretize.TreeMesh`.

Expand Down
148 changes: 148 additions & 0 deletions tests/tree/test_refine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import discretize
import numpy as np
import numpy.testing as npt
import pytest
from discretize.tests import assert_cell_intersects_geometric

Expand Down Expand Up @@ -490,3 +492,149 @@ def test_refine_plane3D():
mesh2.refine_triangle(tris, -1)

assert mesh1.equals(mesh2)


def _make_quadrant_model(mesh, order):
shape_cells = mesh.shape_cells
model = np.zeros(shape_cells, order="F" if order == "flat" else order)
if mesh.dim == 2:
model[: shape_cells[0] // 2, : shape_cells[1] // 2] = 1.0
model[: shape_cells[0] // 4, : shape_cells[1] // 4] = 0.5
else:
model[: shape_cells[0] // 2, : shape_cells[1] // 2, : shape_cells[2] // 2] = 1.0
model[: shape_cells[0] // 4, : shape_cells[1] // 4, : shape_cells[2] // 4] = 0.5
if order == "flat":
model = model.reshape(-1, order="F")
return model


@pytest.mark.parametrize(
"tens_inp",
[
dict(h=[16, 16]),
dict(h=[16, 32]),
dict(h=[32, 16]),
dict(h=[16, 16, 16]),
dict(h=[16, 16, 8]),
dict(h=[16, 8, 16]),
dict(h=[8, 16, 16]),
dict(h=[8, 8, 16]),
dict(h=[8, 16, 8]),
dict(h=[16, 8, 8]),
],
ids=[
"16x16",
"16x32",
"32x16",
"16x16x16",
"16x16x8",
"16x8x16",
"8x16x16",
"8x8x16",
"8x16x8",
"16x8x8",
],
)
def test_refine_image_input_ordering(tens_inp):
base_mesh = discretize.TensorMesh(**tens_inp)
model_0 = _make_quadrant_model(base_mesh, order="flat")
model_1 = _make_quadrant_model(base_mesh, order="C")
model_2 = _make_quadrant_model(base_mesh, order="F")

tree0 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
tree0.refine_image(model_0)

tree1 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
tree1.refine_image(model_1)

tree2 = discretize.TreeMesh(base_mesh.h, base_mesh.origin)
tree2.refine_image(model_2)

assert tree0.n_cells == tree1.n_cells == tree2.n_cells

for cell0, cell1, cell2 in zip(tree0, tree1, tree2):
assert cell0.nodes == cell1.nodes == cell2.nodes


@pytest.mark.parametrize(
"tens_inp",
[
dict(h=[16, 16]),
dict(h=[16, 32]),
dict(h=[32, 16]),
dict(h=[16, 16, 16]),
dict(h=[16, 16, 8]),
dict(h=[16, 8, 16]),
dict(h=[8, 16, 16]),
dict(h=[8, 8, 16]),
dict(h=[8, 16, 8]),
dict(h=[16, 8, 8]),
],
ids=[
"16x16",
"16x32",
"32x16",
"16x16x16",
"16x16x8",
"16x8x16",
"8x16x16",
"8x8x16",
"8x16x8",
"16x8x8",
],
)
@pytest.mark.parametrize(
"model_func",
[
lambda mesh: np.zeros(mesh.n_cells),
lambda mesh: np.arange(mesh.n_cells, dtype=float),
lambda mesh: _make_quadrant_model(mesh, order="flat"),
],
ids=["constant", "full", "quadrant"],
)
def test_refine_image(tens_inp, model_func):
base_mesh = discretize.TensorMesh(**tens_inp)
model = model_func(base_mesh)
mesh = discretize.TreeMesh(base_mesh.h, base_mesh.origin, diagonal_balance=False)
mesh.refine_image(model)

# for every cell in the tree mesh, all aligned cells in the tensor mesh
# should have a single unique value.
# quickest way is to generate a volume interp operator and look at indices in the
# csr matrix
interp_mat = discretize.utils.volume_average(base_mesh, mesh)

# ensure in canonical form:
interp_mat.sum_duplicates()
interp_mat.sort_indices()
assert interp_mat.has_canonical_format

model = model.reshape(-1, order="F")
for row in interp_mat:
vals = model[row.indices]
npt.assert_equal(vals, vals[0])


def test_refine_image_bad_size():
mesh = discretize.TreeMesh([32, 32])
model = np.zeros(32 * 32 + 1)
base_cells = np.prod(mesh.shape_cells)
with pytest.raises(
ValueError,
match=re.escape(
f"image array size: {len(model)} must match the total number of cells in the base tensor mesh: {base_cells}"
),
):
mesh.refine_image(model)


def test_refine_image_bad_shape():
mesh = discretize.TreeMesh([32, 32])
model = np.zeros((16, 64))
with pytest.raises(
ValueError,
match=re.escape(
f"image array shape: {model.shape} must match the base cell shapes: {mesh.shape_cells}"
),
):
mesh.refine_image(model)