Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/coregister_arcticdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main():
"--coreg_steps",
type=str,
nargs="+",
choices=["VerticalShift", "ICP", "NuthKaab", "AffineCoreg", "DhMinimize"],
choices=["VerticalShift", "ICP", "NuthKaab", "AffineCoreg", "DhMinimize", "Deramp"],
default=None,
help="Coregistration steps (default: VerticalShift ICP NuthKaab)",
)
Expand Down
2 changes: 1 addition & 1 deletion scripts/fetch_and_coregister.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main():

parser.add_argument("--date_range", type=str, nargs=2, metavar=("START_DATE", "END_DATE"))

parser.add_argument("--min_valid_fraction", type=float, default=0.0)
parser.add_argument("--min_valid_fraction", type=float, default=0.5)

parser.add_argument("--intersection_threshold", type=float, default=0.8)

Expand Down
268 changes: 209 additions & 59 deletions src/coregister.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
gdal.UseExceptions()


def _calculate_nmad(dem1_values, dem2_values):
"""
Calculate NMAD between two DEM arrays.

Parameters
----------
dem1_values : np.ndarray
First DEM array
dem2_values : np.ndarray
Second DEM array

Returns
-------
float
NMAD value, or np.nan if insufficient valid pixels
"""
dh = dem2_values - dem1_values
valid_mask = ~np.isnan(dh)

if np.count_nonzero(valid_mask) < 100:
return np.nan

return xdem.spatialstats.nmad(dh[valid_mask])


def build_coreg_pipeline(steps):
"""
Build xdem coregistration pipeline from list of step names.
Expand All @@ -41,6 +66,7 @@ def build_coreg_pipeline(steps):
"NuthKaab": xdem.coreg.NuthKaab,
"AffineCoreg": xdem.coreg.AffineCoreg,
"DhMinimize": xdem.coreg.DhMinimize,
"Deramp": xdem.coreg.Deramp,
}

if not steps:
Expand All @@ -53,10 +79,13 @@ def build_coreg_pipeline(steps):
return pipeline


def select_reference_dem(dem_stack, ref_index=None, ref_date=None):
def select_reference_dem(dem_stack, ref_index=None, ref_date=None, subsample=10):
"""
Select reference DEM from stack.

If no ref_index or ref_date provided, selects DEM with lowest median NMAD
to all other DEMs in the stack.

Parameters
----------
dem_stack : xr.DataArray
Expand All @@ -65,6 +94,8 @@ def select_reference_dem(dem_stack, ref_index=None, ref_date=None):
Index of reference DEM
ref_date : str, optional
Date of reference DEM (YYYY-MM-DD)
subsample : int
Subsampling factor for NMAD calculation (default: 10)

Returns
-------
Expand All @@ -76,32 +107,140 @@ def select_reference_dem(dem_stack, ref_index=None, ref_date=None):

if ref_date is not None:
ref_timestamp = pd.Timestamp(ref_date)
time_diffs = [
abs((pd.Timestamp(t.values) - ref_timestamp).total_seconds()) for t in dem_stack.time
]
time_diffs = np.abs(
[(pd.Timestamp(t.values) - ref_timestamp).total_seconds() for t in dem_stack.time]
)
ref_index = int(np.argmin(time_diffs))
return ref_index, dem_stack.isel(time=ref_index)

valid_counts = []
for i in range(len(dem_stack.time)):
da = dem_stack.isel(time=i)
valid_count = (~np.isnan(da.values)).sum()
valid_counts.append(valid_count)
# Automatic selection based on lowest median NMAD
n_dems = len(dem_stack.time)

max_valid = max(valid_counts)
max_indices = [i for i, v in enumerate(valid_counts) if v == max_valid]
if n_dems == 1:
return 0, dem_stack.isel(time=0)

if len(max_indices) == 1:
ref_index = max_indices[0]
else:
times = [pd.Timestamp(dem_stack.time.values[i]) for i in max_indices]
mean_time = pd.Timestamp(np.mean([t.value for t in times]))
time_diffs = [abs((t - mean_time).total_seconds()) for t in times]
ref_index = max_indices[np.argmin(time_diffs)]
# Subsample for efficiency
dem_arrays = dem_stack.values[:, ::subsample, ::subsample]

# Calculate all pairwise differences at once
# Shape: (n_dems, n_dems, height, width)
dh_matrix = dem_arrays[:, np.newaxis, :, :] - dem_arrays[np.newaxis, :, :, :]

# Calculate NMAD for each pair efficiently
median_nmads = np.full(n_dems, np.inf)

for i in range(n_dems):
# Get all differences for DEM i (excluding self-comparison)
dh_i = np.concatenate([dh_matrix[i, :i], dh_matrix[i, i + 1 :]], axis=0)

# Calculate valid pixels across all comparisons
valid_mask = ~np.isnan(dh_i)

