@@ -150,7 +150,7 @@ def analyze_gt(self, trainset, valset, iters):
150150 plt .clf () # Clear figure
151151 plt .close ()
152152
153- @torch .enable_grad () #enable grad for CAM
153+ @torch .no_grad () #enable grad for CAM
154154 def on_val_epoch_end (self , logs : Dict = None ):
155155 """
156156 After finish validation
@@ -160,41 +160,24 @@ def on_val_epoch_end(self, logs: Dict=None):
160160 last_batch = logs ['last_batch' ]
161161 model = self .params ['trainer' ].model
162162 valloader = self .params ['trainer' ].valloader
163- optimizer = self .params ['trainer' ].optimizer
164163
165- # Zeroing gradients in model and optimizer for supress warning
166- optimizer .zero_grad ()
167- model .zero_grad ()
168-
169- # Vizualize Grad Class Activation Mapping and model predictions
164+ # Vizualize model predictions
170165 LOGGER .text ("Visualizing model predictions..." , level = LoggerObserver .DEBUG )
171166
172167 images = last_batch ["inputs" ]
173168 targets = last_batch ["targets" ]
174169 model .eval ()
175170
176- ## Calculate GradCAM
177- model_name = model .model .name
178-
179- try :
180- grad_cam = CAMWrapper .get_method (
181- name = 'gradcam' ,
182- model = model .model .get_model (),
183- model_name = model_name , use_cuda = next (model .parameters ()).is_cuda )
184-
185- grayscale_cams , label_indices , scores = grad_cam (images , return_probs = True )
186-
187- except :
188- LOGGER .text ("Cannot calculate GradCAM" , level = LoggerObserver .ERROR )
189- return
171+ ## Get prediction on last batch
172+ outputs = model .model .get_prediction (last_batch , device = model .device )
173+ label_indices = outputs ['labels' ]
174+ scores = outputs ['confidences' ]
190175
191- gradcam_batch = []
192176 pred_batch = []
193- for idx in range (len (grayscale_cams )):
177+ for idx in range (len (images )):
194178 image = images [idx ]
195179 target = targets [idx ].item ()
196180 label = label_indices [idx ]
197- grayscale_cam = grayscale_cams [idx , :]
198181 score = scores [idx ]
199182
200183 img_show = self .visualizer .denormalize (image )
@@ -217,32 +200,12 @@ def on_val_epoch_end(self, logs: Dict=None):
217200 offset = 100
218201 )
219202
220- img_cam = show_cam_on_image (img_show , grayscale_cam , use_rgb = True )
221-
222- img_cam = TFF .to_tensor (img_cam )
223- gradcam_batch .append (img_cam )
224-
225203 pred_img = self .visualizer .get_image ()
226204 pred_img = TFF .to_tensor (pred_img )
227205 pred_batch .append (pred_img )
228206
229207 if idx == 63 : # limit number of images
230208 break
231-
232- # GradCAM images
233- gradcam_grid_img = self .visualizer .make_grid (gradcam_batch )
234- fig = plt .figure (figsize = (8 ,8 ))
235- plt .imshow (gradcam_grid_img )
236- plt .axis ("off" )
237- plt .tight_layout (pad = 0 )
238- LOGGER .log ([{
239- 'tag' : "Validation/gradcam" ,
240- 'value' : fig ,
241- 'type' : LoggerObserver .FIGURE ,
242- 'kwargs' : {
243- 'step' : iters
244- }
245- }])
246209
247210 # Prediction images
248211 pred_grid_img = self .visualizer .make_grid (pred_batch )
@@ -261,8 +224,4 @@ def on_val_epoch_end(self, logs: Dict=None):
261224
262225 plt .cla () # Clear axis
263226 plt .clf () # Clear figure
264- plt .close ()
265-
266- # Zeroing gradients in model and optimizer for safety
267- optimizer .zero_grad ()
268- model .zero_grad ()
227+ plt .close ()
0 commit comments