@@ -70,6 +70,125 @@ def forward(self, x):
7070
7171
7272
73+ class DoubleConvB (nn .Module ):
74+ def __init__ (self , in_channels , out_channels ):
75+ super (DoubleConvB , self ).__init__ ()
76+ self .conv = nn .Sequential (
77+ nn .Conv2d (in_channels , out_channels , 3 , 1 , 1 , bias = False ),
78+ nn .BatchNorm2d (out_channels ),
79+ nn .ReLU (inplace = True ),
80+ nn .Conv2d (out_channels , out_channels , 3 , 1 , 1 , bias = False ),
81+ nn .BatchNorm2d (out_channels ),
82+ nn .ReLU (inplace = True ))
83+
84+ def forward (self , x ):
85+ return self .conv (x )
86+
87+
88+ class Decoder (nn .Module ):
89+ def __init__ (self , encoder , out_channels , features ):
90+ super (Decoder , self ).__init__ ()
91+ self .ups = nn .ModuleList ()
92+ self .encoder = encoder
93+ self .out_channels = out_channels
94+
95+ for feature in reversed (features ):
96+ self .ups .append (nn .ConvTranspose2d (feature * 2 , feature , kernel_size = 2 , stride = 2 ))
97+ self .ups .append (DoubleConvB (feature * 2 , feature ))
98+
99+
100+ def forward (self , x ):
101+ for idx in range (0 , len (self .ups ), 2 ):
102+ x = self .ups [idx ](x )
103+ skip_connection = self .encoder .skip_connections [idx // 2 ]
104+
105+ if x .shape != skip_connection .shape :
106+ x = TF .resize (x , size = skip_connection .shape [2 :])
107+
108+ concat_skip = torch .cat ((skip_connection , x ), dim = 1 )
109+ x = self .ups [idx + 1 ](concat_skip )
110+
111+ return x
112+
113+
114+ class Encoder (nn .Module ):
115+ def __init__ (self , in_channels , features ):
116+ super (Encoder , self ).__init__ ()
117+ self .downs = nn .ModuleList ()
118+ self .pool = nn .MaxPool2d (kernel_size = 2 , stride = 2 )
119+ self .in_channels = in_channels
120+
121+ #Downsampling
122+ for feature in features :
123+ self .downs .append (DoubleConvB (in_channels , feature ))
124+ in_channels = feature
125+
126+ def forward (self , x ):
127+ self .skip_connections = []
128+ for down in self .downs :
129+ x = down (x )
130+ self .skip_connections .append (x )
131+ x = self .pool (x )
132+
133+ self .skip_connections = self .skip_connections [::- 1 ]
134+
135+ return x
136+
137+
138+
139+ class UNETNew (nn .Module ):
140+ def __init__ (self , config = {'in_channels' : 3 , 'out_channels' : 4 , 'features' : [32 , 64 , 128 , 256 , 512 ]}, state_dict = None , pretrain = False , device = "cuda" ):
141+ super (UNETNew , self ).__init__ ()
142+ try :
143+ in_channels = config ['in_channels' ]
144+ except :
145+ in_channels = 3
146+
147+ try :
148+ out_channels = config ['out_channels' ]
149+ except :
150+ out_channels = 4
151+
152+ features = config ['features' ]
153+
154+ self .bottleneck_size = features [- 1 ]* 2
155+
156+ self .encoder = Encoder (in_channels , features )
157+ self .decoder = Decoder (self .encoder , out_channels , features )
158+ self .bottleneck = DoubleConv (features [- 1 ], self .bottleneck_size )
159+ self .final_conv = nn .Conv2d (features [0 ], out_channels , kernel_size = 1 )
160+
161+ if state_dict :
162+ self .load_from_dict (state_dict )
163+
164+ if pretrain :
165+ self .encoder .requires_grad_ = False
166+
167+ def get_statedict (self ):
168+ return {"Encoder" : self .encoder .state_dict (),
169+ "Bottleneck" : self .bottleneck .state_dict (),
170+ "Decoder" : self .decoder .state_dict (),
171+ "LastConv" : self .final_conv .state_dict ()}
172+
173+ def load_from_dict (self , dict ):
174+ self .encoder .load_state_dict (dict ["Encoder" ])
175+ self .bottleneck .load_state_dict (dict ["Bottleneck" ])
176+ self .decoder .load_state_dict (dict ["Decoder" ])
177+
178+ try :
179+ self .final_conv .load_state_dict (dict ["LastConv" ])
180+ except :
181+ print ("Final conv not initialized." )
182+
183+ def forward (self , x ):
184+ x = self .encoder (x )
185+ x = self .bottleneck (x )
186+ x = self .decoder (x )
187+
188+ return self .final_conv (x )
189+
190+
191+
73192class NeuralSegmentator (BaseSegmentator ):
74193 def __init__ (self , images , path = "assets/model.pth.tar" ):
75194 super ().__init__ (images )
0 commit comments