-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_motor.py
More file actions
98 lines (72 loc) · 2.65 KB
/
train_motor.py
File metadata and controls
98 lines (72 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import transformers
import pathlib
import torch
import sys
import femr.models.transformer
import pickle
import datasets
import femr.models.tokenizer
import femr.models.processor
def main():
pretraining_data = pathlib.Path('pretraining_data')
ontology_path = pretraining_data / 'ontology.pkl'
with open(ontology_path, 'rb') as f:
ontology = pickle.load(f)
tokenizer_path = pretraining_data / 'tokenizer'
tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(tokenizer_path, ontology=ontology)
task_path = pretraining_data / 'motor_task.pkl'
with open(task_path, 'rb') as f:
motor_task = pickle.load(f)
processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)
train_batches_path = pretraining_data / 'train_batches'
train_batches = datasets.Dataset.load_from_disk(train_batches_path)
val_batches_path = pretraining_data / 'val_batches'
val_batches = datasets.Dataset.load_from_disk(val_batches_path)
val_batches = val_batches.select(range(120))
# Finally, given the batches, we can train CLMBR.
# We can use huggingface's trainer to do this.
transformer_config = femr.models.config.FEMRTransformerConfig(
vocab_size=tokenizer.vocab_size,
is_hierarchical=tokenizer.is_hierarchical,
n_layers=6,
use_normed_ages=True,
use_bias=False,
hidden_act='swiglu',
)
config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, motor_task.get_task_config())
model = femr.models.transformer.FEMRModel(config)
model = model.to(torch.device("cuda"))
collator = processor.collate
learning_rate = float(sys.argv[1])
trainer_config = transformers.TrainingArguments(
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
learning_rate=learning_rate,
output_dir='tmp_trainer_' + sys.argv[1],
remove_unused_columns=False,
bf16=True,
weight_decay=0.1,
adam_beta2=0.95,
report_to="tensorboard",
num_train_epochs=100,
warmup_steps=500,
logging_strategy='steps',
logging_steps=500,
disable_tqdm=True,
evaluation_strategy='steps',
eval_steps=500,
prediction_loss_only=True,
dataloader_num_workers=12,
save_total_limit=1,
load_best_model_at_end=True,
)
trainer = transformers.Trainer(
model=model,
data_collator=processor.collate,
train_dataset=train_batches,
eval_dataset=val_batches,
args=trainer_config,
)
trainer.train()
if __name__ == "__main__":
main()