Skip to content

Commit 59e789d

Browse files
committed
megatron
1 parent 956675c commit 59e789d

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

swift/megatron/trainers/grpo_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
class MegatronGRPOTrainer(MegatronRolloutMixin, MegatronRLHFTrainer):
4545

46-
def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
46+
def __init__(self, args: MegatronArguments, template: Template, **kwargs):
4747
self.vllm_client = kwargs.pop('vllm_client')
4848
super().__init__(args, template)
4949
self.args = args
@@ -145,14 +145,7 @@ def _prepare_rewards(self):
145145
for i, reward_func in enumerate(reward_funcs):
146146
if reward_func in orms:
147147
reward_func_class = orms[reward_func]
148-
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
149-
reward_func_kwargs = {
150-
key: getattr(args, key)
151-
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
152-
}
153-
if 'tokenizer' in reward_func_args:
154-
reward_func_kwargs['tokenizer'] = self.processing_class
155-
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
148+
reward_funcs[i] = reward_func_class(args=self.args)
156149
elif not callable(reward_func):
157150
raise ValueError(f'reward_function {reward_func} is not implemented in swift.rewards')
158151

swift/rewards/orm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from swift.infer_engine import InferRequest
1010

1111
if TYPE_CHECKING:
12+
from swift.megatron.arguments import MegatronArguments
1213
from swift.rlhf_trainers import GRPOConfig
1314

1415

@@ -23,7 +24,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
2324
return [1.0 if len(c) > 100 else 0.0 for c in completions]
2425
"""
2526

26-
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
27+
def __init__(self, args: Optional[Union['GRPOConfig', 'MegatronArguments']] = None, **kwargs):
2728
self.args = args
2829

2930
def __call__(self, **kwargs) -> List[float]:
@@ -58,7 +59,7 @@ async def score_single(session, text):
5859
return list(rewards)
5960
"""
6061

61-
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
62+
def __init__(self, args: Optional[Union['GRPOConfig', 'MegatronArguments']] = None, **kwargs):
6263
self.args = args
6364

6465
async def __call__(self, **kwargs) -> List[float]:
@@ -139,7 +140,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
139140

140141
class CosineReward(ORM):
141142
# https://arxiv.org/abs/2502.03373
142-
def __init__(self, args: Optional['GRPOConfig'] = None, accuracy_orm=None):
143+
def __init__(self, args: Optional[Union['GRPOConfig', 'MegatronArguments']] = None, accuracy_orm=None):
143144
super().__init__(args)
144145
self.min_len_value_wrong = args.cosine_min_len_value_wrong
145146
self.max_len_value_wrong = args.cosine_max_len_value_wrong
@@ -174,7 +175,7 @@ def __call__(self, completions, solution, **kwargs) -> List[float]:
174175

175176
class RepetitionPenalty(ORM):
176177
# https://arxiv.org/abs/2502.03373
177-
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
178+
def __init__(self, args: Optional[Union['GRPOConfig', 'MegatronArguments']] = None, **kwargs):
178179
super().__init__(args)
179180
self.ngram_size = args.repetition_n_grams
180181
self.max_penalty = args.repetition_max_penalty
@@ -214,7 +215,7 @@ def __call__(self, completions, **kwargs) -> List[float]:
214215

215216
class SoftOverlong(ORM):
216217

217-
def __init__(self, args: Optional['GRPOConfig'] = None, **kwargs):
218+
def __init__(self, args: Optional[Union['GRPOConfig', 'MegatronArguments']] = None, **kwargs):
218219
super().__init__(args)
219220
assert args.soft_cache_length < args.soft_max_length
220221
self.soft_max_length = args.soft_max_length

0 commit comments

Comments
 (0)