Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] liger DPO support #2568

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
17 changes: 16 additions & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences

To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

<hfoptions id="dpo">
<hfoptions id="truncation">
<hfoption id="DPO">

DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
Expand Down Expand Up @@ -94,6 +94,21 @@ Packing may cause batch contamination, where adjacent sequences influence one an

</Tip>

## Liger for reducing peak memory usage

[To complete]

<hfoptions id="liger">
<hfoption id="DPO">

To use Liger for reducing peak memory usage, use the following code snippet:

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)
```

## Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
Expand Down
79 changes: 78 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
PreTrainedTokenizerBase,
is_vision_available,
)
from transformers.testing_utils import require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, require_vision
from transformers.testing_utils import (
require_liger_kernel,
require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_vision,
)

from trl import DPOConfig, DPOTrainer, FDivergenceType

Expand Down Expand Up @@ -1263,6 +1268,78 @@ def dummy_compute_metrics(*args, **kwargs):

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)

@require_liger_kernel
@parameterized.expand([(0.1,), (0.5,)])
def test_dpo_trainer_with_liger(self, beta):
"""Test DPO trainer with Liger loss enabled.

This test verifies that:
1. Training runs successfully with Liger loss
2. Model parameters update as expected
3. Loss values are reasonable and finite
4. Training works with both default and custom beta values
"""

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
do_eval=True,
eval_steps=1,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=beta,
use_liger_loss=True, # Enable Liger loss
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model, # Add reference model
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# Store initial parameters
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
train_output = trainer.train()

# Verify training completed successfully
self.assertIsNotNone(train_output)
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Verify loss is finite
self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"]))

# Check parameters have been updated
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# Only check non-zero parameters
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
# Verify new parameters are finite
self.assertTrue(torch.isfinite(new_param).all())

# Verify model can still do forward pass after training
dummy_batch = next(iter(trainer.get_train_dataloader()))
model_inputs = {
"input_ids": dummy_batch["prompt_input_ids"],
"attention_mask": dummy_batch["prompt_attention_mask"],
}
with torch.no_grad():
output = trainer.model(**model_inputs)
self.assertIsNotNone(output)
self.assertIsNone(output.loss)


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
17 changes: 17 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class DPOConfig(TrainingArguments):
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.

use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
Expand Down Expand Up @@ -303,6 +308,18 @@ class DPOConfig(TrainingArguments):
],
},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when "
"`use_liger_loss` is `True`."
},
)
beta: float = field(
default=0.1,
metadata={
Expand Down
Loading