-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_selective_classification.py
More file actions
140 lines (112 loc) · 5.06 KB
/
test_selective_classification.py
File metadata and controls
140 lines (112 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import auc
# ---------- Data ----------
def get_testloader(batch_size=100):
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408),
(0.2675, 0.2565, 0.2761))
])
testset = torchvision.datasets.CIFAR100(root='./data', train=False,
download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
return testloader
# ---------- Confidence utils ----------
def nearest_boundary_distance(logits, yhat, W, b, topM=5):
"""Closed-form nearest decision boundary distance"""
B, K = logits.shape
device = logits.device
idx = torch.arange(B, device=device)
zy = logits[idx, yhat]
masked = logits.clone()
masked[idx, yhat] = -1e9
rivals = masked.topk(k=min(topM, K-1), dim=1).indices
zy_exp = zy[:, None].expand_as(rivals)
num = (zy_exp - logits[idx[:, None], rivals]).abs()
Wnorms = (W[yhat][:, None, :] - W[rivals]).norm(dim=2)
den = Wnorms.clamp_min(1e-6)
d_all = num / den
return d_all.min(dim=1).values
def collect_predictions(model, dataloader, device):
model.eval()
preds, confs_rr, confs_msp, labels = [], [], [], []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
W = module.weight
b = module.bias
softmax = torch.nn.Softmax(dim=1)
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
# Confidence 1: nearest boundary distance
conf_rr = nearest_boundary_distance(logits, pred, W, b)
# Confidence 2: max softmax probability
conf_msp = softmax(logits).max(dim=1).values
preds.append(pred.cpu())
confs_rr.append(conf_rr.cpu())
confs_msp.append(conf_msp.cpu())
labels.append(y.cpu())
return (torch.cat(preds).numpy(),
torch.cat(confs_rr).numpy(),
torch.cat(confs_msp).numpy(),
torch.cat(labels).numpy())
# ---------- Metrics ----------
def compute_risk_aurc(preds, confs, labels, coverage_target=0.95):
correct = (preds == labels).astype(np.int32)
N = len(correct)
# Sort by confidence (descending)
order = np.argsort(-confs)
correct_sorted = correct[order]
coverage = np.arange(1, N+1) / N
risk = 1.0 - np.cumsum(correct_sorted) / np.arange(1, N+1)
# Risk@95% coverage
idx_95 = np.searchsorted(coverage, coverage_target)
risk95 = risk[idx_95]
# AURC (numerical integration)
aurc = np.mean(risk)
# ----- Optimal and Random baselines -----
N_err = N - correct.sum()
# Perfect ordering: all correct first, all wrong last
k = np.arange(1, N+1)
errors_cumsum_opt = np.clip(k - (N - N_err), 0, N_err)
risk_opt = errors_cumsum_opt / k
aurc_opt = np.mean(risk_opt)
# Random baseline (expected)
acc = correct.mean()
aurc_rand = 0.5 - acc / 2
# Compute EAURC and NAURC
eaurc = aurc - aurc_opt
naurc = eaurc / (aurc_rand - aurc_opt + 1e-12)
return risk95, aurc, eaurc, naurc
# ---------- Main ----------
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "resnet56"
# Load model
from train import ResNet56 # adjust import if needed
model = ResNet56(model_name=model_name, num_classes=10).to(device)
# ckpt = torch.load(f"/home/microway/Selective Classification/resnet_train/checkpoints/{model_name}_best_model.pt", map_location=device)
# ckpt = torch.load(f"/home/microway/Selective Classification/resnet_train/{model_name}_rat_best.pth", map_location=device)
ckpt = torch.load(f"/home/microway/Selective Classification/resnet_train/checkpoints/cifar10_resnet56_best.pth", map_location=device)
if ckpt.keys().__contains__('model_state_dict'):
ckpt = ckpt['model_state_dict']
elif ckpt.keys().__contains__('backbone_state_dict'):
ckpt = ckpt['backbone_state_dict']
model.load_state_dict(ckpt)
# Get test data
testloader = get_testloader(batch_size=100)
# Collect preds + confidence scores
preds, confs_rr, confs_msp, labels = collect_predictions(model, testloader, device)
# Evaluate both
risk95_rr, aurc_rr, eaurc_rr, naurc_rr = compute_risk_aurc(preds, confs_rr, labels)
risk95_msp, aurc_msp, eaurc_msp, naurc_msp = compute_risk_aurc(preds, confs_msp, labels)
print("=== Selective Classification Results ===")
print(f"Nearest-Boundary Distance | Risk@95%: {risk95_rr:.4f} | AURC: {aurc_rr:.4f} | EAURC: {eaurc_rr:.4f} | NAURC: {naurc_rr:.4f}")
print(f"Max Softmax Probability | Risk@95%: {risk95_msp:.4f} | AURC: {aurc_msp:.4f} | EAURC: {eaurc_msp:.4f} | NAURC: {naurc_msp:.4f}")