-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
134 lines (112 loc) · 5.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
from model import spectroedanet
import torch
import torch.utils.data as torch_data
import torch.nn as nn
from loader.data_loader import PMEmoDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error, r2_score
def unpack_data(data, device: torch.device):
spectrogram, eda_data, arousal_label, valence_label, music_vector = data
spectrogram = spectrogram.to(device)
eda_data = eda_data.to(device)
arousal_label = arousal_label.to(device)
valence_label = valence_label.to(device)
music_vector = music_vector.to(device=device, dtype=torch.float32)
return spectrogram, eda_data, arousal_label, valence_label, music_vector
def evaluate_model(model_run, test_loader):
print("Evaluating Model: ", model_run)
flags = model_run.split('_')
usesSpectrogram = False
usesEDA = False
usesMusic = False
usesAttention = False
predictsArousal = False
predictsValence = False
for flag in flags:
if flag == 'usesSpectrogram':
usesSpectrogram = True
elif flag == 'usesEDA':
usesEDA = True
elif flag == 'usesMusic':
usesMusic = True
elif flag == 'usesAttention':
usesAttention = True
elif flag == 'predictsArousal':
predictsArousal = True
elif flag == 'predictsValence':
predictsValence = True
model = spectroedanet.SpectroEDANet(usesSpectrogram,
usesEDA,
usesMusic,
usesAttention,
predictsArousal,
predictsValence)
device_type = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_type)
model.to(device)
# Load the weights from the checkpoints folder. The file is named model_run.pth
model.load_state_dict(torch.load(f'checkpoints/{model_run}', map_location=device_type))
# Evaluate the model on the test set
model.eval()
test_loss = 0.0
arousal_preds = []
valence_preds = []
arousal_labels = []
valence_labels = []
criterion = nn.MSELoss()
with torch.no_grad():
for data in test_loader:
spectrogram, eda_data, arousal_label, valence_label, music_vector = unpack_data(data, device)
if model.predictsArousal and model.predictsValence:
arousal_output, valence_output = model(spectrogram, eda_data, music_vector)
arousal_loss = criterion(arousal_output, arousal_label)
valence_loss = criterion(valence_output, valence_label)
test_loss += arousal_loss.item() + valence_loss.item()
arousal_preds.extend(arousal_output.cpu().numpy())
valence_preds.extend(valence_output.cpu().numpy())
arousal_labels.extend(arousal_label.cpu().numpy())
valence_labels.extend(valence_label.cpu().numpy())
elif model.predictsArousal:
output = model(spectrogram, eda_data, music_vector)
test_loss = criterion(output, arousal_label)
arousal_preds.extend(output.cpu().numpy())
arousal_labels.extend(arousal_label.cpu().numpy())
elif model.predictsValence:
output = model(spectrogram, eda_data, music_vector)
test_loss = criterion(output, valence_label)
valence_preds.extend(output.cpu().numpy())
valence_labels.extend(valence_label.cpu().numpy())
# Calculate and print test metrics
if model.predictsArousal and model.predictsValence:
test_arousal_rmse = root_mean_squared_error(arousal_labels, arousal_preds)
test_valence_rmse = root_mean_squared_error(valence_labels, valence_preds)
test_arousal_r2 = r2_score(arousal_labels, arousal_preds)
test_valence_r2 = r2_score(valence_labels, valence_preds)
print(f"Test Loss: {test_loss / len(test_loader):.4f}, "
f"Arousal RMSE: {test_arousal_rmse:.4f}, "
f"Valence RMSE: {test_valence_rmse:.4f}, "
f"Arousal R2: {test_arousal_r2:.4f}, "
f"Valence R2: {test_valence_r2:.4f}")
elif model.predictsArousal:
test_arousal_rmse = root_mean_squared_error(arousal_labels, arousal_preds)
test_arousal_r2 = r2_score(arousal_labels, arousal_preds)
print(f"Test Loss: {test_loss / len(test_loader):.4f}, "
f"Arousal RMSE: {test_arousal_rmse:.4f}, "
f"Arousal R2: {test_arousal_r2:.4f}")
elif model.predictsValence:
test_valence_rmse = root_mean_squared_error(valence_labels, valence_preds)
test_valence_r2 = r2_score(valence_labels, valence_preds)
print(f"Test Loss: {test_loss / len(test_loader):.4f}, "
f"Valence RMSE: {test_valence_rmse:.4f}, "
f"Valence R2: {test_valence_r2:.4f}")
def test():
dataset = PMEmoDataset("dataset")
train_val_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
test_loader = torch_data.DataLoader(test_dataset, batch_size=32)
model_runs = json.load(open('final_models.json'))
for model_run in model_runs:
print(f'---------- {model_run} --------')
evaluate_model(model_run, test_loader)
if __name__ == "__main__":
test()