Skip to content

Commit cb66c34

Browse files
author
valhassan
committed
Refactor ScriptModel with enhanced preprocessing and SegmentationScriptModel support
1 parent e9fe56b commit cb66c34

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed
+50-11
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,67 @@
11
import torch
22
import torch.nn as nn
33
from torch.nn import functional as F
4+
from tools.utils import normalization, standardization
45

56
class ScriptModel(nn.Module):
6-
def __init__(self, model, data_module, num_classes, from_logits = True, aux_head = False):
7+
def __init__(self,
8+
model,
9+
device = torch.device("cpu"),
10+
num_classes = 1,
11+
input_shape = (1, 3, 512, 512),
12+
mean = [0.0, 0.0, 0.0],
13+
std = [1.0, 1.0, 1.0],
14+
image_min = 0,
15+
image_max = 255,
16+
norm_min = 0.0,
17+
norm_max = 1.0,
18+
from_logits = True):
719
super().__init__()
8-
self.model = model
9-
self.data_module = data_module
20+
self.device = device
1021
self.num_classes = num_classes
1122
self.from_logits = from_logits
12-
self.aux_head = aux_head
23+
self.mean = torch.tensor(mean).reshape(len(mean), 1)
24+
self.std = torch.tensor(std).reshape(len(std), 1)
25+
self.image_min = int(image_min)
26+
self.image_max = int(image_max)
27+
self.norm_min = float(norm_min)
28+
self.norm_max = float(norm_max)
29+
30+
dummy_input = torch.rand(input_shape).to(self.device)
31+
self.traced_model = torch.jit.trace(model.eval(), dummy_input)
32+
1333
def forward(self, x):
14-
sample = {"image": x}
15-
preprocessed = self.data_module._model_script_preprocess(sample)
16-
output = self.model(preprocessed["image"])
17-
if isinstance(output, dict):
18-
output = output['out']
34+
x = normalization(x, self.image_min, self.image_max, self.norm_min, self.norm_max)
35+
x = standardization(x, self.mean, self.std)
36+
output = self.traced_model(x)
1937
if self.from_logits:
2038
if self.num_classes > 1:
2139
output = F.softmax(output, dim=1)
2240
else:
2341
output = F.sigmoid(output)
2442
return output
2543

26-
def script_model(model, data_module, num_classes, from_logits = True):
27-
return ScriptModel(model, data_module, num_classes, from_logits)
44+
class SegmentationScriptModel(ScriptModel):
45+
"""
46+
This class is used to script a model that returns a NamedTuple.
47+
48+
for example:
49+
class SegmentationOutput(NamedTuple):
50+
out: torch.Tensor
51+
aux: Optional[torch.Tensor]
52+
"""
53+
def __init__(self, model, **kwargs):
54+
super().__init__(model, **kwargs)
55+
56+
def forward(self, x):
57+
x = normalization(x, self.image_min, self.image_max, self.norm_min, self.norm_max)
58+
x = standardization(x, self.mean, self.std)
59+
output = self.traced_model(x)
60+
output = output[0] # Extract main output from NamedTuple
61+
if self.from_logits:
62+
if self.num_classes > 1:
63+
output = F.softmax(output, dim=1)
64+
else:
65+
output = F.sigmoid(output)
66+
return output
2867

0 commit comments

Comments
 (0)