Skip to content

Commit d2dd603

Browse files
committed
move metrics from train.py to separate file
1 parent cc1d04f commit d2dd603

File tree

2 files changed

+79
-80
lines changed

2 files changed

+79
-80
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from tqdm import tqdm
3+
import torch.nn.functional as F
4+
5+
def infer_model(classifier, device, img, loader, model, ttype):
6+
if loader.dataset.feature_store is None:
7+
feats = model.encode(img.to(device))
8+
else:
9+
feats = img.to(device)
10+
feats = feats.to(ttype)
11+
out = classifier(feats)
12+
return out
13+
14+
def score_loss(model, classifier, loader, device, ttype):
15+
criterion = torch.nn.KLDivLoss(reduction='batchmean')
16+
with torch.no_grad():
17+
sum_loss = 0
18+
for y, img in tqdm(loader):
19+
out = infer_model(classifier, device, img, loader, model, ttype)
20+
log_probs = F.log_softmax(out, dim=1)
21+
loss = criterion(log_probs, y.to(device))
22+
sum_loss += loss.item()
23+
sum_loss /= len(loader)
24+
25+
print("Average loss:", sum_loss)
26+
27+
28+
def symmetric_topk_mass_recall(logits, targets, k=10):
29+
30+
probs = F.softmax(logits, dim=-1)
31+
32+
topk_pred_idx = probs.topk(k, dim=-1).indices
33+
target_mass_in_pred_topk = torch.gather(targets, dim=-1, index=topk_pred_idx)
34+
recall_p_to_q = target_mass_in_pred_topk.sum(dim=-1)
35+
36+
topk_target_idx = targets.topk(k, dim=-1).indices
37+
pred_mass_in_target_topk = torch.gather(probs, dim=-1, index=topk_target_idx)
38+
recall_q_to_p = pred_mass_in_target_topk.sum(dim=-1)
39+
40+
symmetric_recall = 0.5 * (recall_p_to_q + recall_q_to_p)
41+
42+
return symmetric_recall.mean()
43+
44+
def symmetric_topk_recall_score(model, classifier, loader, device, ttype):
45+
with torch.no_grad():
46+
average_recall = 0
47+
for y, img in tqdm(loader):
48+
out = infer_model(classifier, device, img, loader, model, ttype).squeeze()
49+
y = y.squeeze().to(device)
50+
recall = symmetric_topk_mass_recall(out, y, k=17)
51+
average_recall += recall.item()
52+
average_recall /= len(loader)
53+
print("Average symmetric top-k recall:", average_recall)
54+
55+
def distribution_iou(logits, targets):
56+
probs = F.softmax(logits, dim=-1)
57+
58+
intersection = torch.minimum(probs, targets).sum(dim=-1)
59+
union = torch.maximum(probs, targets).sum(dim=-1)
60+
61+
return (intersection / (union + 1e-8)).mean()
62+
63+
def IoU_score(model, classifier, loader, device, ttype):
64+
with torch.no_grad():
65+
average_recall = 0
66+
for y, img in tqdm(loader):
67+
out = infer_model(classifier, device, img, loader, model, ttype).squeeze()
68+
y = y.squeeze().to(device)
69+
recall = distribution_iou(out, y)
70+
average_recall += recall.item()
71+
average_recall /= len(loader)
72+
print("Average IoU score:", average_recall)

vla/benchmarks/class_distribution_using_segmentation/train.py

Lines changed: 7 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from model import SimpleClassifier
1010
from features import FeatureStore
1111
from precompute import precompute_features
12+
from metrics import score_loss, IoU_score, symmetric_topk_recall_score
1213

1314
ttype = torch.float32
1415

