Skip to content

Commit bea58f9

Browse files
authored
Merge pull request #41 from kaylode/v1.0.1
Update to V1.0.1
2 parents 697c8f0 + 1024167 commit bea58f9

File tree

17 files changed

+216
-70
lines changed

17 files changed

+216
-70
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ tabulate==0.8.10
1212
loguru==0.6.0
1313
seaborn==0.12.0
1414
wandb==0.13.3
15+
plotly==5.10.0
16+
matplotlib==3.4.3
1517

1618
## Classification
1719
timm

theseus/base/callbacks/checkpoint_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def save_checkpoint(self, trainer, iters, outname='last'):
5454
weights = {
5555
'model': trainer.model.model.state_dict(),
5656
'optimizer': trainer.optimizer.state_dict(),
57+
'scheduler': trainer.scheduler.state_dict(),
5758
'iters': iters,
5859
'best_value': self.best_value,
5960
}

theseus/base/callbacks/wandb_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def on_start(self, logs: Dict=None):
108108
# Save all config files
109109
self.wandb_logger.log_file(
110110
tag='configs',
111+
base_folder=self.save_dir,
111112
value = osp.join(self.save_dir, '*.yaml'))
112113

113114
# Init logging model for debug

theseus/base/datasets/dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def __init__(self, datasets: Iterable[data.Dataset], **kwargs) -> None:
1818
self.classnames = datasets[0].classnames
1919
self.collate_fn = datasets[0].collate_fn
2020

21+
def __getattr__(self, attr):
22+
if hasattr(self, attr):
23+
return getattr(self, attr)
24+
25+
if hasattr(self.datasets[0], attr):
26+
return getattr(self.datasets[0], attr)
27+
28+
raise AttributeError
2129

