Skip to content

Commit 1df8ad4

Browse files
committed
feat: add Direct Preference Optimization (DPO) training support
Add DPO as a new alignment method alongside the existing SFT pipeline. DPO enables fine-tuning LLMs using human preference pairs (chosen vs rejected responses) without requiring a separate reward model. New components: - PreferenceDataset: dataset class for prompt/chosen/rejected triplets - PreferenceDataCollator: tokenizer that prepares paired sequences - DPOTrainer: PyTorch Lightning trainer with DPO loss and reference model - CausalModel.dpo_finetune(): user-facing API for DPO training All components integrate via the existing registry pattern and work with every model variant (full, LoRA, INT8, LoRA+INT8, LoRA+Kbit). Includes tests and an example script. README.md is not modified by this change.
1 parent d6fe1de commit 1df8ad4

File tree

12 files changed

+961
-0
lines changed

12 files changed

+961
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Minimal example showing how to align a model using Direct Preference
2+
Optimization (DPO) with xTuring.
3+
4+
DPO fine-tunes a language model using pairs of preferred and dispreferred
5+
responses so that the model learns to produce outputs that match human
6+
preferences without requiring a separate reward model.
7+
"""
8+
9+
from pathlib import Path
10+
11+
from xturing.datasets.preference_dataset import PreferenceDataset
12+
from xturing.models import BaseModel
13+
14+
OUTPUT_DIR = Path(__file__).parent / "dpo_weights"
15+
16+
17+
def main():
18+
# Build a small preference dataset. Each sample needs a prompt, a chosen
19+
# (preferred) response, and a rejected (dispreferred) response.
20+
preference_data = {
21+
"prompt": [
22+
"Explain quantum computing in simple terms.",
23+
"What is the capital of France?",
24+
"How do I make pasta?",
25+
"What causes rain?",
26+
],
27+
"chosen": [
28+
"Quantum computing uses qubits that can be 0, 1, or both at once, "
29+
"letting it solve certain problems much faster than regular computers.",
30+
"The capital of France is Paris.",
31+
"Boil salted water, cook pasta until al dente, then drain and toss "
32+
"with your favorite sauce.",
33+
"Rain forms when water evaporates, rises, cools into clouds, and "
34+
"falls back as droplets when clouds become saturated.",
35+
],
36+
"rejected": [
37+
"Quantum computing is basically magic computers that can do "
38+
"everything instantly.",
39+
"France doesn't have a capital, it's a collective.",
40+
"Just put some noodles in a microwave with ketchup.",
41+
"Rain happens because the sky is sad.",
42+
],
43+
}
44+
45+
dataset = PreferenceDataset(preference_data)
46+
47+
# Initialise a model with a LoRA adapter. DPO works with any model
48+
# variant, but LoRA is recommended to keep memory usage low since DPO
49+
# requires a frozen reference model in addition to the policy model.
50+
model = BaseModel.create("qwen3_0_6b_lora")
51+
52+
# Run DPO fine-tuning. The beta parameter controls how strongly the model
53+
# is penalised for deviating from the reference policy (higher = more
54+
# conservative).
55+
model.dpo_finetune(dataset=dataset, beta=0.1)
56+
57+
# Verify the aligned model generates reasonable output.
58+
output = model.generate(texts=["Explain gravity in simple terms."])
59+
print(f"Generated output: {output}")
60+
61+
# Save the fine-tuned adapter weights.
62+
model.save(str(OUTPUT_DIR))
63+
print(f"Saved DPO fine-tuned weights to {OUTPUT_DIR}")
64+
65+
66+
if __name__ == "__main__":
67+
main()

src/xturing/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
InstructionDataset,
44
InstructionDatasetMeta,
55
)
6+
from xturing.datasets.preference_dataset import PreferenceDataset, PreferenceDatasetMeta
67
from xturing.datasets.text2image_dataset import Text2ImageDataset
78
from xturing.datasets.text_dataset import TextDataset, TextDatasetMeta
89

910
BaseDataset.add_to_registry(TextDataset.config_name, TextDataset)
1011
BaseDataset.add_to_registry(InstructionDataset.config_name, InstructionDataset)
1112
BaseDataset.add_to_registry(Text2ImageDataset.config_name, Text2ImageDataset)
13+
BaseDataset.add_to_registry(PreferenceDataset.config_name, PreferenceDataset)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Union
5+
6+
from datasets import Dataset as HFDataset
7+
from datasets import DatasetDict, load_from_disk
8+
9+
from xturing.datasets.base import BaseDataset
10+
11+
12+
@dataclass
13+
class PreferenceDatasetMeta:
14+
"""Metadata for preference datasets used in DPO training."""
15+
16+
17+
class PreferenceDataset(BaseDataset):
18+
"""Dataset for Direct Preference Optimization (DPO) training.
19+
20+
Each sample contains a prompt, a chosen (preferred) response, and a
21+
rejected (dispreferred) response. The dataset must have exactly three
22+
columns: ``prompt``, ``chosen``, and ``rejected``.
23+
24+
Args:
25+
path: A local directory saved with ``datasets.save_to_disk``, a path
26+
to a ``.jsonl`` file, a HuggingFace ``Dataset``/``DatasetDict``,
27+
or a plain dictionary with the required keys.
28+
"""
29+
30+
config_name: str = "preference_dataset"
31+
32+
def __init__(self, path: Union[str, Path, HFDataset, DatasetDict, dict]):
33+
if isinstance(path, HFDataset) or isinstance(path, DatasetDict):
34+
self.data = path
35+
elif isinstance(path, dict):
36+
self.data = {"train": HFDataset.from_dict(path)}
37+
else:
38+
path = Path(path)
39+
assert path.exists(), "path does not exist"
40+
if path.is_dir():
41+
self.data = load_from_disk(str(path))
42+
elif path.suffix == ".jsonl":
43+
self.data = {"train": HFDataset.from_dict(self._from_jsonl(path))}
44+
else:
45+
raise ValueError(
46+
f"Unsupported file format: {path.suffix}. Use a directory or .jsonl file."
47+
)
48+
49+
self._validate()
50+
self._meta = PreferenceDatasetMeta()
51+
52+
def _from_jsonl(self, path: Path):
53+
data = {
54+
"prompt": [],
55+
"chosen": [],
56+
"rejected": [],
57+
}
58+
try:
59+
for line in open(path):
60+
json_line = json.loads(line)
61+
data["prompt"].append(json_line["prompt"])
62+
data["chosen"].append(json_line["chosen"])
63+
data["rejected"].append(json_line["rejected"])
64+
except KeyError:
65+
raise ValueError(
66+
"The jsonl file should have keys: prompt, chosen, and rejected"
67+
)
68+
return data
69+
70+
def _validate(self):
71+
assert "train" in self.data, "The dataset should have a train split"
72+
assert (
73+
"prompt" in self.data["train"].column_names
74+
), "The dataset should have a column named prompt"
75+
assert (
76+
"chosen" in self.data["train"].column_names
77+
), "The dataset should have a column named chosen"
78+
assert (
79+
"rejected" in self.data["train"].column_names
80+
), "The dataset should have a column named rejected"
81+
assert (
82+
len(self.data["train"].column_names) == 3
83+
), "The dataset should have only three columns: prompt, chosen, and rejected"
84+
85+
def __len__(self):
86+
return len(self.data["train"])
87+
88+
def __iter__(self):
89+
return iter(self.data["train"])
90+
91+
def __getitem__(self, idx):
92+
return self.data["train"][idx]
93+
94+
def save(self, path):
95+
return self.data["train"].save_to_disk(path)

src/xturing/models/causal.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig
1313
from xturing.config.read_config import load_config
1414
from xturing.datasets.instruction_dataset import InstructionDataset
15+
from xturing.datasets.preference_dataset import PreferenceDataset
1516
from xturing.datasets.text_dataset import TextDataset
1617
from xturing.engines.base import BaseEngine
1718
from xturing.models import BaseModel
1819
from xturing.preprocessors.base import BasePreprocessor
1920
from xturing.trainers.base import BaseTrainer
21+
from xturing.trainers.dpo_trainer import DPOTrainer
2022
from xturing.trainers.lightning_trainer import LightningTrainer
2123
from xturing.utils.logging import configure_logger
2224
from xturing.utils.prompt import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt
@@ -118,6 +120,54 @@ def finetune(
118120
trainer = self._make_trainer(dataset, logger)
119121
trainer.fit()
120122

123+
def _make_dpo_collate_fn(self, dataset: PreferenceDataset):
124+
return BasePreprocessor.create(
125+
dataset.config_name,
126+
self.engine.tokenizer,
127+
int(self.finetuning_args.max_length),
128+
dataset.meta,
129+
)
130+
131+
def _make_dpo_trainer(
132+
self,
133+
dataset: PreferenceDataset,
134+
beta: float = 0.1,
135+
logger: Union[Logger, Iterable[Logger], bool] = True,
136+
):
137+
return BaseTrainer.create(
138+
DPOTrainer.config_name,
139+
self.engine,
140+
dataset,
141+
self._make_dpo_collate_fn(dataset),
142+
int(self.finetuning_args.num_train_epochs),
143+
int(self.finetuning_args.batch_size),
144+
float(self.finetuning_args.learning_rate),
145+
self.finetuning_args.optimizer_name,
146+
beta,
147+
logger=logger,
148+
)
149+
150+
def dpo_finetune(
151+
self,
152+
dataset: PreferenceDataset,
153+
beta: float = 0.1,
154+
logger: Union[Logger, Iterable[Logger], bool] = True,
155+
):
156+
"""Fine-tune the model using Direct Preference Optimization (DPO).
157+
158+
Args:
159+
dataset: A :class:`PreferenceDataset` containing prompt, chosen,
160+
and rejected columns.
161+
beta: Temperature parameter for DPO. Higher values keep the model
162+
closer to the reference policy.
163+
logger: PyTorch Lightning logger(s) for tracking training metrics.
164+
"""
165+
assert (
166+
dataset.config_name == "preference_dataset"
167+
), "Please provide a PreferenceDataset for DPO training"
168+
trainer = self._make_dpo_trainer(dataset, beta, logger)
169+
trainer.fit()
170+
121171
def _generate_from_iterable(
122172
self, data_iterator: Iterable, do_tokenization=False, show_tqdm_bar=True
123173
):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from xturing.preprocessors.base import BasePreprocessor
22
from xturing.preprocessors.instruction_collator import InstructionDataCollator
3+
from xturing.preprocessors.preference_collator import PreferenceDataCollator
34
from xturing.preprocessors.text_collator import TextDataCollator

src/xturing/preprocessors/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from xturing.preprocessors.instruction_collator import InstructionDataCollator
2+
from xturing.preprocessors.preference_collator import PreferenceDataCollator
23
from xturing.preprocessors.text_collator import TextDataCollator
34
from xturing.registry import BaseParent
45

@@ -11,3 +12,6 @@ class BasePreprocessor(BaseParent):
1112
InstructionDataCollator.config_name, InstructionDataCollator
1213
)
1314
BasePreprocessor.add_to_registry(TextDataCollator.config_name, TextDataCollator)
15+
BasePreprocessor.add_to_registry(
16+
PreferenceDataCollator.config_name, PreferenceDataCollator
17+
)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import Dict, List, Optional
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
6+
7+
from xturing.datasets.preference_dataset import PreferenceDatasetMeta
8+
9+
10+
class PreferenceDataCollator:
11+
"""Collator for preference datasets used in DPO training.
12+
13+
For each sample, this collator tokenizes two sequences:
14+
- ``prompt + chosen`` (the preferred completion)
15+
- ``prompt + rejected`` (the dispreferred completion)
16+
17+
The resulting batch contains ``chosen_input_ids``, ``chosen_attention_mask``,
18+
``chosen_labels``, and the corresponding ``rejected_*`` tensors. Labels are
19+
masked so that the loss is only computed over the response tokens (not the
20+
prompt).
21+
"""
22+
23+
config_name = "preference_dataset"
24+
25+
def __init__(
26+
self,
27+
tokenizer: PreTrainedTokenizerBase,
28+
max_length: Optional[int] = None,
29+
meta: PreferenceDatasetMeta = PreferenceDatasetMeta(),
30+
):
31+
self.tokenizer = tokenizer
32+
self.max_length = max_length
33+
self.meta = meta
34+
35+
def _tokenize_pair(self, prompt: str, response: str):
36+
"""Tokenize a prompt-response pair and return input_ids with a label
37+
mask that marks only the response tokens as trainable."""
38+
prompt_tokens = self.tokenizer(prompt)
39+
response_tokens = self.tokenizer(response)
40+
41+
input_ids = prompt_tokens["input_ids"] + response_tokens["input_ids"]
42+
# Labels: -100 for prompt tokens (ignored by loss), actual ids for response
43+
label_mask = [False] * len(prompt_tokens["input_ids"]) + [True] * len(
44+
response_tokens["input_ids"]
45+
)
46+
47+
# Truncate to max_length - 1 to leave room for eos token
48+
input_ids = input_ids[: self.max_length - 1]
49+
input_ids.append(self.tokenizer.eos_token_id)
50+
attention_mask = [1] * len(input_ids)
51+
52+
label_mask = label_mask[: self.max_length - 1]
53+
label_mask.append(True)
54+
55+
return {
56+
"input_ids": torch.tensor(input_ids).long(),
57+
"attention_mask": torch.tensor(attention_mask).long(),
58+
"label_mask": label_mask,
59+
}
60+
61+
def _pad_and_stack(self, samples: List[Dict]):
62+
"""Pad a list of tokenized samples and stack into batch tensors."""
63+
padded = self.tokenizer.pad(
64+
[
65+
{"input_ids": s["input_ids"], "attention_mask": s["attention_mask"]}
66+
for s in samples
67+
],
68+
padding=True,
69+
max_length=self.max_length,
70+
return_tensors="pt",
71+
)
72+
73+
dim = padded["input_ids"].shape[-1]
74+
label_masks = torch.stack(
75+
[
76+
F.pad(
77+
torch.tensor(s["label_mask"]),
78+
(0, dim - len(s["label_mask"])),
79+
value=False,
80+
)
81+
for s in samples
82+
]
83+
)
84+
85+
# Build labels: copy input_ids shifted by 1, masked with -100 for prompt tokens
86+
labels = padded["input_ids"].clone()
87+
labels[~label_masks] = -100
88+
89+
return {
90+
"input_ids": padded["input_ids"],
91+
"attention_mask": padded["attention_mask"],
92+
"labels": labels,
93+
}
94+
95+
def __call__(self, batches: List[Dict]):
96+
chosen_samples = []
97+
rejected_samples = []
98+
99+
for sample in batches:
100+
chosen_samples.append(
101+
self._tokenize_pair(sample["prompt"], sample["chosen"])
102+
)
103+
rejected_samples.append(
104+
self._tokenize_pair(sample["prompt"], sample["rejected"])
105+
)
106+
107+
chosen_batch = self._pad_and_stack(chosen_samples)
108+
rejected_batch = self._pad_and_stack(rejected_samples)
109+
110+
return {
111+
"chosen_input_ids": chosen_batch["input_ids"],
112+
"chosen_attention_mask": chosen_batch["attention_mask"],
113+
"chosen_labels": chosen_batch["labels"],
114+
"rejected_input_ids": rejected_batch["input_ids"],
115+
"rejected_attention_mask": rejected_batch["attention_mask"],
116+
"rejected_labels": rejected_batch["labels"],
117+
}

src/xturing/trainers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from xturing.trainers.base import BaseTrainer
2+
from xturing.trainers.dpo_trainer import DPOTrainer
23
from xturing.trainers.lightning_trainer import LightningTrainer

src/xturing/trainers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from xturing.registry import BaseParent
2+
from xturing.trainers.dpo_trainer import DPOTrainer
23
from xturing.trainers.lightning_trainer import LightningTrainer
34

45

@@ -7,3 +8,4 @@ class BaseTrainer(BaseParent):
78

89

910
BaseTrainer.add_to_registry(LightningTrainer.config_name, LightningTrainer)
11+
BaseTrainer.add_to_registry(DPOTrainer.config_name, DPOTrainer)

0 commit comments

Comments
 (0)