|
14 | 14 | import os |
15 | 15 | import shutil |
16 | 16 | import time |
| 17 | +import glob |
17 | 18 | from contextlib import suppress |
18 | 19 | from tempfile import NamedTemporaryFile |
19 | 20 | from typing import Union |
@@ -272,6 +273,11 @@ def transpose_dimensions(img, dim_order): |
272 | 273 | "max": 200, |
273 | 274 | "description": "Z-batching for ConvPaint auto-mask generation (0 disables batching).", |
274 | 275 | }, |
| 276 | + "auto_load_saved_interleaved_settings": { |
| 277 | + "type": bool, |
| 278 | + "default": True, |
| 279 | + "description": "When restarting interleaved distributed+ConvPaint runs, automatically reuse the most recent cached run settings for the same source and channel.", |
| 280 | + }, |
275 | 281 | }, |
276 | 282 | ) |
277 | 283 | def cellpose_segmentation( |
@@ -299,6 +305,7 @@ def cellpose_segmentation( |
299 | 305 | convpaint_use_cpu: bool = False, |
300 | 306 | convpaint_force_dedicated_env: bool = False, |
301 | 307 | convpaint_z_batch_size: int = 0, |
| 308 | + auto_load_saved_interleaved_settings: bool = True, |
302 | 309 | timepoint_start: int = 0, |
303 | 310 | timepoint_end: int = -1, |
304 | 311 | timepoint_step: int = 1, |
@@ -939,6 +946,143 @@ def _prepare_runtime_distributed_mask( |
939 | 946 | source_base = os.path.splitext( |
940 | 947 | os.path.basename(_source_filepath) |
941 | 948 | )[0] |
| 949 | + channel_tag = str(channel).replace(os.sep, "_") |
| 950 | + |
| 951 | + if auto_load_saved_interleaved_settings: |
| 952 | + settings_glob = os.path.join( |
| 953 | + tmp_root, |
| 954 | + "cellpose_timepoint_cache", |
| 955 | + f"{source_base}_interleaved_ch{channel_tag}_*", |
| 956 | + "run_settings.json", |
| 957 | + ) |
| 958 | + settings_candidates = [ |
| 959 | + p for p in glob.glob(settings_glob) if os.path.isfile(p) |
| 960 | + ] |
| 961 | + if settings_candidates: |
| 962 | + latest_settings = max( |
| 963 | + settings_candidates, key=os.path.getmtime |
| 964 | + ) |
| 965 | + try: |
| 966 | + with open(latest_settings, encoding="utf-8") as f: |
| 967 | + loaded = json.load(f) |
| 968 | + |
| 969 | + loaded_signature = loaded.get("run_signature", {}) |
| 970 | + if not isinstance(loaded_signature, dict): |
| 971 | + loaded_signature = {} |
| 972 | + |
| 973 | + same_source = ( |
| 974 | + os.path.abspath(_source_filepath) |
| 975 | + == loaded_signature.get("source") |
| 976 | + ) |
| 977 | + same_channel = str(channel) == str( |
| 978 | + loaded_signature.get("channel", channel) |
| 979 | + ) |
| 980 | + |
| 981 | + if same_source and same_channel: |
| 982 | + distributed_blocksize_yx = int( |
| 983 | + loaded_signature.get( |
| 984 | + "distributed_blocksize_yx", |
| 985 | + distributed_blocksize_yx, |
| 986 | + ) |
| 987 | + ) |
| 988 | + flow_threshold = float( |
| 989 | + loaded_signature.get( |
| 990 | + "flow_threshold", flow_threshold |
| 991 | + ) |
| 992 | + ) |
| 993 | + cellprob_threshold = float( |
| 994 | + loaded_signature.get( |
| 995 | + "cellprob_threshold", |
| 996 | + cellprob_threshold, |
| 997 | + ) |
| 998 | + ) |
| 999 | + anisotropy = loaded_signature.get( |
| 1000 | + "anisotropy", anisotropy |
| 1001 | + ) |
| 1002 | + if anisotropy is not None: |
| 1003 | + anisotropy = float(anisotropy) |
| 1004 | + flow3D_smooth = int( |
| 1005 | + loaded_signature.get( |
| 1006 | + "flow3D_smooth", flow3D_smooth |
| 1007 | + ) |
| 1008 | + ) |
| 1009 | + tile_norm_blocksize = int( |
| 1010 | + loaded_signature.get( |
| 1011 | + "tile_norm_blocksize", |
| 1012 | + tile_norm_blocksize, |
| 1013 | + ) |
| 1014 | + ) |
| 1015 | + batch_size = int( |
| 1016 | + loaded_signature.get("batch_size", batch_size) |
| 1017 | + ) |
| 1018 | + diameter = float( |
| 1019 | + loaded_signature.get("diameter", diameter) |
| 1020 | + ) |
| 1021 | + convpaint_model_path = str( |
| 1022 | + loaded_signature.get( |
| 1023 | + "convpaint_model_path", |
| 1024 | + convpaint_model_path, |
| 1025 | + ) |
| 1026 | + ) |
| 1027 | + convpaint_image_downsample = int( |
| 1028 | + loaded_signature.get( |
| 1029 | + "convpaint_image_downsample", |
| 1030 | + convpaint_image_downsample, |
| 1031 | + ) |
| 1032 | + ) |
| 1033 | + convpaint_background_label = int( |
| 1034 | + loaded_signature.get( |
| 1035 | + "convpaint_background_label", |
| 1036 | + convpaint_background_label, |
| 1037 | + ) |
| 1038 | + ) |
| 1039 | + convpaint_mask_dilation = int( |
| 1040 | + loaded_signature.get( |
| 1041 | + "convpaint_mask_dilation", |
| 1042 | + convpaint_mask_dilation, |
| 1043 | + ) |
| 1044 | + ) |
| 1045 | + convpaint_min_object_fraction_of_median = float( |
| 1046 | + loaded_signature.get( |
| 1047 | + "convpaint_min_object_fraction_of_median", |
| 1048 | + convpaint_min_object_fraction_of_median, |
| 1049 | + ) |
| 1050 | + ) |
| 1051 | + clip_final_labels_to_convpaint_mask = bool( |
| 1052 | + loaded_signature.get( |
| 1053 | + "clip_final_labels_to_convpaint_mask", |
| 1054 | + clip_final_labels_to_convpaint_mask, |
| 1055 | + ) |
| 1056 | + ) |
| 1057 | + timepoint_start = int( |
| 1058 | + loaded_signature.get( |
| 1059 | + "timepoint_start", timepoint_start |
| 1060 | + ) |
| 1061 | + ) |
| 1062 | + timepoint_end = int( |
| 1063 | + loaded_signature.get( |
| 1064 | + "timepoint_end", timepoint_end |
| 1065 | + ) |
| 1066 | + ) |
| 1067 | + timepoint_step = int( |
| 1068 | + loaded_signature.get( |
| 1069 | + "timepoint_step", timepoint_step |
| 1070 | + ) |
| 1071 | + ) |
| 1072 | + print( |
| 1073 | + "Auto-loaded interleaved run settings from: " |
| 1074 | + f"{latest_settings}" |
| 1075 | + ) |
| 1076 | + else: |
| 1077 | + print( |
| 1078 | + "Ignoring cached run settings due to source/channel mismatch: " |
| 1079 | + f"{latest_settings}" |
| 1080 | + ) |
| 1081 | + except Exception as exc: |
| 1082 | + print( |
| 1083 | + "Warning: failed to load cached interleaved settings " |
| 1084 | + f"from {latest_settings} ({exc})" |
| 1085 | + ) |
942 | 1086 | print( |
943 | 1087 | "Distributed+ConvPaint interleaved mode: " |
944 | 1088 | f"processing {selected_count} selected timepoints " |
@@ -980,7 +1124,6 @@ def _prepare_runtime_distributed_mask( |
980 | 1124 |
|
981 | 1125 | # Persist interleaved slab results so failed runs can resume from |
982 | 1126 | # the last completed slab on restart. |
983 | | - channel_tag = str(channel).replace(os.sep, "_") |
984 | 1127 | checkpoint_path = os.path.join( |
985 | 1128 | tmp_root, |
986 | 1129 | f"{source_base}_cellpose_interleaved_ch{channel_tag}.zarr", |
|
0 commit comments