2230
class ChainDataset(data.ConcatDataset):
2331
"""

theseus/base/losses/multi_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ def __init__(self, losses: Iterable[nn.Module], weights=None, **kwargs):
1111
self.losses = losses
1212
self.weights = [1.0 for _ in range(len(losses))] if weights is None else weights
1313

14-
def forward(self, pred: torch.Tensor, batch: Dict[str, Any], device: torch.device):
14+
def forward(self, outputs: Dict[str, Any], batch: Dict[str, Any], device: torch.device):
1515
"""
1616
Forward inputs and targets through multiple losses
1717
"""
1818
total_loss = 0
1919
total_loss_dict = {}
2020

2121
for weight, loss_fn in zip(self.weights, self.losses):
22-
loss, loss_dict = loss_fn(pred, batch, device)
22+
loss, loss_dict = loss_fn(outputs, batch, device)
2323
total_loss += (weight*loss)
2424
total_loss_dict.update(loss_dict)
2525

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from theseus.classification.callbacks.visualize_callbacks import ClassificationVisualizerCallbacks
2+
from theseus.classification.callbacks.gradcam_callbacks import GradCAMVisualizationCallbacks
23
from theseus.base.callbacks import CALLBACKS_REGISTRY
34

4-
CALLBACKS_REGISTRY.register(ClassificationVisualizerCallbacks)
5+
CALLBACKS_REGISTRY.register(ClassificationVisualizerCallbacks)
6+
CALLBACKS_REGISTRY.register(GradCAMVisualizationCallbacks)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import Dict
2+
import matplotlib.pyplot as plt
3+
import torch
4+
from torchvision.transforms import functional as TFF
5+
6+
from theseus.base.callbacks.base_callbacks import Callbacks
7+
from theseus.utilities.loggers.observer import LoggerObserver
8+
from theseus.classification.utilities.gradcam import CAMWrapper, show_cam_on_image
9+
from theseus.utilities.visualization.visualizer import Visualizer
10+
11+
LOGGER = LoggerObserver.getLogger("main")
12+
13+
class GradCAMVisualizationCallbacks(Callbacks):
14+
"""
15+
Callbacks for visualizing stuff during training
16+
Features:
17+
- Visualize datasets; plot model architecture, analyze datasets in sanity check
18+
- Visualize prediction at every end of validation
19+
20+
"""
21+
22+
def __init__(self, **kwargs) -> None:
23+
super().__init__()
24+
self.visualizer = Visualizer()
25+
26+
@torch.enable_grad() #enable grad for CAM
27+
def on_val_epoch_end(self, logs: Dict=None):
28+
"""
29+
After finish validation
30+
"""
31+
32+
iters = logs['iters']
33+
last_batch = logs['last_batch']
34+
model = self.params['trainer'].model
35+
valloader = self.params['trainer'].valloader
36+
optimizer = self.params['trainer'].optimizer
37+
38+
# Zeroing gradients in model and optimizer for supress warning
39+
optimizer.zero_grad()
40+
model.zero_grad()
41+
42+
# Vizualize Grad Class Activation Mapping and model predictions
43+
LOGGER.text("Visualizing model predictions...", level=LoggerObserver.DEBUG)
44+
45+
images = last_batch["inputs"]
46+
targets = last_batch["targets"]
47+
model.eval()
48+
49+
## Calculate GradCAM and Grad Class Activation Mapping and
50+
model_name = model.model.name
51+
52+
try:
53+
grad_cam = CAMWrapper.get_method(
54+
name='gradcam',
55+
model=model.model.get_model(),
56+
model_name=model_name, use_cuda=next(model.parameters()).is_cuda)
57+
58+
grayscale_cams, label_indices, scores = grad_cam(images, return_probs=True)
59+
60+
except:
61+
LOGGER.text("Cannot calculate GradCAM", level=LoggerObserver.ERROR)
62+
return
63+
64+
gradcam_batch = []
65+
for idx in range(len(grayscale_cams)):
66+
image = images[idx]
67+
target = targets[idx].item()
68+
label = label_indices[idx]
69+
grayscale_cam = grayscale_cams[idx, :]
70+
71+
img_show = self.visualizer.denormalize(image)
72+
if valloader.dataset.classnames is not None:
73+
label = valloader.dataset.classnames[label]
74+
target = valloader.dataset.classnames[target]
75+
76+
img_cam =show_cam_on_image(img_show, grayscale_cam, use_rgb=True)
77+
78+
img_cam = TFF.to_tensor(img_cam)
79+
gradcam_batch.append(img_cam)
80+
81+
if idx == 63: # limit number of images
82+
break
83+
84+
# GradCAM images
85+
gradcam_grid_img = self.visualizer.make_grid(gradcam_batch)
86+
fig = plt.figure(figsize=(8,8))
87+
plt.imshow(gradcam_grid_img)
88+
plt.axis("off")
89+
plt.tight_layout(pad=0)
90+
LOGGER.log([{
91+
'tag': "Validation/gradcam",
92+
'value': fig,
93+
'type': LoggerObserver.FIGURE,
94+
'kwargs': {
95+
'step': iters
96+
}
97+
}])

theseus/classification/callbacks/visualize_callbacks.py

Lines changed: 8 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def analyze_gt(self, trainset, valset, iters):
150150
plt.clf() # Clear figure
151151
plt.close()
152152

153-
@torch.enable_grad() #enable grad for CAM
153+
@torch.no_grad() #enable grad for CAM
154154
def on_val_epoch_end(self, logs: Dict=None):
155155
"""
156156
After finish validation
@@ -160,41 +160,24 @@ def on_val_epoch_end(self, logs: Dict=None):
160160
last_batch = logs['last_batch']
161161
model = self.params['trainer'].model
162162
valloader = self.params['trainer'].valloader
163-
optimizer = self.params['trainer'].optimizer
164163

165-
# Zeroing gradients in model and optimizer for supress warning
166-
optimizer.zero_grad()
167-
model.zero_grad()
168-
169-
# Vizualize Grad Class Activation Mapping and model predictions
164+
# Vizualize model predictions
170165
LOGGER.text("Visualizing model predictions...", level=LoggerObserver.DEBUG)
171166

172167
images = last_batch["inputs"]
173168
targets = last_batch["targets"]
174169
model.eval()
175170

176-
## Calculate GradCAM
177-
model_name = model.model.name
178-
179-
try:
180-
grad_cam = CAMWrapper.get_method(
181-
name='gradcam',
182-
model=model.model.get_model(),
183-
model_name=model_name, use_cuda=next(model.parameters()).is_cuda)
184-
185-
grayscale_cams, label_indices, scores = grad_cam(images, return_probs=True)
186-
187-
except:
188-
LOGGER.text("Cannot calculate GradCAM", level=LoggerObserver.ERROR)
189-
return
171+
## Get prediction on last batch
172+
outputs = model.model.get_prediction(last_batch, device=model.device)
173+
label_indices = outputs['labels']
174+
scores = outputs['confidences']
190175

