Skip to content

Commit 956675c

Browse files
committed
refactor
1 parent 2fbb485 commit 956675c

File tree

8 files changed

+82
-52
lines changed

8 files changed

+82
-52
lines changed

examples/train/grpo/external/vllm_gym.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# CUDA_VISIBLE_DEVICES=7 \
44
# swift rollout \
55
# --model Qwen/Qwen2.5-3B-Instruct \
6-
# --model_type qwen2_5\
76
# --max_turns 3\
87
# --multi_turn_scheduler gym_scheduler \
98
# --use_gym_env true \

examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def rule_math_verify(ground_truth, model_answer):
187187

188188
class DeepEyesReward(ORM):
189189

190-
def __init__(self):
191-
super().__init__()
190+
def __init__(self, args, **kwargs):
191+
super().__init__(args)
192192
try:
193193
self.client = OpenAI(
194194
api_key='EMPTY',

examples/train/grpo/plugin/plugin.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ class MultiTurnThinkingTips(ORM):
154154
function **must return an identical reward for every fragment**
155155
"""
156156

157-
def __init__(self):
157+
def __init__(self, args=None, **kwargs):
158+
super().__init__(args)
158159
from swift.rewards.orm import MathAccuracy
159160
self.acc_func = MathAccuracy()
160161

@@ -183,7 +184,8 @@ def __call__(self, completions, **kwargs) -> List[float]:
183184
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
184185
class CodeReward(ORM):
185186

186-
def __init__(self):
187+
def __init__(self, args=None, **kwargs):
188+
super().__init__(args)
187189
import importlib.util
188190
assert importlib.util.find_spec('e2b') is not None, (
189191
"The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
@@ -368,7 +370,8 @@ class CodeRewardByJudge0(ORM):
368370
}
369371
PYTHON_ID = 71
370372

371-
def __init__(self):
373+
def __init__(self, args, **kwargs):
374+
super().__init__(args)
372375
self.endpoint = os.getenv('JUDGE0_ENDPOINT')
373376
assert self.endpoint is not None, (
374377
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
@@ -488,7 +491,8 @@ class AsyncGenRMReward(AsyncORM):
488491
```
489492
"""
490493

491-
def __init__(self):
494+
def __init__(self, args, **kwargs):
495+
super().__init__(args)
492496
from openai import OpenAI
493497
self.api_base = os.getenv('GENRM_API_BASE', 'http://localhost:8000/v1')
494498
self.temperature = float(os.getenv('GENRM_TEMPERATURE', '0.3'))
@@ -637,7 +641,8 @@ async def __call__(self, completions, messages, **kwargs) -> List[float]:
637641
# COARSEREWARD -> Coarse, INTERMEDIATEREWARD -> Intermediate, REFINEDREWARD -> Finegrained
638642
class ToolUseFormatReward(ORM):
639643

640-
def __init__(self):
644+
def __init__(self, args=None, **kwargs):
645+
super().__init__(args)
641646
self.format_max_possible = 1.0
642647
self.format_min_possible = 0.0
643648

@@ -700,7 +705,8 @@ def __call__(self, completions, solution, **kwargs) -> List[float]:
700705

701706
class ToolUseLengthReward(ORM):
702707

703-
def __init__(self):
708+
def __init__(self, args=None, **kwargs):
709+
super().__init__(args)
704710
self.length_max_possible = 1.0
705711
self.length_min_possible = 0.0
706712

@@ -739,7 +745,8 @@ def __call__(self, completions, solution, **kwargs):
739745

740746
class ToolUseCorrectnessReward(ORM):
741747

742-
def __init__(self):
748+
def __init__(self, args=None, **kwargs):
749+
super().__init__(args)
743750
if str(os.getenv('CORRECTMAX1', 0)) == '1':
744751
self.tool_max_possible = 1.0
745752
self.tool_min_possible = -1.0

swift/rewards/orm.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import json
55
import os
66
import re
7-
from typing import Dict, List, Union
7+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
88

99
from swift.infer_engine import InferRequest
1010

11+
if TYPE_CHECKING:
12+
from swift.rlhf_trainers import GRPOConfig
13+
1114

1215
class ORM:
1316
"""Base class for synchronous outcome reward models (ORM).
@@ -20,6 +23,9 @@ def __call__(self, completions, **kwargs) -> List[float]:
2023
return [1.0 if len(c) > 100 else 0.0 for c in completions]
2124
"""
2225

26+
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
27+
self.args = args
28+
2329
def __call__(self, **kwargs) -> List[float]:
2430
raise NotImplementedError
2531

@@ -52,13 +58,17 @@ async def score_single(session, text):
5258
return list(rewards)
5359
"""
5460

61+
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
62+
self.args = args
63+
5564
async def __call__(self, **kwargs) -> List[float]:
5665
raise NotImplementedError
5766

5867

5968
class MathAccuracy(ORM):
6069

61-
def __init__(self):
70+
def __init__(self, args=None, **kwargs):
71+
super().__init__(args, **kwargs)
6272
import importlib.util
6373
assert importlib.util.find_spec('math_verify') is not None, (
6474
'The math_verify package is required but not installed. '
@@ -129,18 +139,13 @@ def __call__(self, completions, **kwargs) -> List[float]:
129139

130140
class CosineReward(ORM):
131141
# https://arxiv.org/abs/2502.03373
132-
def __init__(self,
133-
cosine_min_len_value_wrong: float = -0.5,
134-
cosine_max_len_value_wrong: float = 0.0,
135-
cosine_min_len_value_correct: float = 1.0,
136-
cosine_max_len_value_correct: float = 0.5,
137-
cosine_max_len: int = 1000,
138-
accuracy_orm=None):
139-
self.min_len_value_wrong = cosine_min_len_value_wrong
140-
self.max_len_value_wrong = cosine_max_len_value_wrong
141-
self.min_len_value_correct = cosine_min_len_value_correct
142-
self.max_len_value_correct = cosine_max_len_value_correct
143-
self.max_len = cosine_max_len
142+
def __init__(self, args: Optional['GRPOConfig'] = None, accuracy_orm=None):
143+
super().__init__(args)
144+
self.min_len_value_wrong = args.cosine_min_len_value_wrong
145+
self.max_len_value_wrong = args.cosine_max_len_value_wrong
146+
self.min_len_value_correct = args.cosine_min_len_value_correct
147+
self.max_len_value_correct = args.cosine_max_len_value_correct
148+
self.max_len = args.cosine_max_len
144149
self.accuracy_orm = accuracy_orm or MathAccuracy()
145150

146151
@staticmethod
@@ -169,9 +174,10 @@ def __call__(self, completions, solution, **kwargs) -> List[float]:
169174

170175
class RepetitionPenalty(ORM):
171176
# https://arxiv.org/abs/2502.03373
172-
def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0):
173-
self.ngram_size = repetition_n_grams
174-
self.max_penalty = repetition_max_penalty
177+
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
178+
super().__init__(args)
179+
self.ngram_size = args.repetition_n_grams
180+
self.max_penalty = args.repetition_max_penalty
175181

176182
@staticmethod
177183
def zipngram(text: str, ngram_size: int):
@@ -208,10 +214,11 @@ def __call__(self, completions, **kwargs) -> List[float]:
208214

209215
class SoftOverlong(ORM):
210216

211-
def __init__(self, soft_max_length, soft_cache_length):
212-
assert soft_cache_length < soft_max_length
213-
self.soft_max_length = soft_max_length
214-
self.soft_cache_length = soft_cache_length
217+
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
218+
super().__init__(args)
219+
assert args.soft_cache_length < args.soft_max_length
220+
self.soft_max_length = args.soft_max_length
221+
self.soft_cache_length = args.soft_cache_length
215222

216223
def __call__(self, completions, **kwargs) -> List[float]:
217224
rewards = []
@@ -369,7 +376,8 @@ def evaluate_rougel(cand_list: list, ref_list: list):
369376

370377
class MathORM(ORM):
371378

372-
def __init__(self):
379+
def __init__(self, args=None, **kwargs):
380+
super().__init__(args)
373381
from transformers.utils import strtobool
374382
self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False'))
375383
if self.use_opencompass:

swift/rlhf_trainers/arguments.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import trl
23
from dataclasses import dataclass
4+
from packaging import version
35
from transformers.utils.versions import require_version
4-
from trl import CPOConfig as HfCPOConfig
6+
7+
if version.parse(trl.__version__) <= version.parse('0.28'):
8+
from trl import CPOConfig as HfCPOConfig
9+
from trl import GKDConfig as HfGKDConfig
10+
from trl import ORPOConfig as HfORPOConfig
11+
from trl import PPOConfig as HfPPOConfig
12+
else:
13+
from trl.experimental.cpo import CPOConfig as HfCPOConfig
14+
from trl.experimental.gkd import GKDConfig as HfGKDConfig
15+
from trl.experimental.orpo import ORPOConfig as HfORPOConfig
16+
from trl.experimental.ppo import PPOConfig as HfPPOConfig
17+
518
from trl import DPOConfig as HfDPOConfig
6-
from trl import GKDConfig as HfGKDConfig
719
from trl import GRPOConfig as HfGRPOConfig
820
from trl import KTOConfig as HfKTOConfig
9-
from trl import ORPOConfig as HfORPOConfig
10-
from trl import PPOConfig as HfPPOConfig
1121
from trl import RewardConfig as HfRewardConfig
1222
from typing import Optional
1323

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,10 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
943943
# rollout_logprobs is List[List[float]] - nested list where each inner list corresponds to
944944
# one assistant response turn. We need to align these with completion_mask positions.
945945
batch_encoded_inputs['rollout_per_token_logps'] = None
946-
if self.use_fast_infer:
946+
should_compute_rollout_logprobs = (
947+
self.rollout_importance_sampling_mode is not None or self.log_rollout_offpolicy_metrics)
948+
949+
if self.use_fast_infer and should_compute_rollout_logprobs:
947950
rollout_logprobs_list = []
948951
for data in batch:
949952
if 'rollout_logprobs' in data and data['rollout_logprobs']:
@@ -2206,14 +2209,7 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non
22062209
for i, reward_func in enumerate(reward_funcs):
22072210
if reward_func in orms:
22082211
reward_func_class = orms[reward_func]
2209-
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
2210-
reward_func_kwargs = {
2211-
key: getattr(args, key)
2212-
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
2213-
}
2214-
if 'tokenizer' in reward_func_args:
2215-
reward_func_kwargs['tokenizer'] = self.processing_class
2216-
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
2212+
reward_funcs[i] = reward_func_class(args=args)
22172213
elif not callable(reward_func):
22182214
raise ValueError(f'reward_function {reward_func} is not implemented in swift.rewards')
22192215

@@ -2247,6 +2243,9 @@ def _prepare_rewards(self, reward_funcs, reward_model=None, reward_templates=Non
22472243
self.reward_funcs.append(rm)
22482244
self.reward_func_names.append(rm.config._name_or_path.split('/')[-1])
22492245

2246+
if self.use_gym_env and not self.reward_func_names:
2247+
self.reward_func_names = ['gym_reward']
2248+
22502249
# Reward weights
22512250
if args.reward_weights is not None:
22522251
if len(args.reward_weights) != len(reward_funcs):

swift/rlhf_trainers/rollout_mixin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ def _prepare_vllm(self):
167167
self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0]
168168
self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0]
169169
self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0]
170-
if self.use_gym_env:
171-
self.reward_func_names = ['gym_reward']
172170

173171
elif self.vllm_mode == 'colocate':
174172
if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0:

swift/rlhf_trainers/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ def _patched_stateless_pg_create(
206206
patch_stateless_process_group_for_ipv6()
207207

208208

209-
def nanstd(tensor: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor:
209+
def nanstd(tensor: torch.Tensor,
210+
dim: Optional[Union[int, tuple[int, ...]]] = None,
211+
keepdim: bool = False) -> torch.Tensor:
210212
"""
211213
Compute the standard deviation of a tensor, ignoring NaNs.
212214
@@ -215,7 +217,7 @@ def nanstd(tensor: torch.Tensor, dim: Optional[int] = None, keepdim: bool = Fals
215217
Args:
216218
tensor (`torch.Tensor`):
217219
Input tensor.
218-
dim (`int`, *optional*):
220+
dim (`int` or `tuple[int, ...]`, *optional*):
219221
Dimension to reduce. Defaults to all dimensions.
220222
keepdim (`bool`, *optional*, defaults to `False`):
221223
Whether to keep reduced dimensions.
@@ -227,13 +229,20 @@ def nanstd(tensor: torch.Tensor, dim: Optional[int] = None, keepdim: bool = Fals
227229
mean = torch.nanmean(tensor, dim=dim, keepdim=True)
228230
variance = torch.nanmean((tensor - mean)**2, dim=dim, keepdim=True)
229231
count = torch.sum(~torch.isnan(tensor), dim=dim, keepdim=True)
230-
correction = torch.where(count > 1, count / (count - 1), torch.full_like(count, float('nan')))
231-
std = torch.sqrt(variance * correction)
232+
correction = count / (count - 1)
233+
correction = torch.where(count > 1, correction, torch.full_like(correction, float('nan')))
234+
variance *= correction # Bessel's correction
235+
std = torch.sqrt(variance)
232236
if keepdim:
233237
return std
234238
if dim is None:
235239
return std.squeeze()
236-
return std.squeeze(dim)
240+
if isinstance(dim, int):
241+
return std.squeeze(dim)
242+
dims = [(d if d >= 0 else d + std.ndim) for d in dim]
243+
for d in sorted(dims, reverse=True):
244+
std = std.squeeze(d)
245+
return std
237246

238247

239248
# code borrowed from verl/verl/utils/memory_utils.py

0 commit comments

Comments
 (0)