Skip to content

Commit 852581b

Browse files
committed
Refactor Cellpose env-manager slab execution
1 parent 8fed5e2 commit 852581b

1 file changed

Lines changed: 58 additions & 54 deletions

File tree

src/napari_tmidas/processing_functions/cellpose_env_manager.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,36 @@ def slab_to_disk_zarr(source_5d_or_4d, index_slices, shape_zyx, chunks, tmp_dir)
649649
return dest, zarr_dir
650650
651651
652+
def run_zyx_slab(source_array, index_slices, shape_zyx, chunks, tmp_dir,
653+
name="", t_idx=None, c_idx=None):
654+
if _distributed_eval_fn is not None:
655+
slab_z, slab_dir = slab_to_disk_zarr(
656+
source_array,
657+
index_slices,
658+
shape_zyx,
659+
chunks,
660+
tmp_dir,
661+
)
662+
slab_mask = _select_mask_for_volume(
663+
_DISTRIBUTED_MASK,
664+
t_idx=t_idx,
665+
c_idx=c_idx,
666+
)
667+
try:
668+
return run_distributed_on_slab(
669+
slab_z,
670+
name,
671+
slab_mask=slab_mask,
672+
)
673+
finally:
674+
_shutil.rmtree(slab_dir, ignore_errors=True)
675+
676+
return process_volume(
677+
np.array(source_array[index_slices + (slice(None), slice(None), slice(None))]),
678+
name,
679+
)
680+
681+
652682
def process_volume(image, name=""):
653683
print(f"\\nProcessing {{name}}: shape={{image.shape}}, range={{np.min(image):.1f}}-{{np.max(image):.1f}}")
654684
sys.stdout.flush()
@@ -778,22 +808,16 @@ def main():
778808
for out_c, c in enumerate(channels_to_process):
779809
print(f"\\n=== T={{t+1}}/{{T}}, C={{c+1}}/{{C}} ===")
780810
sys.stdout.flush()
781-
if _distributed_eval_fn is not None:
782-
slab_z, slab_dir = slab_to_disk_zarr(
783-
zarr_array, (t, c), (Z, Y, X), chunk_3d, _tmp_slabs)
784-
slab_mask = _select_mask_for_volume(
785-
_DISTRIBUTED_MASK, t_idx=t, c_idx=c
786-
)
787-
try:
788-
masks = run_distributed_on_slab(
789-
slab_z,
790-
f"T{{t+1}}C{{c+1}}",
791-
slab_mask=slab_mask,
792-
)
793-
finally:
794-
_shutil.rmtree(slab_dir, ignore_errors=True)
795-
else:
796-
masks = process_volume(np.array(zarr_array[t, c, :, :, :]), f"T{{t+1}}C{{c+1}}")
811+
masks = run_zyx_slab(
812+
zarr_array,
813+
(t, c),
814+
(Z, Y, X),
815+
chunk_3d,
816+
_tmp_slabs,
817+
name=f"T{{t+1}}C{{c+1}}",
818+
t_idx=t,
819+
c_idx=c,
820+
)
797821
798822
if TIMEPOINT_INDEX is not None:
799823
if n_out_channels == 1:
@@ -861,25 +885,15 @@ def main():
861885
for out_i, i in enumerate(indices):
862886
print(f"\\n=== Timepoint {{i+1}}/{{dim1}} ===")
863887
sys.stdout.flush()
864-
if _distributed_eval_fn is not None:
865-
slab_z, slab_dir = slab_to_disk_zarr(
866-
zarr_array, (i,), (Z, Y, X), chunk_3d, _tmp_slabs
867-
)
868-
slab_mask = _select_mask_for_volume(
869-
_DISTRIBUTED_MASK, t_idx=i
870-
)
871-
try:
872-
masks = run_distributed_on_slab(
873-
slab_z,
874-
f"T{{i+1}}",
875-
slab_mask=slab_mask,
876-
)
877-
finally:
878-
_shutil.rmtree(slab_dir, ignore_errors=True)
879-
else:
880-
masks = process_volume(
881-
np.array(zarr_array[i, :, :, :]), f"T{{i+1}}"
882-
)
888+
masks = run_zyx_slab(
889+
zarr_array,
890+
(i,),
891+
(Z, Y, X),
892+
chunk_3d,
893+
_tmp_slabs,
894+
name=f"T{{i+1}}",
895+
t_idx=i,
896+
)
883897
884898
if TIMEPOINT_INDEX is not None:
885899
result[:, :, :] = masks
@@ -911,25 +925,15 @@ def main():
911925
for out_i, i in enumerate(indices):
912926
print(f"\\n=== Channel {{i+1}}/{{dim1}} ===")
913927
sys.stdout.flush()
914-
if _distributed_eval_fn is not None:
915-
slab_z, slab_dir = slab_to_disk_zarr(
916-
zarr_array, (i,), (Z, Y, X), chunk_3d, _tmp_slabs
917-
)
918-
slab_mask = _select_mask_for_volume(
919-
_DISTRIBUTED_MASK, c_idx=i
920-
)
921-
try:
922-
masks = run_distributed_on_slab(
923-
slab_z,
924-
f"C{{i+1}}",
925-
slab_mask=slab_mask,
926-
)
927-
finally:
928-
_shutil.rmtree(slab_dir, ignore_errors=True)
929-
else:
930-
masks = process_volume(
931-
np.array(zarr_array[i, :, :, :]), f"C{{i+1}}"
932-
)
928+
masks = run_zyx_slab(
929+
zarr_array,
930+
(i,),
931+
(Z, Y, X),
932+
chunk_3d,
933+
_tmp_slabs,
934+
name=f"C{{i+1}}",
935+
c_idx=i,
936+
)
933937
934938
if len(indices) == 1:
935939
result[:, :, :] = masks

0 commit comments

Comments
 (0)