Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion tools/bioimaging/bioimage_inference.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<description>with PyTorch</description>
<macros>
<token name="@TOOL_VERSION@">2.4.1</token>
<token name="@VERSION_SUFFIX@">0</token>
<token name="@VERSION_SUFFIX@">1</token>
</macros>
<creator>
<organization name="European Galaxy Team" url="https://galaxyproject.org/eu/" />
Expand Down Expand Up @@ -30,12 +30,14 @@
--imaging_model '$input_imaging_model'
--image_file '$input_image_file'
--image_size '$input_image_input_size'
--image_axes '$input_image_input_axes'
]]>
</command>
<inputs>
<param name="input_imaging_model" type="data" format="zip" label="BioImage.IO model" help="Please upload a BioImage.IO model."/>
<param name="input_image_file" type="data" format="tiff,png" label="Input image" help="Please provide an input image for the analysis."/>
<param name="input_image_input_size" type="text" label="Size of the input image" help="Provide the size of the input image. See the chosen model's RDF file to find the correct input size. For example: for the BioImage.IO model MitochondriaEMSegmentationBoundaryModel, the input size is 256 x 256 x 32 x 1. Enter the size as 256,256,32,1."/>
<param name="input_image_input_axes" type="text" label="Axes of the input image" help="Provide the input axes of the input image. See the chosen model's RDF file to find the correct axes. For example: for the BioImage.IO model MitochondriaEMSegmentationBoundaryModel, the input axes is 'bzyx'"/>
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be a select?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw. a text param can also have select options ...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't it be extracted from the rdf so that the user doesn't need to know this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter axes provides the real shape (number of dimensions) of the input image/matrix required by the model to make inferences. The parameter input_size is necessary to slice pixels from different input image dimensions. So, both parameters I think are important to ascertain the target size of the image to be provided to the model. RDF file is not available inside the tool from where we could extract the axes information. Also, In a few RDF files, the axes information is not straightforward and written like (from 3d-unet-arabidopsis-apical-stem-cells model):

inputs:
  - axes:
      - type: batch
      - channel_names:
          - channel0
        id: channel
        type: channel
      - id: z
        scale: 1
        size: 100
        type: space
      - id: 'y'
        scale: 1
        size: 128
        type: space
      - id: x
        scale: 1
        size: 128
        type: space

Users have to infer the correct axes information from above. From above, it is bczyx. Other RDF files have this information directly given as bczyx. (bczyx == batch, channel, z, y, x)

I can provide this as a select option.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FynnBe I understood that the info in the RDF file is standardised, right?

