|
31 | 31 | ) |
32 | 32 | from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL |
33 | 33 | from paddlenlp.peft import ( |
| 34 | + DisLoRAConfig, |
| 35 | + DisLoRAModel, |
34 | 36 | LoKrConfig, |
35 | 37 | LoKrModel, |
36 | 38 | LoRAConfig, |
|
68 | 70 | ) |
69 | 71 | from paddlenlp.transformers.configuration_utils import LlmMetaConfig |
70 | 72 | from paddlenlp.transformers.longlora import replace_llama_attn, set_group_size |
71 | | -from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer |
| 73 | +from paddlenlp.trl import DataConfig, DisLoRATrainer, ModelConfig, SFTConfig, SFTTrainer |
72 | 74 | from paddlenlp.trl.llm_utils import ( |
73 | 75 | ZeroPaddingIterDatasetCallback, |
74 | 76 | compute_metrics, |
@@ -311,6 +313,15 @@ def neft_post_hook(module, input, output): |
311 | 313 | tokenizer.pad_token_id = tokenizer.eos_token_id |
312 | 314 |
|
313 | 315 | train_ds, dev_ds, test_ds = create_dataset(data_args, training_args) |
| 316 | + |
| 317 | + train_dataset_size = None |
| 318 | + if train_ds is not None and model_args.dislora: |
| 319 | + train_dataset_size = get_dataset_size(train_ds) |
| 320 | + if train_dataset_size is not None: |
| 321 | + logger.info(f"Original training dataset size: {train_dataset_size}") |
| 322 | + else: |
| 323 | + logger.warning("Unable to determine training dataset size for dynamic dash_flag calculation") |
| 324 | + |
314 | 325 | # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. |
315 | 326 | if training_args.resume_from_checkpoint is not None and data_args.lazy: |
316 | 327 | logger.info( |
@@ -377,7 +388,9 @@ def neft_post_hook(module, input, output): |
377 | 388 | if eval_zero_padding and test_ds is not None: |
378 | 389 | test_ds = intoken_dataset(test_ds, tokenizer=tokenizer, max_length=data_args.max_length) |
379 | 390 |
|
380 | | - model = create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers) |
| 391 | + model = create_peft_model( |
| 392 | + model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size |
| 393 | + ) |
381 | 394 |
|
382 | 395 | def compute_metrics_do_generation(eval_preds): |
383 | 396 | rouge1 = Rouge1() |
@@ -441,19 +454,30 @@ def compute_metrics_do_generation(eval_preds): |
441 | 454 | return_attention_mask=not model_args.flash_mask, |
442 | 455 | pad_to_multiple_of=data_args.pad_to_multiple_of, |
443 | 456 | ) |
444 | | - trainer = SFTTrainer( |
445 | | - model=model, |
446 | | - args=training_args, |
447 | | - train_dataset=train_ds, |
448 | | - eval_dataset=dev_ds, |
449 | | - tokenizer=tokenizer, |
450 | | - compute_metrics=metrics, |
451 | | - data_collator=data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn), |
452 | | - do_generation=data_args.eval_with_do_generation, |
453 | | - callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, |
454 | | - gen_args=gen_args, |
455 | | - data_args=data_args, |
456 | | - ) |
| 457 | + |
| 458 | + if model_args.dislora and hasattr(model_args, "ortho_lambda"): |
| 459 | + training_args.dislora_ortho_lambda = model_args.ortho_lambda |
| 460 | + |
| 461 | + trainer_kwargs = { |
| 462 | + "model": model, |
| 463 | + "args": training_args, |
| 464 | + "train_dataset": train_ds, |
| 465 | + "eval_dataset": dev_ds, |
| 466 | + "tokenizer": tokenizer, |
| 467 | + "compute_metrics": metrics, |
| 468 | + "data_collator": data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn), |
| 469 | + "do_generation": data_args.eval_with_do_generation, |
| 470 | + "callbacks": [ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, |
| 471 | + "gen_args": gen_args, |
| 472 | + "data_args": data_args, |
| 473 | + } |
| 474 | + |
| 475 | + if model_args.dislora: |
| 476 | + logger.info("Using DisLoRATrainer for training.") |
| 477 | + trainer = DisLoRATrainer(**trainer_kwargs) |
| 478 | + else: |
| 479 | + trainer = SFTTrainer(**trainer_kwargs) |
| 480 | + |
457 | 481 | trainable_parameters = [ |
458 | 482 | p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name) |
459 | 483 | ] |
@@ -531,7 +555,9 @@ def save_to_aistudio(model_args, training_args, trainer): |
531 | 555 | ) |
532 | 556 |
|
533 | 557 |
|
534 | | -def create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers): |
| 558 | +def create_peft_model( |
| 559 | + model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size |
| 560 | +): |
535 | 561 | if model_args.prefix_tuning: |
536 | 562 | if training_args.pipeline_parallel_degree > 1: |
537 | 563 | raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.") |
@@ -612,6 +638,53 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config, |
612 | 638 | else: |
613 | 639 | model = LoKrModel.from_pretrained(model=model, lokr_path=model_args.lokr_path) |
614 | 640 |
|
| 641 | + if model_args.dislora: |
| 642 | + # Calculate dynamic dash_flag based on training configuration |
| 643 | + if train_dataset_size is not None and training_args.do_train: |
| 644 | + # Calculate warmup steps: len(train_data) * num_epochs // (batch_size * gradient_accumulation_steps * 3) |
| 645 | + effective_batch_size = ( |
| 646 | + training_args.per_device_train_batch_size |
| 647 | + * training_args.gradient_accumulation_steps |
| 648 | + * training_args.dataset_world_size # Consider data parallel |
| 649 | + ) |
| 650 | + calculated_dash_flag = (train_dataset_size * training_args.num_train_epochs) // (effective_batch_size * 3) |
| 651 | + |
| 652 | + # Use calculated value if it's reasonable, otherwise fall back to model_args |
| 653 | + if calculated_dash_flag > 0: |
| 654 | + dash_flag = calculated_dash_flag |
| 655 | + logger.info( |
| 656 | + f"Calculated dynamic dash_flag: {dash_flag} based on dataset size: {train_dataset_size}, " |
| 657 | + f"epochs: {training_args.num_train_epochs}, effective batch size: {effective_batch_size}" |
| 658 | + ) |
| 659 | + else: |
| 660 | + dash_flag = model_args.dash_flag |
| 661 | + logger.warning( |
| 662 | + f"Calculated dash_flag was {calculated_dash_flag}, using model_args.dash_flag: {dash_flag}" |
| 663 | + ) |
| 664 | + else: |
| 665 | + dash_flag = getattr(model_args, "dash_flag", 50) |
| 666 | + if train_dataset_size is None: |
| 667 | + logger.info( |
| 668 | + f"Unable to calculate dynamic dash_flag (dataset size unknown), using configured dash_flag: {dash_flag}" |
| 669 | + ) |
| 670 | + else: |
| 671 | + logger.info(f"Not in training mode, using configured dash_flag: {dash_flag}") |
| 672 | + if model_args.dislora_path is None: |
| 673 | + dislora_config = DisLoRAConfig( |
| 674 | + target_modules=model_args.target_modules |
| 675 | + if model_args.target_modules |
| 676 | + else get_lora_target_modules(model), |
| 677 | + r=model_args.dislora_rank, |
| 678 | + dislora_alpha=1.5 * model_args.dislora_rank, |
| 679 | + dislora_dropout=model_args.dislora_dropout, |
| 680 | + dtype=dtype, |
| 681 | + base_model_name_or_path=model_args.model_name_or_path, |
| 682 | + s_tsd=model_args.s_tsd, |
| 683 | + dash_flag=dash_flag, # Use calculated dash_flag |
| 684 | + ortho_lambda=model_args.ortho_lambda, |
| 685 | + ) |
| 686 | + model = DisLoRAModel(model, dislora_config) |
| 687 | + |
615 | 688 | if model_args.reft: |
616 | 689 | intervention_dtype = dtype |
617 | 690 | intervention_params = { |
@@ -751,5 +824,24 @@ def create_dataset(data_args, training_args): |
751 | 824 | return train_ds, dev_ds, test_ds |
752 | 825 |
|
753 | 826 |
|
| 827 | +def get_dataset_size(dataset): |
| 828 | + """Get the size of a dataset, handling both lazy and regular datasets""" |
| 829 | + if dataset is None: |
| 830 | + return None |
| 831 | + |
| 832 | + try: |
| 833 | + if hasattr(dataset, "__len__"): |
| 834 | + return len(dataset) |
| 835 | + elif hasattr(dataset, "_length"): |
| 836 | + return dataset._length |
| 837 | + else: |
| 838 | + # For lazy datasets, we might need to iterate once to count |
| 839 | + logger.warning("Unable to determine dataset size directly for lazy loading dataset") |
| 840 | + return None |
| 841 | + except Exception as e: |
| 842 | + logger.warning(f"Error getting dataset size: {e}") |
| 843 | + return None |
| 844 | + |
| 845 | + |
754 | 846 | if __name__ == "__main__": |
755 | 847 | main() |
0 commit comments