-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
131 lines (104 loc) · 4.62 KB
/
train.py
File metadata and controls
131 lines (104 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm, trange
from model import DJMaxModel
from dataset import VideoDataset
import matplotlib.pyplot as plt
import os
def moving_avg(t):
# shape of t is (B, T, 6)
t = t.permute(0, 2, 1)
k_size = 3
center = k_size // 2
kernel = [1 - abs(i - center)/center for i in range(k_size)]
kernel = torch.tensor(kernel)
kernel = kernel.repeat(6, 1).unsqueeze(1).to(t.device)
out = F.conv1d(t, kernel, groups=6, padding="same")
return out.permute(0, 2, 1)
def dynamic_loss(predict, target):
# both have shape (B, T, 6)
delta = torch.abs(target - target.roll(1, 1))
delta[:, 0] = 0
weight = torch.ones_like(target) + delta + delta.roll(1, 1) + delta.roll(1, 2) + delta.roll(1, -1)
target = moving_avg(target).clamp(max=1.0)
# Use pos_weight to compensate for class imbalance
# pos_weight = torch.tensor([21764 / 3112] * 6).to(device)
# pos_weight = (torch.ones(6)*(771061/89765)).to(device)
pos_weight = torch.tensor([25, 25, 25, 25, 25, 25], dtype=torch.float32).to(target.device)
# we need pos_weight since most of the time a key isn't pressed; there are
# way more 0's than 1's
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
loss = criterion(predict, target)
weighted = loss * weight
return weighted.mean()
# Main training loop
def train(checkpoint: str=None):
"""
Trains the model. If checkpoint is specified, trains the specified checkpoint.
"""
config = {
'lr': 0.0005, # Learning rate
'batch_size': 1, # Single video per batch
'num_epochs': 60, # Number of full passes over data
'seq_length': 100 # Number of frames per training chunk
}
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# Load the data!
dataset = VideoDataset(
"./capture_preprop", "./key_log_preprop"
)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)
# Initialize model
model = DJMaxModel().to(device)
epoch_last = 0
# Load from an existing checkpoint
if checkpoint is not None:
model.load_state_dict(torch.load(checkpoint, weights_only=True))
epoch_last = int(checkpoint.split('_')[2].removesuffix('.pth'))
config['num_epochs'] += epoch_last
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
criterion = dynamic_loss
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
# Training over multiple epochs
for epoch in range(epoch_last, config['num_epochs']):
model.train()
epoch_loss = 0
num_chunks = 0
# Loop over each sample in the dataset
for video, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
video = video.to(device) # shape: (B, T, C, H, W)
labels = labels.to(device) # shape: (B, T, 6)
seq_len = video.shape[1]
# Train on sequential chunks
# I have this because if it is too big it crashes my MAC :(
for i in range(0, seq_len, config['seq_length']):
chunk = video[:, i:i+config['seq_length']] # shape: (B, chunk_len, C, H, W)
chunk_labels = labels[:, i:i+config['seq_length']] # shape: (B, chunk_len, 6)
chunk_len = chunk.shape[1]
optimizer.zero_grad()
outputs = model(chunk.contiguous(), torch.tensor([chunk_len]).to(device)) # raw logits
# need .contiguous() to ensure training on MPS device
loss = criterion(outputs, chunk_labels)
loss.backward()
optimizer.step()
num_chunks += 1
epoch_loss += loss.item()
# Average loss per chunk for reporting
avg_loss = epoch_loss / num_chunks
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
# Save model checkpoint after each epoch
torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}_loss_{avg_loss:.4f}.pth")
if __name__ == "__main__":
train("checkpoint_epoch_14_loss_0.2887_single_okey.pth")