diff --git a/pyproject.toml b/pyproject.toml index e768b9e..17193f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,8 @@ dependencies = [ "pyreadline3; sys_platform == 'win32'", "cellmap-flow@git+https://github.com/janelia-cellmap/cellmap-flow@1ece404", "pykdtree", + "fastremap", + "connected-components-3d", ] [project.optional-dependencies] diff --git a/src/cellmap_segmentation_challenge/cli/evaluate.py b/src/cellmap_segmentation_challenge/cli/evaluate.py index 5ed0ffb..2c879f6 100644 --- a/src/cellmap_segmentation_challenge/cli/evaluate.py +++ b/src/cellmap_segmentation_challenge/cli/evaluate.py @@ -1,8 +1,8 @@ import os import click -from cellmap_segmentation_challenge.evaluate import ( - INSTANCE_CLASSES, +from cellmap_segmentation_challenge.evaluate import INSTANCE_CLASSES +from cellmap_segmentation_challenge import ( PROCESSED_PATH, SUBMISSION_PATH, TRUTH_PATH, diff --git a/src/cellmap_segmentation_challenge/cli/fetch_data.py b/src/cellmap_segmentation_challenge/cli/fetch_data.py index 08633d3..e4fb52c 100644 --- a/src/cellmap_segmentation_challenge/cli/fetch_data.py +++ b/src/cellmap_segmentation_challenge/cli/fetch_data.py @@ -441,3 +441,4 @@ def fetch_data_cli( log.unbind("save_location") log.info(f"Done after {time.time() - fetch_save_start:0.3f}s") log.info(f"Data saved to {dest_path_abs}") + pool.shutdown(wait=True) diff --git a/src/cellmap_segmentation_challenge/evaluate.py b/src/cellmap_segmentation_challenge/evaluate.py index e8b4a8b..d3c1e0a 100644 --- a/src/cellmap_segmentation_challenge/evaluate.py +++ b/src/cellmap_segmentation_challenge/evaluate.py @@ -1,18 +1,17 @@ import argparse import json import os -from time import time, sleep +from time import time import zipfile import numpy as np import zarr from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import dice -from skimage.measure import label as relabel +from fastremap import remap, unique +import cc3d from skimage.transform import rescale -from pykdtree.kdtree import KDTree as cKDTree - from sklearn.metrics import accuracy_score, jaccard_score from tqdm import tqdm from upath import UPath @@ -21,7 +20,7 @@ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed -from .config import PROCESSED_PATH, SUBMISSION_PATH, TRUTH_PATH +from .config import SUBMISSION_PATH, TRUTH_PATH from .utils import TEST_CROPS_DICT import logging @@ -54,122 +53,236 @@ MAX_SEMANTIC_THREADS = int(os.getenv("MAX_SEMANTIC_THREADS", 20)) PER_INSTANCE_THREADS = int(os.getenv("PER_INSTANCE_THREADS", 16)) # submitted_# of instances / ground_truth_# of instances -INSTANCE_RATIO_CUTOFF = float(os.getenv("INSTANCE_RATIO_CUTOFF", 50)) PRECOMPUTE_LIMIT = int(os.getenv("PRECOMPUTE_LIMIT", 1e7)) -DEBUG = os.getenv("DEBUG", "False") != "False" +DEBUG = os.getenv("DEBUG", "False").lower() != "false" + + +def iou_matrix(gt: np.ndarray, pred: np.ndarray) -> np.ndarray | None: + """ + Compute IoU between all GT and Pred instance IDs. + Assumes IDs are sequential starting at 1 (0 is background). + Returns float32 array of shape (num_gt_ids, num_pred_ids). + """ + INSTANCE_RATIO_CUTOFF = float(os.getenv("INSTANCE_RATIO_CUTOFF", 50)) + + if gt.shape != pred.shape: + raise ValueError("gt and pred must have the same shape") + + # 1D views without copying if possible + g = np.ravel(gt) + p = np.ravel(pred) + + # Number of instances (sequential ids -> max id) + nG = int(g.max()) if g.size else 0 + nP = int(p.max()) if p.size else 0 + + # Early exits + if nG == 0 or nP == 0: + if nG == 0 and nP > 0: + logging.info("No GT instances; returning empty IoU with pred columns.") + if nP == 0 and nG > 0: + logging.info("No Pred instances; returning empty IoU with gt rows.") + return np.zeros((nG, nP), dtype=np.float32) + + if (nP / nG) > INSTANCE_RATIO_CUTOFF: + logging.warning( + f"WARNING: Skipping {nP} instances in submission, {nG} in ground truth, " + f"because there are too many instances in the submission." + ) + return None + + # Foreground (non-background) mask for each side and for pairwise overlaps + g_fg = g > 0 + p_fg = p > 0 + fg = g_fg & p_fg + + # ---- Per-object areas (sizes) ---- + # Use uint32 where possible to reduce memory; cast to int64 for safety if needed. + gt_sizes = np.bincount((g[g_fg].astype(np.int64) - 1), minlength=nG)[:, None] + pr_sizes = np.bincount((p[p_fg].astype(np.int64) - 1), minlength=nP)[None, :] + + # ---- Intersections for observed pairs only (sparse counting) ---- + gi = g[fg].astype(np.int64) - 1 + pj = p[fg].astype(np.int64) - 1 + if gi.size == 0: + # No overlaps anywhere -> IoU is all zeros + return np.zeros((nG, nP), dtype=np.float32) + + # Encode pairs to a single 64-bit key and count only present pairs + # Use unsigned to avoid negative-overflow corner cases. + gi_u = gi.astype(np.uint64) + pj_u = pj.astype(np.uint64) + key = gi_u * np.uint64(nP) + pj_u + uniq_keys, counts = np.unique(key, return_counts=True) + rows = (uniq_keys // np.uint64(nP)).astype(np.int64) + cols = (uniq_keys % np.uint64(nP)).astype(np.int64) -class spoof_precomputed: - def __init__(self, array, ids): - self.array = array - self.ids = ids - self.index = -1 + # ---- IoU only for observed pairs, then scatter into dense matrix ---- + # union_ij = gt_sizes[i] + pr_sizes[j] - inter_ij + inter_ij = counts.astype(np.int64) + union_ij = gt_sizes[rows, 0] + pr_sizes[0, cols] - inter_ij + with np.errstate(divide="ignore", invalid="ignore"): + iou_vals = (inter_ij / union_ij).astype(np.float32) - def __getitem__(self, ids): - if isinstance(ids, int): - return np.array(self.array == self.ids[ids], dtype=bool) - return np.array([self.array == self.ids[i] for i in ids], dtype=bool) + iou = np.zeros((nG, nP), dtype=np.float32) + iou[rows, cols] = iou_vals # all other entries remain 0 - def __len__(self): - return len(self.ids) + return iou def optimized_hausdorff_distances( truth_label, - matched_pred_label, + pred_label, voxel_size, hausdorff_distance_max, method="standard", + percentile: float | None = None, ): - # Get unique truth IDs, excluding the background (0) - truth_ids = np.unique(truth_label) - truth_ids = truth_ids[truth_ids != 0] # Exclude background - if len(truth_ids) == 0: - return [] - - def get_distance(i): - # Skip if both masks are empty - truth_mask = truth_label == truth_ids[i] - pred_mask = matched_pred_label == truth_ids[i] - if not np.any(truth_mask) and not np.any(pred_mask): - return 0 - - # Compute Hausdorff distance for the current pair + """ + Compute per-truth-instance Hausdorff-like distances against the (already remapped) + prediction using multithreading. Returns a 1D float32 numpy array whose i-th + entry corresponds to truth_ids[i]. + + Parameters + ---------- + truth_label : np.ndarray + Ground-truth instance label volume (0 == background). + pred_label : np.ndarray + Prediction instance label volume that has already been remapped to align + with the GT ids (0 == background). + voxel_size : Sequence[float] + Physical voxel sizes in Z, Y, X (or Y, X) order. + hausdorff_distance_max : float + Cap for distances (use np.inf for uncapped). + method : {"standard", "modified", "percentile"} + "standard" -> classic Hausdorff (max of directed maxima) + "modified" -> mean of directed distances, then max of the two means + "percentile" -> use the given percentile of directed distances (requires + `percentile` to be provided). + percentile : float | None + Percentile (0-100) used when method=="percentile". + """ + # Unique GT ids (exclude background = 0) + truth_ids = unique(truth_label) + truth_ids = truth_ids[truth_ids != 0] + true_num = int(truth_ids.size) + if true_num == 0: + return np.empty((0,), dtype=np.float32) + + voxel_size = np.asarray(voxel_size, dtype=np.float64) + + def get_distance(i: int): + tid = int(truth_ids[i]) + truth_mask = truth_label == tid + pred_mask = pred_label == tid + # Note: because tid comes from truth_label, truth_mask has at least one voxel + # Compute directed/undirected Hausdorff according to method h_dist = compute_hausdorff_distance( truth_mask, pred_mask, voxel_size, hausdorff_distance_max, - method, + method=method, + percentile=percentile, ) - return i, h_dist + return i, float(h_dist) + + dists = np.empty((true_num,), dtype=np.float32) - # Initialize list for distances - hausdorff_distances = np.empty(len(truth_ids)) if DEBUG: - # Use tqdm for progress tracking - bar = tqdm( - range(len(truth_ids)), + for i in tqdm( + range(true_num), desc="Computing Hausdorff distances", leave=True, dynamic_ncols=True, - total=len(truth_ids), - ) - # Compute the cost matrix - for i in bar: - i, h_dist = get_distance(i) - hausdorff_distances[i] = h_dist + total=true_num, + ): + idx, h = get_distance(i) + dists[idx] = h else: with ThreadPoolExecutor(max_workers=PER_INSTANCE_THREADS) as executor: - for i, h_dist in tqdm( - executor.map(get_distance, range(len(truth_ids))), + for idx, h in tqdm( + executor.map(get_distance, range(true_num)), desc="Computing Hausdorff distances", - total=len(truth_ids), + total=true_num, dynamic_ncols=True, ): - hausdorff_distances[i] = h_dist + dists[idx] = h - return hausdorff_distances + return dists -def compute_hausdorff_distance(image0, image1, voxel_size, max_distance, method): +def compute_hausdorff_distance( + image0, + image1, + voxel_size, + max_distance: float, + method: str = "standard", + percentile: float | None = None, +): """ - Compute the Hausdorff distance between two binary masks, optimized for pre-vectorized inputs. + Compute the (undirected) Hausdorff-like distance between two binary masks using + Euclidean distance transforms (EDT), which is generally faster and more memory + friendly than building KD-trees for large 3D volumes. + + Parameters + ---------- + image0, image1 : np.ndarray (bool or int) + Binary masks (True/1 = foreground). They should already be aligned to the + same voxel grid. + voxel_size : Sequence[float] + Physical voxel sizes in Z, Y, X (or Y, X) order; passed to EDT via + `sampling` to support anisotropy. + max_distance : float + Distances are clipped to this value (use np.inf for no clipping). + method : {"standard", "modified", "percentile"} + "standard": classic Hausdorff -> max(max(dist(A→B)), max(dist(B→A))) + "modified": robust mean -> max(mean(dist(A→B)), mean(dist(B→A))) + "percentile": robust percentile -> max(P(dist(A→B)), P(dist(B→A))) where P + is the given percentile. + percentile : float | None + Percentile in [0, 100]; required if method == "percentile". """ - # Extract nonzero points - a_points = np.argwhere(image0) - b_points = np.argwhere(image1) + from scipy.ndimage import distance_transform_edt + + a = np.asarray(image0, dtype=bool) + b = np.asarray(image1, dtype=bool) # Handle empty sets - if len(a_points) == 0 and len(b_points) == 0: - return 0 - elif len(a_points) == 0 or len(b_points) == 0: - return np.inf - - # Scale points by voxel size - a_points = a_points * np.array(voxel_size) - b_points = b_points * np.array(voxel_size) - - # Build KD-trees once - a_tree = cKDTree(a_points) - b_tree = cKDTree(b_points) - - # Query distances - # fwd = a_tree.query(b_points, k=1, distance_upper_bound=max_distance)[0] - # bwd = b_tree.query(a_points, k=1, distance_upper_bound=max_distance)[0] - fwd = a_tree.query(b_points, k=1)[0] - bwd = b_tree.query(a_points, k=1)[0] - - # Replace "inf" with `max_distance` for numerical stability - # fwd[fwd == np.inf] = max_distance - # bwd[bwd == np.inf] = max_distance - fwd[fwd > max_distance] = max_distance - bwd[bwd > max_distance] = max_distance + a_n = int(a.sum()) + b_n = int(b.sum()) + if a_n == 0 and b_n == 0: + return 0.0 + if a_n == 0 or b_n == 0: + return float(max_distance) + + voxel_size = np.asarray(voxel_size, dtype=np.float64) + + # Directed distances via EDT to the *other* set's foreground + # distance_transform_edt computes distance to nearest zero -> pass ~mask so that + # zeros are at foreground voxels of the other set. + dist_to_b = distance_transform_edt(~b, sampling=voxel_size) + dist_to_a = distance_transform_edt(~a, sampling=voxel_size) + + fwd = dist_to_b[a] + bwd = dist_to_a[b] if method == "standard": - return max(fwd.max(), bwd.max()) + d = max(fwd.max(initial=0.0), bwd.max(initial=0.0)) elif method == "modified": - return max(fwd.mean(), bwd.mean()) + d = max(fwd.mean() if fwd.size else 0.0, bwd.mean() if bwd.size else 0.0) + elif method == "percentile": + if percentile is None: + raise ValueError("'percentile' must be provided when method='percentile'") + d = max( + float(np.percentile(fwd, percentile)) if fwd.size else 0.0, + float(np.percentile(bwd, percentile)) if bwd.size else 0.0, + ) + else: + raise ValueError("method must be one of {'standard', 'modified', 'percentile'}") + + return float(min(d, max_distance)) def score_instance( @@ -196,114 +309,35 @@ def score_instance( logging.info("Scoring instance segmentation...") # Relabel the predicted instance labels to be consistent with the ground truth instance labels logging.info("Relabeling predicted instance labels...") - pred_label = relabel(pred_label, connectivity=len(pred_label.shape)) + pred_label = cc3d.connected_components(pred_label) - # Get unique IDs, excluding background (assumed to be 0) - truth_ids = np.unique(truth_label) - truth_ids = truth_ids[truth_ids != 0] - - pred_ids = np.unique(pred_label) - pred_ids = pred_ids[pred_ids != 0] + # Compute the IoU cost matrix between the predicted and ground truth instance labels + cost_matrix = iou_matrix(truth_label, pred_label) - # Skip if the submission has way too many instances - if len(truth_ids) > 0 and len(pred_ids) / len(truth_ids) > INSTANCE_RATIO_CUTOFF: - logging.warning( - f"WARNING: Skipping {len(pred_ids)} instances in submission, {len(truth_ids)} in ground truth, because there are too many instances in the submission." - ) + if cost_matrix is None: + # Too many instances in submission, skip scoring return { "accuracy": 0, "hausdorff_distance": np.inf, "normalized_hausdorff_distance": 0, "combined_score": 0, } - - # Flatten the labels for vectorized computation - truth_flat = truth_label.flatten() - pred_flat = pred_label.flatten() - - matched_pred_label = np.zeros_like(pred_label) - - if len(pred_ids) > 0: - - # Precompute binary masks for all `truth_ids` - if len(truth_flat) * len(truth_ids) > PRECOMPUTE_LIMIT: - truth_binary_masks = spoof_precomputed(truth_flat, truth_ids) - else: - logging.info("Precomputing binary masks for all `truth_ids`...") - truth_binary_masks = np.array( - [(truth_flat == tid) for tid in truth_ids], dtype=bool - ) - - def get_cost(j): - # Find all `truth_ids` that overlap with this prediction mask - pred_mask = pred_flat == pred_ids[j] - relevant_truth_ids = np.unique(truth_flat[pred_mask]) - relevant_truth_ids = relevant_truth_ids[relevant_truth_ids != 0] - relevant_truth_indices = np.where(np.isin(truth_ids, relevant_truth_ids))[0] - relevant_truth_masks = truth_binary_masks[relevant_truth_indices] - - if relevant_truth_indices.size == 0: - return [], j, [] - - tp = relevant_truth_masks[:, pred_mask].sum(1) - fn = (relevant_truth_masks[:, pred_mask == 0]).sum(1) - fp = (relevant_truth_masks[:, pred_mask] == 0).sum(1) - - # Compute Jaccard scores - jaccard_scores = tp / (tp + fp + fn) - - # Fill in the cost matrix for this `j` (prediction) - return relevant_truth_indices, j, jaccard_scores - - # Initialize the cost matrix - logging.info( - f"Initializing cost matrix of {len(truth_ids)} x {len(pred_ids)} (true x pred)..." - ) - cost_matrix = np.zeros((len(truth_ids), len(pred_ids))) - - # Compute the cost matrix - if DEBUG: - # Use tqdm for progress tracking - bar = tqdm( - range(pred_ids), - desc="Computing cost matrix", - leave=True, - dynamic_ncols=True, - total=len(pred_ids), - ) - # Compute the cost matrix - for j in bar: - relevant_truth_indices, j, jaccard_scores = get_cost(j) - cost_matrix[relevant_truth_indices, j] = jaccard_scores - else: - with ThreadPoolExecutor(max_workers=PER_INSTANCE_THREADS) as executor: - for relevant_truth_indices, j, jaccard_scores in tqdm( - executor.map(get_cost, range(len(pred_ids))), - desc="Computing cost matrix in parallel", - dynamic_ncols=True, - total=len(pred_ids), - leave=True, - ): - cost_matrix[relevant_truth_indices, j] = jaccard_scores - + elif cost_matrix.size > 0: # Match the predicted instances to the ground truth instances logging.info("Calculating linear sum assignment...") row_inds, col_inds = linear_sum_assignment(cost_matrix, maximize=True) - # Contruct the volume for the matched instances - for i, j in tqdm( - zip(col_inds, row_inds), - desc="Relabeling matched instances", - dynamic_ncols=True, - ): - if pred_ids[i] == 0 or truth_ids[j] == 0: - # Don't score the background - continue - pred_mask = pred_label == pred_ids[i] - matched_pred_label[pred_mask] = truth_ids[j] + # Construct the volume for the matched instances + mapping = {0: 0} # background maps to background + mapping.update( + {pred_id + 1: truth_id + 1 for truth_id, pred_id in zip(row_inds, col_inds)} + ) + pred_label = remap( + pred_label, mapping, in_place=True, preserve_missing_labels=True + ) hausdorff_distances = optimized_hausdorff_distances( - truth_label, matched_pred_label, voxel_size, hausdorff_distance_max + truth_label, pred_label, voxel_size, hausdorff_distance_max ) else: # No predictions to match @@ -311,12 +345,12 @@ def get_cost(j): # Compute the scores logging.info("Computing accuracy score...") - accuracy = accuracy_score(truth_flat, matched_pred_label.flatten()) + accuracy = accuracy_score(truth_label.flatten(), pred_label.flatten()) hausdorff_dist = np.mean(hausdorff_distances) if len(hausdorff_distances) > 0 else 0 normalized_hausdorff_dist = 1.01 ** ( -hausdorff_dist / np.linalg.norm(voxel_size) - ) # normalize Hausdorff distance to [0, 1] using the maximum distance represented by a voxel. 32 is arbitrarily chosen to have a reasonable range - combined_score = (accuracy * normalized_hausdorff_dist) ** 0.5 + ) # normalize Hausdorff distance to [0, 1] using the maximum distance represented by a voxel + combined_score = (accuracy * normalized_hausdorff_dist) ** 0.5 # geometric mean logging.info(f"Accuracy: {accuracy:.4f}") logging.info(f"Hausdorff Distance: {hausdorff_dist:.4f}") logging.info(f"Normalized Hausdorff Distance: {normalized_hausdorff_dist:.4f}") @@ -326,7 +360,7 @@ def get_cost(j): "hausdorff_distance": hausdorff_dist, "normalized_hausdorff_distance": normalized_hausdorff_dist, "combined_score": combined_score, - } + } # type: ignore def score_semantic(pred_label, truth_label) -> dict[str, float]: @@ -347,16 +381,18 @@ def score_semantic(pred_label, truth_label) -> dict[str, float]: # Flatten the label volumes and convert to binary pred_label = (pred_label > 0.0).flatten() truth_label = (truth_label > 0.0).flatten() - # Compute the scores + # Compute the scores if np.sum(truth_label + pred_label) == 0: # If there are no true positives, set the scores to 1 logging.debug("No true positives found. Setting scores to 1.") dice_score = 1 + iou_score = 1 else: dice_score = 1 - dice(truth_label, pred_label) + iou_score = jaccard_score(truth_label, pred_label, zero_division=1) scores = { - "iou": jaccard_score(truth_label, pred_label, zero_division=1), + "iou": iou_score, "dice_score": dice_score if not np.isnan(dice_score) else 1, } @@ -444,6 +480,7 @@ def empty_label_score( label, crop_name, instance_classes=INSTANCE_CLASSES, truth_path=TRUTH_PATH ): if label in instance_classes: + truth_path = UPath(truth_path) return { "accuracy": 0, "hausdorff_distance": 0, diff --git a/src/cellmap_segmentation_challenge/utils/__init__.py b/src/cellmap_segmentation_challenge/utils/__init__.py index 67c5b87..426b096 100644 --- a/src/cellmap_segmentation_challenge/utils/__init__.py +++ b/src/cellmap_segmentation_challenge/utils/__init__.py @@ -35,6 +35,7 @@ ], "submission": [ "package_submission", + "zip_submission", "save_numpy_class_arrays_to_zarr", "save_numpy_class_labels_to_zarr", ], diff --git a/src/cellmap_segmentation_challenge/utils/crops.py b/src/cellmap_segmentation_challenge/utils/crops.py index 8f81bd8..5676439 100644 --- a/src/cellmap_segmentation_challenge/utils/crops.py +++ b/src/cellmap_segmentation_challenge/utils/crops.py @@ -2,6 +2,7 @@ import os import fsspec +import numpy as np from upath import UPath from typing_extensions import Self from yarl import URL diff --git a/src/cellmap_segmentation_challenge/utils/submission.py b/src/cellmap_segmentation_challenge/utils/submission.py index a37d1db..e0f8ad0 100644 --- a/src/cellmap_segmentation_challenge/utils/submission.py +++ b/src/cellmap_segmentation_challenge/utils/submission.py @@ -178,7 +178,7 @@ def package_crop(crop, zarr_group, overwrite, input_search_path=PROCESSED_PATH): label_array.attrs["translation"] = crop.translation label_array.attrs["shape"] = crop.shape - return crop_path + return crop_path.path def package_submission( @@ -219,16 +219,24 @@ def package_submission( overwrite=overwrite, input_search_path=input_search_path, ) + successful_crops = 0 for crop_path in tqdm( pool.map(partial_package_crop, TEST_CROPS), total=len(TEST_CROPS), dynamic_ncols=True, desc="Packaging crops...", ): - tqdm.write(f"Packaged {crop_path}") - + if "skipping" in crop_path.lower(): + tqdm.write(f"{crop_path} skipped.") + else: + tqdm.write(f"Packaged {crop_path}") + successful_crops += 1 + logging.info(f"Packaged {successful_crops}/{len(TEST_CROPS)} crops.") logging.info(f"Saved submission to {output_path}") + if successful_crops == 0: + raise RuntimeError("No crops were packaged; submission zarr is empty.") + logging.info("Zipping submission...") zip_submission(output_path) diff --git a/tests/test_all.py b/tests/test_all.py index 0e1bd94..cfd6ac6 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,27 +1,21 @@ # %% -import json import pytest -import shutil import os from upath import UPath -import numpy as np -from skimage.transform import rescale from cellmap_segmentation_challenge.utils import ( - simulate_predictions_accuracy, - simulate_predictions_iou_binary, download_file, ) - from cellmap_segmentation_challenge import RAW_NAME, CROP_NAME -from cellmap_segmentation_challenge.utils.submission import ( - zip_submission, - save_numpy_class_arrays_to_zarr, -) ERROR_TOLERANCE = 0.1 +skip_in_ci = pytest.mark.skipif( + os.getenv("CI") == "true", + reason="Skipped in CI", +) + @pytest.fixture(autouse=True) def reset_env(): @@ -58,6 +52,7 @@ def setup_temp_path(tmp_path_factory): yield REPO_ROOT, BASE_DATA_PATH, SEARCH_PATH, PREDICTIONS_PATH, PROCESSED_PATH, SUBMISSION_PATH +@skip_in_ci @pytest.mark.dependency() def test_fetch_test_crops(setup_temp_path): from cellmap_segmentation_challenge.cli import fetch_data_cli @@ -83,7 +78,7 @@ def test_fetch_test_crops(setup_temp_path): ) -# %% +@skip_in_ci @pytest.mark.dependency() def test_fetch_data(setup_temp_path): from cellmap_segmentation_challenge.cli import fetch_data_cli @@ -98,7 +93,7 @@ def test_fetch_data(setup_temp_path): ) = setup_temp_path fetch_data_cli.callback( - crops="116,118", + crops="9", raw_padding=0, dest=BASE_DATA_PATH.path, access_mode="append", @@ -109,7 +104,7 @@ def test_fetch_data(setup_temp_path): ) -# %% +@skip_in_ci @pytest.mark.dependency(depends=["test_fetch_data"]) def test_train(setup_temp_path): ( @@ -136,14 +131,23 @@ def test_train(setup_temp_path): classes=["mito", "er"], search_path=SEARCH_PATH, csv_path=REPO_ROOT / "datasplit.csv", - validation_prob=0.5, + validation_prob=0.0, ) + # There's only one crop, so copy it to the validate set + + with open(REPO_ROOT / "datasplit.csv", "r") as f: + line = f.readlines()[0] + parts = line.strip().split(",") + parts[0] = '"validate"' + with open(REPO_ROOT / "datasplit.csv", "a") as f: + f.write(",".join(parts) + "\n") + train_cli.callback(REPO_ROOT / "train_config.py") -# %% -@pytest.mark.dependency(depends=["test_fetch_data"]) +@skip_in_ci +@pytest.mark.dependency(depends=["test_train"]) def test_predict(setup_temp_path): from cellmap_segmentation_challenge.cli import predict_cli @@ -156,14 +160,14 @@ def test_predict(setup_temp_path): SUBMISSION_PATH, ) = setup_temp_path - download_file( - "https://raw.githubusercontent.com/janelia-cellmap/cellmap-segmentation-challenge/refs/heads/main/tests/train_config.py", - REPO_ROOT / "train_config.py", - ) + # download_file( + # "https://raw.githubusercontent.com/janelia-cellmap/cellmap-segmentation-challenge/refs/heads/main/tests/train_config.py", + # REPO_ROOT / "train_config.py", + # ) predict_cli.callback( REPO_ROOT / "train_config.py", - crops="116", + crops="9", output_path=PREDICTIONS_PATH, skip_orthoplanes=True, overwrite=True, @@ -173,7 +177,7 @@ def test_predict(setup_temp_path): ) -# %% +@skip_in_ci @pytest.mark.dependency(depends=["test_fetch_test_crops"]) def test_predict_test_crops(setup_temp_path): from cellmap_segmentation_challenge.cli import predict_cli @@ -187,10 +191,10 @@ def test_predict_test_crops(setup_temp_path): SUBMISSION_PATH, ) = setup_temp_path - download_file( - "https://raw.githubusercontent.com/janelia-cellmap/cellmap-segmentation-challenge/refs/heads/main/tests/train_config.py", - REPO_ROOT / "train_config.py", - ) + # download_file( + # "https://raw.githubusercontent.com/janelia-cellmap/cellmap-segmentation-challenge/refs/heads/main/tests/train_config.py", + # REPO_ROOT / "train_config.py", + # ) predict_cli.callback( REPO_ROOT / "train_config.py", @@ -204,7 +208,7 @@ def test_predict_test_crops(setup_temp_path): ) -# %% +@skip_in_ci @pytest.mark.dependency(depends=["test_predict"]) def test_process(setup_temp_path): from cellmap_segmentation_challenge.cli import process_cli @@ -234,8 +238,8 @@ def test_process(setup_temp_path): ) -# %% -@pytest.mark.dependency(depends=["test_fetch_data"]) +@skip_in_ci +@pytest.mark.dependency(depends=["test_process"]) def test_pack_results(setup_temp_path): from cellmap_segmentation_challenge.cli import package_submission_cli @@ -248,110 +252,6 @@ def test_pack_results(setup_temp_path): SUBMISSION_PATH, ) = setup_temp_path - truth_path = REPO_ROOT / "data" / "truth.zarr" - package_submission_cli.callback( - PROCESSED_PATH, truth_path.path, overwrite=True, max_workers=os.cpu_count() + PROCESSED_PATH, SUBMISSION_PATH, overwrite=True, max_workers=os.cpu_count() ) - - -# %% -@pytest.mark.parametrize( - "scale, iou, accuracy", - [ - (None, None, None), - (2, None, None), # 2x resolution - (0.5, None, None), # 0.5x resolution - (None, 0.8, 0.8), # 0.8 iou, 0.8 accuracy - (2, 0.8, 0.8), # 2x resolution, 0.8 iou, 0.8 accuracy - (0.5, 0.8, 0.8), # 0.5x resolution, 0.8 iou, 0.8 accuracy - ], -) -@pytest.mark.dependency(depends=["test_pack_results"]) -def test_evaluate(setup_temp_path, scale, iou, accuracy): - from cellmap_segmentation_challenge.cli import evaluate_cli - from cellmap_segmentation_challenge.evaluate import INSTANCE_CLASSES - import zarr - - ( - REPO_ROOT, - BASE_DATA_PATH, - SEARCH_PATH, - PREDICTIONS_PATH, - PROCESSED_PATH, - SUBMISSION_PATH, - ) = setup_temp_path - - truth_path = REPO_ROOT / "data" / "truth.zarr" - - if any([scale, iou, accuracy]): - submission_path = REPO_ROOT / "data" / "submission.zarr" - if submission_path.exists(): - # Remove the submission zarr if it already exists - shutil.rmtree(submission_path) - submission_zarr = zarr.open(submission_path, mode="w") - truth_zarr = zarr.open(truth_path, mode="r") - for crop in truth_zarr.keys(): - crop_zarr = truth_zarr[crop] - submission_zarr.create_group(crop) - for label in crop_zarr.keys(): - label_zarr = crop_zarr[label] - attrs = label_zarr.attrs.asdict() - truth = label_zarr[:] - pred = truth.copy() - - if iou is not None and label not in INSTANCE_CLASSES: - pred = simulate_predictions_iou_binary(pred, iou) - if accuracy is not None and label in INSTANCE_CLASSES: - pred = simulate_predictions_accuracy(pred, accuracy) - - if scale: - pred = rescale(pred, scale, order=0, preserve_range=True) - old_voxel_size = attrs["voxel_size"] - new_voxel_size = [s / scale for s in attrs["voxel_size"]] - attrs["voxel_size"] = new_voxel_size - # Adjust the translation - attrs["translation"] = [ - t + (n - o) / 2 - for t, o, n in zip( - attrs["translation"], old_voxel_size, new_voxel_size - ) - ] - - save_numpy_class_arrays_to_zarr( - submission_path, - crop, - [label], - [pred], - attrs=attrs, - ) - else: - submission_path = truth_path - zip_submission(submission_path) - - evaluate_cli.callback( - submission_path.with_suffix(".zip"), - result_file=REPO_ROOT / "result.json", - truth_path=truth_path, - instance_classes=",".join(INSTANCE_CLASSES), - ) - - # Check the results: - with open(REPO_ROOT / "result_submitted_only.json") as f: - results = json.load(f) - - if iou is None and accuracy is None: - assert ( - 1 - results["overall_score"] < ERROR_TOLERANCE - ), f"Overall score should be 1 but is: {results['overall_score']}" - else: - # Check all accuracy scores and ious - for label, scores in results["label_scores"].items(): - if label in INSTANCE_CLASSES: - assert ( - np.abs((accuracy or 1) - scores["accuracy"]) < ERROR_TOLERANCE - ), f"Accuracy score for {label} should be {(accuracy or 1)} but is: {scores['accuracy']}" - else: - assert ( - np.abs((iou or 1) - scores["iou"]) < ERROR_TOLERANCE - ), f"IoU score for {label} should be {(iou or 1)} but is: {scores['iou']}" diff --git a/tests/test_crop_manifest.csv b/tests/test_crop_manifest.csv index 54a49b5..e678b30 100644 --- a/tests/test_crop_manifest.csv +++ b/tests/test_crop_manifest.csv @@ -1,3 +1,3 @@ crop_name,dataset,class_label,voxel_size,translation,shape -118,jrc_ctl-id8-1,er,[3.48;4.0;4.0],[9107.16;5772.0;19260.0],[200;200;200] -118,jrc_ctl-id8-1,mito,[13.92;16.0;16.0],[9112.38;5778.0;19266.0],[50;50;50] \ No newline at end of file +9,jrc_hela-2,mito,[2.62;2.0;2.0],[8429.85;1679.0;11599.0],[106;200;200] +9,jrc_hela-2,er,[2.62;2.0;2.0],[8429.85;1679.0;11599.0],[106;200;200] \ No newline at end of file diff --git a/tests/test_crops.py b/tests/test_crops.py new file mode 100644 index 0000000..31b187e --- /dev/null +++ b/tests/test_crops.py @@ -0,0 +1,82 @@ +"""Unit tests for crops functions in cellmap_segmentation_challenge.utils.crops""" + +from cellmap_segmentation_challenge.utils.crops import ( + TestCropRow, + CropRow, +) + + +class TestTestCropRowDataclass: + """Tests for TestCropRow dataclass""" + + def test_from_csv_row_simple(self): + """Test creating TestCropRow from a simple CSV row""" + row = "116,jrc_hela-2,er,[4.0;4.0;4.0],[100.0;200.0;300.0],[64;64;64]" + crop = TestCropRow.from_csv_row(row) + + assert crop.id == 116 + assert crop.dataset == "jrc_hela-2" + assert crop.class_label == "er" + assert crop.voxel_size == (4.0, 4.0, 4.0) + assert crop.translation == (100.0, 200.0, 300.0) + assert crop.shape == (64, 64, 64) + + def test_from_csv_row_different_shapes(self): + """Test creating TestCropRow with different shape dimensions""" + row = "234,jrc_cos7-1a,nuc,[8.0;8.0;8.0],[1000.0;2000.0;3000.0],[128;128;128]" + crop = TestCropRow.from_csv_row(row) + + assert crop.id == 234 + assert crop.dataset == "jrc_cos7-1a" + assert crop.class_label == "nuc" + assert crop.voxel_size == (8.0, 8.0, 8.0) + assert crop.translation == (1000.0, 2000.0, 3000.0) + assert crop.shape == (128, 128, 128) + + def test_from_csv_row_float_voxel_sizes(self): + """Test creating TestCropRow with non-integer voxel sizes""" + row = "118,jrc_hela-3,mito,[2.5;2.5;2.5],[50.0;100.0;150.0],[32;32;32]" + crop = TestCropRow.from_csv_row(row) + + assert crop.id == 118 + assert crop.voxel_size == (2.5, 2.5, 2.5) + + def test_from_csv_row_negative_translations(self): + """Test creating TestCropRow with negative translations""" + row = "120,jrc_hela-2,cell,[-4.0;4.0;4.0],[-100.0;200.0;300.0],[64;64;64]" + crop = TestCropRow.from_csv_row(row) + + assert crop.voxel_size == (-4.0, 4.0, 4.0) + assert crop.translation == (-100.0, 200.0, 300.0) + + +class TestCropRowDataclass: + """Tests for CropRow dataclass""" + + def test_from_csv_row_simple(self): + """Test creating CropRow from a simple CSV row""" + row = "116,jrc_hela-2,xy,s3://janelia-cellmap-0/jrc_hela-2/jrc_hela-2.n5/groundtruth,s3://janelia-cellmap-0/jrc_hela-2/jrc_hela-2.n5/em/fibsem-uint8" + crop = CropRow.from_csv_row(row) + + assert crop.id == 116 + assert crop.dataset == "jrc_hela-2" + assert crop.alignment == "xy" + assert str(crop.gt_source).startswith("s3://janelia-cellmap-0/jrc_hela-2") + assert str(crop.em_url).startswith("s3://janelia-cellmap-0/jrc_hela-2") + + def test_from_csv_row_different_alignment(self): + """Test creating CropRow with different alignment""" + row = "234,jrc_cos7-1a,xz,s3://janelia-cellmap-0/jrc_cos7-1a/jrc_cos7-1a.n5/groundtruth,s3://janelia-cellmap-0/jrc_cos7-1a/jrc_cos7-1a.n5/em/fibsem-uint8" + crop = CropRow.from_csv_row(row) + + assert crop.id == 234 + assert crop.dataset == "jrc_cos7-1a" + assert crop.alignment == "xz" + + def test_from_csv_row_different_datasets(self): + """Test creating CropRow from different datasets""" + row = "342,jrc_jurkat-1,xy,s3://janelia-cellmap-0/jrc_jurkat-1/jrc_jurkat-1.n5/groundtruth,s3://janelia-cellmap-0/jrc_jurkat-1/jrc_jurkat-1.n5/em/fibsem-uint8" + crop = CropRow.from_csv_row(row) + + assert crop.dataset == "jrc_jurkat-1" + assert crop.id == 342 diff --git a/tests/test_evaluate_metrics.py b/tests/test_evaluate_metrics.py new file mode 100644 index 0000000..20acc45 --- /dev/null +++ b/tests/test_evaluate_metrics.py @@ -0,0 +1,535 @@ +import zipfile +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import zarr +from fastremap import unique + +from cellmap_segmentation_challenge import evaluate as ev +from cellmap_segmentation_challenge.utils import zip_submission + + +# ------------------------ +# Helpers / tiny dataclasses +# ------------------------ + + +@dataclass +class DummyCrop: + voxel_size: tuple + shape: tuple + translation: tuple + + +# ------------------------ +# iou_matrix tests +# ------------------------ + + +def test_iou_matrix_basic(): + """ + gt: + 0 1 + 2 2 + + pred: + 0 1 + 0 2 + + gt id 1: 1 voxel, pred id 1: 1 voxel, intersection 1 -> IoU 1 + gt id 2: 2 voxels, pred id 2: 1 voxel, intersection 1 -> IoU 1/2 + """ + gt = np.array([[0, 1], [2, 2]], dtype=np.int32) + pred = np.array([[0, 1], [0, 2]], dtype=np.int32) + + iou = ev.iou_matrix(gt, pred) + + assert iou.shape == (2, 2) + # (gt1, pred1) + assert np.isclose(iou[0, 0], 1.0) + # (gt2, pred2) = 1 / (2 + 1 - 1) = 0.5 + assert np.isclose(iou[1, 1], 0.5) + # all other entries should be 0 + assert np.isclose(iou[0, 1], 0.0) + assert np.isclose(iou[1, 0], 0.0) + + +def test_iou_matrix_no_gt_instances(): + gt = np.zeros((3, 3), dtype=np.int32) + pred = np.array([[0, 1, 1], [0, 0, 2], [0, 0, 0]], dtype=np.int32) + + iou = ev.iou_matrix(gt, pred) + # nG = 0, nP = 2 -> shape (0,2) + assert iou.shape == (0, 2) + + +def test_iou_matrix_no_pred_instances(): + gt = np.array([[0, 1, 1], [0, 2, 2], [0, 0, 0]], dtype=np.int32) + pred = np.zeros_like(gt) + + iou = ev.iou_matrix(gt, pred) + # nG = 2, nP = 0 -> shape (2,0) + assert iou.shape == (2, 0) + + +def test_iou_matrix_too_many_pred_instances(monkeypatch): + # force INSTANCE_RATIO_CUTOFF low to trigger None + monkeypatch.setenv("INSTANCE_RATIO_CUTOFF", "1") + gt = np.array([[0, 1, 0, 0]], dtype=np.int32) # 1 instance + pred = np.array([[0, 1, 2, 3]], dtype=np.int32) # 3 instances + + res = ev.iou_matrix(gt, pred) + assert res is None + + +# ------------------------ +# Hausdorff distance tests +# ------------------------ + + +def test_compute_hausdorff_distance_identical_masks(): + mask = np.array([[0, 1, 0], [0, 1, 0], [0, 0, 0]], dtype=bool) + d = ev.compute_hausdorff_distance( + mask, mask, voxel_size=(1.0, 1.0), max_distance=np.inf, method="standard" + ) + assert np.isclose(d, 0.0) + + +def test_compute_hausdorff_distance_separated_points(): + a = np.zeros((1, 5), dtype=bool) + b = np.zeros((1, 5), dtype=bool) + a[0, 0] = True + b[0, 3] = True # distance 3 along x + + d_std = ev.compute_hausdorff_distance( + a, b, voxel_size=(1.0, 1.0), max_distance=np.inf, method="standard" + ) + assert np.isclose(d_std, 3.0) + + d_mod = ev.compute_hausdorff_distance( + a, b, voxel_size=(1.0, 1.0), max_distance=np.inf, method="modified" + ) + # only one distance each direction -> mean == max == 3 + assert np.isclose(d_mod, 3.0) + + +def test_compute_hausdorff_distance_percentile(): + a = np.zeros((1, 5), dtype=bool) + b = np.zeros((1, 5), dtype=bool) + # A has foreground at 0,1; B has at 3,4. + a[0, 0] = True + a[0, 1] = True + b[0, 3] = True + b[0, 4] = True + # Distances from each point in A to B: + # A(0)->3 = 3, A(1)->3 = 2 (closest). + # Similarly B->A: 3->1=2, 4->1=3. + # So forward distances [3,2], backward [2,3]. + # 50th percentile is ~2.5 each side. + d_p50 = ev.compute_hausdorff_distance( + a, + b, + voxel_size=(1.0, 1.0), + max_distance=np.inf, + method="percentile", + percentile=50, + ) + assert np.isclose(d_p50, 2.5, atol=1e-6) + + +def test_compute_hausdorff_distance_empty_sets(): + max_distance = 5.0 + a = np.zeros((4, 4), dtype=bool) + b = np.zeros_like(a) + d = ev.compute_hausdorff_distance( + a, b, voxel_size=(1.0, 1.0), max_distance=max_distance + ) + assert np.isclose(d, 0.0) + + a[0, 0] = True + d2 = ev.compute_hausdorff_distance( + a, b, voxel_size=(1.0, 1.0), max_distance=max_distance + ) + assert np.isclose(d2, max_distance) + + +def test_optimized_hausdorff_distances_per_instance(): + # two instances, perfectly matched + truth = np.array([[0, 1, 1], [0, 2, 2]], dtype=np.int32) + pred = truth.copy() + voxel_size = (1.0, 1.0) + + dists = ev.optimized_hausdorff_distances( + truth, pred, voxel_size, hausdorff_distance_max=np.inf, method="standard" + ) + # there are two non-zero ids + ids = unique(truth) + ids = ids[ids != 0] + assert dists.shape == (ids.size,) + assert np.allclose(dists, 0.0) + + +# ------------------------ +# score_instance tests +# ------------------------ + + +def test_score_instance_perfect_match(): + label = np.array([[0, 1, 1], [0, 2, 2]], dtype=np.int32) + scores = ev.score_instance(label, label, voxel_size=(1.0, 1.0)) + + assert np.isclose(scores["accuracy"], 1.0) + assert np.isclose(scores["hausdorff_distance"], 0.0) + assert np.isclose(scores["normalized_hausdorff_distance"], 1.0) + assert np.isclose(scores["combined_score"], 1.0) + + +def test_score_instance_simple_shift(): + # GT has one instance [0,0] and [0,1] + truth = np.array([[1, 1, 0], [0, 0, 0]], dtype=np.int32) + # Prediction shifted one voxel to the right: [0,1],[0,2] (and needs renumbering) + pred = np.array([[0, 2, 2], [0, 0, 0]], dtype=np.int32) + voxel_size = (1.0, 1.0) + scores = ev.score_instance(pred, truth, voxel_size) + + # Accuracy should not be 1 but positive + assert 0.0 < scores["accuracy"] < 1.0 + # Hausdorff distance is 1 (each point moves 1 voxel) + assert np.isclose(scores["hausdorff_distance"], 1.0) + + +# ------------------------ +# score_semantic tests +# ------------------------ + + +def test_score_semantic_perfect_match(): + truth = np.array([[0, 1], [1, 1]], dtype=float) + pred = truth.copy() + scores = ev.score_semantic(pred, truth) + assert np.isclose(scores["iou"], 1.0) + assert np.isclose(scores["dice_score"], 1.0) + + +def test_score_semantic_partial_overlap(): + truth = np.array([[0, 1], [1, 1]], dtype=float) + # prediction misses one positive voxel + pred = np.array([[0, 1], [0, 1]], dtype=float) + + scores = ev.score_semantic(pred, truth) + + # manual IoU: TP = 2, FP = 0, FN = 1 -> IoU = 2 / (2+0+1) = 2/3 + assert np.isclose(scores["iou"], 2 / 3) + # manual Dice: 2TP / (2TP + FP + FN) = 4 / (4 + 0 + 1) = 0.8 + assert np.isclose(scores["dice_score"], 0.8) + + +def test_score_semantic_no_foreground(): + truth = np.zeros((3, 3), dtype=float) + pred = np.zeros_like(truth) + scores = ev.score_semantic(pred, truth) + assert np.isclose(scores["iou"], 1.0) + assert np.isclose(scores["dice_score"], 1.0) + + +# ------------------------ +# resize_array tests +# ------------------------ + + +def test_resize_array_pad_and_crop(): + arr = np.ones((2, 2), dtype=np.int32) + # First confirm padding up to (4,4) + padded = ev.resize_array(arr, (4, 4), pad_value=0) + assert padded.shape == (4, 4) + # original ones centered + assert np.all(padded[1:3, 1:3] == 1) + + # Crop down from (4,4) to (2,2) (should give central 2x2) + cropped = ev.resize_array(padded, (2, 2), pad_value=0) + assert cropped.shape == (2, 2) + assert np.all(cropped == 1) + + +def test_resize_array_only_crop(): + arr = np.arange(16).reshape(4, 4) + target = (2, 2) + out = ev.resize_array(arr, target) + + assert out.shape == target + # Center crop: indices [1:3, 1:3] + expected = arr[1:3, 1:3] + assert np.array_equal(out, expected) + + +# ------------------------ +# match_crop_space tests (single-scale) +# ------------------------ + + +def test_match_crop_space_no_rescale_no_translation(tmp_path): + """Single-scale array with matching voxel size and translation=0: should just crop/pad.""" + arr = np.arange(16, dtype=np.uint8).reshape(4, 4) + root = tmp_path / "vol.zarr" + ds = zarr.open(str(root), mode="w", shape=arr.shape, dtype=arr.dtype) + ds[:] = arr + ds.attrs["voxel_size"] = (1.0, 1.0) + ds.attrs["translation"] = (0.0, 0.0) + + # same voxel size, same shape + out = ev.match_crop_space( + str(root), + "sem_label", + voxel_size=(1.0, 1.0), + shape=(4, 4), + translation=(0.0, 0.0), + ) + assert np.array_equal(out, arr) + + # same voxel size, smaller shape -> centered crop + out2 = ev.match_crop_space( + str(root), + "sem_label", + voxel_size=(1.0, 1.0), + shape=(2, 2), + translation=(1.0, 1.0), + ) + assert out2.shape == (2, 2) + assert np.array_equal(out2, arr[1:3, 1:3]) + + +def test_match_crop_space_rescale_instance(tmp_path): + """ + Simple rescale case for instance label: + input voxel_size=(2,2), target=(1,1). We expect a 2x upsample along each axis. + """ + arr = np.zeros((2, 2), dtype=np.uint8) + arr[0, 0] = 1 + root = tmp_path / "inst.zarr" + ds = zarr.open(str(root), mode="w", shape=arr.shape, dtype=arr.dtype) + ds[:] = arr + ds.attrs["voxel_size"] = (2.0, 2.0) + ds.attrs["translation"] = (0.0, 0.0) + + out_shape = (4, 4) + out = ev.match_crop_space( + str(root), + "instance", + voxel_size=(1.0, 1.0), + shape=out_shape, + translation=(0.0, 0.0), + ) + assert out.shape == out_shape + # nearest neighbor upsampling should preserve instance id in top-left 2x2 region + assert np.all(out[0:2, 0:2] == 1) + + +# ------------------------ +# empty_label_score & missing_volume_score +# ------------------------ + + +def _create_simple_volume( + path: Path, + crop_name: str, + label_name: str, + arr: np.ndarray, + voxel_size=(1.0, 1.0, 1.0), +): + ds = zarr.open( + str(path / crop_name), + path=label_name, + mode="w", + shape=arr.shape, + dtype=arr.dtype, + ) + ds[:] = arr + ds.attrs["voxel_size"] = voxel_size + + +def test_empty_label_score_instance(tmp_path): + truth_root = tmp_path / "truth.zarr" + arr = np.zeros((2, 2, 2), dtype=np.uint8) + _create_simple_volume(truth_root, "crop1", "instance", arr) + + scores = ev.empty_label_score( + label="instance", + crop_name="crop1", + instance_classes=["instance"], + truth_path=truth_root.as_posix(), + ) + # num_voxels should match volume size + assert scores["num_voxels"] == arr.size + assert scores["is_missing"] is True + assert scores["accuracy"] == 0 + + +def test_missing_volume_score_mixed_labels(tmp_path): + truth_root = tmp_path / "truth_volume.zarr" + arr_inst = np.zeros((2, 2, 2), dtype=np.uint8) + arr_sem = np.zeros((2, 2, 2), dtype=np.uint8) + + _create_simple_volume(truth_root, "crop1", "instance", arr_inst) + _create_simple_volume(truth_root, "crop1", "sem", arr_sem) + + scores = ev.missing_volume_score( + truth_volume_path=(truth_root / "crop1").as_posix(), + instance_classes=["instance"], + ) + + assert set(scores.keys()) == {"instance", "sem"} + assert scores["instance"]["is_missing"] is True + assert scores["sem"]["is_missing"] is True + assert scores["instance"]["accuracy"] == 0.0 + assert scores["sem"]["iou"] == 0.0 + + +# ------------------------ +# combine_scores tests +# ------------------------ + + +def test_combine_scores_instance_and_semantic(): + # Two volumes, one instance, one semantic + scores = { + "crop1": { + "instance": { + "accuracy": 1.0, + "hausdorff_distance": 0.0, + "normalized_hausdorff_distance": 1.0, + "combined_score": 1.0, + "num_voxels": 8, + "voxel_size": (1.0, 1.0, 1.0), + "is_missing": False, + } + }, + "crop2": { + "sem": { + "iou": 0.5, + "dice_score": 2 / 3, + "num_voxels": 8, + "voxel_size": (1.0, 1.0, 1.0), + "is_missing": False, + } + }, + } + + combined = ev.combine_scores( + scores, include_missing=True, instance_classes=["instance"] + ) + + ls = combined["label_scores"] + assert np.isclose(ls["instance"]["combined_score"], 1.0) + assert np.isclose(ls["sem"]["iou"], 0.5) + assert "overall_instance_score" in combined + assert "overall_semantic_score" in combined + assert "overall_score" in combined + # only one of each type -> overall scores are just these + assert np.isclose(combined["overall_instance_score"], 1.0) + assert np.isclose(combined["overall_semantic_score"], 0.5) + + +# ------------------------ +# score_label & score_submission-style integration +# ------------------------ + + +def test_score_label_instance_integration(monkeypatch, tmp_path): + """ + Small integration test: zarr truth + pred, dummy TEST_CROPS_DICT entry, + score_label should return instance metrics consistent with score_instance. + """ + # Arrange mini truth volume + crop_name = "crop1" + label_name = "instance" + truth_root = tmp_path / "truth.zarr" + + arr3d = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]], dtype=np.uint8) + _create_simple_volume(truth_root, crop_name, label_name, arr3d) + + # matching pred volume with same data + pred_root = tmp_path / "pred.zarr" + _create_simple_volume(pred_root, crop_name, label_name, arr3d) + + # monkeypatch TEST_CROPS_DICT for this one crop/label + dummy_crop = DummyCrop( + voxel_size=(1.0, 1.0, 1.0), shape=arr3d.shape, translation=(0.0, 0.0, 0.0) + ) + monkeypatch.setattr( + ev, + "TEST_CROPS_DICT", + {(1, label_name): dummy_crop}, + raising=False, + ) + + crop_name_str = crop_name # "crop1" + pred_label_path = ev.UPath(pred_root.as_posix()) / crop_name_str / label_name + + crop_out, label_out, results = ev.score_label( + pred_label_path=pred_label_path, + label_name=label_name, + crop_name=crop_name_str, + truth_path=ev.UPath(truth_root.as_posix()), + instance_classes=["instance"], + ) + + assert crop_out == crop_name_str + assert label_out == label_name + assert np.isclose(results["accuracy"], 1.0) + assert np.isclose(results["hausdorff_distance"], 0.0) + assert results["is_missing"] is False + + +def test_score_submission_debug_serial(monkeypatch, tmp_path): + """ + End-to-end-ish test: + - create a truth volume with a single crop and label + - create matching prediction volume + - zip it + - run score_submission in DEBUG mode (serial scoring) + """ + # ensure DEBUG=True so score_submission uses serial path + monkeypatch.setattr(ev, "DEBUG", True) + + # Create truth.zarr/crop1/instance + truth_root = tmp_path / "truth.zarr" + crop_name = "crop1" + label_name = "instance" + + arr = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]], dtype=np.uint8) + _create_simple_volume(truth_root, crop_name, label_name, arr) + + # Create submission zarr directory structure: submission_root/crop1/instance + submission_root = tmp_path / "submission.zarr" + _create_simple_volume(submission_root, crop_name, label_name, arr) + + # Patch TEST_CROPS_DICT + dummy_crop = DummyCrop( + voxel_size=(1.0, 1.0, 1.0), shape=arr.shape, translation=(0.0, 0.0, 0.0) + ) + monkeypatch.setattr( + ev, + "TEST_CROPS_DICT", + {(1, label_name): dummy_crop}, + raising=False, + ) + + # Zip the submission_root contents so that unzip_file will create + # a directory with crop1 directly inside. + zip_path = zip_submission(submission_root) + + # Run score_submission with explicit truth_path and instance_classes + scores = ev.score_submission( + submission_path=zip_path.as_posix(), + result_file=None, + truth_path=truth_root.as_posix(), + instance_classes=[label_name], + ) + + # We expect perfect instance score + assert np.isclose(scores["overall_instance_score"], 1.0) + # No semantic labels -> overall_semantic_score is nan, but that’s fine; + # just ensure label_scores present and correct. + assert "label_scores" in scores + assert np.isclose(scores["label_scores"][label_name]["accuracy"], 1.0) diff --git a/tests/test_iou_matrix.py b/tests/test_iou_matrix.py new file mode 100644 index 0000000..af62e7f --- /dev/null +++ b/tests/test_iou_matrix.py @@ -0,0 +1,164 @@ +import os +import numpy as np +from cellmap_segmentation_challenge.evaluate import iou_matrix + + +def test_basic_iou_small_grid(monkeypatch): + """ + gt IDs: 1..2 ; pred IDs: 1..2 ; 0 is background + Expected IoU: + [[0.5, 0.0], + [0.0, 0.75]] + """ + gt = np.array( + [ + [0, 1, 1], + [0, 0, 2], + [0, 2, 2], + ], + dtype=np.int32, + ) + + pred = np.array( + [ + [0, 1, 0], + [0, 2, 2], + [0, 2, 2], + ], + dtype=np.int32, + ) + + # Make the cutoff permissive so it doesn't trip in this test + monkeypatch.setenv("INSTANCE_RATIO_CUTOFF", "999") + + iou = iou_matrix(gt, pred) + assert iou.shape == (2, 2) + np.testing.assert_allclose( + iou, np.array([[0.5, 0.0], [0.0, 0.75]], dtype=np.float32), rtol=1e-6, atol=1e-6 + ) + assert iou.dtype == np.float32 + + +def test_empty_pred(): + gt = np.array( + [ + [0, 1, 1], + [0, 0, 2], + [0, 2, 2], + ], + dtype=np.int32, + ) + pred = np.zeros_like(gt) + iou = iou_matrix(gt, pred) + # nG=2, nP=0 → (2,0) array of zeros + assert iou.shape == (2, 0) + assert iou.size == 0 + + +def test_empty_gt(): + gt = np.zeros((3, 3), dtype=np.int32) + pred = np.array( + [ + [0, 1, 0], + [0, 2, 2], + [0, 2, 2], + ], + dtype=np.int32, + ) + iou = iou_matrix(gt, pred) + # nG=0, nP=2 → (0,2) array of zeros + assert iou.shape == (0, 2) + assert iou.size == 0 + + +def test_rectangle_more_preds(monkeypatch): + """ + Non-square case: nG=2, nP=3; verify shape and a couple of values. + """ + gt = np.array( + [ + [0, 1, 1], + [0, 0, 2], + [0, 2, 2], + ], + dtype=np.int32, + ) + pred = np.array( + [ + [0, 1, 3], + [0, 2, 3], + [0, 2, 2], + ], + dtype=np.int32, + ) + + # Ensure cutoff won't trigger + monkeypatch.setenv("INSTANCE_RATIO_CUTOFF", "999") + iou = iou_matrix(gt, pred) + assert iou.shape == (2, 3) + + # Quick sanity checks: + # gt1∩pred1 = 1 pixel, |gt1|=2, |pred1|=1 → IoU = 1/(2+1-1)=0.5 + # gt2∩pred2 >= 2, |gt2|=3, |pred2|>=2 → IoU should be > 0 + assert abs(float(iou[0, 0]) - 0.5) < 1e-6 + assert iou[1, 1] > 0.0 + + +def test_ratio_cutoff_triggers(monkeypatch): + """ + Force the INSTANCE_RATIO_CUTOFF branch to return None. + """ + gt = np.array( + [ + [1, 1, 0], + [2, 2, 0], + [0, 0, 0], + ], + dtype=np.int32, + ) # nG = 2 + pred = np.array( + [ + [1, 2, 3], + [4, 5, 0], + [0, 0, 0], + ], + dtype=np.int32, + ) # nP = 5 + + # Set cutoff small so nP/nG = 2.5 exceeds it + monkeypatch.setenv("INSTANCE_RATIO_CUTOFF", "0.1") + + out = iou_matrix(gt, pred) + assert out is None + + +def test_matches_naive_reference(): + """ + Cross-check against a naive python/dict reference on a tiny random mask. + """ + rng = np.random.default_rng(0) + H, W = 6, 7 + nG, nP = 3, 4 + gt = rng.integers(0, nG + 1, size=(H, W), dtype=np.int32) + pred = rng.integers(0, nP + 1, size=(H, W), dtype=np.int32) + + # Fast impl + fast = iou_matrix(gt, pred) + + # Naive reference + gt_counts = {k: int((gt == k).sum()) for k in range(1, nG + 1)} + pr_counts = {k: int((pred == k).sum()) for k in range(1, nP + 1)} + inter = np.zeros((nG, nP), dtype=np.int64) + for y in range(H): + for x in range(W): + g = int(gt[y, x]) + p = int(pred[y, x]) + if g > 0 and p > 0: + inter[g - 1, p - 1] += 1 + gt_sizes = np.array([gt_counts[i] for i in range(1, nG + 1)])[:, None] + pr_sizes = np.array([pr_counts[j] for j in range(1, nP + 1)])[None, :] + union = gt_sizes + pr_sizes - inter + with np.errstate(divide="ignore", invalid="ignore"): + ref = np.where(union > 0, inter / union, 0.0).astype(np.float32) + + np.testing.assert_allclose(fast, ref, rtol=1e-6, atol=1e-6) diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..51902bf --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,129 @@ +"""Unit tests for loss functions in cellmap_segmentation_challenge.utils.loss""" + +import torch + +from cellmap_segmentation_challenge.utils.loss import CellMapLossWrapper + + +class TestCellMapLossWrapper: + """Tests for CellMapLossWrapper class""" + + def test_init_with_mse_loss(self): + """Test initialization with MSE loss""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + assert isinstance(loss_wrapper.loss_fn, torch.nn.MSELoss) + assert loss_wrapper.kwargs["reduction"] == "none" + + def test_init_with_bce_loss(self): + """Test initialization with BCE loss""" + loss_wrapper = CellMapLossWrapper(torch.nn.BCELoss) + assert isinstance(loss_wrapper.loss_fn, torch.nn.BCELoss) + + def test_calc_loss_no_nans(self): + """Test calc_loss with no NaN values""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + targets = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + loss = loss_wrapper.calc_loss(outputs, targets) + + # With identical outputs and targets, MSE should be 0 + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_calc_loss_with_nans(self): + """Test calc_loss with NaN values in targets""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + targets = torch.tensor([[1.0, float("nan")], [3.0, 4.0]]) + + loss = loss_wrapper.calc_loss(outputs, targets) + + # Loss should be computed only for non-NaN values + # Expected loss = ((1-1)^2 + (3-3)^2 + (4-4)^2) / 3 = 0 + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_calc_loss_all_nans(self): + """Test calc_loss when all targets are NaN""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + targets = torch.tensor( + [[float("nan"), float("nan")], [float("nan"), float("nan")]] + ) + + loss = loss_wrapper.calc_loss(outputs, targets) + + # When all targets are NaN, the loss is 0 (no valid pixels to compute loss on) + assert torch.allclose(loss, torch.tensor(0.0)) or torch.isnan(loss) + + def test_forward_tensor_inputs(self): + """Test forward with tensor inputs""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + targets = torch.tensor([[1.5, 2.5], [3.5, 4.5]]) + + loss = loss_wrapper.forward(outputs, targets) + + # MSE = ((1-1.5)^2 + (2-2.5)^2 + (3-3.5)^2 + (4-4.5)^2) / 4 = 0.25 + assert torch.allclose(loss, torch.tensor(0.25), atol=1e-6) + + def test_forward_dict_inputs_matching_dicts(self): + """Test forward with matching dictionary inputs""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = { + "class1": torch.tensor([[1.0, 2.0]]), + "class2": torch.tensor([[3.0, 4.0]]), + } + targets = { + "class1": torch.tensor([[1.0, 2.0]]), + "class2": torch.tensor([[3.0, 4.0]]), + } + + loss = loss_wrapper.forward(outputs, targets) + + # Perfect match, loss should be 0 + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_forward_dict_inputs_with_nans(self): + """Test forward with dictionary inputs containing NaN values""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = { + "class1": torch.tensor([[1.0, 2.0]]), + "class2": torch.tensor([[3.0, 4.0]]), + } + targets = { + "class1": torch.tensor([[1.0, float("nan")]]), + "class2": torch.tensor([[3.0, 4.0]]), + } + + loss = loss_wrapper.forward(outputs, targets) + + # Loss from class1: only first element counted (0) + # Loss from class2: both elements (0) + # Average of 0 and 0 = 0 + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_forward_dict_targets_list_outputs(self): + """Test forward with dict targets and list/tuple outputs""" + loss_wrapper = CellMapLossWrapper(torch.nn.MSELoss) + outputs = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] + targets = { + "class1": torch.tensor([[1.0, 2.0]]), + "class2": torch.tensor([[3.0, 4.0]]), + } + + loss = loss_wrapper.forward(outputs, targets) + + # Perfect match, loss should be 0 + assert torch.allclose(loss, torch.tensor(0.0)) + + def test_bce_loss_with_nans(self): + """Test with BCE loss and NaN values""" + loss_wrapper = CellMapLossWrapper(torch.nn.BCELoss) + outputs = torch.tensor([[0.5, 0.8], [0.3, 0.9]]) + targets = torch.tensor([[1.0, float("nan")], [0.0, 1.0]]) + + loss = loss_wrapper.calc_loss(outputs, targets) + + # Loss should be computed only for non-NaN values + assert not torch.isnan(loss) + assert loss >= 0 # BCE loss is always non-negative diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..69c36c2 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,169 @@ +"""Unit tests for security functions in cellmap_segmentation_challenge.utils.security""" + +import tempfile +import os + +from cellmap_segmentation_challenge.utils.security import ( + analyze_script, + Config, +) + + +class TestAnalyzeScript: + """Tests for analyze_script function""" + + def test_analyze_safe_script(self): + """Test analysis of a safe script""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import numpy as np +import torch + +def safe_function(): + x = np.array([1, 2, 3]) + return x * 2 +""" + ) + f.flush() + try: + is_safe, issues = analyze_script(f.name) + assert is_safe is True + assert len(issues) == 0 + finally: + os.unlink(f.name) + + def test_analyze_script_with_disallowed_import(self): + """Test detection of disallowed imports""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import os +import sys + +def unsafe_function(): + os.system("echo test") +""" + ) + f.flush() + try: + is_safe, issues = analyze_script(f.name) + assert is_safe is False + assert len(issues) > 0 + assert any("os" in issue for issue in issues) + finally: + os.unlink(f.name) + + def test_analyze_script_with_disallowed_function(self): + """Test detection of disallowed function calls""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +def unsafe_function(): + exec("print('hello')") + compile("x = 1", "", "exec") +""" + ) + f.flush() + try: + is_safe, issues = analyze_script(f.name) + assert is_safe is False + assert len(issues) > 0 + assert any("exec" in issue.lower() for issue in issues) + finally: + os.unlink(f.name) + + def test_analyze_script_with_from_import(self): + """Test detection of disallowed from imports""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from subprocess import run + +def unsafe_function(): + run(["echo", "test"]) +""" + ) + f.flush() + try: + is_safe, issues = analyze_script(f.name) + assert is_safe is False + assert len(issues) > 0 + assert any("subprocess" in issue for issue in issues) + finally: + os.unlink(f.name) + + +class TestConfig: + """Tests for Config class""" + + def test_config_initialization(self): + """Test Config initialization with kwargs""" + config = Config(learning_rate=0.001, batch_size=32, epochs=10) + assert config.learning_rate == 0.001 + assert config.batch_size == 32 + assert config.epochs == 10 + + def test_config_to_dict(self): + """Test Config to_dict method""" + config = Config(learning_rate=0.001, batch_size=32) + config_dict = config.to_dict() + assert config_dict == {"learning_rate": 0.001, "batch_size": 32} + + def test_config_get_existing_key(self): + """Test Config get method with existing key""" + config = Config(learning_rate=0.001, batch_size=32) + assert config.get("learning_rate") == 0.001 + assert config.get("batch_size") == 32 + + def test_config_get_missing_key_with_default(self): + """Test Config get method with missing key and default""" + config = Config(learning_rate=0.001) + assert config.get("batch_size", 16) == 16 + assert config.get("epochs", 10) == 10 + + def test_config_get_missing_key_no_default(self): + """Test Config get method with missing key and no default""" + config = Config(learning_rate=0.001) + assert config.get("batch_size") is None + + def test_config_serialize_simple_types(self): + """Test Config serialize with simple data types""" + config = Config( + learning_rate=0.001, batch_size=32, model_name="unet", use_cuda=True + ) + serialized = config.serialize() + assert serialized["learning_rate"] == 0.001 + assert serialized["batch_size"] == 32 + assert serialized["model_name"] == "unet" + assert serialized["use_cuda"] is True + + def test_config_serialize_skips_modules_and_functions(self): + """Test that serialize skips modules, classes, and functions""" + import torch + + def my_function(): + pass + + config = Config(learning_rate=0.001, module=torch, function=my_function) + serialized = config.serialize() + + # Should only include simple types + assert "learning_rate" in serialized + assert "module" not in serialized + assert "function" not in serialized + + def test_config_serialize_skips_private_attributes(self): + """Test that serialize skips private attributes (with __)""" + config = Config(learning_rate=0.001, __private_value=42) + serialized = config.serialize() + assert "learning_rate" in serialized + assert "__private_value" not in serialized + + def test_config_serialize_converts_complex_types_to_string(self): + """Test that serialize converts complex types to strings""" + config = Config(learning_rate=0.001, shape=(64, 64, 64), data=[1, 2, 3]) + serialized = config.serialize() + assert serialized["learning_rate"] == 0.001 + assert isinstance(serialized["shape"], str) + assert isinstance(serialized["data"], str) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..967239a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,182 @@ +"""Unit tests for utility functions in cellmap_segmentation_challenge.utils""" + +import pytest +import numpy as np +import tempfile +import os + +from cellmap_segmentation_challenge.utils.utils import ( + format_coordinates, + format_string, + download_file, + simulate_predictions_iou_binary, + simulate_predictions_accuracy, +) + + +class TestFormatCoordinates: + """Tests for format_coordinates function""" + + def test_format_coordinates_simple_list(self): + """Test formatting a simple list of coordinates""" + coordinates = [1, 2, 3] + result = format_coordinates(coordinates) + assert result == "[1;2;3]" + + def test_format_coordinates_float_list(self): + """Test formatting coordinates with float values""" + coordinates = [1.5, 2.5, 3.5] + result = format_coordinates(coordinates) + assert result == "[1.5;2.5;3.5]" + + def test_format_coordinates_single_value(self): + """Test formatting a single coordinate""" + coordinates = [42] + result = format_coordinates(coordinates) + assert result == "[42]" + + def test_format_coordinates_empty_list(self): + """Test formatting an empty list""" + coordinates = [] + result = format_coordinates(coordinates) + assert result == "[]" + + +class TestFormatString: + """Tests for format_string function""" + + def test_format_string_all_keys_present(self): + """Test formatting when all keys are present in the string""" + string = "Hello {name}, you are {age} years old" + format_kwargs = {"name": "Alice", "age": 30} + result = format_string(string, format_kwargs) + assert result == "Hello Alice, you are 30 years old" + + def test_format_string_partial_keys(self): + """Test formatting when only some keys are present""" + string = "Hello {name}" + format_kwargs = {"name": "Bob", "age": 25} + result = format_string(string, format_kwargs) + assert result == "Hello Bob" + + def test_format_string_no_keys_in_string(self): + """Test when no format keys are in the string""" + string = "Hello World" + format_kwargs = {"name": "Charlie"} + result = format_string(string, format_kwargs) + assert result == "Hello World" + + def test_format_string_missing_required_key(self): + """Test when a required key is missing from format_kwargs""" + string = "Hello {name}" + format_kwargs = {"age": 30} + result = format_string(string, format_kwargs) + # Should return the original string with placeholders preserved + assert result == "Hello {name}" + + +class TestDownloadFile: + """Tests for download_file function""" + + def test_download_file_success(self): + """Test successful file download from a real URL""" + # Use a small, reliable test file from the repository itself + url = "https://raw.githubusercontent.com/janelia-cellmap/cellmap-segmentation-challenge/refs/heads/main/LICENSE" + + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: + dest_path = f.name + + try: + download_file(url, dest_path) + + # Verify the file was downloaded and has content + assert os.path.exists(dest_path) + with open(dest_path, 'r') as f: + content = f.read() + assert len(content) > 0 + # LICENSE file should contain "MIT" + assert "MIT" in content or "License" in content + finally: + if os.path.exists(dest_path): + os.unlink(dest_path) + + def test_download_file_invalid_url(self): + """Test file download with invalid URL""" + import requests + + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: + dest_path = f.name + + try: + with pytest.raises(requests.exceptions.RequestException): + download_file("http://invalid-url-that-does-not-exist-12345.com/file.txt", dest_path) + finally: + if os.path.exists(dest_path): + os.unlink(dest_path) + + +class TestSimulatePredictionsIouBinary: + """Tests for simulate_predictions_iou_binary function""" + + def test_simulate_predictions_perfect_iou(self): + """Test with perfect IOU (1.0)""" + labels = np.array([[[1, 1], [0, 0]], [[1, 0], [1, 1]]]) + result = simulate_predictions_iou_binary(labels, 1.0) + + # With IOU 1.0, all positive labels should remain positive + assert result.shape == labels.shape + assert np.all((result > 0) == (labels > 0)) + + def test_simulate_predictions_zero_iou(self): + """Test with zero IOU""" + labels = np.array([[[1, 1], [0, 0]], [[1, 0], [1, 1]]]) + result = simulate_predictions_iou_binary(labels, 0.0) + + # With IOU 0.0, all positive labels should become 0 + assert result.shape == labels.shape + assert np.all(result == 0) + + def test_simulate_predictions_partial_iou(self): + """Test with partial IOU (0.5)""" + np.random.seed(42) # For reproducibility + labels = np.ones((10, 10, 10)) + result = simulate_predictions_iou_binary(labels, 0.5) + + assert result.shape == labels.shape + # Approximately half should remain positive + positive_ratio = np.sum(result > 0) / result.size + assert 0.3 < positive_ratio < 0.7 # Allow some variance + + +class TestSimulatePredictionsAccuracy: + """Tests for simulate_predictions_accuracy function""" + + def test_simulate_predictions_perfect_accuracy(self): + """Test with perfect accuracy (1.0)""" + np.random.seed(42) + true_labels = np.array([[[1, 1], [0, 0]], [[1, 0], [1, 1]]]) + result = simulate_predictions_accuracy(true_labels, 1.0) + + # With accuracy 1.0, result should match input (after relabeling) + assert result.shape == true_labels.shape + # All pixels should be classified correctly (binary) + assert np.sum((result > 0) == (true_labels > 0)) == result.size + + def test_simulate_predictions_zero_accuracy(self): + """Test with zero accuracy""" + np.random.seed(42) + true_labels = np.ones((5, 5, 5), dtype=int) + result = simulate_predictions_accuracy(true_labels, 0.0) + + # With accuracy 0.0, all labels should be flipped + assert result.shape == true_labels.shape + + def test_simulate_predictions_partial_accuracy(self): + """Test with partial accuracy (0.8)""" + np.random.seed(42) + true_labels = np.random.randint(0, 2, size=(10, 10, 10)) + result = simulate_predictions_accuracy(true_labels, 0.8) + + assert result.shape == true_labels.shape + # Result should have some instances labeled + assert len(np.unique(result)) >= 1 diff --git a/tests/train_config.py b/tests/train_config.py index 3d627b3..881f131 100644 --- a/tests/train_config.py +++ b/tests/train_config.py @@ -2,21 +2,21 @@ from upath import UPath import torch -from cellmap_segmentation_challenge.models import ResNet, UNet_2D +from cellmap_segmentation_challenge.models import UNet_2D # %% Set hyperparameters and other configurations learning_rate = 0.0001 # learning rate for the optimizer batch_size = 1 # batch size for the dataloader input_array_info = { - "shape": (1, 64, 64), - "scale": (8, 8, 8), + "shape": (64, 64), + "scale": (8, 8), } # shape and voxel size of the data to load for the input target_array_info = { - "shape": (1, 64, 64), - "scale": (8, 8, 8), + "shape": (64, 64), + "scale": (8, 8), } # shape and voxel size of the data to load for the target epochs = 1 # number of epochs to train the model for -iterations_per_epoch = 3 # number of iterations per epoch +iterations_per_epoch = 2 # number of iterations per epoch random_seed = 42 # random seed for reproducibility classes = ["mito", "er"] # list of classes to segment