forked from bgruening/galaxytools
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
144 lines (117 loc) · 4.92 KB
/
main.py
File metadata and controls
144 lines (117 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Predict images using AI models from BioImage.IO
"""
import argparse
import imageio
import numpy as np
import torch
import torch.nn.functional as F
def dynamic_resize(image: torch.Tensor, target_shape: tuple):
"""
Resize an input tensor dynamically to the target shape.
Parameters:
- image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims)
- target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN')
Returns:
- Resized tensor with target shape target_shape.
"""
# Extract input shape
input_shape = image.shape
num_dims = len(input_shape) # Includes channels and spatial dimensions
# Ensure target shape matches the number of dimensions
if len(target_shape) != num_dims:
raise ValueError(
f"Target shape {target_shape} must match input dimensions {num_dims}"
)
# Extract target channels and spatial sizes
target_channels = target_shape[0] # First element is the target channel count
target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions
# Add batch dim (N=1) for resizing
image = image.unsqueeze(0)
# Choose the best interpolation mode based on dimensionality
if num_dims == 4:
interp_mode = "trilinear"
elif num_dims == 3:
interp_mode = "bilinear"
elif num_dims == 2:
interp_mode = "bicubic"
else:
interp_mode = "nearest"
# Resize spatial dimensions dynamically
image = F.interpolate(
image, size=target_spatial_size, mode=interp_mode, align_corners=False
)
# Adjust channels if necessary
current_channels = image.shape[1]
if target_channels > current_channels:
# Expand channels by repeating existing ones
expand_factor = target_channels // current_channels
remainder = target_channels % current_channels
image = image.repeat(1, expand_factor, *[1] * (num_dims - 1))
if remainder > 0:
extra_channels = image[
:, :remainder, ...
] # Take the first few channels to match target
image = torch.cat([image, extra_channels], dim=1)
elif target_channels < current_channels:
# Reduce channels by averaging adjacent ones
image = image[:, :target_channels, ...] # Simply slice to reduce channels
return image.squeeze(0) # Remove batch dimension before returning
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"-im", "--imaging_model", required=True, help="Input BioImage model"
)
arg_parser.add_argument(
"-ii", "--image_file", required=True, help="Input image file"
)
arg_parser.add_argument(
"-is", "--image_size", required=True, help="Input image file's size"
)
arg_parser.add_argument(
"-ia", "--image_axes", required=True, help="Input image file's axes"
)
# get argument values
args = vars(arg_parser.parse_args())
model_path = args["imaging_model"]
input_image_path = args["image_file"]
input_size = args["image_size"]
# load all embedded images in TIF file
test_data = imageio.v3.imread(input_image_path, index="...")
test_data = test_data.astype(np.float32)
test_data = np.squeeze(test_data)
target_image_dim = input_size.split(",")[::-1]
target_image_dim = [int(i) for i in target_image_dim if i != "1"]
target_image_dim = tuple(target_image_dim)
exp_test_data = torch.tensor(test_data)
# check if image dimensions are reversed
reversed_order = list(reversed(range(exp_test_data.dim())))
exp_test_data_T = exp_test_data.permute(*reversed_order)
if exp_test_data_T.shape == target_image_dim:
exp_test_data = exp_test_data_T
if exp_test_data.shape != target_image_dim:
for i in range(len(target_image_dim) - exp_test_data.dim()):
exp_test_data = exp_test_data.unsqueeze(i)
try:
exp_test_data = dynamic_resize(exp_test_data, target_image_dim)
except Exception as e:
raise RuntimeError(f"Error during resizing: {e}") from e
current_dimension = len(exp_test_data.shape)
input_axes = args["image_axes"]
target_dimension = len(input_axes)
# expand input image based on the number of target dimensions
for i in range(target_dimension - current_dimension):
exp_test_data = torch.unsqueeze(exp_test_data, i)
# load model
model = torch.load(model_path)
model.eval()
# make prediction
pred_data = model(exp_test_data)
pred_data_output = pred_data.detach().numpy()
# save original image matrix
np.save("output_predicted_image_matrix.npy", pred_data_output)
# post process predicted file to correctly save as TIF file
pred_data = torch.squeeze(pred_data)
pred_numpy = pred_data.detach().numpy()
# write predicted TIF image to file
imageio.v3.imwrite("output_predicted_image.tiff", pred_numpy, extension=".tiff")