-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcentralized_train.py
49 lines (37 loc) · 1.94 KB
/
centralized_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torch.utils.data import DataLoader
from build_utils import build_dataset, build_optimizer
from checkpoint import save_model
from datasets.PFL_DocVQA import collate_fn
from eval import evaluate
from logger import Logger
from metrics import Evaluator
from train import fl_train
from utils import seed_everything
def train(model, config):
epochs = config.train_epochs
# device = config.device
seed_everything(config.seed)
evaluator = Evaluator(case_sensitive=False)
logger = Logger(config)
logger.log_model_parameters(model)
train_dataset = build_dataset(config, 'train')
val_dataset = build_dataset(config, 'val')
train_data_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)
logger.len_dataset = len(train_data_loader)
optimizer, lr_scheduler = build_optimizer(model, length_train_loader=len(train_data_loader), config=config)
config.return_scores_by_sample = False
config.return_pred_answers = False
if getattr(config, 'eval_start', False):
logger.current_epoch = 0
accuracy, anls, ret_prec, _, _ = evaluate(val_data_loader, model, evaluator, config)
is_updated = evaluator.update_global_metrics(accuracy, anls, 0)
logger.log_val_metrics(accuracy, anls, ret_prec, update_best=is_updated)
logger.current_epoch += 1
for epoch_ix in range(epochs):
logger.current_epoch = epoch_ix
_ = fl_train(train_data_loader, model, optimizer, lr_scheduler, evaluator, logger)
accuracy, anls, ret_prec, _, _ = evaluate(val_data_loader, model, evaluator, config)
is_updated = evaluator.update_global_metrics(accuracy, anls, epoch_ix)
logger.log_val_metrics(accuracy, anls, ret_prec, update_best=is_updated)
save_model(model, epoch_ix, config, update_best=is_updated)