1
1
import torch
2
2
import torch .nn as nn
3
3
from torch .nn import functional as F
4
+ from tools .utils import normalization , standardization
4
5
5
6
class ScriptModel (nn .Module ):
6
- def __init__ (self , model , data_module , num_classes , from_logits = True , aux_head = False ):
7
+ def __init__ (self ,
8
+ model ,
9
+ device = torch .device ("cpu" ),
10
+ num_classes = 1 ,
11
+ input_shape = (1 , 3 , 512 , 512 ),
12
+ mean = [0.0 , 0.0 , 0.0 ],
13
+ std = [1.0 , 1.0 , 1.0 ],
14
+ image_min = 0 ,
15
+ image_max = 255 ,
16
+ norm_min = 0.0 ,
17
+ norm_max = 1.0 ,
18
+ from_logits = True ):
7
19
super ().__init__ ()
8
- self .model = model
9
- self .data_module = data_module
20
+ self .device = device
10
21
self .num_classes = num_classes
11
22
self .from_logits = from_logits
12
- self .aux_head = aux_head
23
+ self .mean = torch .tensor (mean ).reshape (len (mean ), 1 )
24
+ self .std = torch .tensor (std ).reshape (len (std ), 1 )
25
+ self .image_min = int (image_min )
26
+ self .image_max = int (image_max )
27
+ self .norm_min = float (norm_min )
28
+ self .norm_max = float (norm_max )
29
+
30
+ dummy_input = torch .rand (input_shape ).to (self .device )
31
+ self .traced_model = torch .jit .trace (model .eval (), dummy_input )
32
+
13
33
def forward (self , x ):
14
- sample = {"image" : x }
15
- preprocessed = self .data_module ._model_script_preprocess (sample )
16
- output = self .model (preprocessed ["image" ])
17
- if isinstance (output , dict ):
18
- output = output ['out' ]
34
+ x = normalization (x , self .image_min , self .image_max , self .norm_min , self .norm_max )
35
+ x = standardization (x , self .mean , self .std )
36
+ output = self .traced_model (x )
19
37
if self .from_logits :
20
38
if self .num_classes > 1 :
21
39
output = F .softmax (output , dim = 1 )
22
40
else :
23
41
output = F .sigmoid (output )
24
42
return output
25
43
26
- def script_model (model , data_module , num_classes , from_logits = True ):
27
- return ScriptModel (model , data_module , num_classes , from_logits )
44
+ class SegmentationScriptModel (ScriptModel ):
45
+ """
46
+ This class is used to script a model that returns a NamedTuple.
47
+
48
+ for example:
49
+ class SegmentationOutput(NamedTuple):
50
+ out: torch.Tensor
51
+ aux: Optional[torch.Tensor]
52
+ """
53
+ def __init__ (self , model , ** kwargs ):
54
+ super ().__init__ (model , ** kwargs )
55
+
56
+ def forward (self , x ):
57
+ x = normalization (x , self .image_min , self .image_max , self .norm_min , self .norm_max )
58
+ x = standardization (x , self .mean , self .std )
59
+ output = self .traced_model (x )
60
+ output = output [0 ] # Extract main output from NamedTuple
61
+ if self .from_logits :
62
+ if self .num_classes > 1 :
63
+ output = F .softmax (output , dim = 1 )
64
+ else :
65
+ output = F .sigmoid (output )
66
+ return output
28
67
0 commit comments