diff --git a/cellmap_flow/blockwise/blockwise_processor.py b/cellmap_flow/blockwise/blockwise_processor.py index ba177ea..e171f2f 100644 --- a/cellmap_flow/blockwise/blockwise_processor.py +++ b/cellmap_flow/blockwise/blockwise_processor.py @@ -6,13 +6,16 @@ import numpy as np from funlib.geometry.coordinate import Coordinate from funlib.persistence import Array, open_ds, prepare_ds +from zarr.storage import NestedDirectoryStore +from zarr.hierarchy import open_group from zarr.storage import DirectoryStore - +from functools import partial from cellmap_flow.globals import g from cellmap_flow.image_data_interface import ImageDataInterface from cellmap_flow.inferencer import Inferencer from cellmap_flow.utils.config_utils import build_models, load_config from cellmap_flow.utils.serilization_utils import get_process_dataset +from cellmap_flow.utils.ds import generate_singlescale_metadata logger = logging.getLogger(__name__) @@ -36,6 +39,7 @@ def __init__(self, yaml_config: str, create=False): self.output_path = self.config["output_path"] self.output_path = Path(self.output_path) + output_channels = None if "output_channels" in self.config: output_channels = self.config["output_channels"].split(",") @@ -50,12 +54,15 @@ def __init__(self, yaml_config: str, create=False): if "workers" not in self.config: logger.error("Missing required field in YAML: workers") return + + task_name = self.config["task_name"] self.workers = self.config["workers"] if self.workers <= 1: logger.error("Workers should be greater than 1.") return self.cpu_workers = self.config.get("cpu_workers", 12) - if "create" in self.config: + # Added and create == True to fix client error when create: True in the yaml, so when it is a client it will not be changed + if "create" in self.config and create == True: create = self.config["create"] if isinstance(create, str): logger.warning( @@ -63,7 +70,13 @@ def __init__(self, yaml_config: str, create=False): ) create = create.lower() == "true" - task_name = self.config["task_name"] + if "tmp_dir" not in self.config: + logger.error("Missing required field in YAML: tmp_dir, it is mandatory to track progress") + return + + self.tmp_dir = Path(self.config["tmp_dir"]) / f"tmp_flow_daisy_progress_{task_name}" + if not self.tmp_dir.exists(): + self.tmp_dir.mkdir(parents=True, exist_ok=True) # Build model configuration objects models = build_models(self.config["models"]) @@ -101,7 +114,7 @@ def __init__(self, yaml_config: str, create=False): if json_data: g.input_norms, g.postprocess = get_process_dataset(json_data) - self.inferencer = Inferencer(self.model_config) + self.inferencer = Inferencer(self.model_config, use_half_prediction=False) self.idi_raw = ImageDataInterface( self.input_path, voxel_size=self.input_voxel_size @@ -118,26 +131,43 @@ def __init__(self, yaml_config: str, create=False): print(f"type: {self.dtype}") print(f"output_path: {self.output_path}") for channel in self.output_channels: + ndim = len(self.block_shape) if create: try: array = prepare_ds( - DirectoryStore(self.output_path / channel / "s0"), + NestedDirectoryStore(self.output_path / channel / "s0"), output_shape, dtype=self.dtype, chunk_shape=self.block_shape, voxel_size=self.output_voxel_size, - axis_names=["z", "y", "x"], - units=["nm", "nm", "nm"], - offset=(0, 0, 0), + axis_names=["z", "y", "x"][-ndim:], + units=["nanometer",]*ndim, + offset=(0,)*ndim, ) except Exception as e: raise Exception( f"Failed to prepare {self.output_path/channel/'s0'} \n try deleting it manually and run again ! {e}" ) + try: + z_store = NestedDirectoryStore(self.output_path / channel) + zg = open_group(store=z_store, mode='a') + if 'multiscales' in zg.attrs: + raise ValueError(f'multiscales attribute already exists in {z_store.path}') + else: + zattrs = generate_singlescale_metadata(arr_name='s0', + voxel_size=self.output_voxel_size, + translation=[0.0,]*ndim, + units=['nanometer',]*ndim, + axes=['z', 'y', 'x'][-ndim:]) + zg.attrs['multiscales'] = zattrs['multiscales'] + except Exception as e: + raise Exception( + f"Failed to prepare ome-ngff metadata for {self.output_path/channel/'s0'}, {e}" + ) else: try: array = open_ds( - DirectoryStore(self.output_path / channel / "s0"), + NestedDirectoryStore(self.output_path / channel / "s0"), "a", ) except Exception as e: @@ -156,9 +186,6 @@ def process_fn(self, block): chunk_data = chunk_data.astype(self.dtype) - if self.output_arrays[0][block.write_roi].any(): - return - for i, array in enumerate(self.output_arrays): if chunk_data.shape == 3: @@ -177,7 +204,8 @@ def process_fn(self, block): self.output_voxel_size, ) array[write_roi] = predictions.to_ndarray(write_roi) - logger.info(f"Processed block {block.id} with write ROI {write_roi}") + + def client(self): client = daisy.Client() @@ -185,9 +213,14 @@ def client(self): with client.acquire_block() as block: if block is None: break - self.process_fn(block) + try: + self.process_fn(block) - block.status = daisy.BlockStatus.SUCCESS + block.status = daisy.BlockStatus.SUCCESS + (self.tmp_dir / f"{block.block_id[1]}").touch() + except Exception as e: + logger.error(f"Error processing block {block}: {e}") + block.status = daisy.BlockStatus.FAILED def run(self): @@ -217,17 +250,18 @@ def run(self): self.queue, ncpu=self.cpu_workers, ), - read_write_conflict=True, + check_function=partial(check_block, self.tmp_dir), + read_write_conflict=False, fit="overhang", max_retries=0, timeout=None, num_workers=self.workers, ) - daisy.run_blockwise([task]) - # , multiprocessing= False - - + task_state = daisy.run_blockwise([task]) + logger.info(f"Task state: {task_state}") +def check_block(tmp_dir, block: daisy.Block) -> bool: + return (tmp_dir / f"{block.block_id[1]}").exists() def spawn_worker(name, yaml_config, charge_group, queue, ncpu=12): def run_worker(): if not Path("prediction_logs").exists(): diff --git a/cellmap_flow/cli/fly_model.py b/cellmap_flow/cli/fly_model.py index a9dd12f..1cabb44 100644 --- a/cellmap_flow/cli/fly_model.py +++ b/cellmap_flow/cli/fly_model.py @@ -31,7 +31,8 @@ def main(): if "charge_group" not in data: raise ValueError("charge_group is required in the YAML file") charge_group = data["charge_group"] - + input_size = tuple(data.get("input_size", (178, 178, 178))) + output_size = tuple(data.get("output_size", (56, 56, 56))) g.charge_group = charge_group threads = [] for run_name, run_items in data["runs"].items(): @@ -51,6 +52,8 @@ def main(): input_voxel_size=res, output_voxel_size=res, name=run_name, + input_size=input_size, + output_size=output_size, ) model_command = f"fly -c {model_config.checkpoint_path} -ch {','.join(model_config.channels)} -ivs {','.join(map(str,model_config.input_voxel_size))} -ovs {','.join(map(str,model_config.output_voxel_size))}" command = f"{SERVER_COMMAND} {model_command} -d {data_path}" diff --git a/cellmap_flow/utils/ds.py b/cellmap_flow/utils/ds.py index cb01723..853a384 100644 --- a/cellmap_flow/utils/ds.py +++ b/cellmap_flow/utils/ds.py @@ -16,6 +16,36 @@ from cellmap_flow.globals import g +def generate_singlescale_metadata( + arr_name: str, + voxel_size: list, + translation: list, + units: str, + axes: list, +): + z_attrs: dict = {"multiscales": [{}]} + z_attrs["multiscales"][0]["axes"] = [ + {"name": axis, "type": "space", "unit": unit} for axis, unit in zip(axes, units) + ] + z_attrs["multiscales"][0]["coordinateTransformations"] = [ + {"scale": [1.0,]*len(voxel_size), "type": "scale"} + ] + z_attrs["multiscales"][0]["datasets"] = [ + { + "coordinateTransformations": [ + {"scale": voxel_size, "type": "scale"}, + {"translation": translation, "type": "translation"}, + ], + "path": arr_name, + } + ] + + z_attrs["multiscales"][0]["name"] = "/" + z_attrs["multiscales"][0]["version"] = "0.4" + + return z_attrs + + def get_scale_info(zarr_grp): attrs = zarr_grp.attrs