Skip to content

Adding DPO training #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
85f6b5d
udpate lora_config.yaml
Goekdeniz-Guelmez Mar 14, 2025
e22beba
update datasets.py
Goekdeniz-Guelmez Mar 14, 2025
db5fc85
update lora.py
Goekdeniz-Guelmez Mar 14, 2025
3d75414
update lora_config.yaml + add dpo_trainer.py
Goekdeniz-Guelmez Mar 14, 2025
03eb6bf
update LORA.md
Goekdeniz-Guelmez Mar 14, 2025
2b91010
fix typo
Goekdeniz-Guelmez Mar 14, 2025
4c2c2f7
fix lora.py
Goekdeniz-Guelmez Mar 14, 2025
4623aa2
update acknowledgements.md + nits
Goekdeniz-Guelmez Mar 14, 2025
72dd971
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Mar 17, 2025
3a18487
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Mar 18, 2025
dd2eb66
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Mar 18, 2025
a73ecc6
formatting
Goekdeniz-Guelmez Mar 18, 2025
26c45b4
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Mar 19, 2025
d277bfa
makiing key names customizable
Goekdeniz-Guelmez Mar 19, 2025
b41b8f9
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Mar 24, 2025
e4d993a
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Mar 25, 2025
79f0d2e
nits
Goekdeniz-Guelmez Mar 25, 2025
6d437b7
nits
Goekdeniz-Guelmez Mar 25, 2025
652bcbe
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Mar 27, 2025
1af2e9a
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Mar 27, 2025
97fc155
fix
Goekdeniz-Guelmez Mar 27, 2025
4281a8d
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Mar 31, 2025
1305482
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Apr 17, 2025
70e3eea
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ MLX LM was developed with contributions from the following individuals:

- Shunta Saito: Added support for PLaMo models.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, and Allenai's `OLMoE`; Added support for the following training algorithms: `full-fine-tuning`; Added support for the following other features: `Multiple Optimizers to choose for training`.
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, and Allenai's `OLMoE`; Added support for the following training algorithms: `full-fine-tuning`, and `Direct Preference Optimization (DPO)`; Added support for the following other features: `Multiple Optimizers to choose for training`.
33 changes: 33 additions & 0 deletions mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:

- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [DPO-Training](#DPO-Training)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
Expand Down Expand Up @@ -84,6 +85,38 @@ ignore the prompt and compute loss for just the completion by passing
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.

### DPO Training

Direct Preference Optimization (DPO) training allows you to fine-tune models using human preference data. To use DPO training, set the training mode to 'dpo':

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--training-mode dpo \
--data <path_to_data> \
--beta 0.1
```

The DPO training accepts the following additional parameters:

- `--beta`: Controls the strength of the DPO loss (default: 0.1)
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
- `--delta`: Margin parameter for hinge loss (default: 50.0)
- `--reference-model-path`: Path to a reference model for DPO training

For DPO training, the data should be in JSONL format with the following structure:

```jsonl
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```

if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field.

```jsonl
{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```

### Evaluate

To compute test set perplexity use:
Expand Down
18 changes: 17 additions & 1 deletion mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora

# The training-mode: "normal", or "dpo"
training_mode: normal

# If you set training_mode to "dpo"
# beta: 0.1
# The dpo-loss-type: "sigmoid", "hinge", "ipo", or "dpop"
# dpo_loss_type: "sigmoid"
# is_reference_free: False
# delta: 50.0
# If reference_model_path is not given it will just use the same model
# reference_model_path: "mlx_model"

# The Optimizer with its possible inputs
optimizer: adamw
# optimizer_config:
Expand Down Expand Up @@ -86,4 +98,8 @@ lora_parameters:
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

# for DPO training
# prompt_feature: "text"
# system_feature: "system"
# chosen_feature: "chosen"
# rejected_feature: "rejected"
163 changes: 131 additions & 32 deletions mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
Expand Down Expand Up @@ -44,6 +45,7 @@
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"training_mode": "normal",
"optimizer": "adam",
"optimizer_config": {
"adam": {},
Expand All @@ -69,6 +71,11 @@
"lr_schedule": None,
"lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
# DPO args
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"delta": 50.0,
"reference_model_path": None,
}


Expand Down Expand Up @@ -101,6 +108,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "dpo"],
help="Training mode: normal or DPO",
)
parser.add_argument(
"--optimizer",
type=str,
Expand Down Expand Up @@ -181,6 +194,30 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")

# DPO args
parser.add_argument(
"--beta",
type=float,
help="Temperature parameter for DPO training.",
default=0.1,
)
parser.add_argument(
"--dpo-loss-type",
type=str,
help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.",
choices=["sigmoid", "hinge", "ipo", "dpop"],
default="sigmoid",
)
parser.add_argument(
"--delta", type=float, help="Delta parameter for DPOP loss type.", default=50.0
)
parser.add_argument(
"--reference-model-path",
type=str,
help="Path to reference model weights. If None, uses the same model.",
default=None,
)
return parser


Expand Down Expand Up @@ -227,18 +264,7 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()

# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
Expand All @@ -255,31 +281,104 @@ def train_model(

opt = opt_class(learning_rate=lr, **optimizer_config)

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
loss_type=args.dpo_loss_type,
delta=args.delta,
reference_model_path=args.reference_model_path,
)

if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

train_dpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
model.eval()

if args.training_mode == "dpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model

test_loss, _, _, test_metrics = evaluate_dpo(
model=model,
ref_model=reference_model.freeze(),
dataset=test_set,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.dpo_loss_type,
)

test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}")
print("DPO Test Metrics:")
for metric_name, metric_value in test_metrics.items():
print(f" {metric_name}: {float(metric_value):.3f}")

else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand Down
Loading