Skip to content

Commit 51787ba

Browse files
authored
commit dislora (#11122)
* commit dislora * add DisLoRATrainer * “readme文件修改” * Delete docs/zh/llm/benchmark/rl/README.md
1 parent a4d16fa commit 51787ba

20 files changed

Lines changed: 1968 additions & 20 deletions

File tree

docs/zh/llm/alignment/ppo/README.md

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
3+
"dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
4+
"output_dir": "./checkpoints/dislora_ckpts_3",
5+
"dislora": true,
6+
"per_device_train_batch_size": 1,
7+
"gradient_accumulation_steps": 5,
8+
"num_train_epochs": 1,
9+
"learning_rate": 2e-05,
10+
"lr_scheduler_type": "linear",
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "no",
14+
"save_strategy": "steps",
15+
"save_steps": 500,
16+
"src_length": 256,
17+
"max_length": 512,
18+
"bf16": true,
19+
"do_train": true,
20+
"do_eval": false,
21+
"disable_tqdm": false,
22+
"load_best_model_at_end": false,
23+
"eval_with_do_generation": false,
24+
"recompute": false,
25+
"save_total_limit": 5,
26+
"fp16_opt_level": "O2",
27+
"sharding": "stage3",
28+
"zero_padding": false,
29+
"use_flash_attention": false,
30+
"unified_checkpoint": false,
31+
"dislora_rank": 8,
32+
"dislora_dropout": 0.05,
33+
"target_modules": [".*q_proj.*", ".*v_proj.*", ".*k_proj.*", ".*o_proj.*"],
34+
"s_tsd": 8,
35+
"ortho_lambda": 1.0,
36+
"prefer_small_sigma": true
37+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
3+
"dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
4+
"output_dir": "./checkpoints/dislora_ckpts",
5+
"dislora": true,
6+
"per_device_train_batch_size": 1,
7+
"gradient_accumulation_steps": 1,
8+
"num_train_epochs": 1,
9+
"learning_rate": 2e-05,
10+
"lr_scheduler_type": "linear",
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "no",
14+
"save_strategy": "steps",
15+
"save_steps": 500,
16+
"src_length": 256,
17+
"max_length": 512,
18+
"bf16": true,
19+
"do_train": true,
20+
"do_eval": false,
21+
"disable_tqdm": false,
22+
"load_best_model_at_end": false,
23+
"eval_with_do_generation": false,
24+
"recompute": false,
25+
"save_total_limit": 5,
26+
"fp16_opt_level": "O2",
27+
"sharding": "stage3",
28+
"zero_padding": false,
29+
"use_flash_attention": false,
30+
"unified_checkpoint": false,
31+
"dislora_rank": 8,
32+
"dislora_dropout": 0.05,
33+
"s_tsd": 8,
34+
"ortho_lambda": 1.0,
35+
"prefer_small_sigma": true
36+
}

llm/docs/finetune.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,33 @@ python merge_lokr_params.py \
177177
- `device`: 运行环境,默认为 gpu。
178178
</div>
179179

