Skip to content

Commit c3d5345

Browse files
committed
cellpose: fix distributed resume + output skip logic
- Mount-path alias matching: source paths from saved run_signatures are now compared by basename fallback (_source_signature_matches) so that /media/... and /run/media/... mount prefixes don't prevent resume. _run_signatures_compatible strips the source key before comparing the rest of the signature, then delegates source comparison to the same basename-fallback matcher. - Auto-zarr cache key stabilisation: simplified cache key to (source basename + channel) only; bumped auto_zarr_cache_version to 3 so stale wide-keyed caches are ignored on first run. - OME-Zarr group root in legacy auto-zarr reuse probe: the root of an OME-Zarr group has no .shape; code now opens root and accesses ['s0'] before reading shape/dtype, so the probe no longer crashes silently. - Legacy hashed output detection: _legacy_output_candidates() globs {source_base}*{suffix}{ext} to find outputs written under the old hashed-filename scheme. Early-skip checks in both the non-interleaved and interleaved paths now cover these legacy names, preventing redundant re-processing when valid output already exists. - ConvPaint model path normalisation from saved settings: _normalize_loaded_path() tries the saved path as-is, then retries with /media/ <-> /run/media/ prefix swap. Applied when restoring convpaint_model_path from run_settings.json so that a path written under one mount alias resolves correctly under the other.
1 parent 290f193 commit c3d5345

1 file changed

Lines changed: 243 additions & 38 deletions

File tree

src/napari_tmidas/processing_functions/cellpose_segmentation.py

Lines changed: 243 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -367,59 +367,191 @@ def cellpose_segmentation(
367367
f"source_path={_source_filepath}"
368368
)
369369

370+
original_source_filepath = _source_filepath
371+
372+
def _source_signature_id(path: Union[str, None] = None) -> str:
373+
"""Return a stable source identifier robust to mount-path aliases."""
374+
value = _source_filepath if path is None else path
375+
if not value:
376+
return ""
377+
return os.path.basename(os.path.normpath(str(value)))
378+
379+
def _source_signature_matches(saved_source: Union[str, None]) -> bool:
380+
"""Match sources across absolute path, realpath, and basename fallback."""
381+
if not saved_source or not _source_filepath:
382+
return False
383+
384+
current = str(_source_filepath)
385+
saved = str(saved_source)
386+
387+
if current == saved:
388+
return True
389+
390+
try:
391+
if os.path.abspath(current) == os.path.abspath(saved):
392+
return True
393+
except Exception:
394+
pass
395+
396+
try:
397+
if os.path.realpath(current) == os.path.realpath(saved):
398+
return True
399+
except Exception:
400+
pass
401+
402+
return _source_signature_id(current) == _source_signature_id(saved)
403+
404+
def _run_signatures_compatible(
405+
existing_signature_json: str, current_signature_json: str
406+
) -> bool:
407+
"""Allow resume when only source path prefix differs across mounts."""
408+
if existing_signature_json == current_signature_json:
409+
return True
410+
411+
try:
412+
existing_signature = json.loads(existing_signature_json)
413+
current_signature = json.loads(current_signature_json)
414+
except Exception:
415+
return False
416+
417+
if not isinstance(existing_signature, dict) or not isinstance(
418+
current_signature, dict
419+
):
420+
return False
421+
422+
existing_source = existing_signature.pop("source", None)
423+
current_source = current_signature.pop("source", None)
424+
425+
if existing_signature != current_signature:
426+
return False
427+
428+
return _source_signature_matches(
429+
current_source if current_source is not None else existing_source
430+
)
431+
432+
def _normalize_loaded_path(path: str, fallback: str = "") -> str:
433+
"""Return path if it exists, else try swapping /media/ <-> /run/media/.
434+
435+
Handles the case where a saved run_settings.json was written with a
436+
different mount-path prefix (e.g. /media/... vs /run/media/...) than
437+
the one currently active. Falls back to *fallback* when neither
438+
variant resolves.
439+
"""
440+
if not path:
441+
return fallback
442+
if os.path.exists(path):
443+
return path
444+
swaps = [
445+
("/run/media/", "/media/"),
446+
("/media/", "/run/media/"),
447+
]
448+
for old_prefix, new_prefix in swaps:
449+
if path.startswith(old_prefix):
450+
candidate = new_prefix + path[len(old_prefix):]
451+
if os.path.exists(candidate):
452+
return candidate
453+
# Neither variant exists; return the saved path as-is (caller decides)
454+
return path
455+
370456
def _direct_output_path() -> Union[str, None]:
371-
if not (_output_folder and _output_suffix and _source_filepath):
457+
source_for_output = original_source_filepath or _source_filepath
458+
if not (_output_folder and _output_suffix and source_for_output):
372459
return None
373460

