forked from xxyliuyang/qianyan_similarity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_simcse.py
36 lines (28 loc) · 936 Bytes
/
run_simcse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
import shutil
from allennlp.commands import main
import sys
device = -1
if len(sys.argv) == 2:
device = int(sys.argv[1])
exp_name = "simcse_base"
force = "force"
# 指定训练的 config,output_dir
config_file = "experiments/simcse_base.jsonnet"
overrides = json.dumps({"trainer": {"cuda_device": device}})
serialization_dir = "records/simcse/{}".format(exp_name)
# 是否覆盖 output_dir 文件夹:force 参数
assert force in ["force", "not_force"], "Please confirm whether to overwrite the output folder."
if force == "force":
shutil.rmtree(serialization_dir, ignore_errors=True)
# Assemble the command into sys.argv
sys.argv = [
"allennlp", # command name, not used by main
"train",
config_file,
"--include-package", "extends", # 模型的扩展包(路径)
"-o", overrides, # 覆盖掉 config 中的参数
"--file-friendly-logging",
"-s", serialization_dir
]
main()