From 4dbf169dca81bb5e9e26d69e89f2fa59de2bcf5f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 15 Jun 2025 20:55:09 -0700 Subject: [PATCH 1/7] upgrade the segmentation to use cellpose sam --- biahub/segment.py | 42 +++++++++++++++++++++++------------------- biahub/settings.py | 9 ++++++--- pyproject.toml | 2 +- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index 96ffcdfa..ce710a5d 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -83,13 +83,21 @@ def segment_data( ) czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs) + # Reorder channels based on channels_for_segmentation + cellpose_czyx = np.zeros( + (3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype + ) + for i, channel in enumerate(model_args.channels): + if channel is not None: + cellpose_czyx[i] = czyx_data_to_segment[channel] + # Apply the segmentation model = models.CellposeModel( - model_type=model_args.path_to_model, gpu=gpu, device=device + gpu=gpu, device=device, pretrained_model=model_args.path_to_model ) segmentation, _, _ = model.eval( - czyx_data_to_segment, channel_axis=0, z_axis=1, **model_args.eval_args - ) # noqa: python-no-eval + cellpose_czyx, channel_axis=0, z_axis=1, **model_args.eval_args + ) if z_slice_2D is not None and isinstance(z_slice_2D, int): segmentation = segmentation[np.newaxis, ...] czyx_segmentation.append(segmentation) @@ -147,23 +155,19 @@ def segment_cli( if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int): Z = 1 # Ensure channel names exist in the dataset - if not all(channel in channel_names for channel in model_args.eval_args["channels"]): + if not all(channel in channel_names for channel in model_args.channels): raise ValueError( - f"Channels {model_args.eval_args['channels']} not found in dataset {channel_names}" + f"Channels {model_args.channels} not found in dataset {channel_names}" ) - # Channel strings to indices with the cellpose offset of 1 - model_args.eval_args["channels"] = [ - channel_names.index(channel) + 1 for channel in model_args.eval_args["channels"] - ] + # Channel strings to indices to be used in cellpose. Hiding this from the + model_args.channels = [channel_names.index(channel) for channel in model_args.channels] # NOTE:List of channels, either of length 2 or of length number of images by 2. # First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). - # Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). - if len(model_args.eval_args["channels"]) < 2: - model_args.eval_args["channels"].append(0) + # Second element of list is the optional nuclear channel or organelle channel. Only 3 channels are supported. The rest are ignored + if len(model_args.channels) < 2: + model_args.channels.append(0) - click.echo( - f"Segmenting with model {model_name} using channels {model_args.eval_args['channels']}" - ) + click.echo(f"Segmenting with model {model_name} using channels {model_args.channels}") if ( "anisotropy" not in model_args.eval_args or model_args.eval_args["anisotropy"] is None @@ -202,7 +206,7 @@ def segment_cli( # Estimate resources num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20) num_gpus = 1 - slurm_time = np.ceil(np.max([80, T * 2.5])).astype(int) + slurm_time = np.ceil(np.max([120, T * Z * 5])).astype(int) slurm_array_parallelism = 100 # Prepare SLURM arguments slurm_args = { @@ -236,9 +240,9 @@ def segment_cli( jobs.append( executor.submit( process_single_position, - segment_data, - input_position_path, - output_position_path, + func=segment_data, + input_position_path=input_position_path, + output_position_path=output_position_path, input_channel_indices=[list(range(C))], output_channel_indices=[list(range(C_segment))], num_processes=np.min([20, int(num_cpus * 0.8)]), diff --git a/biahub/settings.py b/biahub/settings.py index b0f8b143..10990434 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -14,7 +14,6 @@ PositiveInt, field_validator, model_validator, - validator, ) @@ -393,10 +392,12 @@ class PreprocessingFunctions(BaseModel): class SegmentationModel(BaseModel): path_to_model: str eval_args: Dict[str, Any] + channels: list[str] z_slice_2D: Optional[int] = None preprocessing: list[PreprocessingFunctions] = [] - @validator("eval_args", pre=True) + @field_validator("eval_args") + @classmethod def validate_eval_args(cls, value): # Retrieve valid arguments dynamically if cellpose is required valid_args = get_valid_eval_args() @@ -410,7 +411,8 @@ def validate_eval_args(cls, value): return value - @validator("z_slice_2D") + @field_validator("z_slice_2D") + @classmethod def check_z_slice_with_do_3D(cls, z_slice_2D, values): # Only run this check if z_slice is provided (not None) and do_3D exists in eval_args if z_slice_2D is not None: @@ -421,6 +423,7 @@ def check_z_slice_with_do_3D(cls, z_slice_2D, values): "If 'z_slice_2D' is provided, 'do_3D' in 'eval_args' must be set to False." ) z_slice_2D = 0 + return z_slice_2D diff --git a/pyproject.toml b/pyproject.toml index fc7eb3e8..ba7271e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ ] segmentation = [ - "cellpose", + "cellpose>=4.0.4", ] build = ["build", "twine"] From a234a4115d1844aeffab15d142f254ff7bad5377 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sun, 15 Jun 2025 21:08:18 -0700 Subject: [PATCH 2/7] update the segmentation config --- biahub/segment.py | 17 ++++++++---- settings/example_segmentation_settings.yml | 32 +++++++++++++--------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index ce710a5d..204bf99a 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -1,3 +1,5 @@ +import warnings + from pathlib import Path import click @@ -161,11 +163,16 @@ def segment_cli( ) # Channel strings to indices to be used in cellpose. Hiding this from the model_args.channels = [channel_names.index(channel) for channel in model_args.channels] - # NOTE:List of channels, either of length 2 or of length number of images by 2. - # First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). - # Second element of list is the optional nuclear channel or organelle channel. Only 3 channels are supported. The rest are ignored - if len(model_args.channels) < 2: - model_args.channels.append(0) + # NOTE: Cellpose requires 3 channels. If the channels list is less than 3, the first channel is repeated. + + if len(model_args.channels) < 3: + model_args.channels.extend( + [model_args.channels[0]] * (3 - len(model_args.channels)) + ) + else: + warnings.warn( + f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used." + ) click.echo(f"Segmenting with model {model_name} using channels {model_args.channels}") if ( diff --git a/settings/example_segmentation_settings.yml b/settings/example_segmentation_settings.yml index a9fa6bf0..37dddfb7 100644 --- a/settings/example_segmentation_settings.yml +++ b/settings/example_segmentation_settings.yml @@ -3,18 +3,17 @@ models: # One can instantiate as many models membrane: - path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)" - # These are the common model.CellposeModel().eval() arguments used, but one can add more. - # For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0 + path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model. eval_args: diameter: 65 - channels: ['mem', 'nuc'] #The channel count for Cellpose starts at 1 cellprob_threshold: 0.4 invert: false - do_3D: false # Optional, if false, 2D segmentation is performed. + do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided. anisotropy: 3.26 min_size: 8000 - z_slice_2D: 10 # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity + z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + channels: ['mem', 'nuc'] preprocessing: - function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity kwargs: {"out_range": [0, 1]} @@ -22,16 +21,23 @@ models: - function: skimage.exposure.equalize_adapthist kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]} channel: 'mem' + # One can instantiate as many models nucleus: - path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)" - # These are the common model.CellposeModel().eval() arguments used, but one can add more. - # For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0 + path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model. eval_args: - diameter: 60 - channels: ['nuc'] #For nucleus segmentation, only one channel is required. We populate the other channel with zero. - cellprob_threshold: 0.0 + diameter: 55 + cellprob_threshold: 0.4 invert: false - do_3D: true + do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided. anisotropy: 3.26 min_size: 8000 + normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + channels: ['nuc','mem'] + preprocessing: + - function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity + kwargs: {"out_range": [0, 1]} + channel: 'nuc' + - function: skimage.exposure.equalize_adapthist + kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]} + channel: 'nuc' From 6d9245c75e980793ae68e2fe669b23894526ee0c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 2 Jul 2025 17:42:12 -0700 Subject: [PATCH 3/7] fix bug when using 2d zslice --- biahub/segment.py | 11 +++++++---- biahub/settings.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index 204bf99a..0842945c 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -66,9 +66,12 @@ def segment_data( for i, (model_name, model_args) in enumerate(segmentation_models.items()): click.echo(f"Segmenting with model {model_name}") z_slice_2D = model_args.z_slice_2D - czyx_data_to_segment = ( - czyx_data[:, z_slice_2D : z_slice_2D + 1] if z_slice_2D is not None else czyx_data - ) + if z_slice_2D is not None: + czyx_data_to_segment = czyx_data[:, z_slice_2D] + z_axis = None + else: + czyx_data_to_segment = czyx_data + z_axis = 1 # Apply preprocessing functions preprocessing_functions = model_args.preprocessing for preproc in preprocessing_functions: @@ -98,7 +101,7 @@ def segment_data( gpu=gpu, device=device, pretrained_model=model_args.path_to_model ) segmentation, _, _ = model.eval( - cellpose_czyx, channel_axis=0, z_axis=1, **model_args.eval_args + cellpose_czyx, channel_axis=0, z_axis=z_axis, **model_args.eval_args ) if z_slice_2D is not None and isinstance(z_slice_2D, int): segmentation = segmentation[np.newaxis, ...] diff --git a/biahub/settings.py b/biahub/settings.py index 10990434..bdebb660 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -413,15 +413,17 @@ def validate_eval_args(cls, value): @field_validator("z_slice_2D") @classmethod - def check_z_slice_with_do_3D(cls, z_slice_2D, values): + def check_z_slice_with_do_3D(cls, z_slice_2D, info): # Only run this check if z_slice is provided (not None) and do_3D exists in eval_args if z_slice_2D is not None: - eval_args = values.get("eval_args", {}) - do_3D = eval_args.get("do_3D", None) - if do_3D: - raise ValueError( - "If 'z_slice_2D' is provided, 'do_3D' in 'eval_args' must be set to False." - ) + # In Pydantic v2, we need to access the data differently + if hasattr(info, 'data') and 'eval_args' in info.data: + eval_args = info.data['eval_args'] + do_3D = eval_args.get("do_3D", None) + if do_3D: + raise ValueError( + "If 'z_slice_2D' is provided, 'do_3D' in 'eval_args' must be set to False." + ) z_slice_2D = 0 return z_slice_2D From 8101134db1ad631053722bb2d24acbefe7ca4760 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 3 Jul 2025 08:36:04 -0700 Subject: [PATCH 4/7] bump cellpose that supports fp16 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ba7271e4..39d90731 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ ] segmentation = [ - "cellpose>=4.0.4", + "cellpose>=4.0.5", ] build = ["build", "twine"] From 95cb565068e179bdb974c4609fd85ca7b48131af Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 3 Jul 2025 16:42:25 -0700 Subject: [PATCH 5/7] modify default submitit parameters and optimizing segment_data func --- biahub/segment.py | 75 +++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index 0842945c..9b2ad01d 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -29,10 +29,8 @@ def segment_data( segmentation_models: dict, gpu: bool = True, ) -> np.ndarray: - from cellpose import models - """ - Segment a CZYX image using a Cellpose segmentation model + Segment a CZYX image using Cellpose segmentation models. Parameters ---------- @@ -48,8 +46,8 @@ def segment_data( np.ndarray A CZYX segmentation image """ + from cellpose import models - # Segmenetation in cpu or gpu if gpu: try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -61,22 +59,35 @@ def segment_data( click.echo(f"Using device: {device}") + # Pre-load unique models to avoid redundant loading + unique_models = {} + for model_name, model_args in segmentation_models.items(): + model_path = model_args.path_to_model + if model_path not in unique_models: + click.echo(f"Loading model: {model_path}") + unique_models[model_path] = models.CellposeModel( + gpu=gpu, device=device, pretrained_model=model_path + ) + czyx_segmentation = [] - # Process each model in a loop - for i, (model_name, model_args) in enumerate(segmentation_models.items()): - click.echo(f"Segmenting with model {model_name}") + + # Process each model + for model_name, model_args in segmentation_models.items(): + click.echo(f"Starting segmentation with model {model_name}") + + # Extract the data we need for this model 2D or 3D z_slice_2D = model_args.z_slice_2D if z_slice_2D is not None: - czyx_data_to_segment = czyx_data[:, z_slice_2D] + czyx_data_to_segment = czyx_data[:, z_slice_2D].copy() z_axis = None else: - czyx_data_to_segment = czyx_data + czyx_data_to_segment = czyx_data.copy() z_axis = 1 - # Apply preprocessing functions - preprocessing_functions = model_args.preprocessing - for preproc in preprocessing_functions: + click.echo(f"Segmenting {model_name} with z_axis {z_axis}") + # Apply preprocessing specific to this model + for preproc in model_args.preprocessing: func = preproc.function - kwargs = preproc.kwargs + kwargs = preproc.kwargs.copy() c_idx = preproc.channel # Convert list to tuple for out_range if needed @@ -86,9 +97,9 @@ def segment_data( click.echo( f"Processing with {func.__name__} with kwargs {kwargs} to channel {c_idx}" ) - czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs) + czyx_data_to_segment[c_idx] = func(czyx_data_to_segment[c_idx], **kwargs) - # Reorder channels based on channels_for_segmentation + # Prepare cellpose input cellpose_czyx = np.zeros( (3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype ) @@ -96,19 +107,27 @@ def segment_data( if channel is not None: cellpose_czyx[i] = czyx_data_to_segment[channel] - # Apply the segmentation - model = models.CellposeModel( - gpu=gpu, device=device, pretrained_model=model_args.path_to_model - ) + # Get pre-loaded model and run segmentation + click.echo(f"Running segmentation for {model_args.path_to_model}") + model = unique_models[model_args.path_to_model] segmentation, _, _ = model.eval( cellpose_czyx, channel_axis=0, z_axis=z_axis, **model_args.eval_args ) + + # Handle 2D output formatting if z_slice_2D is not None and isinstance(z_slice_2D, int): segmentation = segmentation[np.newaxis, ...] + czyx_segmentation.append(segmentation) - czyx_segmentation = np.stack(czyx_segmentation, axis=0) - return czyx_segmentation + # Clean up intermediate arrays + del cellpose_czyx, czyx_data_to_segment + + # Clean up GPU memory + if gpu and device.type == 'cuda': + torch.cuda.empty_cache() + + return np.stack(czyx_segmentation, axis=0) @click.command("segment") @@ -177,7 +196,7 @@ def segment_cli( f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used." ) - click.echo(f"Segmenting with model {model_name} using channels {model_args.channels}") + click.echo(f"Segmenting {model_name} using channels {model_args.channels}") if ( "anisotropy" not in model_args.eval_args or model_args.eval_args["anisotropy"] is None @@ -214,17 +233,17 @@ def segment_cli( ) # Estimate resources - num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20) + num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=10) num_gpus = 1 - slurm_time = np.ceil(np.max([120, T * Z * 5])).astype(int) - slurm_array_parallelism = 100 + slurm_time = np.ceil(np.max([120, T * Z * 10])).astype(int) + slurm_array_parallelism = 9 # Prepare SLURM arguments slurm_args = { "slurm_job_name": "segment", "slurm_gres": f"gpu:{num_gpus}", "slurm_mem_per_cpu": f"{gb_ram_request}G", - "slurm_cpus_per_task": np.max([int(20 * 1.3), num_cpus]), - "slurm_array_parallelism": slurm_array_parallelism, # process up to 20 positions at a time + "slurm_cpus_per_task": np.max([int(slurm_array_parallelism * 2), num_cpus]), + "slurm_array_parallelism": slurm_array_parallelism, "slurm_time": slurm_time, "slurm_partition": "gpu", } @@ -255,7 +274,7 @@ def segment_cli( output_position_path=output_position_path, input_channel_indices=[list(range(C))], output_channel_indices=[list(range(C_segment))], - num_processes=np.min([20, int(num_cpus * 0.8)]), + num_processes=np.min([slurm_array_parallelism, int(num_cpus * 0.8)]), segmentation_models=segment_args, ) ) From 744375947c7d2819c50f28c03ca0c8192a2d3d76 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 21 Jul 2025 17:43:29 -0700 Subject: [PATCH 6/7] setting to uint32 --- biahub/segment.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index 9b2ad01d..fedaa98a 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -127,7 +127,7 @@ def segment_data( if gpu and device.type == 'cuda': torch.cuda.empty_cache() - return np.stack(czyx_segmentation, axis=0) + return np.stack(czyx_segmentation, axis=0).astype(np.uint32) @click.command("segment") @@ -150,6 +150,7 @@ def segment_cli( -i ./input.zarr/*/*/* \ -c ./segment_params.yml \ -o ./output.zarr + """ # Convert string paths to Path objects @@ -223,6 +224,7 @@ def segment_cli( segmentation_shape = (T, C_segment, Z, Y, X) # Create a zarr store output to mirror the input + # Note, dtype is set to uint32. Change this if one envisions having more than 2^32 labels. create_empty_plate( store_path=output_dirpath, position_keys=[path.parts[-3:] for path in input_position_dirpaths], @@ -230,6 +232,7 @@ def segment_cli( shape=segmentation_shape, chunks=None, scale=scale, + dtype=np.uint32, ) # Estimate resources @@ -274,7 +277,7 @@ def segment_cli( output_position_path=output_position_path, input_channel_indices=[list(range(C))], output_channel_indices=[list(range(C_segment))], - num_processes=np.min([slurm_array_parallelism, int(num_cpus * 0.8)]), + num_processes=np.min([5, int(num_cpus * 0.8)]), segmentation_models=segment_args, ) ) From 202f943cce6302a4a02fb22dfcbdb46a21ccb07f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 30 Jul 2025 17:06:47 -0700 Subject: [PATCH 7/7] enforcing gpu in the config as optional --- biahub/segment.py | 18 ++++++++++-------- biahub/settings.py | 1 + 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/biahub/segment.py b/biahub/segment.py index fedaa98a..14b285ed 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -51,12 +51,12 @@ def segment_data( if gpu: try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + click.echo(f"Using GPU: {device}") except torch.cuda.CudaError: click.echo("No GPU available. Using CPU") device = torch.device("cpu") else: device = torch.device("cpu") - click.echo(f"Using device: {device}") # Pre-load unique models to avoid redundant loading @@ -66,7 +66,9 @@ def segment_data( if model_path not in unique_models: click.echo(f"Loading model: {model_path}") unique_models[model_path] = models.CellposeModel( - gpu=gpu, device=device, pretrained_model=model_path + gpu=True if device.type == 'cuda' else False, + device=device, + pretrained_model=model_path, ) czyx_segmentation = [] @@ -174,9 +176,9 @@ def segment_cli( # Load the segmentation models with their respective configurations # TODO: implement logic for 2D segmentation. Have a slicing parameter - segment_args = settings.models - C_segment = len(segment_args) - for model_name, model_args in segment_args.items(): + + C_segment = len(settings.models) + for model_name, model_args in settings.models.items(): if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int): Z = 1 # Ensure channel names exist in the dataset @@ -228,7 +230,7 @@ def segment_cli( create_empty_plate( store_path=output_dirpath, position_keys=[path.parts[-3:] for path in input_position_dirpaths], - channel_names=[model_name + "_labels" for model_name in segment_args.keys()], + channel_names=[model_name + "_labels" for model_name in settings.models.keys()], shape=segmentation_shape, chunks=None, scale=scale, @@ -238,7 +240,7 @@ def segment_cli( # Estimate resources num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=10) num_gpus = 1 - slurm_time = np.ceil(np.max([120, T * Z * 10])).astype(int) + slurm_time = np.ceil(np.max([600, T * Z * 10])).astype(int) slurm_array_parallelism = 9 # Prepare SLURM arguments slurm_args = { @@ -278,7 +280,7 @@ def segment_cli( input_channel_indices=[list(range(C))], output_channel_indices=[list(range(C_segment))], num_processes=np.min([5, int(num_cpus * 0.8)]), - segmentation_models=segment_args, + segmentation_models=settings.models, ) ) diff --git a/biahub/settings.py b/biahub/settings.py index bdebb660..d1a827e8 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -432,3 +432,4 @@ def check_z_slice_with_do_3D(cls, z_slice_2D, info): class SegmentationSettings(BaseModel): models: Dict[str, SegmentationModel] model_config = {"extra": "forbid", "protected_namespaces": ()} + gpu: bool = True