-
Notifications
You must be signed in to change notification settings - Fork 7
Description
An error occurred when I reproduced simcse tri_encoder with the following parameters:
#!/bin/bash
model=${MODEL:-princeton-nlp/sup-simcse-roberta-base} # pre-trained model
encoding=${ENCODER_TYPE:-tri_encoder} # cross_encoder, bi_encoder, tri_encoder
lr=${LR:-3e-5} # learning rate
wd=${WD:-0.1} # weight decay
transform=${TRANSFORM:-False} # whether to use an additional linear layer after the encoder
objective=${OBJECTIVE:-triplet_mse} # mse, triplet, triplet_mse
triencoder_head=${TRIENCODER_HEAD:-hadamard} # hadamard, concat (set for tri_encoder)
seed=${SEED:-42}
output_dir=${OUTPUT_DIR:-tri_simcse_base}
config=enc_${encoding}_lr${lr}_wd${wd}_trans${transform}_obj${objective}tri${triencoder_head}s${seed}
train_file=${TRAIN_FILE:-data/csts_train.csv}
eval_file=${EVAL_FILE:-data/csts_validation.csv}
test_file=${TEST_FILE:-data/csts_test.csv}
python run_sts.py
--output_dir "${output_dir}/${model////}/${config}"
--model_name_or_path ${model}
--objective ${objective}
--encoding_type ${encoding}
--pooler_type cls
--freeze_encoder False
--transform ${transform}
--triencoder_head ${triencoder_head}
--max_seq_length 512
--train_file ${train_file}
--validation_file ${eval_file}
--test_file ${test_file}
--condition_only False
--sentences_only False
--do_train
--do_eval
--do_predict
--evaluation_strategy epoch
--per_device_train_batch_size 8
--gradient_accumulation_steps 4
--learning_rate ${lr}
--weight_decay ${wd}
--max_grad_norm 0.0
--num_train_epochs 3
--lr_scheduler_type linear
--warmup_ratio 0.1
--log_level info
--disable_tqdm True
--save_strategy epoch
--save_total_limit 1
--seed ${seed}
--data_seed ${seed}
--fp16 True
--log_time_interval 15
error:
[INFO|trainer.py:621] 2023-09-11 21:59:05,853 >> Using cuda_amp half precision backend
Traceback (most recent call last):
File "run_sts.py", line 568, in
main()
File "run_sts.py", line 503, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/home/cgt/anaconda3/envs/CSTS/lib/python3.8/site-packages/transformers/trainer.py", line 1662, in train
return inner_training_loop(
File "/home/cgt/anaconda3/envs/CSTS/lib/python3.8/site-packages/transformers/trainer.py", line 1674, in _inner_training_loop
train_dataloader = self.get_train_dataloader()
File "/home/cgt/CSTS/c-sts/utils/sts/triplet_trainer.py", line 95, in get_train_dataloader
train_sampler = TripletBatchSampler(
File "/home/cgt/CSTS/c-sts/utils/sts/triplet_trainer.py", line 40, in init
self.pairs = self._get_idx_pairs(self.trainer.train_dataset, sentence1_key, sentence2_key)
File "/home/cgt/CSTS/c-sts/utils/sts/triplet_trainer.py", line 47, in _get_idx_pairs
pairs[datum[sentence1_key] + '' + datum[sentence2_key]].append(ix)
KeyError: 'Input.sent_1'