</inputs>
<outputs>
<data format="tif" name="output_predicted_image" from_work_dir="output_predicted_image.tif" label="Predicted image"></data>
Expand All @@ -46,16 +48,50 @@
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://zenodo.org/api/records/6647674/files/weights-torchscript.pt/content"/>
<param name="input_image_file" value="input_image_file.tif" location="https://zenodo.org/api/records/6647674/files/sample_input_0.tif/content"/>
<param name="input_image_input_size" value="256,256,1,1"/>
<param name="input_image_input_axes" value="bcyx"/>
<output name="output_predicted_image" file="output_nucleisegboundarymodel.tif" compare="sim_size" delta="100" />
<output name="output_predicted_image_matrix" file="output_nucleisegboundarymodel_matrix.npy" compare="sim_size" delta="100" />
</test>
<test>
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://zenodo.org/api/records/6647674/files/weights-torchscript.pt/content"/>
<param name="input_image_file" value="input_nucleisegboundarymodel.png"/>
<param name="input_image_input_size" value="256,256,1,1"/>
<param name="input_image_input_axes" value="bcyx"/>
<output name="output_predicted_image" file="output_nucleisegboundarymodel.tif" compare="sim_size" delta="100" />
<output name="output_predicted_image_matrix" file="output_nucleisegboundarymodel_matrix.npy" compare="sim_size" delta="100" />
</test>
<test>
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/emotional-cricket/1.1/files/torchscript_tracing.pt"/>
<param name="input_image_file" value="input_image_file.tif" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/emotional-cricket/1.1/files/sample_input_0.tif"/>
<param name="input_image_input_size" value="128,128,100,1"/>
<param name="input_image_input_axes" value="bczyx"/>
<output name="output_predicted_image" file="output_3d-unet-arabidopsis-apical-stem-cells.tif" compare="sim_size" delta="100" />
<output name="output_predicted_image_matrix" file="output_3d-unet-arabidopsis-apical-stem-cells.npy" compare="sim_size" delta="100" />
</test>
<test>
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/emotional-cricket/1.1/files/torchscript_tracing.pt"/>
<param name="input_image_file" value="input_3d-unet-arabidopsis-apical-stem-cells.png" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/emotional-cricket/1.1/files/raw.png"/>
<param name="input_image_input_size" value="128,128,100,1"/>
<param name="input_image_input_axes" value="bczyx"/>
<output name="output_predicted_image" file="output_3d-unet-arabidopsis-apical-stem-cells.tif" compare="sim_size" delta="100" />
<output name="output_predicted_image_matrix" file="output_3d-unet-arabidopsis-apical-stem-cells.npy" compare="sim_size" delta="100" />
</test>
<test>
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/organized-badger/1/files/weights-torchscript.pt"/>
<param name="input_image_file" value="input_platynereisemnucleisegmentationboundarymodel.tif" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/organized-badger/1/files/sample_input_0.tif"/>
<param name="input_image_input_size" value="256,256,32,1"/>
<param name="input_image_input_axes" value="bczyx"/>
<output name="output_predicted_image" file="output_platynereisemnucleisegmentationboundarymodel.tif" compare="sim_size" delta="100" />
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<output name="output_predicted_image_matrix" file="output_platynereisemnucleisegmentationboundarymodel.npy" compare="sim_size" delta="100" />
</test>
<test>
<param name="input_imaging_model" value="input_imaging_model.zip" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/thoughtful-turtle/1/files/torchscript_tracing.pt"/>
<param name="input_image_file" value="input_3d-unet-lateral-root-primordia-cells.tif" location="https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/thoughtful-turtle/1/files/sample_input_0.tif"/>
<param name="input_image_input_size" value="128,128,100,1"/>
<param name="input_image_input_axes" value="bczyx"/>
<output name="output_predicted_image" file="output_3d-unet-lateral-root-primordia-cells.tif" compare="sim_size" delta="100" />
<output name="output_predicted_image_matrix" file="output_3d-unet-lateral-root-primordia-cells.npy" compare="sim_size" delta="100" />
</test>
</tests>
<help>
<![CDATA[
Expand All @@ -67,6 +103,7 @@
- BioImage.IO model: Add one of the model from Galaxy file uploader by choosing a "remote" file at "ML Models/bioimaging-models"
- Image to be analyzed: Provide an image as TIF/PNG file
- Provide the necessary input size for the model. This information can be found in the RDF file of each model (RDF file > config > test_information > inputs > size)
- Provide axes of input image. This information can also be found in the RDF file of each model (RDF file > inputs > axes). An example value of axes is 'bczyx' for 3D U-Net Arabidopsis Lateral Root Primordia model

**Output files**
- Predicted image: Predicted image using the BioImage.IO model
Expand Down
140 changes: 99 additions & 41 deletions tools/bioimaging/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,128 @@
import imageio
import numpy as np
import torch
import torch.nn.functional as F


def find_dim_order(user_in_shape, input_image):
def dynamic_resize(image: torch.Tensor, target_shape: tuple):
"""
Find the correct order of input image's
shape. For a few models, the order of input size
mentioned in the RDF.yaml file is reversed compared
to the input image's original size. If it is reversed,
transpose the image to find correct order of image's
dimensions.
Resize an input tensor dynamically to the target shape.

Parameters:
- image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims)
- target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN')

Returns:
- Resized tensor with target shape target_shape.
"""
image_shape = list(input_image.shape)
# reverse the input shape provided from RDF.yaml file
correct_order = user_in_shape.split(",")[::-1]
# remove 1s from the original dimensions
correct_order = [int(i) for i in correct_order if i != "1"]
if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape):
input_image = torch.tensor(input_image.transpose())
return input_image, correct_order
# Extract input shape
input_shape = image.shape
num_dims = len(input_shape) # Includes channels and spatial dimensions

# Ensure target shape matches the number of dimensions
if len(target_shape) != num_dims:
raise ValueError(
f"Target shape {target_shape} must match input dimensions {num_dims}"
)

# Extract target channels and spatial sizes
target_channels = target_shape[0] # First element is the target channel count
target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions

# Add batch dim (N=1) for resizing
image = image.unsqueeze(0)

# Choose the best interpolation mode based on dimensionality
if num_dims == 4:
interp_mode = "trilinear"
elif num_dims == 3:
interp_mode = "bilinear"
elif num_dims == 2:
interp_mode = "bicubic"
else:
interp_mode = "nearest"

# Resize spatial dimensions dynamically
image = F.interpolate(
image, size=target_spatial_size, mode=interp_mode, align_corners=False
)

# Adjust channels if necessary
current_channels = image.shape[1]

if target_channels > current_channels:
# Expand channels by repeating existing ones
expand_factor = target_channels // current_channels
remainder = target_channels % current_channels
image = image.repeat(1, expand_factor, *[1] * (num_dims - 1))

if remainder > 0:
extra_channels = image[
:, :remainder, ...
] # Take the first few channels to match target
image = torch.cat([image, extra_channels], dim=1)

elif target_channels < current_channels:
# Reduce channels by averaging adjacent ones
image = image[:, :target_channels, ...] # Simply slice to reduce channels
return image.squeeze(0) # Remove batch dimension before returning


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model")
arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file")
arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size")
arg_parser.add_argument(
"-im", "--imaging_model", required=True, help="Input BioImage model"
)
arg_parser.add_argument(
"-ii", "--image_file", required=True, help="Input image file"
)
arg_parser.add_argument(
"-is", "--image_size", required=True, help="Input image file's size"
)
arg_parser.add_argument(
"-ia", "--image_axes", required=True, help="Input image file's axes"
)

# get argument values
args = vars(arg_parser.parse_args())
model_path = args["imaging_model"]
input_image_path = args["image_file"]
input_size = args["image_size"]

# load all embedded images in TIF file
test_data = imageio.v3.imread(input_image_path, index="...")
test_data = np.squeeze(test_data)
test_data = test_data.astype(np.float32)
test_data = np.squeeze(test_data)

# assess the correct dimensions of TIF input image
input_image_shape = args["image_size"]
im_test_data, shape_vals = find_dim_order(input_image_shape, test_data)
target_image_dim = input_size.split(",")[::-1]
target_image_dim = [int(i) for i in target_image_dim if i != "1"]
target_image_dim = tuple(target_image_dim)

exp_test_data = torch.tensor(test_data)
# check if image dimensions are reversed
reversed_order = list(reversed(range(exp_test_data.dim())))
exp_test_data_T = exp_test_data.permute(*reversed_order)
if exp_test_data_T.shape == target_image_dim:
exp_test_data = exp_test_data_T
if exp_test_data.shape != target_image_dim:
for i in range(len(target_image_dim) - exp_test_data.dim()):
exp_test_data = exp_test_data.unsqueeze(i)
try:
exp_test_data = dynamic_resize(exp_test_data, target_image_dim)
except Exception as e:
raise RuntimeError(f"Error during resizing: {e}") from e

current_dimension = len(exp_test_data.shape)
input_axes = args["image_axes"]
target_dimension = len(input_axes)
# expand input image based on the number of target dimensions
for i in range(target_dimension - current_dimension):
exp_test_data = torch.unsqueeze(exp_test_data, i)

# load model
model = torch.load(model_path)
model.eval()

# find the number of dimensions required by the model
target_dimension = 0
for param in model.named_parameters():
target_dimension = len(param[1].shape)
break
current_dimension = len(list(im_test_data.shape))

# update the dimensions of input image if the required image by
# the model is smaller
slices = tuple(slice(0, s_val) for s_val in shape_vals)

# apply the slices to the reshaped_input
im_test_data = im_test_data[slices]
exp_test_data = torch.tensor(im_test_data)

# expand input image's dimensions
for i in range(target_dimension - current_dimension):
exp_test_data = torch.unsqueeze(exp_test_data, i)

# make prediction
pred_data = model(exp_test_data)
pred_data_output = pred_data.detach().numpy()
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tools/bioimaging/test-data/output_nucleisegboundarymodel.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading