-
Notifications
You must be signed in to change notification settings - Fork 173
Expand file tree
/
Copy pathtrain.py
More file actions
29 lines (23 loc) · 773 Bytes
/
train.py
File metadata and controls
29 lines (23 loc) · 773 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# -*- coding: utf-8 -*-
# @Author : ssbuild
# @Time : 2023/10/12 10:50
import os
from config import global_args
def main():
trainer_backend = global_args["trainer_backend"]
if trainer_backend == "pl":
from training.train_pl import main as main_execute
elif trainer_backend == "hf":
from training.train_hf import main as main_execute
elif trainer_backend == "cl":
from training.train_cl import main as main_execute
elif trainer_backend == "ac":
from training.train_ac import main as main_execute
else:
raise ValueError(f"{trainer_backend} NotImplemented ")
main_execute()
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()