Skip to content

Commit 35f36f1

Browse files
committed
Fix Cellpose resume export and auto-load cached settings
1 parent dd1d146 commit 35f36f1

4 files changed

Lines changed: 299 additions & 5 deletions

File tree

src/napari_tmidas/_file_selector.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import concurrent.futures
16+
import glob
1617
import inspect
1718
import json
1819
import os
@@ -2003,6 +2004,44 @@ def get_parameter_values(self) -> Dict[str, Any]:
20032004

20042005
return values
20052006

2007+
def apply_parameter_values(self, values: Dict[str, Any]):
2008+
"""Apply parameter values to existing widgets when keys are present."""
2009+
for param_name, raw_value in (values or {}).items():
2010+
if param_name not in self.param_widgets:
2011+
continue
2012+
2013+
widget = self.param_widgets[param_name]
2014+
param_info = self.parameters.get(param_name, {})
2015+
param_type = param_info.get("type", str)
2016+
widget_type = param_info.get("widget_type")
2017+
2018+
try:
2019+
if isinstance(widget, QComboBox):
2020+
if widget_type == "channel_selector":
2021+
target = "all" if raw_value in ("all", None) else str(raw_value)
2022+
idx = widget.findData(target)
2023+
if idx < 0:
2024+
idx = widget.findData(raw_value)
2025+
if idx >= 0:
2026+
widget.setCurrentIndex(idx)
2027+
else:
2028+
idx = widget.findText(str(raw_value))
2029+
if idx >= 0:
2030+
widget.setCurrentIndex(idx)
2031+
elif isinstance(widget, QCheckBox):
2032+
widget.setChecked(bool(raw_value))
2033+
elif isinstance(widget, QSpinBox):
2034+
widget.setValue(int(raw_value))
2035+
elif isinstance(widget, QDoubleSpinBox):
2036+
widget.setValue(float(raw_value))
2037+
elif isinstance(widget, QLineEdit):
2038+
widget.setText(
2039+
"" if raw_value is None else str(param_type(raw_value))
2040+
)
2041+
except Exception:
2042+
# Ignore invalid values and keep current/default widget value.
2043+
continue
2044+
20062045
def update_channel_selector(self, file_list: List[str]):
20072046
"""
20082047
Update the channel selector widget based on the loaded files.
@@ -3276,6 +3315,8 @@ def update_function_info(self, function_name: str):
32763315
# Update channel selector if the function has a channel parameter
32773316
if any(p.get("widget_type") == "channel_selector" for p in parameters.values()):
32783317
self.param_widget_instance.update_channel_selector(self.file_list)
3318+
# Auto-load cached Cellpose interleaved settings into UI controls.
3319+
self._maybe_apply_cached_cellpose_settings(function_name)
32793320
# Initial update of thread count based on current use_cpu value
32803321
if "use_cpu" in parameters:
32813322
use_cpu_value = parameters["use_cpu"].get("default", False)
@@ -3290,6 +3331,74 @@ def update_function_info(self, function_name: str):
32903331
self.param_widget_instance
32913332
)
32923333

3334+
def _maybe_apply_cached_cellpose_settings(self, function_name: str):
3335+
"""Populate Cellpose parameters from the most recent cached run settings."""
3336+
if "cellpose" not in str(function_name).lower():
3337+
return
3338+
if not getattr(self, "file_list", None):
3339+
return
3340+
if not hasattr(self, "param_widget_instance") or not hasattr(
3341+
self.param_widget_instance, "apply_parameter_values"
3342+
):
3343+
return
3344+
3345+
first_file = self.file_list[0]
3346+
source_base = os.path.splitext(os.path.basename(first_file))[0]
3347+
source_parent = os.path.dirname(os.path.abspath(first_file))
3348+
3349+
patterns = [
3350+
os.path.join(
3351+
source_parent,
3352+
"tmp",
3353+
"cellpose_timepoint_cache",
3354+
f"{source_base}*_interleaved_ch*_*/run_settings.json",
3355+
),
3356+
os.path.join(
3357+
source_parent,
3358+
"tmp",
3359+
"cellpose_auto_zarr",
3360+
"tmp",
3361+
"cellpose_timepoint_cache",
3362+
f"{source_base}*_interleaved_ch*_*/run_settings.json",
3363+
),
3364+
]
3365+
3366+
settings_candidates = []
3367+
for pattern in patterns:
3368+
settings_candidates.extend(
3369+
p for p in glob.glob(pattern) if os.path.isfile(p)
3370+
)
3371+
3372+
if not settings_candidates:
3373+
return
3374+
3375+
latest_settings = max(settings_candidates, key=os.path.getmtime)
3376+
try:
3377+
with open(latest_settings, encoding="utf-8") as f:
3378+
loaded = json.load(f)
3379+
3380+
loaded_signature = loaded.get("run_signature", {})
3381+
if not isinstance(loaded_signature, dict):
3382+
return
3383+
3384+
# Map saved run signature keys to current widget parameters.
3385+
widget_values = {
3386+
k: v
3387+
for k, v in loaded_signature.items()
3388+
if k in getattr(self.param_widget_instance, "parameters", {})
3389+
}
3390+
if widget_values:
3391+
self.param_widget_instance.apply_parameter_values(widget_values)
3392+
print(
3393+
"Loaded cached Cellpose parameters into widget from: "
3394+
f"{latest_settings}"
3395+
)
3396+
except Exception as exc:
3397+
print(
3398+
"Warning: could not load cached Cellpose settings into widget "
3399+
f"({exc})"
3400+
)
3401+
32933402
def _update_thread_count_for_gpu(self, use_cpu: bool):
32943403
"""
32953404
Update thread count widget based on use_cpu parameter.

