Our experiment is based on the repo Fairseq. For the requirements and installation of Fairseq and Apex, please refer to that repo.
Our experiments for this task are based on the following pkg version.
torch.__version__ = '1.10.1+cu111'
torchvision.__version__ = '0.11.2+cu111'
torchaudio.__version__ = '0.10.1+cu111'
fairseq.__version__ = '0.12.2'If you want to strictly follow our environment, please refer to our released docker image xyxie/adan-image:fairseq.
Please first put the file adan.py to the directory path/to/fairseq/fairseq/optim. Then you can choose Adan as the optimizer in the config file. See following example for pre-training:
optimizer:
_name: adan
weight_decay: 0.02
adan_betas: (0.98,0.92,0.99)
adan_eps: 1e-08The following steps are modified from Fairseq-Roberta. For completeness, we list some key steps here.
Data should be preprocessed following the language modeling format. That is, each document should be separated by an empty line (only useful with --sample-break-mode complete_doc, and all lines should be concatenated as a 1D text stream during training.
In the following steps, we use the Bookcorpus dataset and Wikipedia to demonstrate how to preprocess raw text data with the GPT-2 BPE.
wget https://the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz
tar -zxvf books1.tar.gz -C ./bert-corpus/pip install datasets
from datasets import load_dataset
dataset = load_dataset("wikipedia", "20220301.en")-
For wikipedia dataset, we need to read each line of the json line file , replace the
\nin the text field with a space, and write the line (add\nat the end), to the file newall_data.raw. -
For bookcorpus dataset, read out the contexts of each book, then replace the
\nwith the space, and then write the context of the book as one line inall_data.raw, ended up with\n. -
Split the
all_data.rawin towiki.train.rawandwiki.dev.rawwith the ratio of 99:1. Setwiki.test.raw = wiki.dev.rawfor compatibility of fairseq.
mkdir -p gpt2_bpe
wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
for SPLIT in train valid test; do \
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json gpt2_bpe/encoder.json \
--vocab-bpe gpt2_bpe/vocab.bpe \
--inputs bert-corpus/wiki.${SPLIT}.raw \
--outputs bert-corpus/wiki.${SPLIT}.bpe \
--keep-empty \
--workers 60; \
donewget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
fairseq-preprocess \
--only-source \
--srcdict gpt2_bpe/dict.txt \
--trainpref bert-corpus/wiki.train.bpe \
--validpref bert-corpus/wiki.valid.bpe \
--testpref bert-corpus/wiki.test.bpe \
--destdir data-bin/bert-corpus \
--workers 60Put the provided config files to the directory path/to/fairseq/examples/roberta/config/pretraining
DATA_DIR=/path/to/fairseq/bert-corpus
fairseq-hydra-train -m --config-dir examples/roberta/config/pretraining \
--config-name ${NAME} task.data=$DATA_DIR \
checkpoint.save_dir=/path/to/save_dir/
We can optionally resume the training of the released BERT-base model by adding checkpoint.restore_file=/path/to/model.pt. Note, in our experiments, we use Adan to train BERT-base from scratch. You can use the following config files to train BERT-base with Adam or Adan:
| NAME | Optimizer | Config | Download |
|---|---|---|---|
| bert-base | Adam | config | log/model |
| bert-adan | Adan | config | log/model |
The above command assumes the training is on 8x40GB A100 GPUs. Each GPU uses a batch size of 32 sequences (dataset.batch_size). If you have fewer GPUs or GPUs with less memory, you may need to reduce dataset.batch_size and increase dataset.update_freq to compensate. Alternatively if you have more GPUs you can decrease dataset.update_freq accordingly to improve the training speed.
1) Download the data from GLUE website using following commands:
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
python download_glue_data.py --data_dir glue_data --tasks allThere some problems to download MRPC and MNLI , hence we pass the MRPC task and download the data of MNLI from the unofficial sources.
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>glue_task_nameis one of the following:{ALL, QQP, MNLI, QNLI, RTE, STS-B, SST-2, CoLA}. UseALLfor preprocessing all the glue tasks.
Example fine-tuning cmd for RTE task
TASK=RTE;
python path/to/fairseq/examples/roberta/config/finetuning/acc_test.py --avg_num 1 \
--data_path /path/to/fairseq/GLUE/glue_data/$TASK \
--bin_path /path/to/fairseq/GLUE/$TASK-bin \
--pre_path /path/to/fairseq/bert-adan/checkpoint_best.pt \
--finetune_path /path/to/fairseq/bert-fintune/adan/$TASK/ \
--task rte-adan-
avg_numnumber of repetitions. -
data_pathpath to the data of GLUE task, e.g., CoLA, MNLI, etc. -
bin_pathsimilar todata_path, but is path to the binarized data after processing. -
pre_pathpath to the pre-trained model. -
finetune_pathpath to save/load fine-tuned model. -
taskconfig name, please refer to the directory of fine-tuning for the additional config files for each of the GLUE tasks. -
This cmd-args and hyperparams are tested on one Nvidia
A100GPU with40gbof memory for each task. Depending on the GPU memory resources available to you, you can use increase--update-freqand reduce--batch-size.
After training the model by using previous step, we can perform inference with checkpoints in finetune_path directory using following code snippet:
TASK=RTE;
python path/to/fairseq/examples/roberta/config/finetuning/acc_test.py --inference \
--data_path /path/to/fairseq/GLUE/glue_data/$TASK \
--bin_path /path/to/fairseq/GLUE/$TASK-bin \
--pre_path /path/to/fairseq/bert-adan/checkpoint_best.pt \
--finetune_path /path/to/fairseq/bert-fintune/adan/$TASK/ \
--task rte-adan
This should give:
| GLUE-Task | Metric | Result | Config |
|---|---|---|---|
| CoLA | Matthew's corr. | 64.6 | config |
| SST-2 | Accuracy | 93.2 | config |
| STS-B | Person corr. | 89.3 | config |
| QQP | Accuracy | 91.2 | config |
| MNLI | Matched acc./Mismatched acc. | 85.7/85.6 | config |
| QNLI | Accuracy | 91.3 | config |
| RTE | Accuracy | 73.3 | config |