-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_spatial.py
More file actions
56 lines (45 loc) · 1.81 KB
/
Copy patheval_spatial.py
File metadata and controls
56 lines (45 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from train_spatial import SpatialImageDataset
from spatial_model import SpatialBranch
def evaluate(checkpoint="best_spatial.pth", batch_size=32):
repo_root = Path(__file__).resolve().parents[1]
data_root = repo_root / "dataset"
test_roots = [str(data_root / "hf")]
test_set = SpatialImageDataset(test_roots, augment=False)
if len(test_set) == 0:
raise RuntimeError(f"No test images found in: {test_roots}")
loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SpatialBranch(num_classes=2).to(device)
# try repo root first, then src/
if not Path(checkpoint).exists():
if (repo_root / checkpoint).exists():
checkpoint = str(repo_root / checkpoint)
elif (repo_root / "src" / checkpoint).exists():
checkpoint = str(repo_root / "src" / checkpoint)
state = torch.load(checkpoint, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()
correct = 0
total = 0
total_loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for imgs, labels in loader:
imgs = imgs.to(device)
labels = labels.to(device)
logits = model(imgs)
loss = criterion(logits, labels)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total_loss += loss.item() * labels.size(0)
total += labels.size(0)
acc = correct / max(1, total)
avg_loss = total_loss / max(1, total)
print(f"Test samples: {total}")
print(f"Loss: {avg_loss:.4f}")
print(f"Accuracy: {acc:.4f}")
if __name__ == "__main__":
evaluate()