44import json
55import os
66import re
7- from typing import Dict , List , Union
7+ from typing import TYPE_CHECKING , Dict , List , Optional , Union
88
99from swift .infer_engine import InferRequest
1010
11+ if TYPE_CHECKING :
12+ from swift .rlhf_trainers import GRPOConfig
13+
1114
1215class ORM :
1316 """Base class for synchronous outcome reward models (ORM).
@@ -20,6 +23,9 @@ def __call__(self, completions, **kwargs) -> List[float]:
2023 return [1.0 if len(c) > 100 else 0.0 for c in completions]
2124 """
2225
26+ def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
27+ self .args = args
28+
2329 def __call__ (self , ** kwargs ) -> List [float ]:
2430 raise NotImplementedError
2531
@@ -52,13 +58,17 @@ async def score_single(session, text):
5258 return list(rewards)
5359 """
5460
61+ def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
62+ self .args = args
63+
5564 async def __call__ (self , ** kwargs ) -> List [float ]:
5665 raise NotImplementedError
5766
5867
5968class MathAccuracy (ORM ):
6069
61- def __init__ (self ):
70+ def __init__ (self , args = None , ** kwargs ):
71+ super ().__init__ (args , ** kwargs )
6272 import importlib .util
6373 assert importlib .util .find_spec ('math_verify' ) is not None , (
6474 'The math_verify package is required but not installed. '
@@ -129,18 +139,13 @@ def __call__(self, completions, **kwargs) -> List[float]:
129139
130140class CosineReward (ORM ):
131141 # https://arxiv.org/abs/2502.03373
132- def __init__ (self ,
133- cosine_min_len_value_wrong : float = - 0.5 ,
134- cosine_max_len_value_wrong : float = 0.0 ,
135- cosine_min_len_value_correct : float = 1.0 ,
136- cosine_max_len_value_correct : float = 0.5 ,
137- cosine_max_len : int = 1000 ,
138- accuracy_orm = None ):
139- self .min_len_value_wrong = cosine_min_len_value_wrong
140- self .max_len_value_wrong = cosine_max_len_value_wrong
141- self .min_len_value_correct = cosine_min_len_value_correct
142- self .max_len_value_correct = cosine_max_len_value_correct
143- self .max_len = cosine_max_len
142+ def __init__ (self , args : Optional ['GRPOConfig' ] = None , accuracy_orm = None ):
143+ super ().__init__ (args )
144+ self .min_len_value_wrong = args .cosine_min_len_value_wrong
145+ self .max_len_value_wrong = args .cosine_max_len_value_wrong
146+ self .min_len_value_correct = args .cosine_min_len_value_correct
147+ self .max_len_value_correct = args .cosine_max_len_value_correct
148+ self .max_len = args .cosine_max_len
144149 self .accuracy_orm = accuracy_orm or MathAccuracy ()
145150
146151 @staticmethod
@@ -169,9 +174,10 @@ def __call__(self, completions, solution, **kwargs) -> List[float]:
169174
170175class RepetitionPenalty (ORM ):
171176 # https://arxiv.org/abs/2502.03373
172- def __init__ (self , repetition_n_grams : int = 3 , repetition_max_penalty : float = - 1.0 ):
173- self .ngram_size = repetition_n_grams
174- self .max_penalty = repetition_max_penalty
177+ def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
178+ super ().__init__ (args )
179+ self .ngram_size = args .repetition_n_grams
180+ self .max_penalty = args .repetition_max_penalty
175181
176182 @staticmethod
177183 def zipngram (text : str , ngram_size : int ):
@@ -208,10 +214,11 @@ def __call__(self, completions, **kwargs) -> List[float]:
208214
209215class SoftOverlong (ORM ):
210216
211- def __init__ (self , soft_max_length , soft_cache_length ):
212- assert soft_cache_length < soft_max_length
213- self .soft_max_length = soft_max_length
214- self .soft_cache_length = soft_cache_length
217+ def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
218+ super ().__init__ (args )
219+ assert args .soft_cache_length < args .soft_max_length
220+ self .soft_max_length = args .soft_max_length
221+ self .soft_cache_length = args .soft_cache_length
215222
216223 def __call__ (self , completions , ** kwargs ) -> List [float ]:
217224 rewards = []
@@ -369,7 +376,8 @@ def evaluate_rougel(cand_list: list, ref_list: list):
369376
370377class MathORM (ORM ):
371378
372- def __init__ (self ):
379+ def __init__ (self , args = None , ** kwargs ):
380+ super ().__init__ (args )
373381 from transformers .utils import strtobool
374382 self .use_opencompass = strtobool (os .environ .get ('USE_OPENCOMPASS_EVALUATOR' , 'False' ))
375383 if self .use_opencompass :
0 commit comments