-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmain.py
More file actions
76 lines (64 loc) · 2.93 KB
/
main.py
File metadata and controls
76 lines (64 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from dataset import MRIDataset
from torch.utils.data import DataLoader, Dataset, RandomSampler
from yAwareContrastiveLearning import yAwareCLModel
from losses import GeneralizedSupervisedNTXenLoss
from torch.nn import CrossEntropyLoss
from models.densenet import densenet121
from models.unet import UNet
import argparse
from config import Config, PRETRAINING, FINE_TUNING
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["pretraining", "finetuning"], required=True,
help="Set the training mode. Do not forget to configure config.py accordingly !")
args = parser.parse_args()
mode = PRETRAINING if args.mode == "pretraining" else FINE_TUNING
config = Config(mode)
if config.mode == mode:
dataset_train = MRIDataset(config, training=True)
dataset_val = MRIDataset(config, validation=True)
else:
## Fill with your target dataset
dataset_train = Dataset()
dataset_val = Dataset()
loader_train = DataLoader(dataset_train,
batch_size=config.batch_size,
sampler=RandomSampler(dataset_train),
collate_fn=dataset_train.collate_fn,
pin_memory=config.pin_mem,
num_workers=config.num_cpu_workers
)
loader_val = DataLoader(dataset_val,
batch_size=config.batch_size,
sampler=RandomSampler(dataset_val),
collate_fn=dataset_val.collate_fn,
pin_memory=config.pin_mem,
num_workers=config.num_cpu_workers
)
if config.mode == PRETRAINING:
if config.model == "DenseNet":
net = densenet121(mode="encoder", drop_rate=0.0)
elif config.model == "UNet":
net = UNet(config.num_classes, mode="simCLR")
else:
raise ValueError("Unkown model: %s"%config.model)
else:
if config.model == "DenseNet":
net = densenet121(mode="classifier", drop_rate=0.0, num_classes=config.num_classes)
elif config.model == "UNet":
net = UNet(config.num_classes, mode="classif")
else:
raise ValueError("Unkown model: %s"%config.model)
if config.mode == PRETRAINING:
loss = GeneralizedSupervisedNTXenLoss(temperature=config.temperature,
kernel='rbf',
sigma=config.sigma,
return_logits=True)
elif config.mode == FINE_TUNING:
loss = CrossEntropyLoss()
model = yAwareCLModel(net, loss, loader_train, loader_val, config)
if config.mode == PRETRAINING:
model.pretraining()
else:
model.fine_tuning()