src/napari_tmidas/_tests/test_ome_output_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import tifffile
6+
import zarr
67

78
from napari_tmidas.processing_functions.ome_output_utils import (
89
write_labels_with_source_metadata,
@@ -147,3 +148,37 @@ def _failing_imwrite(path, *args, **kwargs):
147148

148149
assert not output_path.exists()
149150
assert not list(tmp_path.glob("*.tmp-*"))
151+
152+
153+
def test_write_labels_with_source_metadata_streams_zarr_array_to_ome_tiff(
154+
tmp_path,
155+
):
156+
labels_path = tmp_path / "labels_cache.zarr"
157+
labels = zarr.open_array(
158+
str(labels_path),
159+
mode="w",
160+
shape=(2, 3, 16, 16),
161+
chunks=(1, 1, 16, 16),
162+
dtype=np.uint32,
163+
)
164+
labels[:] = np.arange(2 * 3 * 16 * 16, dtype=np.uint32).reshape(
165+
2, 3, 16, 16
166+
)
167+
168+
output_path = tmp_path / "labels_streamed.ome.tif"
169+
returned = write_labels_with_source_metadata(
170+
labels=labels,
171+
source_path=None,
172+
output_path=str(output_path),
173+
output_format="tiff",
174+
dim_order="TZYX",
175+
)
176+
177+
assert returned == str(output_path)
178+
assert output_path.exists()
179+
180+
with tifffile.TiffFile(output_path) as tif:
181+
assert tif.is_ome
182+
arr = tif.asarray()
183+
assert arr.dtype == np.uint32
184+
assert arr.shape == (2, 3, 16, 16)

src/napari_tmidas/processing_functions/cellpose_segmentation.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import shutil
1616
import time
17+
import glob
1718
from contextlib import suppress
1819
from tempfile import NamedTemporaryFile
1920
from typing import Union
@@ -272,6 +273,11 @@ def transpose_dimensions(img, dim_order):
272273
"max": 200,
273274
"description": "Z-batching for ConvPaint auto-mask generation (0 disables batching).",
274275
},
276+
"auto_load_saved_interleaved_settings": {
277+
"type": bool,
278+
"default": True,
279+
"description": "When restarting interleaved distributed+ConvPaint runs, automatically reuse the most recent cached run settings for the same source and channel.",
280+
},
275281
},
276282
)
277283
def cellpose_segmentation(
@@ -299,6 +305,7 @@ def cellpose_segmentation(
299305
convpaint_use_cpu: bool = False,
300306
convpaint_force_dedicated_env: bool = False,
301307
convpaint_z_batch_size: int = 0,
308+
auto_load_saved_interleaved_settings: bool = True,
302309
timepoint_start: int = 0,
303310
timepoint_end: int = -1,
304311
timepoint_step: int = 1,
@@ -939,6 +946,143 @@ def _prepare_runtime_distributed_mask(
939946
source_base = os.path.splitext(
940947
os.path.basename(_source_filepath)
941948
)[0]
949+
channel_tag = str(channel).replace(os.sep, "_")
950+
951+
if auto_load_saved_interleaved_settings:
952+
settings_glob = os.path.join(
953+
tmp_root,
954+
"cellpose_timepoint_cache",
955+
f"{source_base}_interleaved_ch{channel_tag}_*",
956+
"run_settings.json",
957+
)
958+
settings_candidates = [
959+
p for p in glob.glob(settings_glob) if os.path.isfile(p)
960+
]
961+
if settings_candidates:
962+
latest_settings = max(
963+
settings_candidates, key=os.path.getmtime
964+
)
965+
try:
966+
with open(latest_settings, encoding="utf-8") as f:
967+
loaded = json.load(f)
968+
969+
loaded_signature = loaded.get("run_signature", {})
970+
if not isinstance(loaded_signature, dict):
971+
loaded_signature = {}
972+
973+
same_source = (
974+
os.path.abspath(_source_filepath)
975+
== loaded_signature.get("source")
976+
)
977+
same_channel = str(channel) == str(
978+
loaded_signature.get("channel", channel)
979+
)
980+
981+
if same_source and same_channel:
982+
distributed_blocksize_yx = int(
983+
loaded_signature.get(
984+
"distributed_blocksize_yx",
985+
distributed_blocksize_yx,
986+
)
987+
)
988+
flow_threshold = float(
989+
loaded_signature.get(
990+
"flow_threshold", flow_threshold
991+
)
992+
)
993+
cellprob_threshold = float(
994+
loaded_signature.get(
995+
"cellprob_threshold",
996+
cellprob_threshold,
997+
)
998+
)
999+
anisotropy = loaded_signature.get(
1000+
"anisotropy", anisotropy
1001+
)
1002+
if anisotropy is not None:
1003+
anisotropy = float(anisotropy)
1004+
flow3D_smooth = int(
1005+
loaded_signature.get(
1006+
"flow3D_smooth", flow3D_smooth
1007+
)
1008+
)
1009+
tile_norm_blocksize = int(
1010+
loaded_signature.get(
1011+
"tile_norm_blocksize",
1012+
tile_norm_blocksize,
1013+
)
1014+
)
1015+
batch_size = int(
1016+
loaded_signature.get("batch_size", batch_size)
1017+
)
1018+
diameter = float(
1019+
loaded_signature.get("diameter", diameter)
1020+
)
1021+
convpaint_model_path = str(
1022+
loaded_signature.get(
1023+
"convpaint_model_path",
1024+
convpaint_model_path,
1025+
)
1026+
)
1027+
convpaint_image_downsample = int(
1028+
loaded_signature.get(
1029+
"convpaint_image_downsample",
1030+
convpaint_image_downsample,
1031+
)
1032+
)
1033+
convpaint_background_label = int(
1034+
loaded_signature.get(
1035+
"convpaint_background_label",
1036+
convpaint_background_label,
1037+
)
1038+
)
1039+
convpaint_mask_dilation = int(
1040+
loaded_signature.get(
1041+
"convpaint_mask_dilation",
1042+
convpaint_mask_dilation,
1043+
)
1044+
)
1045+
convpaint_min_object_fraction_of_median = float(
1046+
loaded_signature.get(
1047+
"convpaint_min_object_fraction_of_median",
1048+
convpaint_min_object_fraction_of_median,
1049+
)
1050+
)
1051+
clip_final_labels_to_convpaint_mask = bool(
1052+
loaded_signature.get(
1053+
"clip_final_labels_to_convpaint_mask",
1054+
clip_final_labels_to_convpaint_mask,
1055+
)
1056+
)
1057+
timepoint_start = int(
1058+
loaded_signature.get(
1059+
"timepoint_start", timepoint_start
1060+
)
1061+
)
1062+
timepoint_end = int(
1063+
loaded_signature.get(
1064+
"timepoint_end", timepoint_end
1065+
)
1066+
)
1067+
timepoint_step = int(
1068+
loaded_signature.get(
1069+
"timepoint_step", timepoint_step
1070+
)
1071+
)
1072+
print(
1073+
"Auto-loaded interleaved run settings from: "
1074+
f"{latest_settings}"
1075+
)
1076+
else:
1077+
print(
1078+
"Ignoring cached run settings due to source/channel mismatch: "
1079+
f"{latest_settings}"
1080+
)
1081+
except Exception as exc:
1082+
print(
1083+
"Warning: failed to load cached interleaved settings "
1084+
f"from {latest_settings} ({exc})"
1085+
)
9421086
print(
9431087
"Distributed+ConvPaint interleaved mode: "
9441088
f"processing {selected_count} selected timepoints "
@@ -980,7 +1124,6 @@ def _prepare_runtime_distributed_mask(
9801124

9811125
# Persist interleaved slab results so failed runs can resume from
9821126
# the last completed slab on restart.
983-
channel_tag = str(channel).replace(os.sep, "_")
9841127
checkpoint_path = os.path.join(
9851128
tmp_root,
9861129
f"{source_base}_cellpose_interleaved_ch{channel_tag}.zarr",

src/napari_tmidas/processing_functions/ome_output_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,20 @@ def write_labels_with_source_metadata(
177177
)
178178

179179
if can_stream:
180-
def _iter_slabs():
181-
for i in range(labels_shape[0]):
182-
yield np.asarray(labels[i], dtype=labels_dtype)
180+
# tifffile expects page-wise data chunks for multidimensional OME
181+
# writes; yield YX planes in C order across all leading indices.
182+
def _iter_planes():
183+
lead_shape = labels_shape[:-2]
184+
if not lead_shape:
185+
yield np.asarray(labels, dtype=labels_dtype)
186+
return
187+
188+
for lead_idx in np.ndindex(*lead_shape):
189+
yield np.asarray(labels[lead_idx], dtype=labels_dtype)
183190

184191
tifffile.imwrite(
185192
tmp_output_path,
186-
data=_iter_slabs(),
193+
data=_iter_planes(),
187194
shape=labels_shape,
188195
dtype=labels_dtype,
189196
ome=True,

0 commit comments

Comments
 (0)