-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfig.yaml
31 lines (31 loc) · 3.37 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
data_filename: "dataset/train_data.txt" # 总数据集路径,未进行训练集测试集划分
split_ratio: "9:1" # 训练集、验证集划分比例,格式:9:1
random_seed: 42 # 随机数种子
dict_path: "pre-train_models/nezha_base/vocab.txt" # 分词器字典路径
maxlen: 256 # 每条数据最大长度
rdrop_number: 2 # r-drop参数,2;当值为1时,不采取r-drop策略
batch_size: 6 # 当rdrop_number参数不为1时,该值取双数
pretrain_config_path: "pre-train_models/nezha_base/bert_config.json" # 预训练模型配置文件
pretrain_checkpoint_path: "pre-train_models/nezha_base/model.ckpt-900000" # 保存的预训练模型路径
pretrain_name: "nezha" # 预训练模型名称
# 可选模型名称:bert_bilstm_crf、bert_crf、bert_global_pointer、bert_softmax、bert_span
model_name: "bert_bilstm_crf" # 模型名称,据此程序选择相应的模型、损失函数等
dropout_rate: 0.3 # dropout概率
hidden_size: 128 # 隐藏层维度,主要是LSTM部分
crf_lr_multiplier: 1000 # crf层学习率的放大倍数
learning_rate: 2e-5 # 学习率
exclude_from_weight_decay: ['Norm', 'bias'] # 权重衰减排除项
paramwise_lr_schedule: {'bidirectional_1':200} # 分参数学习率调整参数,字典形式
ema_momentum: 0.999 # 动量参数
eplison: 0.5 # 对抗训练(FGM)约束因子,>=0;当值为0时,不采用对抗策略
log_path: "./" # 日志文件保存文件夹
log_name: "bert_bilstm_crf" # 日志文件名(代码中会自动拼接日期,并保存为txt文件)
best_path: "best_model_bert_bilstm_crf.weights" # 训练最优模型保存路径
epochs: 20 # 训练代数
model_path: "best_model_bert_bilstm_crf.weights" # 评估、预测时加载模型路径
test_data: "dataset/train_data.txt" # 测试数据集路径
predict_data: "dataset/predict_data.txt" # 预测数据路径
predict_path: "dataset/predict_results.txt" # 预测结果保存路径
web_server: False # 在预测时是否开启web服务
web_host: "0.0.0.0" # 主机ip
web_port: 8033 # 端口号