-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconfig_demo_dn2.py
More file actions
69 lines (65 loc) · 2.11 KB
/
config_demo_dn2.py
File metadata and controls
69 lines (65 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn as nn
import torch.optim as optim
import model as model # Change here if you have a different `model.py` file
import encoder as encoder # Change here if you have a different `encoder.py` file
config = {
"chkp_name" : "demo_deepneo",
"chkp_path" : "models",
"log_file" : "train.log",
"plot_path" : "plots",
"seed" : 100,
"model" : model.DeepNeo_2_Custom,
"model_args" : {
},
"encoder" : encoder.deepneo_2,
"encoder_args" : {
},
"CrossValidation": {
"num_folds" : 5,
},
"Data": {
"epi_path" : "~/project/IMG/data/final/mhc2_full_human_train.csv",
"epi_args" : {
"epi_header": 'Epi_Seq',
"hla_header": 'HLA_Name',
"tgt_header": 'Target',
"seperator" : ",",
},
"hla_path" : "~/project/IMG/data/final/HLA2_IMGT_light.csv",
"hla_args" : {
"hla_header": 'HLA_Name',
"seq_header": 'HLA_Seq',
"seperator" : ",",
},
"test_path" : "~/project/IMG/data/final/tcell_human_train_set.csv",
"test_args" : {
"epi_header": 'Epi_Seq',
"hla_header": 'HLA_Name',
"tgt_header": 'Target',
"seperator" : ",",
},
"num_workers" : 8,
"val_size" : 0.2,
},
"Train": {
"batch_size" : 128,
"num_epochs" : 100,
"patience" : 10,
"regularize" : True, # true if regularize method is implemented in the model
"criterion" : nn.BCELoss,
"optimizer" : optim.AdamW,
"optimizer_args": {
"lr" : 1e-5,
# "weight_decay" : 0.01,
},
"use_scheduler" : False,
"chkp_prefix" : "best",
"transfer" : True,
},
"Test": {
"batch_size" : 1024,
"chkp_prefix" : "best",
"feat_path" : "feat.h5",
"target_layer" : "fc",
},
}