@@ -112,80 +113,6 @@ def train_classifier(
112113

113114
return classifier
114115

115-
def score_loss(model, classifier, loader, device):
116-
criterion = nn.KLDivLoss(reduction='batchmean')
117-
with torch.no_grad():
118-
sum_loss = 0
119-
for y, img in tqdm(loader):
120-
if loader.dataset.feature_store is None:
121-
feats = model.encode(img.to(device))
122-
else:
123-
feats = img.to(device)
124-
feats = feats.to(ttype)
125-
out = classifier(feats)
126-
log_probs = F.log_softmax(out, dim=1)
127-
loss = criterion(log_probs, y.to(device))
128-
sum_loss += loss.item()
129-
sum_loss /= len(loader)
130-
131-
print("Average loss:", sum_loss)
132-
133-
def symmetric_topk_mass_recall(logits, targets, k=10):
134-
135-
probs = F.softmax(logits, dim=-1)
136-
137-
topk_pred_idx = probs.topk(k, dim=-1).indices
138-
target_mass_in_pred_topk = torch.gather(targets, dim=-1, index=topk_pred_idx)
139-
recall_p_to_q = target_mass_in_pred_topk.sum(dim=-1)
140-
141-
topk_target_idx = targets.topk(k, dim=-1).indices
142-
pred_mass_in_target_topk = torch.gather(probs, dim=-1, index=topk_target_idx)
143-
recall_q_to_p = pred_mass_in_target_topk.sum(dim=-1)
144-
145-
symmetric_recall = 0.5 * (recall_p_to_q + recall_q_to_p)
146-
147-
return symmetric_recall.mean()
148-
149-
def symmetric_topk_recall_score(model, classifier, loader, device):
150-
with torch.no_grad():
151-
average_recall = 0
152-
for y, img in tqdm(loader):
153-
if loader.dataset.feature_store is None:
154-
feats = model.encode(img.to(device))
155-
else:
156-
feats = img.to(device)
157-
feats = feats.to(ttype)
158-
out = classifier(feats).squeeze()
159-
y = y.squeeze().to(device)
160-
recall = symmetric_topk_mass_recall(out, y, k=17)
161-
average_recall += recall.item()
162-
average_recall /= len(loader)
163-
print("Average symmetric top-k recall:", average_recall)
164-
165-
def distribution_iou(logits, targets):
166-
probs = F.softmax(logits, dim=-1)
167-
168-
intersection = torch.minimum(probs, targets).sum(dim=-1)
169-
union = torch.maximum(probs, targets).sum(dim=-1)
170-
171-
return (intersection / (union + 1e-8)).mean()
172-
173-
def IoU_score(model, classifier, loader, device):
174-
with torch.no_grad():
175-
average_recall = 0
176-
for y, img in tqdm(loader):
177-
if loader.dataset.feature_store is None:
178-
feats = model.encode(img.to(device))
179-
else:
180-
feats = img.to(device)
181-
feats = feats.to(ttype)
182-
out = classifier(feats).squeeze()
183-
y = y.squeeze().to(device)
184-
recall = distribution_iou(out, y)
185-
average_recall += recall.item()
186-
average_recall /= len(loader)
187-
print("Average IoU score:", average_recall)
188-
189116
def benchmark(model, preprocessor, train_json="train_los_dataset.json", test_json="test_los_dataset.json",
190117
use_precomputed_features=True, random_seed=None,
191118
generalization_set_folder="", config_path="example_config.json"):
@@ -253,11 +180,11 @@ def benchmark(model, preprocessor, train_json="train_los_dataset.json", test_jso
253180

254181
classifier = train_classifier(model, classifier, train_loader, test_loader, device)
255182
print("Computing score on test set:\n")
256-
symmetric_topk_recall_score(model, classifier, score_loader, device)
257-
IoU_score(model, classifier, score_loader, device)
258-
score_loss(model, classifier, score_loader, device)
183+
symmetric_topk_recall_score(model, classifier, score_loader, device, ttype)
184+
IoU_score(model, classifier, score_loader, device, ttype)
185+
score_loss(model, classifier, score_loader, device, ttype)
259186
if generalization_dataset_manager is not None:
260187
print("Computing score on generalization set:\n")
261-
symmetric_topk_recall_score(model, classifier, generalization_loader, device)
262-
IoU_score(model, classifier, generalization_loader, device)
263-
score_loss(model, classifier, generalization_loader, device)
188+
symmetric_topk_recall_score(model, classifier, generalization_loader, device, ttype)
189+
IoU_score(model, classifier, generalization_loader, device, ttype)
190+
score_loss(model, classifier, generalization_loader, device, ttype)

0 commit comments

Comments
 (0)