374461
source_base = os.path.splitext(
375-
os.path.basename(_source_filepath)
462+
os.path.basename(source_for_output)
376463
)[0]
377464
output_ext = ".zarr" if _output_format == "zarr" else ".tif"
378465
return os.path.join(
379466
_output_folder,
380467
f"{source_base}{_output_suffix}{output_ext}",
381468
)
382469

383-
def _write_interleaved_checkpoint_output(
384-
checkpoint: zarr.Array, checkpoint_path: str
385-
) -> Union[str, None]:
386-
output_path = _direct_output_path()
387-
if not output_path:
388-
return None
470+
def _legacy_output_candidates() -> list:
471+
"""Find legacy output names for the same original source basename."""
472+
source_for_output = original_source_filepath or _source_filepath
473+
if not (_output_folder and _output_suffix and source_for_output):
474+
return []
475+
476+
source_base = os.path.splitext(os.path.basename(source_for_output))[0]
477+
output_ext = ".zarr" if _output_format == "zarr" else ".tif"
478+
pattern = os.path.join(
479+
_output_folder,
480+
f"{source_base}*{_output_suffix}{output_ext}",
481+
)
482+
candidates = [
483+
p
484+
for p in glob.glob(pattern)
485+
if os.path.abspath(p) != os.path.abspath(_direct_output_path() or "")
486+
]
487+
return sorted(candidates, key=os.path.getmtime, reverse=True)
389488

390-
def _existing_tiff_is_valid(path: str) -> bool:
489+
def _existing_output_is_valid(
490+
path: str, expected_shape: Union[tuple, None] = None
491+
) -> bool:
492+
"""Check whether an existing output looks complete for safe reuse."""
493+
output_format = str(_output_format).lower()
494+
495+
if output_format in {"tif", "tiff"}:
496+
if not os.path.isfile(path):
497+
return False
391498
try:
392-
expected_shape = tuple(int(s) for s in checkpoint.shape)
393499
with tifffile.TiffFile(path) as tif:
394500
series = tif.series[0] if tif.series else None
395501
if series is None:
396502
return False
397503
actual_shape = tuple(int(s) for s in series.shape)
398504
actual_dtype = np.dtype(series.dtype)
399-
expected_dtype = np.dtype(np.uint32)
400-
return (
401-
actual_shape == expected_shape
402-
and actual_dtype == expected_dtype
403-
)
505+
if expected_shape is not None:
506+
expected_shape = tuple(int(s) for s in expected_shape)
507+
if actual_shape != expected_shape:
508+
return False
509+
return actual_dtype == np.dtype(np.uint32)
404510
except Exception:
405511
return False
406512

513+
if output_format == "zarr":
514+
# For zarr outputs, existence is a practical completion signal.
515+
return os.path.isdir(path)
516+
517+
return os.path.exists(path)
518+
519+
def _write_interleaved_checkpoint_output(
520+
checkpoint: zarr.Array, checkpoint_path: str
521+
) -> Union[str, None]:
522+
output_path = _direct_output_path()
523+
if not output_path:
524+
return None
525+
407526
if (
408527
skip_overwrite_existing_valid_output
409-
and _output_format.lower() in {"tif", "tiff"}
410-
and os.path.isfile(output_path)
411-
and _existing_tiff_is_valid(output_path)
528+
and _existing_output_is_valid(
529+
output_path,
530+
expected_shape=tuple(int(s) for s in checkpoint.shape),
531+
)
412532
):
413533
print(
414-
"Skipping output write: existing TIFF appears complete and valid -> "
534+
"Skipping output write: existing output appears complete and valid -> "
415535
f"{output_path}"
416536
)
417537
return output_path
418538

