Skip to content
74 changes: 54 additions & 20 deletions cellmap_flow/blockwise/blockwise_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import numpy as np
from funlib.geometry.coordinate import Coordinate
from funlib.persistence import Array, open_ds, prepare_ds
from zarr.storage import NestedDirectoryStore
from zarr.hierarchy import open_group
from zarr.storage import DirectoryStore

from functools import partial
from cellmap_flow.globals import g
from cellmap_flow.image_data_interface import ImageDataInterface
from cellmap_flow.inferencer import Inferencer
from cellmap_flow.utils.config_utils import build_models, load_config
from cellmap_flow.utils.serilization_utils import get_process_dataset
from cellmap_flow.utils.ds import generate_singlescale_metadata

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +39,7 @@ def __init__(self, yaml_config: str, create=False):
self.output_path = self.config["output_path"]
self.output_path = Path(self.output_path)


output_channels = None
if "output_channels" in self.config:
output_channels = self.config["output_channels"].split(",")
Expand All @@ -50,20 +54,29 @@ def __init__(self, yaml_config: str, create=False):
if "workers" not in self.config:
logger.error("Missing required field in YAML: workers")
return

task_name = self.config["task_name"]
self.workers = self.config["workers"]
if self.workers <= 1:
logger.error("Workers should be greater than 1.")
return
self.cpu_workers = self.config.get("cpu_workers", 12)
if "create" in self.config:
# Added and create == True to fix client error when create: True in the yaml, so when it is a client it will not be changed
if "create" in self.config and create == True:
create = self.config["create"]
if isinstance(create, str):
logger.warning(
f"Type config[create] is str = {create}, better set a bool"
)
create = create.lower() == "true"

task_name = self.config["task_name"]
if "tmp_dir" not in self.config:
logger.error("Missing required field in YAML: tmp_dir, it is mandatory to track progress")
return

self.tmp_dir = Path(self.config["tmp_dir"]) / f"tmp_flow_daisy_progress_{task_name}"
if not self.tmp_dir.exists():
self.tmp_dir.mkdir(parents=True, exist_ok=True)

# Build model configuration objects
models = build_models(self.config["models"])
Expand Down Expand Up @@ -101,7 +114,7 @@ def __init__(self, yaml_config: str, create=False):
if json_data:
g.input_norms, g.postprocess = get_process_dataset(json_data)

self.inferencer = Inferencer(self.model_config)
self.inferencer = Inferencer(self.model_config, use_half_prediction=False)

self.idi_raw = ImageDataInterface(
self.input_path, voxel_size=self.input_voxel_size
Expand All @@ -118,26 +131,43 @@ def __init__(self, yaml_config: str, create=False):
print(f"type: {self.dtype}")
print(f"output_path: {self.output_path}")
for channel in self.output_channels:
ndim = len(self.block_shape)
if create:
try:
array = prepare_ds(
DirectoryStore(self.output_path / channel / "s0"),
NestedDirectoryStore(self.output_path / channel / "s0"),
output_shape,
dtype=self.dtype,
chunk_shape=self.block_shape,
voxel_size=self.output_voxel_size,
axis_names=["z", "y", "x"],
units=["nm", "nm", "nm"],
offset=(0, 0, 0),
axis_names=["z", "y", "x"][-ndim:],
units=["nanometer",]*ndim,
offset=(0,)*ndim,
)
except Exception as e:
raise Exception(
f"Failed to prepare {self.output_path/channel/'s0'} \n try deleting it manually and run again ! {e}"
)
try:
z_store = NestedDirectoryStore(self.output_path / channel)
zg = open_group(store=z_store, mode='a')
if 'multiscales' in zg.attrs:
raise ValueError(f'multiscales attribute already exists in {z_store.path}')
else:
zattrs = generate_singlescale_metadata(arr_name='s0',
voxel_size=self.output_voxel_size,
translation=[0.0,]*ndim,
units=['nanometer',]*ndim,
axes=['z', 'y', 'x'][-ndim:])
zg.attrs['multiscales'] = zattrs['multiscales']
except Exception as e:
raise Exception(
f"Failed to prepare ome-ngff metadata for {self.output_path/channel/'s0'}, {e}"
)
else:
try:
array = open_ds(
DirectoryStore(self.output_path / channel / "s0"),
NestedDirectoryStore(self.output_path / channel / "s0"),
"a",
)
except Exception as e:
Expand All @@ -156,9 +186,6 @@ def process_fn(self, block):

chunk_data = chunk_data.astype(self.dtype)

if self.output_arrays[0][block.write_roi].any():
return

for i, array in enumerate(self.output_arrays):

if chunk_data.shape == 3:
Expand All @@ -177,17 +204,23 @@ def process_fn(self, block):
self.output_voxel_size,
)
array[write_roi] = predictions.to_ndarray(write_roi)
logger.info(f"Processed block {block.id} with write ROI {write_roi}")



