diff --git a/ernie/loss/dpo.py b/ernie/loss/dpo.py index 914c9c2f3..29dddce73 100644 --- a/ernie/loss/dpo.py +++ b/ernie/loss/dpo.py @@ -312,6 +312,8 @@ def forward( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, score_deltas ) loss = dpo_loss + sft_loss + if "iluvatar" in paddle.get_device(): + paddle.device.empty_cache() if self.use_infohub: infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) diff --git a/examples/configs/iluvatar_gpu/ERNIE-4.5-21B-A3B/dpo/run_dpo_lora_8k.yaml b/examples/configs/iluvatar_gpu/ERNIE-4.5-21B-A3B/dpo/run_dpo_lora_8k.yaml new file mode 100644 index 000000000..52af7afa7 --- /dev/null +++ b/examples/configs/iluvatar_gpu/ERNIE-4.5-21B-A3B/dpo/run_dpo_lora_8k.yaml @@ -0,0 +1,105 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/dpo-train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/dpo-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 8192 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: /home/tianyu.zhou/ERNIE-4.5-21B-A3B-Paddle +moe_group: mp +fine_tuning: LoRA +lora_rank: 32 +lora_alpha: 128 +lora_plus_scale: 12 +rslora: True +fuse_rope: True +fuse_linear: True + +### finetuning +# base +stage: DPO +seed: 42 +do_train: True +do_eval: True +distributed_dataloader: True +dataloader_num_workers: 4 +batch_size: 1 +num_train_epochs: 1 +max_steps: 800 +max_evaluate_steps: 10000 +eval_steps: 20000 +evaluation_strategy: epoch +save_steps: 100 +save_total_limit: 5 +save_strategy: epoch +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# train +warmup_steps: 50 +learning_rate: 5.0e-7 +lr_scheduler_type: cosine +min_lr: 5.0e-7 +layerwise_lr_decay_bound: 1.0 +attention_probs_dropout_prob: 0.1 +dropout_warmup_steps: 100 + +# loss +offset_alpha: 0.0 +scale_loss: 8192 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 +offload_optim: True + +# performance +use_sp_callback: True +tensor_parallel_degree: 4 +tensor_parallel_config: "sync_param sync_grad sync_moment" +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache disable_batch_p2p_comm +recompute: True +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +amp_master_grad: True +amp_custom_white_list: + - "lookup_table" + - "lookup_table_v2" + - "flash_attn" + - "matmul" + - "matmul_v2" + - "fused_gemm_epilogue" +amp_custom_black_list: + - "reduce_sum" + - "softmax_with_cross_entropy" + - "c_softmax_with_cross_entropy" + - "elementwise_div" + - "sin" + - "cos" +unified_checkpoint: True +unified_checkpoint_config: async_save + +use_flash_attention: True +use_sparse_head_and_loss_fn: False +use_attn_mask_startend_row_indices: False +use_sparse_flash_attn: False +moe_multimodal_dispatch_use_allgather: "v2-alltoall" +device: iluvatar_gpu +fuse_rms_norm: False + +