539+
if skip_overwrite_existing_valid_output:
540+
for legacy_output in _legacy_output_candidates():
541+
if _existing_output_is_valid(
542+
legacy_output,
543+
expected_shape=tuple(int(s) for s in checkpoint.shape),
544+
):
545+
print(
546+
"Skipping output write: existing legacy output appears "
547+
f"complete and valid -> {legacy_output}"
548+
)
549+
return legacy_output
550+
419551
os.makedirs(_output_folder, exist_ok=True)
420552
write_labels_with_source_metadata(
421553
labels=checkpoint,
422-
source_path=_source_filepath,
554+
source_path=original_source_filepath or _source_filepath,
423555
output_path=output_path,
424556
output_format=_output_format,
425557
dim_order=dim_order,
@@ -456,12 +588,10 @@ def _existing_tiff_is_valid(path: str) -> bool:
456588
auto_root = os.path.join(tmp_root, "cellpose_auto_zarr")
457589
os.makedirs(auto_root, exist_ok=True)
458590

591+
channel_tag = str(channel).replace(os.sep, "_")
459592
cache_payload = {
460-
"source": source_abs,
461-
"source_mtime": os.path.getmtime(source_abs),
462-
"shape": tuple(int(s) for s in image.shape),
463-
"dtype": str(getattr(image, "dtype", "unknown")),
464-
"dim_order": str(dim_order),
593+
"auto_zarr_cache_version": 3,
594+
"source": _source_signature_id(source_abs),
465595
"channel": str(channel),
466596
}
467597
cache_key = hashlib.sha1(
@@ -470,9 +600,52 @@ def _existing_tiff_is_valid(path: str) -> bool:
470600

471601
source_base = os.path.splitext(os.path.basename(source_abs))[0]
472602
auto_zarr_path = os.path.join(
473-
auto_root, f"{source_base}_cellpose_{cache_key}.zarr"
603+
auto_root,
604+
f"{source_base}_cellpose_ch{channel_tag}_{cache_key}.zarr",
474605
)
475606

607+
if not os.path.exists(auto_zarr_path):
608+
expected_shape = tuple(int(s) for s in image.shape)
609+
expected_dtype = str(getattr(image, "dtype", "unknown"))
610+
legacy_glob = os.path.join(
611+
auto_root,
612+
f"{source_base}_cellpose*.zarr",
613+
)
614+
legacy_candidates = sorted(
615+
[
616+
p
617+
for p in glob.glob(legacy_glob)
618+
if os.path.isdir(p) and p != auto_zarr_path
619+
],
620+
key=os.path.getmtime,
621+
reverse=True,
622+
)
623+
624+
for legacy_path in legacy_candidates:
625+
try:
626+
legacy_root = zarr.open(legacy_path, mode="r")
627+
legacy_arr = (
628+
legacy_root["s0"]
629+
if hasattr(legacy_root, "__contains__")
630+
and "s0" in legacy_root
631+
else legacy_root
632+
)
633+
legacy_shape = tuple(int(s) for s in legacy_arr.shape)
634+
legacy_dtype = str(np.dtype(legacy_arr.dtype))
635+
if (
636+
legacy_shape == expected_shape
637+
and legacy_dtype == expected_dtype
638+
):
639+
auto_zarr_path = legacy_path
640+
print(
641+
"Distributed segmentation: reusing compatible "
642+
"cached auto-converted zarr: "
643+
f"{auto_zarr_path}"
644+
)
645+
break
646+
except Exception:
647+
continue
648+
476649
if not os.path.exists(auto_zarr_path):
477650
axes = str(dim_order).upper() if dim_order else ""
478651
if len(axes) != image.ndim:
@@ -1005,9 +1178,8 @@ def _prepare_runtime_distributed_mask(
10051178
if not isinstance(loaded_signature, dict):
10061179
loaded_signature = {}
10071180

1008-
same_source = (
1009-
os.path.abspath(_source_filepath)
1010-
== loaded_signature.get("source")
1181+
same_source = _source_signature_matches(
1182+
loaded_signature.get("source")
10111183
)
10121184
same_channel = str(channel) == str(
10131185
loaded_signature.get("channel", channel)
@@ -1053,11 +1225,14 @@ def _prepare_runtime_distributed_mask(
10531225
diameter = float(
10541226
loaded_signature.get("diameter", diameter)
10551227
)
1056-
convpaint_model_path = str(
1057-
loaded_signature.get(
1058-
"convpaint_model_path",
1059-
convpaint_model_path,
1060-
)
1228+
convpaint_model_path = _normalize_loaded_path(
1229+
str(
1230+
loaded_signature.get(
1231+
"convpaint_model_path",
1232+
convpaint_model_path,
1233+
)
1234+
),
1235+
fallback=convpaint_model_path,
10611236
)
10621237
convpaint_image_downsample = int(
10631238
loaded_signature.get(
@@ -1126,7 +1301,7 @@ def _prepare_runtime_distributed_mask(
11261301

11271302
# Cache ConvPaint masks on disk so reruns can reuse them.
11281303
mask_cache_key_payload = {
1129-
"source": os.path.abspath(_source_filepath),
1304+
"source": _source_signature_id(_source_filepath),
11301305
"source_mtime": os.path.getmtime(_source_filepath),
11311306
"mask_cache_version": 3,
11321307
"channel": channel,
@@ -1166,13 +1341,41 @@ def _prepare_runtime_distributed_mask(
11661341

11671342
slab_shape = tuple(int(s) for s in image.shape[1:])
11681343
checkpoint_shape = (selected_count, *slab_shape)
1344+
1345+
preexisting_output = _direct_output_path()
1346+
if (
1347+
skip_overwrite_existing_valid_output
1348+
and preexisting_output
1349+
and _existing_output_is_valid(
1350+
preexisting_output,
1351+
expected_shape=checkpoint_shape,
1352+
)
1353+
):
1354+
print(
1355+
"Skipping interleaved processing: existing output appears "
1356+
f"complete and valid -> {preexisting_output}"
1357+
)
1358+
return preexisting_output
1359+
1360+
if skip_overwrite_existing_valid_output:
1361+
for legacy_output in _legacy_output_candidates():
1362+
if _existing_output_is_valid(
1363+
legacy_output,
1364+
expected_shape=checkpoint_shape,
1365+
):
1366+
print(
1367+
"Skipping interleaved processing: existing legacy output "
1368+
f"appears complete and valid -> {legacy_output}"
1369+
)
1370+
return legacy_output
1371+
11691372
z_chunk = max(1, min(16, slab_shape[0]))
11701373
y_chunk = max(1, min(512, slab_shape[1]))
11711374
x_chunk = max(1, min(512, slab_shape[2]))
11721375
checkpoint_chunks = (1, z_chunk, y_chunk, x_chunk)
11731376

11741377
run_signature = {
1175-
"source": os.path.abspath(_source_filepath),
1378+
"source": _source_signature_id(_source_filepath),
11761379
"channel": channel,
11771380
"distributed_blocksize_yx": int(distributed_blocksize_yx),
11781381
"flow_threshold": float(flow_threshold),
@@ -1258,7 +1461,9 @@ def _prepare_runtime_distributed_mask(
12581461
existing_sig = existing.attrs.get("run_signature", "")
12591462
if (
12601463
tuple(existing.shape) != checkpoint_shape
1261-
or existing_sig != run_signature_json
1464+
or not _run_signatures_compatible(
1465+
existing_sig, run_signature_json
1466+
)
12621467
):
12631468
print(
12641469
"Interleaved checkpoint exists but is incompatible; "
@@ -1824,7 +2029,7 @@ def _prepare_runtime_distributed_mask(
18242029
# Cache key intentionally excludes selected_timepoints so different
18252030
# intervals can reuse already completed per-timepoint outputs.
18262031
tp_cache_signature = {
1827-
"source": os.path.abspath(_source_filepath),
2032+
"source": _source_signature_id(_source_filepath),
18282033
"source_mtime": os.path.getmtime(_source_filepath),
18292034
"channel": channel,
18302035
"flow_threshold": float(flow_threshold),

0 commit comments

Comments
 (0)