191-
gradcam_batch = []
192176
pred_batch = []
193-
for idx in range(len(grayscale_cams)):
177+
for idx in range(len(images)):
194178
image = images[idx]
195179
target = targets[idx].item()
196180
label = label_indices[idx]
197-
grayscale_cam = grayscale_cams[idx, :]
198181
score = scores[idx]
199182

200183
img_show = self.visualizer.denormalize(image)
@@ -217,32 +200,12 @@ def on_val_epoch_end(self, logs: Dict=None):
217200
offset=100
218201
)
219202

220-
img_cam =show_cam_on_image(img_show, grayscale_cam, use_rgb=True)
221-
222-
img_cam = TFF.to_tensor(img_cam)
223-
gradcam_batch.append(img_cam)
224-
225203
pred_img = self.visualizer.get_image()
226204
pred_img = TFF.to_tensor(pred_img)
227205
pred_batch.append(pred_img)
228206

229207
if idx == 63: # limit number of images
230208
break
231-
232-
# GradCAM images
233-
gradcam_grid_img = self.visualizer.make_grid(gradcam_batch)
234-
fig = plt.figure(figsize=(8,8))
235-
plt.imshow(gradcam_grid_img)
236-
plt.axis("off")
237-
plt.tight_layout(pad=0)
238-
LOGGER.log([{
239-
'tag': "Validation/gradcam",
240-
'value': fig,
241-
'type': LoggerObserver.FIGURE,
242-
'kwargs': {
243-
'step': iters
244-
}
245-
}])
246209

247210
# Prediction images
248211
pred_grid_img = self.visualizer.make_grid(pred_batch)
@@ -261,8 +224,4 @@ def on_val_epoch_end(self, logs: Dict=None):
261224

262225
plt.cla() # Clear axis
263226
plt.clf() # Clear figure
264-
plt.close()
265-
266-
# Zeroing gradients in model and optimizer for safety
267-
optimizer.zero_grad()
268-
model.zero_grad()
227+
plt.close()

theseus/classification/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from .confusion_matrix import *
77
from .errorcases import *
88
from .projection import *
9+
from .precision_recall import *
910

1011
METRIC_REGISTRY.register(Accuracy)
1112
METRIC_REGISTRY.register(BalancedAccuracyMetric)
1213
METRIC_REGISTRY.register(F1ScoreMetric)
1314
METRIC_REGISTRY.register(ConfusionMatrix)
1415
METRIC_REGISTRY.register(ErrorCases)
15-
METRIC_REGISTRY.register(EmbeddingProjection)
16+
METRIC_REGISTRY.register(EmbeddingProjection)
17+
METRIC_REGISTRY.register(PrecisionRecall)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from sklearn.metrics import precision_score, recall_score
2+
from typing import Any, Dict
3+
4+
from theseus.base.metrics.metric_template import Metric
5+
from theseus.classification.utilities.logits import logits2labels
6+
7+
class PrecisionRecall(Metric):
8+
"""
9+
F1 Score Metric (including macro, micro)
10+
"""
11+
def __init__(self, average = 'weighted', label_type:str = 'multiclass', **kwargs):
12+
super().__init__(**kwargs)
13+
self.average = average
14+
self.type =label_type
15+
self.threshold = kwargs.get('threshold', 0.5)
16+
self.reset()
17+
18+
def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]):
19+
"""
20+
Perform calculation based on prediction and targets
21+
"""
22+
targets = batch["targets"]
23+
outputs = outputs["outputs"]
24+
25+
outputs = logits2labels(outputs, label_type=self.type, threshold=self.threshold)
26+
targets = targets.squeeze()
27+
28+
self.preds += outputs.numpy().tolist()
29+
self.targets += targets.numpy().tolist()
30+
31+
def value(self):
32+
precision = precision_score(self.targets, self.preds, average=self.average, zero_division=1)
33+
recall = recall_score(self.targets, self.preds, average=self.average, zero_division=1)
34+
return {f"{self.average}-precision": precision, f"{self.average}-recall": recall}
35+
36+
def reset(self):
37+
self.targets = []
38+
self.preds = []

0 commit comments

Comments
 (0)