-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
92 lines (82 loc) · 3.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Time: 2022-12-20-11-22
# Author: Xianxian Zeng
# Name: main.py
# Details: Model with Pytorch-Lightning for Fine-graiend Hashing
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import argparse
from datacsv import DInterface
from model import HInterface
def main(config):
data_module = DInterface(config=config)
model = HInterface(config=config)
# training
saving_dir = "log/%s_%s/%s_%s" % (config.model_name, config.dataset, config.code_length, config.dataset)
checkpoint_callback = ModelCheckpoint(dirpath=saving_dir,
filename='{epoch}-{val_mAP:.4f}',
save_top_k=1,
monitor='val_mAP',
mode='max')
csv_dir="log/%s_%s" % (config.model_name, config.dataset)
csv_name = "%s_%s" % (config.code_length, config.dataset)
logger = CSVLogger(save_dir=csv_dir,name=csv_name)
trainer = pl.Trainer(accelerator='gpu',
devices=config.gpu,
precision=16,
# limit_train_batches=0.5
default_root_dir=saving_dir,
num_sanity_val_steps=-1,
max_epochs=300,
callbacks=checkpoint_callback,
logger=logger,
)
trainer.fit(model, data_module)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# dataset selecting
parser.add_argument('--dataset', default='cub', help='select dataset for experiment')
# lr
parser.add_argument('--lr', default=0.01, type=float, help='learning rate for training')
# batch-size
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
# epoch
parser.add_argument('--epoch', default=300, type=int, help='epoch for training')
# gpu
parser.add_argument('--gpu', default=[0], help='gpu-id')
# num_workers
parser.add_argument('--num_workers', default=8, type=int, help='number workers for training')
# model
parser.add_argument('--model_name', default='resnet50', type=str, help='backbone model for experiment')
# code length
parser.add_argument('--code_length', default=32, type=int, help='code length for experiment')
args = parser.parse_args()
config = args
if config.dataset == 'cub':
config.classlen = 200
config.data_root = '/workspace/data/CUB_200_2011'
config.train_csv = './datacsv/cub/train.csv'
config.test_csv = './datacsv/cub/test.csv'
elif config.dataset == 'aircraft':
config.classlen = 100
config.data_root = '/workspace/data/FGVC/data'
config.train_csv = './datacsv/aircraft/train.csv'
config.test_csv = './datacsv/aircraft/test.csv'
elif config.dataset == 'food101':
config.classlen = 101
config.data_root = '/workspace/dataset/fine-grained-dataset/food-101/images'
config.train_csv = './datacsv/food101/train.csv'
config.test_csv = './datacsv/food101/test.csv'
elif config.dataset == 'nabirds':
config.classlen = 555
config.data_root = '/workspace/dataset/fine-grained-dataset/nabirds/'
config.train_csv = './datacsv/nabirds/train.csv'
config.test_csv = './datacsv/nabirds/test.csv'
elif config.dataset == 'vegfru':
config.classlen = 292
config.data_root = '/workspace/dataset/fine-grained-dataset/vegfru-dataset/'
config.train_csv = './datacsv/vegfru/train.csv'
config.test_csv = './datacsv/vegfru/test.csv'
else:
print("We have not provided the experiments of %s" % config.dataset)
main(config)