def client(self):
client = daisy.Client()
while True:
with client.acquire_block() as block:
if block is None:
break
self.process_fn(block)
try:
self.process_fn(block)

block.status = daisy.BlockStatus.SUCCESS
block.status = daisy.BlockStatus.SUCCESS
(self.tmp_dir / f"{block.block_id[1]}").touch()
except Exception as e:
logger.error(f"Error processing block {block}: {e}")
block.status = daisy.BlockStatus.FAILED

def run(self):

Expand Down Expand Up @@ -217,17 +250,18 @@ def run(self):
self.queue,
ncpu=self.cpu_workers,
),
read_write_conflict=True,
check_function=partial(check_block, self.tmp_dir),
read_write_conflict=False,
fit="overhang",
max_retries=0,
timeout=None,
num_workers=self.workers,
)

daisy.run_blockwise([task])
# , multiprocessing= False


task_state = daisy.run_blockwise([task])
logger.info(f"Task state: {task_state}")
def check_block(tmp_dir, block: daisy.Block) -> bool:
return (tmp_dir / f"{block.block_id[1]}").exists()
Comment on lines +263 to +264
Copy link

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic index [1] for block.block_id is unclear. Consider adding a comment explaining what this index represents or using a named constant.

Copilot uses AI. Check for mistakes.
def spawn_worker(name, yaml_config, charge_group, queue, ncpu=12):
def run_worker():
if not Path("prediction_logs").exists():
Expand Down
5 changes: 4 additions & 1 deletion cellmap_flow/cli/fly_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def main():
if "charge_group" not in data:
raise ValueError("charge_group is required in the YAML file")
charge_group = data["charge_group"]

input_size = tuple(data.get("input_size", (178, 178, 178)))
output_size = tuple(data.get("output_size", (56, 56, 56)))
g.charge_group = charge_group
threads = []
for run_name, run_items in data["runs"].items():
Expand All @@ -51,6 +52,8 @@ def main():
input_voxel_size=res,
output_voxel_size=res,
name=run_name,
input_size=input_size,
output_size=output_size,
)
model_command = f"fly -c {model_config.checkpoint_path} -ch {','.join(model_config.channels)} -ivs {','.join(map(str,model_config.input_voxel_size))} -ovs {','.join(map(str,model_config.output_voxel_size))}"
command = f"{SERVER_COMMAND} {model_command} -d {data_path}"
Expand Down
30 changes: 30 additions & 0 deletions cellmap_flow/utils/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@

from cellmap_flow.globals import g

def generate_singlescale_metadata(
arr_name: str,
voxel_size: list,
translation: list,
units: str,
axes: list,
):
z_attrs: dict = {"multiscales": [{}]}
z_attrs["multiscales"][0]["axes"] = [
{"name": axis, "type": "space", "unit": unit} for axis, unit in zip(axes, units)
]
z_attrs["multiscales"][0]["coordinateTransformations"] = [
{"scale": [1.0,]*len(voxel_size), "type": "scale"}
]
z_attrs["multiscales"][0]["datasets"] = [
{
"coordinateTransformations": [
{"scale": voxel_size, "type": "scale"},
{"translation": translation, "type": "translation"},
],
"path": arr_name,
}
]

z_attrs["multiscales"][0]["name"] = "/"
z_attrs["multiscales"][0]["version"] = "0.4"

return z_attrs



def get_scale_info(zarr_grp):
attrs = zarr_grp.attrs
Expand Down