forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdpo_trainer.py
More file actions
2016 lines (1761 loc) · 101 KB
/
dpo_trainer.py
File metadata and controls
2016 lines (1761 loc) · 101 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import textwrap
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState, logging
from accelerate.utils import tqdm
from datasets import Dataset, IterableDataset
from torch import autocast
from torch.utils.data import DataLoader
from transformers import (
AutoProcessor,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.integrations import (
is_comet_available,
is_mlflow_available,
is_wandb_available,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_liger_kernel_available, is_peft_available
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_extract_prompt
from ..models import create_reference_model, prepare_deepspeed
from ..models.utils import peft_module_casting_to_bf16, prepare_fsdp
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .utils import (
RunningMoments,
cap_exp,
create_model_from_path,
disable_dropout_in_model,
empty_cache,
flush_left,
flush_right,
get_config_model_id,
log_table_to_comet_experiment,
pad,
pad_to_length,
selective_log_softmax,
)
if is_peft_available():
from peft import (
PeftConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
)
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
if is_wandb_available():
import wandb
if is_mlflow_available():
import mlflow
logger = logging.get_logger(__name__)
def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor:
"""Shift input ids one token to the right, and pad with pad_token_id"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
return shifted_input_ids
@dataclass
class DataCollatorForPreference(DataCollatorMixin):
"""
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they are
not all of the same length.
Args:
pad_token_id (`int`):
Token ID to use for padding.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
Examples:
```python
>>> from trl import DataCollatorForPreference
>>> collator = DataCollatorForPreference(pad_token_id=0)
>>> examples = [
... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]},
... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]},
... ]
>>> collator(examples)
{'prompt_input_ids': tensor([[1, 2, 3],
[0, 7, 8]]),
'prompt_attention_mask': tensor([[1, 1, 1],
[0, 1, 1]]),
'chosen_input_ids': tensor([[ 4, 5],
[ 9, 10]]),
'chosen_attention_mask': tensor([[1, 1],
[1, 1]]),
'rejected_input_ids': tensor([[ 6, 0, 0],
[11, 12, 13]]),
'rejected_attention_mask': tensor([[1, 0, 0],
[1, 1, 1]])
}
```
"""
pad_token_id: int
return_tensors: str = "pt"
def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
# Convert to tensor
prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids]
rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids]
if "pixel_values" in examples[0]:
pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
if "pixel_attention_mask" in examples[0]:
pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
# Pad
output = {}
output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left")
output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left")
output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id)
output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0)
output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id)
output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0)
if "pixel_values" in examples[0]:
output["pixel_values"] = pad(pixel_values, padding_value=0.0)
if "pixel_attention_mask" in examples[0]:
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
if "image_sizes" in examples[0]:
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
if "token_type_ids" in examples[0]:
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
return output
class DPOTrainer(BaseTrainer):
"""
Trainer for Direct Preference Optimization (DPO) method.
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
Args:
model (`str | PreTrainedModel`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
`args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
ref_model ([`~transformers.PreTrainedModel`])
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
and loss. If no reference model is provided, the trainer will create a reference model with the same
architecture as the model to be optimized.
args ([`DPOConfig`], *optional*):
Configuration for this trainer. If `None`, a default configuration is used.
data_collator ([`~transformers.DataCollator`], *optional*):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`DataCollatorForPreference`].
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can
be either:
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
with [`~transformers.AutoTokenizer.from_pretrained`].
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics.
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
`args`. Incompatible with the `optimizers` argument.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
Note that the labels (second parameter) will be `None` if the dataset does not have them.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
_tag_names = ["trl", "dpo"]
_name = "DPO"
_paper = {
"title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
"id": "2305.18290",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}"""),
}
def __init__(
self,
model: str | nn.Module | PreTrainedModel,
ref_model: PreTrainedModel | nn.Module | str | None = None,
args: DPOConfig | None = None,
data_collator: DataCollator | None = None, # type: ignore
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: "PeftConfig | None" = None,
):
# Args
if args is None:
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = DPOConfig(f"{model_name}-DPO")
# IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate
# batches from multiple processes, leading to mismatch errors.
if isinstance(train_dataset, IterableDataset):
if args.accelerator_config.dispatch_batches is True:
logger.warning(
"You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` "
"is forced to `False` when using an `IterableDataset`. To remove this warning, unset "
"`dispatch_batches` in `DPOConfig` or set it to `False`."
)
args.accelerator_config.dispatch_batches = False
# Model and reference model
if isinstance(model, str):
model_init_kwargs = args.model_init_kwargs or {}
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
else:
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
model_id = get_config_model_id(model.config)
if isinstance(ref_model, str):
model_init_kwargs = args.ref_model_init_kwargs or {}
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
ref_model = create_model_from_path(ref_model, **model_init_kwargs)
else:
if args.ref_model_init_kwargs is not None:
logger.warning(
"You passed `ref_model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
"The `ref_model_init_kwargs` will be ignored."
)
if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you can simply omit the `ref_model` argument and it will be created for you."
)
# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model_id)
# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
tokenizer = processing_class.tokenizer
self._is_vlm = True
elif isinstance(processing_class, PreTrainedTokenizerBase):
tokenizer = processing_class
self._is_vlm = False
else:
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
if self.pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
# PEFT configuration and model wrapping
model = self._prepare_peft_model(model, ref_model, peft_config, args)
if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
" Please install `wandb`, `mlflow` or `comet-ml` to resolve."
)
self.is_encoder_decoder = model.config.is_encoder_decoder
self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Liger kernel
if args.use_liger_kernel:
if not is_liger_kernel_available():
raise ImportError(
"You set `use_liger_kernel=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]:
raise ValueError(
"You set `use_liger_kernel=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
"Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel."
)
self.dpo_loss_fn = LigerFusedLinearDPOLoss(
ignore_index=args.label_pad_token_id,
beta=args.beta,
use_ref_model=not args.reference_free,
average_log_prob=False,
loss_type=args.loss_type,
)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Data collator
if data_collator is None:
data_collator = DataCollatorForPreference(pad_token_id=self.pad_token_id)
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length
self.max_length = args.max_length
self.truncation_mode = args.truncation_mode
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_logits_to_keep = args.use_logits_to_keep
if args.padding_free:
if model.config._attn_implementation != "flash_attention_2":
logger.warning(
"Padding-free training is enabled, but the attention implementation is not set to "
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
"attention mechanism can handle flattened sequences."
)
if args.per_device_train_batch_size == 1:
logger.warning(
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
"of 1 annihilate the benefits of padding-free training. Please consider increasing the batch size "
"to at least 2."
)
self.padding_free = args.padding_free
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type]
self.loss_weights = args.loss_weights
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self.use_weighting = args.use_weighting
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
logger.warning(
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
)
for loss_type in self.loss_type:
if (
loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
and args.label_smoothing > 0
):
logger.warning(
f"You are using the {loss_type} loss type that does not support label smoothing. The "
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this "
"warning.",
)
if loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
self.dataset_num_proc = args.dataset_num_proc
# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
# Deepspeed Zero-3 does not support precompute_ref_log_probs
if self.is_deepspeed_enabled:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
)
if self.ref_model is None:
if not (self.is_peft_model or self.precompute_ref_log_probs):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
if args.sync_ref_model:
raise ValueError(
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if args.sync_ref_model:
if self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
if "bco_pair" in self.loss_type:
self.running = RunningMoments(self.accelerator)
def _prepare_peft_model(
self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig
) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
"merge and unload the existing adapter, save the resulting base model, and then pass that base "
"model along with the new `peft_config` to the trainer."
)
if ref_model is not None and not args.force_use_ref_model:
raise ValueError(
"You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
" model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
" if you want to use a different ref_model."
)
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
else:
model = self._prepare_gradient_checkpointing(model, args)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
else:
model = self._prepare_gradient_checkpointing(model, args)
return model
def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig):
"""Prepare the gradienting checkpointing for the model."""
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
if args.gradient_checkpointing:
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
def _prepare_dataset(
self,
dataset: Dataset | IterableDataset,
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
args: DPOConfig,
dataset_name: str,
) -> Dataset | IterableDataset:
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size
map_kwargs["num_proc"] = args.dataset_num_proc
map_kwargs["writer_batch_size"] = 10
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
is_chat = is_conversational(next(iter(dataset)))
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
)
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
"is_chat": is_chat,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features: dict[str, str],
processing_class: PreTrainedTokenizerBase,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> dict[str, list[int]]:
"""
Tokenize a row of the dataset.
Args:
features (`dict[str, str]`):
Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`.
processing_class ([`~transformers.PreTrainedTokenizerBase`]):
Processing class used to process the data.
max_prompt_length (`int` or `None`):
Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated.
max_completion_length (`int` or `None`):
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
add_special_tokens (`bool`):
Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`,
the prompt sequence will have a bos token prepended and an eos token appended. In any case, the
completion sequences will have an eos token appended.
is_chat (`bool`):
Whether the data is conversational. If `True`, the completion sequences will not have an eos token
appended.
Returns:
`dict[str, list[int]]`:
Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and
`"rejected_input_ids".
Example:
```python
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
>>> DPOTrainer.tokenize_row(
... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False
... )
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
```
"""
tokenizer = processing_class # the processing class is a tokenizer
prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
# Add special tokens (typically for encoder-decoder models)
if add_special_tokens:
if tokenizer.bos_token_id is not None:
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
if tokenizer.eos_token_id is not None:
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
# For conversational data, the chat template already includes proper EOS tokens
if not is_chat:
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
if max_completion_length is not None:
chosen_input_ids = chosen_input_ids[:max_completion_length]
rejected_input_ids = rejected_input_ids[:max_completion_length]
return {
"prompt_input_ids": prompt_input_ids,
"chosen_input_ids": chosen_input_ids,
"rejected_input_ids": rejected_input_ids,
}
@staticmethod
def process_row(
features: dict[str, str],
processing_class: PreTrainedTokenizerBase,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> dict[str, list[int]]:
"""
Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
"""
processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor
processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False)
prompt_input_ids = processed_features["input_ids"][0]
pixel_values = processed_features["pixel_values"][0]
chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
# Add special tokens (typically for encoder-decoder models)
if add_special_tokens:
if tokenizer.bos_token_id is not None:
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
if tokenizer.eos_token_id is not None:
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
if not is_chat:
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
if max_completion_length is not None:
chosen_input_ids = chosen_input_ids[:max_completion_length]
rejected_input_ids = rejected_input_ids[:max_completion_length]
output = {
"prompt_input_ids": prompt_input_ids,
"pixel_values": pixel_values,
"chosen_input_ids": chosen_input_ids,
"rejected_input_ids": rejected_input_ids,
}
if "pixel_attention_mask" in processed_features:
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]
if "token_type_ids" in processed_features:
output["token_type_ids"] = processed_features["token_type_ids"][0]
return output
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override.
if self._signature_columns is None:
self._signature_columns = [
"prompt_input_ids",
"chosen_input_ids",
"rejected_input_ids",
"image_sizes",
"token_type_ids",
"ref_chosen_logps",
"ref_rejected_logps",
]
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
"""
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
dataloader_params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
}
# prepare dataloader
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
ref_chosen_logps = []
ref_rejected_logps = []
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
(ref_chosen_logp, ref_rejected_logp)
)
ref_chosen_logps.append(ref_chosen_logp.cpu())
ref_rejected_logps.append(ref_rejected_logp.cpu())
# Unnecessary cache clearing to avoid OOM
empty_cache()
self.accelerator.free_memory()
all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
self.train_dataset = self.train_dataset.add_column(
name="ref_rejected_logps", column=all_ref_rejected_logps
)
self._precomputed_train_ref_log_probs = True
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size
dataloader_params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
}
# prepare dataloader
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
ref_chosen_logps = []
ref_rejected_logps = []
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
(ref_chosen_logp, ref_rejected_logp)
)
ref_chosen_logps.append(ref_chosen_logp.cpu())
ref_rejected_logps.append(ref_rejected_logp.cpu())
all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps)
# Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs
if self.eval_dataset is not None:
self.eval_dataset = eval_dataset
self._precomputed_eval_ref_log_probs = True
return super().get_eval_dataloader(eval_dataset=eval_dataset)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")
def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
compte_ref_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
with torch.no_grad(), compte_ref_context_manager:
if self.ref_model is None:
with self.null_ref_context():
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
else:
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
@staticmethod
def concatenated_inputs(
batch: dict[str, list | torch.LongTensor], padding_value: int
) -> dict[str, torch.LongTensor]:
"""
Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and
completion sequences.
Args:
batch (`dict[str, list | torch.LongTensor]`):
A batch of input data. The batch must contain the following keys:
- `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input
IDs.
- `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen
completion input IDs.
- `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected
completion input IDs.
- `"prompt_pixel_values"` (optional): Tensor for pixel values, if available.
- `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available.
padding_value (`int`):
The padding value to use for the concatenated completion sequences (`chosen_input_ids` and
`rejected_input_ids`).
Returns:
`dict[str, torch.LongTensor]`: A dictionary containing:
- `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`.
- `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 *
batch_size, max_completion_length)`.
- `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size,
prompt_length)`.
- `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 *
batch_size, max_completion_length)`.
- `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present.
- `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if
`"prompt_pixel_attention_mask"` are present.
Notes:
The completion input IDs and attention masks are padded to the maximum completion length of the chosen or
rejected sequences.
"""
output = {}
# For the prompt, the input_ids are the same for both the chosen and rejected responses
output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
output["prompt_attention_mask"] = torch.cat(