@@ -649,6 +649,36 @@ def slab_to_disk_zarr(source_5d_or_4d, index_slices, shape_zyx, chunks, tmp_dir)
649649 return dest, zarr_dir
650650
651651
652+ def run_zyx_slab(source_array, index_slices, shape_zyx, chunks, tmp_dir,
653+ name="", t_idx=None, c_idx=None):
654+ if _distributed_eval_fn is not None:
655+ slab_z, slab_dir = slab_to_disk_zarr(
656+ source_array,
657+ index_slices,
658+ shape_zyx,
659+ chunks,
660+ tmp_dir,
661+ )
662+ slab_mask = _select_mask_for_volume(
663+ _DISTRIBUTED_MASK,
664+ t_idx=t_idx,
665+ c_idx=c_idx,
666+ )
667+ try:
668+ return run_distributed_on_slab(
669+ slab_z,
670+ name,
671+ slab_mask=slab_mask,
672+ )
673+ finally:
674+ _shutil.rmtree(slab_dir, ignore_errors=True)
675+
676+ return process_volume(
677+ np.array(source_array[index_slices + (slice(None), slice(None), slice(None))]),
678+ name,
679+ )
680+
681+
652682def process_volume(image, name=""):
653683 print(f"\\ nProcessing {{name}}: shape={{image.shape}}, range={{np.min(image):.1f}}-{{np.max(image):.1f}}")
654684 sys.stdout.flush()
@@ -778,22 +808,16 @@ def main():
778808 for out_c, c in enumerate(channels_to_process):
779809 print(f"\\ n=== T={{t+1}}/{{T}}, C={{c+1}}/{{C}} ===")
780810 sys.stdout.flush()
781- if _distributed_eval_fn is not None:
782- slab_z, slab_dir = slab_to_disk_zarr(
783- zarr_array, (t, c), (Z, Y, X), chunk_3d, _tmp_slabs)
784- slab_mask = _select_mask_for_volume(
785- _DISTRIBUTED_MASK, t_idx=t, c_idx=c
786- )
787- try:
788- masks = run_distributed_on_slab(
789- slab_z,
790- f"T{{t+1}}C{{c+1}}",
791- slab_mask=slab_mask,
792- )
793- finally:
794- _shutil.rmtree(slab_dir, ignore_errors=True)
795- else:
796- masks = process_volume(np.array(zarr_array[t, c, :, :, :]), f"T{{t+1}}C{{c+1}}")
811+ masks = run_zyx_slab(
812+ zarr_array,
813+ (t, c),
814+ (Z, Y, X),
815+ chunk_3d,
816+ _tmp_slabs,
817+ name=f"T{{t+1}}C{{c+1}}",
818+ t_idx=t,
819+ c_idx=c,
820+ )
797821
798822 if TIMEPOINT_INDEX is not None:
799823 if n_out_channels == 1:
@@ -861,25 +885,15 @@ def main():
861885 for out_i, i in enumerate(indices):
862886 print(f"\\ n=== Timepoint {{i+1}}/{{dim1}} ===")
863887 sys.stdout.flush()
864- if _distributed_eval_fn is not None:
865- slab_z, slab_dir = slab_to_disk_zarr(
866- zarr_array, (i,), (Z, Y, X), chunk_3d, _tmp_slabs
867- )
868- slab_mask = _select_mask_for_volume(
869- _DISTRIBUTED_MASK, t_idx=i
870- )
871- try:
872- masks = run_distributed_on_slab(
873- slab_z,
874- f"T{{i+1}}",
875- slab_mask=slab_mask,
876- )
877- finally:
878- _shutil.rmtree(slab_dir, ignore_errors=True)
879- else:
880- masks = process_volume(
881- np.array(zarr_array[i, :, :, :]), f"T{{i+1}}"
882- )
888+ masks = run_zyx_slab(
889+ zarr_array,
890+ (i,),
891+ (Z, Y, X),
892+ chunk_3d,
893+ _tmp_slabs,
894+ name=f"T{{i+1}}",
895+ t_idx=i,
896+ )
883897
884898 if TIMEPOINT_INDEX is not None:
885899 result[:, :, :] = masks
@@ -911,25 +925,15 @@ def main():
911925 for out_i, i in enumerate(indices):
912926 print(f"\\ n=== Channel {{i+1}}/{{dim1}} ===")
913927 sys.stdout.flush()
914- if _distributed_eval_fn is not None:
915- slab_z, slab_dir = slab_to_disk_zarr(
916- zarr_array, (i,), (Z, Y, X), chunk_3d, _tmp_slabs
917- )
918- slab_mask = _select_mask_for_volume(
919- _DISTRIBUTED_MASK, c_idx=i
920- )
921- try:
922- masks = run_distributed_on_slab(
923- slab_z,
924- f"C{{i+1}}",
925- slab_mask=slab_mask,
926- )
927- finally:
928- _shutil.rmtree(slab_dir, ignore_errors=True)
929- else:
930- masks = process_volume(
931- np.array(zarr_array[i, :, :, :]), f"C{{i+1}}"
932- )
928+ masks = run_zyx_slab(
929+ zarr_array,
930+ (i,),
931+ (Z, Y, X),
932+ chunk_3d,
933+ _tmp_slabs,
934+ name=f"C{{i+1}}",
935+ c_idx=i,
936+ )
933937
934938 if len(indices) == 1:
935939 result[:, :, :] = masks
0 commit comments