-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbuild_utils.py
96 lines (64 loc) · 2.81 KB
/
build_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from transformers import get_scheduler
def build_optimizer(model, config):
optimizer_class = getattr(torch.optim, 'AdamW')
optimizer = optimizer_class(model.model.parameters(), lr=float(config.lr))
return optimizer
"""
def build_optimizer(model, length_train_loader, config):
optimizer_class = getattr(transformers, 'AdamW')
optimizer = optimizer_class(model.model.parameters(), lr=float(config.lr))
num_training_steps = config.train_epochs * length_train_loader
if config.flower and config.fl_params.num_rounds:
num_training_steps = num_training_steps * config.fl_params.num_rounds * config.fl_params.iterations_per_fl_round
lr scheduler disabled due to malfunctioning in FL setup.
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=config.warmup_iterations, num_training_steps=num_training_steps
)
return optimizer, lr_scheduler
"""
def build_model(config):
available_models = ['t5', 'vt5']
if config.model_name.lower() == 't5':
from models.T5 import T5
model = T5(config)
elif config.model_name.lower() == 'vt5':
from models.VT5 import VT5
model = VT5(config)
else:
raise ValueError("Value '{:s}' for model selection not expected. Please choose one of {:}".format(config.model_name, ', '.join(available_models)))
model.model.to(config.device)
return model
def build_dataset(config, split, client_id=None):
# Specify special params for data processing depending on the model used.
dataset_kwargs = {}
if config.model_name.lower() in ['vt5']:
dataset_kwargs['get_raw_ocr_data'] = True
if config.model_name.lower() in ['vt5']:
dataset_kwargs['use_images'] = True
if client_id is not None:
dataset_kwargs['client_id'] = client_id
# Build dataset
if config.dataset_name == 'PFL-DocVQA':
from datasets.PFL_DocVQA import PFL_DocVQA
dataset = PFL_DocVQA(config.imdb_dir, config.images_dir, split, dataset_kwargs)
else:
raise ValueError
return dataset
def build_provider_dataset(config, split, provider_to_doc, provider, client_id=None):
# Specify special params for data processing depending on the model used.
dataset_kwargs = {}
if config.model_name.lower() in ['vt5']:
dataset_kwargs['get_raw_ocr_data'] = True
if config.model_name.lower() in ['vt5']:
dataset_kwargs['use_images'] = True
if client_id:
dataset_kwargs['client_id'] = client_id
# Build dataset
indexes = provider_to_doc[provider]
if config.dataset_name == 'PFL-DocVQA':
from datasets.PFL_DocVQA import PFL_DocVQA
dataset = PFL_DocVQA(config.imdb_dir, config.images_dir, split, dataset_kwargs, indexes)
else:
raise ValueError
return dataset