-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmetrics.py
More file actions
65 lines (54 loc) · 2.19 KB
/
metrics.py
File metadata and controls
65 lines (54 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
# @file metrics.py
# @author Benedikt Mersch [mersch@igg.uni-bonn.de]
# Copyright (c) 2022 Benedikt Mersch, all rights reserved
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
class ClassificationMetrics(nn.Module):
def __init__(self, n_classes):
super().__init__()
self.n_classes = n_classes
self.conf_matrix = torch.zeros(
(self.n_classes, self.n_classes)).long()
# def compute_confusion_matrix(self, pred_labels: torch.Tensor, gt_labels: torch.Tensor):
def compute_confusion_matrix(self, pred_logits: torch.Tensor, gt_labels: torch.Tensor):
pred_softmax = F.softmax(pred_logits, dim=1)
pred_labels = torch.argmax(pred_softmax, axis=1).long()
gt_labels = gt_labels.long()
idxs = torch.stack([pred_labels, gt_labels], dim=0)
ones = torch.ones((idxs.shape[-1])).type_as(gt_labels)
self.conf_matrix = self.conf_matrix.index_put_(tuple(idxs), ones, accumulate=True)
def getStats(self, confusion_matrix):
# we only care about moving class
tp = confusion_matrix.diag()[1]
fp = confusion_matrix.sum(dim=1)[1] - tp
fn = confusion_matrix.sum(dim=0)[1] - tp
return tp, fp, fn
def getIoU(self, confusion_matrix):
tp, fp, fn = self.getStats(confusion_matrix)
intersection = tp
union = tp + fp + fn + 1e-15
iou = intersection / union
return iou
def getacc(self, confusion_matrix):
tp, fp, fn = self.getStats(confusion_matrix)
total_tp = tp.sum()
total = tp.sum() + fp.sum() + 1e-15
acc_mean = total_tp / total
return acc_mean
def getStaticIoU(self, confusion_matrix):
tp, fp, fn = self.getStats(confusion_matrix)
tn = confusion_matrix.diag()[0]
intersection = tn
union = tn + fp + fn + 1e-15
iou = intersection / union
return iou
def getStaticAcc(self, confusion_matrix):
tp, fp, fn = self.getStats(confusion_matrix)
tn = confusion_matrix.diag()[0]
total_tn = tn.sum()
total = tn.sum() + fn.sum() + 1e-15
acc_mean = total_tn / total
return acc_mean