# Calculate NMAD for each comparison
nmads_i = []
for j in range(n_dems - 1):
dh_j = dh_i[j]
valid_j = valid_mask[j]

if np.count_nonzero(valid_j) >= 100:
nmad_j = xdem.spatialstats.nmad(dh_j[valid_j])
nmads_i.append(nmad_j)

if nmads_i:
median_nmads[i] = np.median(nmads_i)

# Select DEM with lowest median NMAD
ref_index = int(np.argmin(median_nmads))

return ref_index, dem_stack.isel(time=ref_index)


def create_inlier_mask_from_dh(
dem_stack, ref_idx, slope_min=2, slope_max=30, nmad_threshold=1.0, logger=None
):
"""
Create inlier mask based on elevation differences between first and last DEM.
Falls back to slope-based mask if insufficient valid pixels.

Parameters
----------
dem_stack : xr.DataArray
Stack of DEMs with time dimension
ref_idx : int
Index of reference DEM
slope_min : float
Minimum slope threshold for fallback mask (degrees)
slope_max : float
Maximum slope threshold for fallback mask (degrees)
nmad_threshold : float
NMAD threshold multiplier for inlier selection
logger : logging.Logger, optional
Logger instance

Returns
-------
np.ndarray
Boolean inlier mask
"""
if logger is None:
logger = logging.getLogger(__name__)

ref_da = dem_stack.isel(time=ref_idx)
ref_dem = xdem.DEM.from_xarray(ref_da)

# Get first and last DEMs
first_dem = dem_stack.isel(time=0)
last_dem = dem_stack.isel(time=-1)

# Calculate NMAD using helper function
nmad = _calculate_nmad(first_dem.values, last_dem.values)

# Check if we have valid NMAD calculation
if np.isnan(nmad):
logger.warning(
"Insufficient valid pixels for NMAD-based inlier mask. "
f"Falling back to slope-based mask ({slope_min}-{slope_max} degrees)."
)
slope_raster = xdem.terrain.slope(ref_dem)
inlier_mask = (slope_raster.data.filled(np.nan) > slope_min) & (
slope_raster.data.filled(np.nan) < slope_max
)
inlier_mask &= ~np.isnan(ref_dem.data.filled(np.nan))
logger.info(f"Using slope-based inlier mask with {np.count_nonzero(inlier_mask)} pixels")
return inlier_mask

# Calculate elevation difference for mask creation
dh = last_dem.values - first_dem.values
valid_mask = ~np.isnan(dh)
valid_count = np.count_nonzero(valid_mask)

logger.info(f"Valid pixels for dh calculation: {valid_count}")
logger.info(f"Elevation difference NMAD: {nmad:.3f} m")
logger.info(f"Using NMAD threshold: {nmad_threshold} * NMAD = {nmad_threshold * nmad:.3f} m")

# Create NMAD-based inlier mask
inlier_mask = valid_mask & (np.abs(dh) < (nmad_threshold * nmad))
inlier_pixels = np.count_nonzero(inlier_mask)

logger.info(f"NMAD-based inlier mask contains {inlier_pixels} pixels")

# Check if NMAD mask has sufficient pixels
min_pixels_required = 100
if inlier_pixels < min_pixels_required:
logger.warning(
f"NMAD-based mask has too few pixels ({inlier_pixels} < {min_pixels_required}). "
f"Falling back to slope-based mask ({slope_min}-{slope_max} degrees)."
)
slope_raster = xdem.terrain.slope(ref_dem)
inlier_mask = (slope_raster.data.filled(np.nan) > slope_min) & (
slope_raster.data.filled(np.nan) < slope_max
)
inlier_mask &= ~np.isnan(ref_dem.data.filled(np.nan))
logger.info(f"Using slope-based inlier mask with {np.count_nonzero(inlier_mask)} pixels")

return inlier_mask


