diff --git a/README.md b/README.md index 45ac2c4605..c30a6ac0dd 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ Running Environment: | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.3/3.0.0b1 | | -| trl | >=0.15,<0.29 | 0.28.0 | RLHF | +| trl | >=0.15,<0.30 | 0.28.0 | RLHF | | deepspeed | >=0.14 | 0.18.8 | Training | | vllm | >=0.5.1 | 0.11.0/0.17.1 | Inference/Deployment | | sglang | >=0.4.6 | | Inference/Deployment | diff --git a/README_CN.md b/README_CN.md index e142e12c73..e73519f7f5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -141,7 +141,7 @@ uv pip install -e . --torch-backend=auto | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.3/3.0.0b1 | | -| trl | >=0.15,<0.29 | 0.28.0 | RLHF | +| trl | >=0.15,<0.30 | 0.28.0 | RLHF | | deepspeed | >=0.14 | 0.18.8 | 训练 | | vllm | >=0.5.1 | 0.11.0/0.17.1 | 推理/部署 | | sglang | >=0.4.6 | | 推理/部署 | diff --git a/docs/source/GetStarted/SWIFT-installation.md b/docs/source/GetStarted/SWIFT-installation.md index f29c7af072..a82d2807a8 100644 --- a/docs/source/GetStarted/SWIFT-installation.md +++ b/docs/source/GetStarted/SWIFT-installation.md @@ -144,7 +144,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.3/3.0.0b1 | | -| trl | >=0.15,<0.29 | 0.28.0 | RLHF | +| trl | >=0.15,<0.30 | 0.28.0 | RLHF | | deepspeed | >=0.14 | 0.18.8 | 训练 | | vllm | >=0.5.1 | 0.11.0/0.17.1 | 推理/部署 | | sglang | >=0.4.6 | | 推理/部署 | diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 8f31b7dca9..ece7ae1484 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -73,7 +73,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | transformers | >=4.33 | 4.57.6/5.2.0 | | | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | LoRA | -| trl | >=0.15,<0.29 | | RLHF | +| trl | >=0.15,<0.30 | | RLHF | ## 快速入门案例 diff --git a/docs/source_en/GetStarted/SWIFT-installation.md b/docs/source_en/GetStarted/SWIFT-installation.md index 631271da68..b4a534b6b4 100644 --- a/docs/source_en/GetStarted/SWIFT-installation.md +++ b/docs/source_en/GetStarted/SWIFT-installation.md @@ -143,7 +143,7 @@ More images can be found [here](https://modelscope.cn/docs/intro/environment-set | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.3/3.0.0b1 | | -| trl | >=0.15,<0.29 | 0.28.0 | RLHF | +| trl | >=0.15,<0.30 | 0.28.0 | RLHF | | deepspeed | >=0.14 | 0.18.8 | Training | | vllm | >=0.5.1 | 0.11.0/0.17.1 | Inference/Deployment | | sglang | >=0.4.6 | | Inference/Deployment | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 3740eb4608..201a4e113b 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -73,7 +73,7 @@ Recommended Operating Environment: | transformers | >=4.33 | 4.57.6/5.2.0 | | | modelscope | >=1.23 | | | | peft | >=0.11,<0.19 | | LoRA | -| trl | >=0.15,<0.29 | | RLHF | +| trl | >=0.15,<0.30 | | RLHF | ## Quick Start Example diff --git a/requirements/framework.txt b/requirements/framework.txt index 9ebfbc1ea5..7b62644a65 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -35,6 +35,6 @@ tiktoken tqdm transformers>=4.33,<5.4.0 transformers_stream_generator -trl>=0.15,<0.29 +trl>=0.15,<0.30 uvicorn zstandard diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 7375e4ef80..2f1ddbec84 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -13,12 +13,11 @@ class DummyDPOTrainer(DPOTrainer): - # For reusing the dpo_loss function in TRL. + # For reusing the dpo_loss function implemented in Swift's DPOTrainer. def __init__(self, args): - from trl.trainer import FDivergenceConstants self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) self.f_alpha_divergence_coef = 1. - self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: self.f_alpha_divergence_coef} + self.f_divergence_params = {'alpha_divergence_coef': self.f_alpha_divergence_coef} self.reference_free = args.reference_free self.label_smoothing = args.label_smoothing self.f_divergence_type = args.f_divergence_type diff --git a/swift/rlhf_trainers/arguments.py b/swift/rlhf_trainers/arguments.py index 1abcf5b184..311bc3edb0 100644 --- a/swift/rlhf_trainers/arguments.py +++ b/swift/rlhf_trainers/arguments.py @@ -28,6 +28,10 @@ @dataclass class DPOConfig(TrainArgumentsMixin, HfDPOConfig): ld_alpha: Optional[float] = None # compat trl==0.15 + # Fields removed in trl 0.29, kept here for backward compatibility + rpo_alpha: Optional[float] = None + ref_adapter_name: Optional[str] = None + reference_free: Optional[bool] = None def __post_init__(self): TrainArgumentsMixin.__post_init__(self) diff --git a/swift/rlhf_trainers/dpo_trainer.py b/swift/rlhf_trainers/dpo_trainer.py index a43acaf8be..1a190d0af1 100644 --- a/swift/rlhf_trainers/dpo_trainer.py +++ b/swift/rlhf_trainers/dpo_trainer.py @@ -1,24 +1,45 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch import torch.nn as nn +import torch.nn.functional as F import warnings from accelerate.utils import gather_object +from contextlib import contextmanager, nullcontext from peft import PeftModel from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import DPOTrainer as HFDPOTrainer from trl.trainer.dpo_config import DPOConfig -from trl.trainer.utils import RunningMoments -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from swift.trainers import DataLoaderMixin, SwiftMixin from swift.utils import get_logger, to_device from .rlhf_mixin import RLHFTrainerMixin +try: + from trl.trainer.utils import RunningMoments +except ImportError: + # trl >= 0.29 + from trl.experimental.bco.bco_trainer import RunningMoments + +_ALPHA_DIVERGENCE_COEF_KEY = 'alpha_divergence_coef' +_ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 + del HFDPOTrainer.__init__ logger = get_logger() +def _get_exp_cap(value, decimal=4): + vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max + vdtype_log_max = torch.log(vdtype_max).to(value.device) + return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max + + +def _cap_exp(value, cap=-1): + cap = _get_exp_cap(value) if cap < 0 else cap + return torch.exp(torch.clamp(value, max=cap)) + + def new_gather_function(tensor): tensor_list = gather_object([tensor]) tensor_list = [t[None] if t.ndim == 0 else t for t in tensor_list] @@ -32,7 +53,6 @@ def __init__(self, ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): - from trl.trainer import FDivergenceConstants args = kwargs['args'] self.label_smoothing = args.label_smoothing if 'loss_weights' in DPOConfig.__dict__: @@ -56,21 +76,24 @@ def __init__(self, raise ValueError('Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.') self.precompute_ref_log_probs = args.precompute_ref_log_probs - self.f_divergence_type = args.f_divergence_type - self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + self.f_divergence_type = getattr(args, 'f_divergence_type', 'reverse_kl') + self.f_alpha_divergence_coef = getattr(args, 'f_alpha_divergence_coef', 0.5) + self.f_divergence_params = {_ALPHA_DIVERGENCE_COEF_KEY: self.f_alpha_divergence_coef} self.is_peft_model = isinstance(model, PeftModel) - self.ref_adapter_name = args.ref_adapter_name + self.ref_adapter_name = getattr(args, 'ref_adapter_name', None) self.model_adapter_name = None - self.reference_free = args.reference_free + self.reference_free = getattr(args, 'reference_free', None) or False self.use_weighting = False super().__init__(model, ref_model, *_args, **kwargs) if 'bco_pair' in loss_types: self.running = RunningMoments(self.accelerator) + if self.args.ld_alpha is not None: require_version('trl>=0.18', '`ld_alpha` requires that "trl>=0.18".') + if self.template.packing: self.accelerator.gather_for_metrics = new_gather_function @@ -175,11 +198,286 @@ def concatenated_forward( output['aux_loss'] = outputs.aux_loss return output + # some methods are removed in trl>=0.29, override them to compatible trl<0.29 and trl>=0.29 + # consider abort to refactor these methods to follow trl>=0.29 in the future + @contextmanager + def null_ref_context(self): + with (self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name else nullcontext()): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or 'default') + + def compute_ref_log_probs(self, batch): + compute_ref_context_manager = ( + torch.autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) + with torch.no_grad(), compute_ref_context_manager: + if self.ref_model is None: + with self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) + else: + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) + return ref_model_output['chosen_logps'], ref_model_output['rejected_logps'] + + def dpo_loss( + self, + chosen_logps: torch.FloatTensor, + rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + loss_type: str = 'sigmoid', + model_output: Optional[Dict[str, torch.FloatTensor]] = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + device = self.accelerator.device + + chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) + rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) + + if self.f_divergence_type == 'alpha_divergence': + alpha_coef = _ALPHA_DIVERGENCE_COEF_DEFAULT + if self.f_divergence_params and _ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: + alpha_coef = float(self.f_divergence_params[_ALPHA_DIVERGENCE_COEF_KEY]) + logits = (_cap_exp(rejected_logratios * -alpha_coef) + - _cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef + else: + logratios = chosen_logps - rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logratios = logratios.to(device) + ref_logratios = ref_logratios.to(device) + logits = logratios - ref_logratios + + if self.f_divergence_type == 'js_divergence': + logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) + + if loss_type == 'sigmoid': + losses = (-F.logsigmoid(self.beta * logits) * + (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing) + + elif loss_type == 'robust': + losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing) / (1 - 2 * self.label_smoothing) + + elif loss_type == 'exo_pair': + import math + if self.label_smoothing == 0: + self.label_smoothing = 1e-3 + losses = (self.beta * logits).sigmoid() * (F.logsigmoid( + self.beta * logits) - math.log(1 - self.label_smoothing)) + (-self.beta * logits).sigmoid() * ( + F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) + + elif loss_type == 'hinge': + losses = torch.relu(1 - self.beta * logits) + + elif loss_type == 'ipo': + losses = (logits - 1 / (2 * self.beta))**2 + + elif loss_type == 'bco_pair': + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) + delta = self.running.mean + losses = -F.logsigmoid( + (self.beta * chosen_logratios) - delta) - F.logsigmoid(-(self.beta * rejected_logratios - delta)) + + elif loss_type == 'sppo_hard': + a = chosen_logps - ref_chosen_logps + b = rejected_logps - ref_rejected_logps + losses = (a - 0.5 / self.beta)**2 + (b + 0.5 / self.beta)**2 + + elif loss_type == 'nca_pair': + chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta + rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta + losses = (-F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) + - 0.5 * F.logsigmoid(-rejected_rewards)) + + elif loss_type == 'aot_unpaired': + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) + rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) + delta = chosen_logratios_sorted - rejected_logratios_sorted + losses = (-F.logsigmoid(self.beta * delta) * + (1 - self.label_smoothing) - F.logsigmoid(-self.beta * delta) * self.label_smoothing) + + elif loss_type == 'aot': + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logratios_sorted, _ = torch.sort(logratios, dim=0) + ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) + delta = logratios_sorted - ref_logratios_sorted + losses = (-F.logsigmoid(self.beta * delta) * + (1 - self.label_smoothing) - F.logsigmoid(-self.beta * delta) * self.label_smoothing) + + elif loss_type == 'apo_zero': + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) + losses_rejected = F.sigmoid(self.beta * rejected_logratios) + losses = losses_chosen + losses_rejected + + elif loss_type == 'apo_down': + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) + losses = losses_chosen + losses_rejected + + elif loss_type == 'discopop': + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logits = logratios - ref_logratios + logits = logits * self.beta + log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) + logistic_component = -F.logsigmoid(logits) + exp_component = torch.exp(-logits) + losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation + + elif loss_type == 'sft': + sft_loss = model_output['nll_loss'] + batch_size = chosen_logps.shape[0] + losses = sft_loss.expand(batch_size) + chosen_rewards = torch.zeros_like(chosen_logps) + rejected_rewards = torch.zeros_like(rejected_logps) + + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_unpaired', 'discopop', 'apo_zero', " + "'apo_down', 'sft']") + + chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() + rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() + + return losses, chosen_rewards, rejected_rewards + + def get_batch_loss_metrics( + self, + model: Union[PreTrainedModel, nn.Module], + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal['train', 'eval'] = 'train', + ) -> Tuple[torch.Tensor, Dict[str, float]]: + metrics = {} + + model_output = self.concatenated_forward(model, batch) + + if 'ref_chosen_logps' in batch and 'ref_rejected_logps' in batch: + ref_chosen_logps = batch['ref_chosen_logps'] + ref_rejected_logps = batch['ref_rejected_logps'] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + losses = 0 + chosen_rewards = 0 + rejected_rewards = 0 + + loss_types = self.loss_type if isinstance(self.loss_type, list) else [self.loss_type] + loss_weights = self.loss_weights if hasattr(self, 'loss_weights') and self.loss_weights else None + for idx, loss_type in enumerate(loss_types): + _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss( + model_output['chosen_logps'], + model_output['rejected_logps'], + ref_chosen_logps, + ref_rejected_logps, + loss_type, + model_output, + ) + weight = loss_weights[idx] if loss_weights else 1.0 + losses = losses + _losses * weight + chosen_rewards = chosen_rewards + _chosen_rewards * weight + rejected_rewards = rejected_rewards + _rejected_rewards * weight + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + if self.args.rpo_alpha is not None: + losses = losses + self.args.rpo_alpha * model_output['nll_loss'] + + if self.use_weighting: + losses = losses * model_output['policy_weights'] + + if self.aux_loss_enabled: + losses = losses + self.aux_loss_coef * model_output['aux_loss'] + + prefix = 'eval_' if train_eval == 'eval' else '' + metrics[f'{prefix}rewards/chosen'] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f'{prefix}rewards/rejected'] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f'{prefix}rewards/accuracies'] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f'{prefix}rewards/margins'] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()) + metrics[f'{prefix}logps/chosen'] = ( + self.accelerator.gather_for_metrics(model_output['chosen_logps']).detach().mean().item()) + metrics[f'{prefix}logps/rejected'] = ( + self.accelerator.gather_for_metrics(model_output['rejected_logps']).detach().mean().item()) + metrics[f'{prefix}logits/chosen'] = ( + self.accelerator.gather_for_metrics(model_output['mean_chosen_logits']).detach().mean().item()) + metrics[f'{prefix}logits/rejected'] = ( + self.accelerator.gather_for_metrics(model_output['mean_rejected_logits']).detach().mean().item()) + if self.args.rpo_alpha is not None or 'sft' in loss_types: + metrics[f'{prefix}nll_loss'] = ( + self.accelerator.gather_for_metrics(model_output['nll_loss']).detach().mean().item()) + if self.aux_loss_enabled: + metrics[f'{prefix}aux_loss'] = ( + self.accelerator.gather_for_metrics(model_output['aux_loss']).detach().mean().item()) + + return losses.mean(), metrics + + def store_metrics(self, metrics, train_eval='train'): + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs, start_time=None): + from transformers import Trainer + train_eval = 'train' if 'loss' in logs else 'eval' + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + mode = 'train' if self.model.training else 'eval' + custom_metrics = self.custom_metrics[mode] + prefix = 'eval_' if mode == 'eval' else '' + logs.update(self.compute_custom_metrics(custom_metrics, prefix)) + return Trainer.log(self, logs, start_time) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + compute_loss_context_manager = ( + torch.autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()) + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval='train') + + loss = loss.to(self.args.device) + self.store_metrics(metrics, train_eval='train') + + if num_items_in_batch is not None and self.model_accepts_loss_kwargs: + loss = loss / self.args.gradient_accumulation_steps + + if return_outputs: + return loss, metrics + return loss + def training_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): return super().training_step(model, inputs, *args, **kwargs) - def prediction_step(self, model, inputs, *args, **kwargs): + def prediction_step(self, model, inputs, prediction_loss_only=False, *args, **kwargs): with self.template.forward_context(self.model, inputs): inputs = self._prepare_inputs(inputs) - return super().prediction_step(model, inputs, *args, **kwargs) + + with torch.no_grad(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval='eval') + + self.store_metrics(metrics, train_eval='eval') + + if prediction_loss_only: + return loss.detach(), None, None + + logits_dict = { + 'eval_logits/chosen': metrics['eval_logits/chosen'], + 'eval_logits/rejected': metrics['eval_logits/rejected'], + } + logits = torch.tensor(list(logits_dict.values()), device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index cf4472f63f..152fbb8508 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -113,14 +113,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non # Initialize rollout infrastructure for vLLM support self.prepare_rollout() - # Initialize activation offloading context - args.activation_offloading = False # TODO: remove - if args.activation_offloading: - from trl.models import get_act_offloading_ctx_manager - self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) - else: - self.maybe_activation_offload_context = nullcontext() - # Initialize resample data iterator for truncation_strategy 'raise'('delete') if self.template.truncation_strategy == 'raise': self._prepare_resample_data_iterator() diff --git a/swift/rlhf_trainers/reward_trainer.py b/swift/rlhf_trainers/reward_trainer.py index e099350688..53c03baa10 100644 --- a/swift/rlhf_trainers/reward_trainer.py +++ b/swift/rlhf_trainers/reward_trainer.py @@ -29,14 +29,6 @@ class RewardTrainer(RLHFTrainerMixin, SwiftMixin, HFRewardTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - try: - from trl.models import get_act_offloading_ctx_manager - if getattr(self.args, 'activation_offloading', False): - self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) - else: - self.maybe_activation_offload_context = nullcontext() - except ImportError: - self.maybe_activation_offload_context = nullcontext() self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} if version.parse(trl.__version__) >= version.parse('0.24'): # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty. diff --git a/swift/rlhf_trainers/rlhf_mixin.py b/swift/rlhf_trainers/rlhf_mixin.py index d60c6f1944..a3ab2c24ec 100644 --- a/swift/rlhf_trainers/rlhf_mixin.py +++ b/swift/rlhf_trainers/rlhf_mixin.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial from torch.utils.data import DataLoader from transformers import PreTrainedModel @@ -42,6 +42,7 @@ def __init__(self, self.is_vision_model = False self.label_pad_token_id = -100 self.use_dpo_data_collator = True + self.maybe_activation_offload_context = nullcontext() super().__init__(model, *_args, **kwargs) self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0 self.aux_loss_coef = args.router_aux_loss_coef