-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_lenet.py
More file actions
30 lines (21 loc) · 902 Bytes
/
test_lenet.py
File metadata and controls
30 lines (21 loc) · 902 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torchvision
from torch import nn, optim
import torchvision.transforms as transforms
from datasets import get_food101_dataset, get_cifar10_dataset
from models import LeNetModel, vgg13
from trainer import Trainer
batch_size = 10
criterion = nn.CrossEntropyLoss()
#model = LeNetModel((512, 512), 2)
model = vgg13()
selected_classes = ['foie_gras', 'tacos']
train_loader, test_loader = get_food101_dataset(selected_classes=selected_classes)
#train_loader, test_loader = get_cifar10_dataset()
trainer = Trainer(model, train_loader, test_loader)
trainer.train()
trainer.test()
#TrainingDataLoader = DataLoader(data, batch_size=batch_size, sampler=SubsetRandomSampler(data.training_indices()))
#TestDataLoader = DataLoader(data, batch_size=batch_size, sampler=SubsetRandomSampler(data.test_indices()))
#for index, data in enumerate(TrainingDataLoader):
# batch_out = net(data)