Skip to content

Latest commit

 

History

History
153 lines (110 loc) · 3.79 KB

File metadata and controls

153 lines (110 loc) · 3.79 KB

对 50k IMDB 电影评论的情绪进行分类的附加实验

 

步骤 1:安装依赖项

通过以下方式安装额外的依赖项

pip install -r requirements-extra.txt

 

步骤 2:下载数据集

代码使用来自 IMDb 的 50k 电影评论(数据集来源) 来预测电影评论是正面的还是负面的。

运行以下代码以创建 train.csvvalidation.csvtest.csv 数据集:

python download_prepare_dataset.py

 

步骤 3:运行模型

主要章节中使用的 124M GPT-2 模型,从预训练权重开始,并对所有权重进行微调:

python train_gpt.py --trainable_layers "all" --num_epochs 1
Ep 1(步骤 000000):训练损失 3.706,Val 损失 3.853
Ep 1(步骤 000050):训练损失 0.682,Val 损失 0.706
...
Ep 1(步骤 004300):训练损失 0.199,Val 损失 0.285
Ep 1(步骤 004350):训练损失 0.188,Val 损失 0.208
训练准确率:95.62% |验证准确率:95.00%
训练用时 9.48 分钟。

在完整数据集上进行评估...

训练准确率:95.64%
验证准确率:92.32%
测试准确率:91.88%



340M 参数编码器样式 BERT 模型:

python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "bert"
Ep 1 (步骤 000000):训练损失 0.848,验证损失 0.775
Ep 1 (步骤 000050):训练损失 0.655,验证损失 0.682
...
Ep 1 (步骤 004300):训练损失 0.146,验证损失0.318
Ep 1(步骤 004350):训练损失 0.204,验证损失 0.217
训练准确率:92.50% | 验证准确率:88.75%
训练在 7.65 分钟内完成。

在完整数据集上进行评估...

训练准确率:94.35%
验证准确率:90.74%
测试准确率:90.89%



66M 参数编码器样式 DistilBERT 模型(从 340M 参数 BERT 模型中提炼而来),从预训练权重开始,仅训练最后一个转换器块和输出层:

python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "distilbert"
Ep 1(步骤 000000):训练损失 0.693,验证损失 0.688
Ep 1(步骤 000050):训练损失0.452,Val 损失 0.460
...
Ep 1(步骤 004300):训练损失 0.179,Val 损失 0.272
Ep 1(步骤 004350):训练损失 0.199,Val 损失 0.182
训练准确率:95.62% | 验证准确率:91.25%
训练在 4.26 分钟内完成。

在完整数据集上进行评估...

训练准确率:95.30%
验证准确率:91.12%
测试准确率:91.40%



355M 参数编码器样式 RoBERTa 模型,从预训练权重开始,仅训练最后一个转换器块和输出层:

python train_bert_hf.py --trainable_layers "last_block" --num_epochs 1 --model "roberta"
Ep 1 (步骤 000000):训练损失 0.695,验证损失 0.698
Ep 1 (步骤 000050):训练损失 0.670,验证损失 0.690
...
Ep 1 (步骤 004300):训练损失 0.126,验证损失 0.149
Ep 1(步骤 004350):训练损失 0.211,验证损失 0.138
训练准确率:92.50% | 验证准确率:94.38%
训练在 7.20 分钟内完成。

对完整数据集进行评估...

训练准确率:93.44%
验证准确率:93.02%
测试准确率:92.95%



以 scikit-learn 逻辑回归分类器作为基线:

python train_sklearn_logreg.py
虚拟分类器:
训练准确率:50.01%
验证准确率:50.14%
测试准确率:49.91%

逻辑回归分类器:
训练准确率:99.80%
验证准确率:88.62%
测试准确率:88.85%