-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
270 lines (214 loc) · 8.66 KB
/
train.py
File metadata and controls
270 lines (214 loc) · 8.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from models import SpikingLightningModel, SpikingMLP, SpikingGEM
from yin_yang_data_set.dataset import YinYangDataset
import argparse
from functools import partial
from multiprocessing import Process
import os
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import lightning
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from snntorch import spikegen
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import wandb
PROJECT = 'gen-robust'
BETA = .9 # NOTE: for tau=10
NUM_STEPS = 100
BS = 512 # NOTE: 16 for iris
NUM_EPOCHS = 300
NUM_WORKERS = 0
CKPT_DIR = 'ckpts'
GPUS = [2, 5, 7]
NUM_PROCESSES = len(GPUS)*1
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str)
parser.add_argument('-d', '--dataset', type=str)
parser.add_argument('-ha', '--hardware_aware', action='store_true')
args = parser.parse_args()
if args.dataset == 'mnist':
MNIST_DIR = 'datasets'
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(
lambda x: spikegen.rate(x.unsqueeze(0), num_steps=NUM_STEPS, gain=.1)\
.flatten(start_dim=1)
)
])
dataset = MNIST(MNIST_DIR, train=True, transform=transform)
test_dataset = MNIST(MNIST_DIR, train=False, transform=transform)
val_size = 5000
train_dataset, val_dataset = random_split(
dataset,
[len(dataset) - val_size, val_size],
generator=torch.Generator().manual_seed(1)
)
layer_sizes = [28**2, 28**2, 10]
num_genes = [8, 64, 256]
elif args.dataset == 'yinyang':
def rate_encoding(sample):
"""Converts the input data into Poisson spike trains.
"""
x, y = sample
x, y = torch.tensor(x).float(), torch.tensor(y)
x = spikegen.rate(x.unsqueeze(0), num_steps=NUM_STEPS, gain=1.).squeeze()
return x, y
# initialize the dataset splits
train_dataset = YinYangDataset(size=5000, seed=42, transform=rate_encoding)
val_dataset = YinYangDataset(size=1000, seed=41, transform=rate_encoding)
test_dataset = YinYangDataset(size=1000, seed=40, transform=rate_encoding)
# TODO: update before training
layer_sizes = [4, 32, 3]
num_genes = [4, 8, 16] # NOTE: it seems very beneficial to over parametrize
# num_genes = [4, 16, 64] # NOTE: it seems very beneficial to over parametrize
elif args.dataset == 'iris':
# load the Iris dataset
iris = datasets.load_iris()
X, Y = iris['data'], iris['target']
# normalize
scaler = MinMaxScaler()
X = scaler.fit_transform(X)
# split dataset into train, validation and test splits
X_train_val, X_test, Y_train_val, Y_test = train_test_split(X, Y, test_size=0.3, random_state=1)
X_train, X_val, Y_train, Y_val = train_test_split(X_train_val, Y_train_val, test_size=0.1, random_state=1)
# rate encoding
seed_everything(seed=1)
X_train = spikegen.rate(torch.tensor(X_train), num_steps=NUM_STEPS).permute(1, 0, -1)
X_val = spikegen.rate(torch.tensor(X_val), num_steps=NUM_STEPS).permute(1, 0, -1)
X_test = spikegen.rate(torch.tensor(X_test), num_steps=NUM_STEPS).permute(1, 0, -1)
# convert labels into tensors
Y_train, Y_val, Y_test = torch.tensor(Y_train), torch.tensor(Y_val), torch.tensor(Y_test)
# to PyTorch Datasets
train_dataset = TensorDataset(X_train.float(), Y_train.long())
val_dataset = TensorDataset(X_val.float(), Y_val.long())
test_dataset = TensorDataset(X_test.float(), Y_test.long())
layer_sizes = [4, 128, 3]
num_genes = [2, 4, 6]
else:
raise Exception('Invalid dataset')
# initialize the sweep name
sweep_name = args.model
# build a suffix for making the sweep name descriptive
sweep_suffix = '-' + '-'.join([str(size) for size in layer_sizes[1:-1]])
if args.hardware_aware: sweep_suffix += '-ha'
sweep_name += sweep_suffix
# W&B config
wandb.login()
CONFIG = {
'name': sweep_name,
'method': 'grid',
'metric': {'goal': 'minimize', 'name': 'val_loss'},
'parameters': {
'model': {'values': [args.model]},
'dataset': {'values': [args.dataset]},
'layer_sizes': {'values': [layer_sizes]},
'num_genes': {'values': num_genes if args.model == 'gem' else [0]},
'learning_rate': {'values': [3e-2, 3e-3, 3e-4]},
'seed': {'values': [1]},
'hardware_aware': {'values': [args.hardware_aware]},
}
}
sweep_id = wandb.sweep(CONFIG, project=PROJECT)
def train(device_id):
"""Wrapper for training.
"""
# set hyperparameters
run = wandb.init()
config = wandb.config
layer_sizes = config['layer_sizes']
lr = config['learning_rate']
seed = config['seed']
if args.model == 'gem': num_genes = config['num_genes']
# initialize data loaders
seed_everything(seed=1, workers=True)
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True, num_workers=NUM_WORKERS)
val_dataloader = DataLoader(val_dataset, batch_size=BS, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(test_dataset, batch_size=BS, num_workers=NUM_WORKERS)
# initialize the model
seed_everything(seed=seed, workers=True)
if args.model == 'mlp':
model = SpikingMLP(layer_sizes=layer_sizes, beta=BETA, num_steps=NUM_STEPS, bias=False)
elif args.model == 'gem':
model = SpikingGEM(layer_sizes=layer_sizes, num_genes=num_genes, beta=BETA, num_steps=NUM_STEPS, bias=False)
else:
raise Exception('Invalid model')
model = SpikingLightningModel(
model,
num_classes=layer_sizes[-1],
max_epochs=NUM_EPOCHS,
lr=lr # NOTE: very important for GEM
)
# define callbacks
ckpt_best = ModelCheckpoint(
dirpath=os.path.join(CKPT_DIR, sweep_name),
filename=run.name,
monitor='val_loss',
mode='min'
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
class SetValidationSeed(lightning.Callback):
"""Callback that sets the random seed for validation in order to keep the rate encoding of validation samples the same across validation epochs.
"""
def on_validation_start(self, trainer, pl_module):
seed_everything(seed=1, workers=True)
class ChangeWeightNoiseSeed(lightning.Callback):
"""Callback that changes the random seed for weight noise generation at each training step. It is used for hardware-aware training. The callback also disables noise generation during validation epochs.
"""
def __init__(self, weight_cv):
"""Sets the coefficient of variation for weight noise.
"""
super().__init__()
self.weight_cv = weight_cv
def on_train_epoch_start(self, trainer, pl_module):
# enable weight noise
pl_module.model.weight_cv = self.weight_cv
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
# increment the random seed
pl_module.model.weight_noise_seed += 1
def on_validation_start(self, trainer, pl_module):
# disable weight noise
pl_module.model.weight_cv = 0.
def on_train_end(self, trainer, pl_module):
# disable weight noise
pl_module.model.weight_cv = 0.
callbacks = [ckpt_best, lr_monitor, SetValidationSeed()]
if args.hardware_aware: callbacks.append(ChangeWeightNoiseSeed(weight_cv=0.1)) # TODO: set CV
trainer = Trainer(
max_epochs=NUM_EPOCHS,
callbacks=callbacks,
accelerator='gpu',
devices=[device_id],
logger=WandbLogger(),
enable_progress_bar=False
)
# train
trainer.fit(model, train_dataloader, val_dataloader)
# test
seed_everything(seed=1, workers=True) # NOTE: rate encoding is stochastic
trainer.test(dataloaders=test_dataloader, ckpt_path='best')
run.finish()
# parallelize sweep on multiple processes
def run_agent(device_id):
wandb.agent(
sweep_id,
function=partial(train, device_id=device_id),
project=PROJECT
)
# create a list to store the processes
processes = []
# start the parallel processes
process_devices = []
for gpu in GPUS:
process_devices += NUM_PROCESSES//len(GPUS)*[gpu]
for device_id in process_devices:
process = Process(target=run_agent, args=(device_id,))
process.start()
processes.append(process)
# wait for all processes to finish
for process in processes:
process.join()