99from swift .infer_engine import InferRequest
1010
1111if TYPE_CHECKING :
12+ from swift .megatron .arguments import MegatronArguments
1213 from swift .rlhf_trainers import GRPOConfig
1314
1415
@@ -23,7 +24,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
2324 return [1.0 if len(c) > 100 else 0.0 for c in completions]
2425 """
2526
26- def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
27+ def __init__ (self , args : Optional [Union [ 'GRPOConfig' , 'MegatronArguments' ] ] = None , ** kwargs ):
2728 self .args = args
2829
2930 def __call__ (self , ** kwargs ) -> List [float ]:
@@ -58,7 +59,7 @@ async def score_single(session, text):
5859 return list(rewards)
5960 """
6061
61- def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
62+ def __init__ (self , args : Optional [Union [ 'GRPOConfig' , 'MegatronArguments' ] ] = None , ** kwargs ):
6263 self .args = args
6364
6465 async def __call__ (self , ** kwargs ) -> List [float ]:
@@ -139,7 +140,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
139140
140141class CosineReward (ORM ):
141142 # https://arxiv.org/abs/2502.03373
142- def __init__ (self , args : Optional ['GRPOConfig' ] = None , accuracy_orm = None ):
143+ def __init__ (self , args : Optional [Union [ 'GRPOConfig' , 'MegatronArguments' ] ] = None , accuracy_orm = None ):
143144 super ().__init__ (args )
144145 self .min_len_value_wrong = args .cosine_min_len_value_wrong
145146 self .max_len_value_wrong = args .cosine_max_len_value_wrong
@@ -174,7 +175,7 @@ def __call__(self, completions, solution, **kwargs) -> List[float]:
174175
175176class RepetitionPenalty (ORM ):
176177 # https://arxiv.org/abs/2502.03373
177- def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
178+ def __init__ (self , args : Optional [Union [ 'GRPOConfig' , 'MegatronArguments' ] ] = None , ** kwargs ):
178179 super ().__init__ (args )
179180 self .ngram_size = args .repetition_n_grams
180181 self .max_penalty = args .repetition_max_penalty
@@ -214,7 +215,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
214215
215216class SoftOverlong (ORM ):
216217
217- def __init__ (self , args : Optional ['GRPOConfig' ] = None , ** kwargs ):
218+ def __init__ (self , args : Optional [Union [ 'GRPOConfig' , 'MegatronArguments' ] ] = None , ** kwargs ):
218219 super ().__init__ (args )
219220 assert args .soft_cache_length < args .soft_max_length
220221 self .soft_max_length = args .soft_max_length
0 commit comments