def coregister_arcticdem_stack(
input_dir,
output_dir,
Expand Down Expand Up @@ -130,9 +269,9 @@ def coregister_arcticdem_stack(
coreg_steps : list of str, optional
Coregistration steps, defaults to ['VerticalShift', 'ICP', 'NuthKaab']
slope_min : float
Minimum slope threshold for inlier mask (degrees)
Minimum slope threshold for fallback inlier mask (degrees)
slope_max : float
Maximum slope threshold for inlier mask (degrees)
Maximum slope threshold for fallback inlier mask (degrees)
resolution : float
Output resolution in meters
generate_hillshade : bool
Expand Down Expand Up @@ -195,19 +334,19 @@ def coregister_arcticdem_stack(
f"Selected reference DEM at index {ref_idx}, date {pd.to_datetime(ref_da.time.values)}"
)

# Create reference DEM object for coregistration
ref_dem = xdem.DEM.from_xarray(ref_da)

logger.info("Generating inlier mask from reference DEM")
slope_raster = xdem.terrain.slope(ref_dem)

inlier_mask_array = (slope_raster.data.filled(np.nan) > slope_min) & (
slope_raster.data.filled(np.nan) < slope_max
logger.info("Generating inlier mask from elevation differences")
inlier_mask_array = create_inlier_mask_from_dh(
dem_stack, ref_idx, slope_min, slope_max, nmad_threshold=1, logger=logger
)
inlier_mask_array &= ~np.isnan(ref_dem.data.filled(np.nan))

logger.info(f"Using {np.count_nonzero(inlier_mask_array)} pixels as stable ground")

if generate_slope_files:
logger.info("Generating slope raster for reference DEM")
slope_raster = xdem.terrain.slope(ref_dem)
slope_output = output_dir / "reference_slope.tif"
slope_da = xr.DataArray(
slope_raster.data,
Expand Down Expand Up @@ -301,39 +440,50 @@ def coregister_arcticdem_stack(
# Extract transformation information
coreg_file.write(f"DEM {i + 1}/{len(dem_stack.time)}: {datetime_str}\n")

# Get cumulative transformation from the final matrix
final_matrix = pipeline_i.to_matrix()
cumulative_shift_x = final_matrix[0, 3]
cumulative_shift_y = final_matrix[1, 3]
cumulative_shift_z = final_matrix[2, 3]

# Write cumulative transformation (summary)
coreg_file.write(" Cumulative Transformation:\n")
coreg_file.write(" Shifts (m):\n")
coreg_file.write(f" Easting (X): {cumulative_shift_x:>10.3f}\n")
coreg_file.write(f" Northing (Y): {cumulative_shift_y:>10.3f}\n")
coreg_file.write(f" Vertical (Z): {cumulative_shift_z:>10.3f}\n")
coreg_file.write(" Transformation Matrix:\n")
for row in final_matrix:
coreg_file.write(f" {row}\n")

# Extract rotations from final matrix
try:
rotations = xdem.coreg.AffineCoreg.to_rotations(final_matrix)
coreg_file.write(" Rotations (degrees):\n")
coreg_file.write(f" X-axis: {np.degrees(rotations[0]):>10.6f}\n")
coreg_file.write(f" Y-axis: {np.degrees(rotations[1]):>10.6f}\n")
coreg_file.write(f" Z-axis: {np.degrees(rotations[2]):>10.6f}\n")
except Exception:
pass

coreg_file.write("\n")

# Log to console
logger.info(
f" Cumulative shifts - X: {cumulative_shift_x:.3f} m, "
f"Y: {cumulative_shift_y:.3f} m, Z: {cumulative_shift_z:.3f} m"
)
# Check if the pipeline is fully affine before trying to get a cumulative matrix
is_affine_pipeline = all(step._is_affine for step in pipeline_i.pipeline)

if is_affine_pipeline:
# Get cumulative transformation from the final matrix
final_matrix = pipeline_i.to_matrix()
cumulative_shift_x = final_matrix[0, 3]
cumulative_shift_y = final_matrix[1, 3]
cumulative_shift_z = final_matrix[2, 3]

# Write cumulative transformation (summary)
coreg_file.write(" Cumulative Transformation:\n")
coreg_file.write(" Shifts (m):\n")
coreg_file.write(f" Easting (X): {cumulative_shift_x:>10.3f}\n")
coreg_file.write(f" Northing (Y): {cumulative_shift_y:>10.3f}\n")
coreg_file.write(f" Vertical (Z): {cumulative_shift_z:>10.3f}\n")
coreg_file.write(" Transformation Matrix:\n")
for row in final_matrix:
coreg_file.write(f" {row}\n")

# Extract rotations from final matrix
try:
rotations = xdem.coreg.AffineCoreg.to_rotations(final_matrix)
coreg_file.write(" Rotations (degrees):\n")
coreg_file.write(f" X-axis: {np.degrees(rotations[0]):>10.6f}\n")
coreg_file.write(f" Y-axis: {np.degrees(rotations[1]):>10.6f}\n")
coreg_file.write(f" Z-axis: {np.degrees(rotations[2]):>10.6f}\n")
except Exception:
pass

coreg_file.write("\n")

# Log to console
logger.info(
f" Cumulative shifts - X: {cumulative_shift_x:.3f} m, "
f"Y: {cumulative_shift_y:.3f} m, Z: {cumulative_shift_z:.3f} m"
)
else:
logger.info(
" Pipeline contains non-affine steps. Cumulative matrix not calculated."
)
coreg_file.write(
" Pipeline contains non-affine steps. Cumulative matrix not applicable.\n\n"
)

# Write detailed per-step transformations
coreg_file.write(" Individual Step Transformations:\n")
Expand Down