-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
152 lines (120 loc) · 5.29 KB
/
test_model.py
File metadata and controls
152 lines (120 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from model import DJMaxModel
import cv2
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
# Helper function to load npy file
def load_npy_file(path):
if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
return np.load(path)
# Helper function to ensure correct dimensions and resize
def preprocess_video(video):
if video.shape[-1] == 3: # Channel-last format
video = np.transpose(video, (0, 3, 1, 2)) # (T, C, H, W)
video = torch.from_numpy(video).float()
video = torch.nn.functional.interpolate(video, size=(112, 64), mode='bilinear')
return video
def test_video_and_labels(video_path, label_path, model, device, chunk_size=100):
video = load_npy_file(video_path) # (T, H, W, C)
labels = load_npy_file(label_path) # (T, 6)
print(f"Original video shape: {video.shape}")
video = preprocess_video(video)
print(f"Processed video shape: {video.shape}")
# Metrics storage
all_preds = []
all_true = []
model.eval()
with torch.no_grad():
for start_idx in range(0, len(video), chunk_size):
chunk = video[start_idx:start_idx+chunk_size].to(device)
chunk_labels = labels[start_idx:start_idx+chunk_size]
# Forward pass
outputs = model(chunk.unsqueeze(0), torch.tensor([len(chunk)]).to(device))
probs = torch.sigmoid(outputs).cpu().numpy()[0] # (seq_len, 6)
preds = (probs ).astype(float) #> 0.5
all_preds.append(preds)
all_true.append(chunk_labels)
# Concatenate all results
all_preds = np.concatenate(all_preds)
all_true = np.concatenate(all_true)
frame_start = 200
frame_duration = 200
frame_window = (frame_start, frame_start + frame_duration)
#visualization
plt.figure(figsize=(15, 10))
for k in range(6):
plt.subplot(2, 3, k+1)
plt.plot(all_true[frame_window[0]:frame_window[1], k], label='Ground Truth', alpha=0.7)
plt.plot(all_preds[frame_window[0]:frame_window[1], k], label='Predicted', alpha=0.7)
plt.title(f'Key {k} (frame {frame_window[0]} to frame {frame_window[1]})')
plt.legend()
plt.tight_layout()
plt.savefig('predictions_vs_truth.png')
plt.close()
# Calculate metrics
all_preds = (all_preds > 0.5)
print("\n=== Evaluation Metrics ===")
for k in range(6):
acc = accuracy_score(all_true[:, k], all_preds[:, k])
print(f"Key {k} Accuracy: {acc:.4f}")
print("\nOverall Metrics:")
print(f"Micro F1: {f1_score(all_true, all_preds, average='micro'):.4f}")
print(f"Micro Precision: {precision_score(all_true, all_preds, average='micro'):.4f}")
print(f"Micro Recall: {recall_score(all_true, all_preds, average='micro'):.4f}")
# Commented out video to npy conversion function
def video_to_npy(video_path: str, output_path: str):
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
frames = np.array(frames)
np.save(output_path, frames)
print(f"Saved {len(frames)} frames to {output_path}")
# Commented out directory processing logic
def process_all_mp4_in_directory(mp4_dir, output_dir, label_dir):
# Loop through all mp4 files in the given directory
for filename in os.listdir(mp4_dir):
if filename.endswith(".mp4"):
video_path = os.path.join(mp4_dir, filename)
# Get the base filename without extension
base_filename = os.path.splitext(filename)[0]
# Define paths for output video .npy and label .npy
video_npy_path = os.path.join(output_dir, f"{base_filename}.npy")
# label_path = os.path.join(label_dir, f"key_log-{base_filename}.npy")
# Check if the .npy file already exists, skip if it does
if not os.path.exists(video_npy_path):
os.makedirs(output_dir, exist_ok=True)
video_to_npy(video_path, video_npy_path)
#
# # Test video and labels
# if os.path.exists(label_path):
# test_video_and_labels(video_npy_path, label_path)
# else:
# print(f"Label file for {base_filename} not found.")
def main():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
model = DJMaxModel().to(device)
model.load_state_dict(torch.load("checkpoint_epoch_14_loss_0.2887_single_okey.pth", map_location=device))
# Create .npy from video (uncomment to use)
# mp4_path = "trimmed_capture-1746310117.151055.mp4"
# npy_output_path = "trimmed_capture-1746310117.151055.npy"
# os.makedirs(os.path.dirname(npy_output_path), exist_ok=True)
# video_to_npy(mp4_path, npy_output_path)
# Run test/evaluation
label_path = "key_log-1746310117.151055.npy"
test_video_and_labels("trimmed_capture-1746310117.151055.npy", label_path, model, device)
if __name__ == "__main__":
main()