|
7 | 7 | import imageio |
8 | 8 | import numpy as np |
9 | 9 | import torch |
| 10 | +import torch.nn.functional as F |
10 | 11 |
|
11 | 12 |
|
12 | | -def find_dim_order(user_in_shape, input_image): |
| 13 | +def dynamic_resize(image: torch.Tensor, target_shape: tuple): |
13 | 14 | """ |
14 | | - Find the correct order of input image's |
15 | | - shape. For a few models, the order of input size |
16 | | - mentioned in the RDF.yaml file is reversed compared |
17 | | - to the input image's original size. If it is reversed, |
18 | | - transpose the image to find correct order of image's |
19 | | - dimensions. |
| 15 | + Resize an input tensor dynamically to the target shape. |
| 16 | +
|
| 17 | + Parameters: |
| 18 | + - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims) |
| 19 | + - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN') |
| 20 | +
|
| 21 | + Returns: |
| 22 | + - Resized tensor with target shape target_shape. |
20 | 23 | """ |
21 | | - image_shape = list(input_image.shape) |
22 | | - # reverse the input shape provided from RDF.yaml file |
23 | | - correct_order = user_in_shape.split(",")[::-1] |
24 | | - # remove 1s from the original dimensions |
25 | | - correct_order = [int(i) for i in correct_order if i != "1"] |
26 | | - if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): |
27 | | - input_image = torch.tensor(input_image.transpose()) |
28 | | - return input_image, correct_order |
| 24 | + # Extract input shape |
| 25 | + input_shape = image.shape |
| 26 | + num_dims = len(input_shape) # Includes channels and spatial dimensions |
| 27 | + |
| 28 | + # Ensure target shape matches the number of dimensions |
| 29 | + if len(target_shape) != num_dims: |
| 30 | + raise ValueError( |
| 31 | + f"Target shape {target_shape} must match input dimensions {num_dims}" |
| 32 | + ) |
| 33 | + |
| 34 | + # Extract target channels and spatial sizes |
| 35 | + target_channels = target_shape[0] # First element is the target channel count |
| 36 | + target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions |
| 37 | + |
| 38 | + # Add batch dim (N=1) for resizing |
| 39 | + image = image.unsqueeze(0) |
| 40 | + |
| 41 | + # Choose the best interpolation mode based on dimensionality |
| 42 | + if num_dims == 4: |
| 43 | + interp_mode = "trilinear" |
| 44 | + elif num_dims == 3: |
| 45 | + interp_mode = "bilinear" |
| 46 | + elif num_dims == 2: |
| 47 | + interp_mode = "bicubic" |
| 48 | + else: |
| 49 | + interp_mode = "nearest" |
| 50 | + |
| 51 | + # Resize spatial dimensions dynamically |
| 52 | + image = F.interpolate( |
| 53 | + image, size=target_spatial_size, mode=interp_mode, align_corners=False |
| 54 | + ) |
| 55 | + |
| 56 | + # Adjust channels if necessary |
| 57 | + current_channels = image.shape[1] |
| 58 | + |
| 59 | + if target_channels > current_channels: |
| 60 | + # Expand channels by repeating existing ones |
| 61 | + expand_factor = target_channels // current_channels |
| 62 | + remainder = target_channels % current_channels |
| 63 | + image = image.repeat(1, expand_factor, *[1] * (num_dims - 1)) |
| 64 | + |
| 65 | + if remainder > 0: |
| 66 | + extra_channels = image[ |
| 67 | + :, :remainder, ... |
| 68 | + ] # Take the first few channels to match target |
| 69 | + image = torch.cat([image, extra_channels], dim=1) |
| 70 | + |
| 71 | + elif target_channels < current_channels: |
| 72 | + # Reduce channels by averaging adjacent ones |
| 73 | + image = image[:, :target_channels, ...] # Simply slice to reduce channels |
| 74 | + return image.squeeze(0) # Remove batch dimension before returning |
29 | 75 |
|
30 | 76 |
|
31 | 77 | if __name__ == "__main__": |
32 | 78 | arg_parser = argparse.ArgumentParser() |
33 | | - arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") |
34 | | - arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") |
35 | | - arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") |
| 79 | + arg_parser.add_argument( |
| 80 | + "-im", "--imaging_model", required=True, help="Input BioImage model" |
| 81 | + ) |
| 82 | + arg_parser.add_argument( |
| 83 | + "-ii", "--image_file", required=True, help="Input image file" |
| 84 | + ) |
| 85 | + arg_parser.add_argument( |
| 86 | + "-is", "--image_size", required=True, help="Input image file's size" |
| 87 | + ) |
| 88 | + arg_parser.add_argument( |
| 89 | + "-ia", "--image_axes", required=True, help="Input image file's axes" |
| 90 | + ) |
36 | 91 |
|
37 | 92 | # get argument values |
38 | 93 | args = vars(arg_parser.parse_args()) |
39 | 94 | model_path = args["imaging_model"] |
40 | 95 | input_image_path = args["image_file"] |
| 96 | + input_size = args["image_size"] |
41 | 97 |
|
42 | 98 | # load all embedded images in TIF file |
43 | 99 | test_data = imageio.v3.imread(input_image_path, index="...") |
44 | | - test_data = np.squeeze(test_data) |
45 | 100 | test_data = test_data.astype(np.float32) |
| 101 | + test_data = np.squeeze(test_data) |
46 | 102 |
|
47 | | - # assess the correct dimensions of TIF input image |
48 | | - input_image_shape = args["image_size"] |
49 | | - im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) |
| 103 | + target_image_dim = input_size.split(",")[::-1] |
| 104 | + target_image_dim = [int(i) for i in target_image_dim if i != "1"] |
| 105 | + target_image_dim = tuple(target_image_dim) |
| 106 | + |
| 107 | + exp_test_data = torch.tensor(test_data) |
| 108 | + # check if image dimensions are reversed |
| 109 | + reversed_order = list(reversed(range(exp_test_data.dim()))) |
| 110 | + exp_test_data_T = exp_test_data.permute(*reversed_order) |
| 111 | + if exp_test_data_T.shape == target_image_dim: |
| 112 | + exp_test_data = exp_test_data_T |
| 113 | + if exp_test_data.shape != target_image_dim: |
| 114 | + for i in range(len(target_image_dim) - exp_test_data.dim()): |
| 115 | + exp_test_data = exp_test_data.unsqueeze(i) |
| 116 | + try: |
| 117 | + exp_test_data = dynamic_resize(exp_test_data, target_image_dim) |
| 118 | + except Exception as e: |
| 119 | + raise RuntimeError(f"Error during resizing: {e}") from e |
| 120 | + |
| 121 | + current_dimension = len(exp_test_data.shape) |
| 122 | + input_axes = args["image_axes"] |
| 123 | + target_dimension = len(input_axes) |
| 124 | + # expand input image based on the number of target dimensions |
| 125 | + for i in range(target_dimension - current_dimension): |
| 126 | + exp_test_data = torch.unsqueeze(exp_test_data, i) |
50 | 127 |
|
51 | 128 | # load model |
52 | 129 | model = torch.load(model_path) |
53 | 130 | model.eval() |
54 | 131 |
|
55 | | - # find the number of dimensions required by the model |
56 | | - target_dimension = 0 |
57 | | - for param in model.named_parameters(): |
58 | | - target_dimension = len(param[1].shape) |
59 | | - break |
60 | | - current_dimension = len(list(im_test_data.shape)) |
61 | | - |
62 | | - # update the dimensions of input image if the required image by |
63 | | - # the model is smaller |
64 | | - slices = tuple(slice(0, s_val) for s_val in shape_vals) |
65 | | - |
66 | | - # apply the slices to the reshaped_input |
67 | | - im_test_data = im_test_data[slices] |
68 | | - exp_test_data = torch.tensor(im_test_data) |
69 | | - |
70 | | - # expand input image's dimensions |
71 | | - for i in range(target_dimension - current_dimension): |
72 | | - exp_test_data = torch.unsqueeze(exp_test_data, i) |
73 | | - |
74 | 132 | # make prediction |
75 | 133 | pred_data = model(exp_test_data) |
76 | 134 | pred_data_output = pred_data.detach().numpy() |
|
0 commit comments