180-
#### 3.4.4 ReFT
180+
#### 3.4.5 DisLoRA
181+
```
182+
# 单卡DisLoRA
183+
python run_finetune.py ./config/llama/dislora_argument.json
184+
185+
# 多卡DisLoRA(暂不支持张量模型并行)
186+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/dislora_argument.json
187+
```
188+
为了后续的**压缩****静态图推理**方便,我们提供 DisLoRA 参数合并脚本,可以将 DisLoRA 参数合并到主干模型并保存相应的权重。
189+
```
190+
python merge_dislora_params.py \
191+
--model_name_or_path ./base_model \
192+
--dislora_path ./checkpoints/dislora_ckpts \
193+
--merge_dislora_model_path ./checkpoints/dislora_merge \
194+
--device "gpu" \
195+
--low_gpu_mem True
196+
```
197+
198+
<summary>&emsp; 脚本参数介绍</summary><div>
199+
200+
- `dislora_path`: DisLoRA 参数和配置路径,对 DisLoRA 参数进行初始化,默认为 None。
201+
- `model_name_or_path`: 必须,主干模型参数路径,默认为 None。
202+
- `merge_dislora_model_path`: 必须,合并参数后保存路径,默认为 None。
203+
- `device`: 运行环境,默认为 gpu。
204+
</div>
205+
206+
#### 3.4.6 ReFT
181207
```
182208
# 单卡ReFT
183209
python run_finetune.py ./config/llama/reft_argument.json
@@ -228,6 +254,8 @@ python ./predict/reft_predictor.py \
228254
- `vera_rank`: VeRA 算法中 rank(秩)的值,默认为8。
229255
- `lokr`: 是否开启 [LoKr](https://arxiv.org/abs/2309.14859) 微调策略,默认为 False。
230256
- `lokr_rank`: LoKr 算法中 rank(秩)的值,默认为8。
257+
- `dislora`: 是否开启 [DisLoRA] 微调策略,默认为 False。
258+
- `dislora_rank`: DisLoRA 算法中 rank(秩)的值,默认为8。
231259
- `use_long_sequence_strategies`: 是否使用长序列扩展策略,默认为 False。
232260
- `reft`: 是否开启 [ReFT](https://arxiv.org/abs/2404.03592) 微调策略,默认为 False。
233261
- `use_mora`: 是否开启 [MoRA](https://arxiv.org/abs/2405.12130) 微调策略,默认为 False。

llm/run_finetune.py

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
3333
from paddlenlp.peft import (
34+
DisLoRAConfig,
35+
DisLoRAModel,
3436
LoKrConfig,
3537
LoKrModel,
3638
LoRAConfig,
@@ -68,7 +70,7 @@
6870
)
6971
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
7072
from paddlenlp.transformers.longlora import replace_llama_attn, set_group_size
71-
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
73+
from paddlenlp.trl import DataConfig, DisLoRATrainer, ModelConfig, SFTConfig, SFTTrainer
7274
from paddlenlp.trl.llm_utils import (
7375
ZeroPaddingIterDatasetCallback,
7476
compute_metrics,
@@ -311,6 +313,15 @@ def neft_post_hook(module, input, output):
311313
tokenizer.pad_token_id = tokenizer.eos_token_id
312314

313315
train_ds, dev_ds, test_ds = create_dataset(data_args, training_args)
316+
317+
train_dataset_size = None
318+
if train_ds is not None and model_args.dislora:
319+
train_dataset_size = get_dataset_size(train_ds)
320+
if train_dataset_size is not None:
321+
logger.info(f"Original training dataset size: {train_dataset_size}")
322+
else:
323+
logger.warning("Unable to determine training dataset size for dynamic dash_flag calculation")
324+
314325
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
315326
if training_args.resume_from_checkpoint is not None and data_args.lazy:
316327
logger.info(
@@ -377,7 +388,9 @@ def neft_post_hook(module, input, output):
377388
if eval_zero_padding and test_ds is not None:
378389
test_ds = intoken_dataset(test_ds, tokenizer=tokenizer, max_length=data_args.max_length)
379390

380-
model = create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers)
391+
model = create_peft_model(
392+
model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
393+
)
381394

382395
def compute_metrics_do_generation(eval_preds):
383396
rouge1 = Rouge1()
@@ -441,19 +454,30 @@ def compute_metrics_do_generation(eval_preds):
441454
return_attention_mask=not model_args.flash_mask,
442455
pad_to_multiple_of=data_args.pad_to_multiple_of,
443456
)
444-
trainer = SFTTrainer(
445-
model=model,
446-
args=training_args,
447-
train_dataset=train_ds,
448-
eval_dataset=dev_ds,
449-
tokenizer=tokenizer,
450-
compute_metrics=metrics,
451-
data_collator=data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn),
452-
do_generation=data_args.eval_with_do_generation,
453-
callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None,
454-
gen_args=gen_args,
455-
data_args=data_args,
456-
)
457+
458+
if model_args.dislora and hasattr(model_args, "ortho_lambda"):
459+
training_args.dislora_ortho_lambda = model_args.ortho_lambda
460+
461+
trainer_kwargs = {
462+
"model": model,
463+
"args": training_args,
464+
"train_dataset": train_ds,
465+
"eval_dataset": dev_ds,
466+
"tokenizer": tokenizer,
467+
"compute_metrics": metrics,
468+
"data_collator": data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn),
469+
"do_generation": data_args.eval_with_do_generation,
470+
"callbacks": [ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None,
471+
"gen_args": gen_args,
472+
"data_args": data_args,
473+
}
474+
475+
if model_args.dislora:
476+
logger.info("Using DisLoRATrainer for training.")
477+
trainer = DisLoRATrainer(**trainer_kwargs)
478+
else:
479+
trainer = SFTTrainer(**trainer_kwargs)
480+
457481
trainable_parameters = [
458482
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)
459483
]
@@ -531,7 +555,9 @@ def save_to_aistudio(model_args, training_args, trainer):
531555
)
532556

533557

534-
def create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers):
558+
def create_peft_model(
559+
model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
560+
):
535561
if model_args.prefix_tuning:
536562
if training_args.pipeline_parallel_degree > 1:
537563
raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.")
@@ -612,6 +638,53 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
612638
else:
613639
model = LoKrModel.from_pretrained(model=model, lokr_path=model_args.lokr_path)
614640

641+
if model_args.dislora:
642+
# Calculate dynamic dash_flag based on training configuration
643+
if train_dataset_size is not None and training_args.do_train:
644+
# Calculate warmup steps: len(train_data) * num_epochs // (batch_size * gradient_accumulation_steps * 3)
645+
effective_batch_size = (
646+
training_args.per_device_train_batch_size
647+
* training_args.gradient_accumulation_steps
648+
* training_args.dataset_world_size # Consider data parallel
649+
)
650+
calculated_dash_flag = (train_dataset_size * training_args.num_train_epochs) // (effective_batch_size * 3)
651+
652+
# Use calculated value if it's reasonable, otherwise fall back to model_args
653+
if calculated_dash_flag > 0:
654+
dash_flag = calculated_dash_flag
655+
logger.info(
656+
f"Calculated dynamic dash_flag: {dash_flag} based on dataset size: {train_dataset_size}, "
657+
f"epochs: {training_args.num_train_epochs}, effective batch size: {effective_batch_size}"
658+
)
659+
else:
660+
dash_flag = model_args.dash_flag
661+
logger.warning(
662+
f"Calculated dash_flag was {calculated_dash_flag}, using model_args.dash_flag: {dash_flag}"
663+
)
664+
else:
665+
dash_flag = getattr(model_args, "dash_flag", 50)
666+
if train_dataset_size is None:
667+
logger.info(
668+
f"Unable to calculate dynamic dash_flag (dataset size unknown), using configured dash_flag: {dash_flag}"
669+
)
670+
else:
671+
logger.info(f"Not in training mode, using configured dash_flag: {dash_flag}")
672+
if model_args.dislora_path is None:
673+
dislora_config = DisLoRAConfig(
674+
target_modules=model_args.target_modules
675+
if model_args.target_modules
676+
else get_lora_target_modules(model),
677+
r=model_args.dislora_rank,
678+
dislora_alpha=1.5 * model_args.dislora_rank,
679+
dislora_dropout=model_args.dislora_dropout,
680+
dtype=dtype,
681+
base_model_name_or_path=model_args.model_name_or_path,
682+
s_tsd=model_args.s_tsd,
683+
dash_flag=dash_flag, # Use calculated dash_flag
684+
ortho_lambda=model_args.ortho_lambda,
685+
)
686+
model = DisLoRAModel(model, dislora_config)
687+
615688
if model_args.reft:
616689
intervention_dtype = dtype
617690
intervention_params = {
@@ -751,5 +824,24 @@ def create_dataset(data_args, training_args):
751824
return train_ds, dev_ds, test_ds
752825

753826

827+
def get_dataset_size(dataset):
828+
"""Get the size of a dataset, handling both lazy and regular datasets"""
829+
if dataset is None:
830+
return None
831+
832+
try:
833+
if hasattr(dataset, "__len__"):
834+
return len(dataset)
835+
elif hasattr(dataset, "_length"):
836+
return dataset._length
837+
else:
838+
# For lazy datasets, we might need to iterate once to count
839+
logger.warning("Unable to determine dataset size directly for lazy loading dataset")
840+
return None
841+
except Exception as e:
842+
logger.warning(f"Error getting dataset size: {e}")
843+
return None
844+
845+
754846
if __name__ == "__main__":
755847
main()

0 commit comments

Comments
 (0)