Skip to content
Merged
89 changes: 89 additions & 0 deletions tests/test_coreg/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,95 @@ def test_fit_and_apply__pipeline(self) -> None:
for k in coreg_fit_and_apply.pipeline[1].meta.keys()
)

@pytest.mark.parametrize(
"combination",
[
("raster", "raster", False, "raster", "passes", ""),
("raster", "raster", False, "array", "error", "Input mask array"),
("raster", "raster", True, "raster", "passes", ""),
("array", "raster", True, "raster", "passes", ""),
("raster", "array", True, "raster", "passes", ""),
("array", "array", True, "raster", "passes", ""),
("pc", "raster", False, "raster", "passes", ""),
("raster", "pc", False, "raster", "passes", ""),
("pc", "array", True, "array", "error", "Input mask array"),
("array", "pc", True, "array", "error", "Input mask array"),
],
) # type: ignore
def test_fit_and_apply__cropped_mask(self, combination: tuple[str, str, str, str, str, str]) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing series of tests, thanks a lot 😉

"""
Assert that the same mask, no matter its projection, gives the same results after a fit_and_apply (by shift
output values). NuthKaab has been chosen if this case but the method doesn't change anything.

The 'combination' param contains this in order:
1. The ref_type : raster, array or pc for pointclouds
2. The tba_type : raster, array or pc for pointclouds
3. If the fit_and_apply needs ref_dem.transform and ref_dem.crs
4. The mask_type : raster or array
6. The expected outcome of the test
7. The error message (if applicable)
"""

ref_type, tba_type, info, mask_type, result, text = combination

# Init data
ref_dem, tba_dem, mask = load_examples()
inlier_mask = ~mask.create_mask(ref_dem)

# Load dem_ref info if needed
transform = None
crs = None
if info:
transform = ref_dem.transform
crs = ref_dem.crs

# Crop mask
nrows, ncols = inlier_mask.shape
inlier_mask_crop = inlier_mask.icrop((0, 0, ncols - 10, nrows - 10))

# And reprojected the cropped mask to have the same size as before
inlier_mask_crop_proj = inlier_mask_crop.reproject(ref_dem, resampling=rio.warp.Resampling.nearest)

# Evaluate the type of the inputs
if ref_type == "array":
ref_dem = ref_dem.data
elif ref_type == "pc":
ref_dem = ref_dem.to_pointcloud().ds
ref_dem.rename(columns={"b1": "z"}, inplace=True)
if tba_type == "array":
tba_dem = tba_dem.data
elif tba_type == "pc":
tba_dem = tba_dem.to_pointcloud().ds
tba_dem.rename(columns={"b1": "z"}, inplace=True)
if mask_type == "array":
inlier_mask_crop = inlier_mask_crop.data

list_shift = ["shift_x", "shift_y", "shift_z"]
warnings.filterwarnings("ignore") # to do the process until the end

# Use VerticalShift as a representative example.
nuthkaab_ref = xdem.coreg.NuthKaab()
nuthkaab_ref.fit_and_apply(
ref_dem, tba_dem, inlier_mask=inlier_mask_crop_proj, transform=transform, crs=crs, random_state=42
)
shifts_ref = [nuthkaab_ref.meta["outputs"]["affine"][k] for k in list_shift] # type: ignore

nuthkaab_crop = xdem.coreg.NuthKaab()
if result == "error":
with pytest.raises(ValueError, match=re.escape(text)):
nuthkaab_crop.fit_and_apply(
ref_dem, tba_dem, inlier_mask=inlier_mask_crop, transform=transform, crs=crs, random_state=42
)
return
else:
nuthkaab_crop.fit_and_apply(
ref_dem, tba_dem, inlier_mask=inlier_mask_crop, transform=transform, crs=crs, random_state=42
)
shifts_crop = [nuthkaab_crop.meta["outputs"]["affine"][k] for k in list_shift] # type: ignore

# Check the output shifts match
assert shifts_ref == pytest.approx(shifts_crop)

@pytest.mark.parametrize(
"combination",
[
Expand Down
32 changes: 29 additions & 3 deletions xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
except ImportError:
_HAS_P3D = False


# Map each workflow name to a function and optimizer
fit_workflows = {
"norder_polynomial": {"func": polynomial_1d, "optimizer": robust_norder_polynomial_fit},
Expand Down Expand Up @@ -121,6 +120,8 @@
"icp_picky": "Picky closest pair selection",
"cpd_weight": "Weight of CPD outlier removal",
}


#####################################
# Generic functions for preprocessing
###########################################
Expand All @@ -143,6 +144,19 @@ def _preprocess_coreg_fit_raster_raster(
f"'reference_dem': {reference_dem}, 'dem_to_be_aligned': {dem_to_be_aligned}"
)

if inlier_mask is not None:
# If inlier_mask has not the same shape of the input dem, reproject it
if reference_dem.shape != inlier_mask.shape:
if isinstance(inlier_mask, gu.Raster):
if isinstance(reference_dem, gu.Raster):
inlier_mask = inlier_mask.reproject(reference_dem, resampling=rio.warp.Resampling.nearest)
else:
ref = Raster.from_array(data=reference_dem, transform=transform, crs=crs)
inlier_mask = inlier_mask.reproject(ref, resampling=rio.warp.Resampling.nearest)
# in case of mask is a array
else:
raise ValueError("Input mask array can't be a different size array as input elevation.")

# If both DEMs are Rasters, validate that 'dem_to_be_aligned' is in the right grid. Then extract its data.
if isinstance(dem_to_be_aligned, gu.Raster) and isinstance(reference_dem, gu.Raster):
dem_to_be_aligned = dem_to_be_aligned.reproject(reference_dem, silent=True)
Expand Down Expand Up @@ -241,6 +255,20 @@ def _preprocess_coreg_fit_raster_point(
) -> tuple[NDArrayf, gpd.GeoDataFrame, NDArrayb, affine.Affine, rio.crs.CRS, Literal["Area", "Point"] | None]:
"""Pre-processing and checks of fit for raster-point input."""

if inlier_mask is not None:
# If inlier_mask has not the same shape of the input dem, reproject it
if raster_elev.shape != inlier_mask.shape:
if isinstance(inlier_mask, gu.Raster):
if isinstance(raster_elev, gu.Raster):
inlier_mask = inlier_mask.reproject(raster_elev, resampling=rio.warp.Resampling.nearest)
else:
raster_rst = Raster.from_array(data=raster_elev, transform=transform, crs=crs)
inlier_mask = inlier_mask.reproject(raster_rst, resampling=rio.warp.Resampling.nearest)

# If inlier_mask is an array, it is not possible to reproject it
else:
raise ValueError("Input mask array can't be a different size array as input elevation.")

# TODO: Convert to point cloud once class is done
# TODO: Raise warnings consistently with raster-raster function, see Amelie's Dask PR? #525
if isinstance(raster_elev, gu.Raster):
Expand Down Expand Up @@ -1704,7 +1732,6 @@ class OutAffineDict(TypedDict, total=False):


class InputCoregDict(TypedDict, total=False):

random: InRandomDict
fitorbin: InFitOrBinDict
iterative: InIterativeDict
Expand Down Expand Up @@ -2695,7 +2722,6 @@ def info(self, as_str: bool = False) -> None | str:
# Get the pipeline information for each step as a string
final_str = []
for i, step in enumerate(self.pipeline):

final_str.append(f"Pipeline step {i}:\n" f"################\n")
step_str = step.info(as_str=True)
final_str.append(step_str)
Expand Down
Loading