From de8b44fbef82660301f981082fce5fde8a55724e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 9 Jun 2026 15:59:39 +0800 Subject: [PATCH 1/9] megatron remove num_samples --- swift/megatron/trainers/base.py | 4 ++-- swift/megatron/trainers/utils.py | 9 +++------ swift/template/base.py | 9 --------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 5990faf10d..cf211b028e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -984,8 +984,8 @@ def _should_use_npu_generated_attention_mask(self, args) -> bool: return (is_torch_npu_available() and args.task_type == 'causal_lm' and not args.padding_free and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False)) - def _prepare_batch(self, data, vp_stage=None, num_samples=None): - return prepare_batch(self.args, data, vp_stage=vp_stage, num_samples=num_samples) + def _prepare_batch(self, data, vp_stage=None): + return prepare_batch(self.args, data, vp_stage=vp_stage) def get_batch(self, data_iterator, vp_stage=None): """Generate a batch.""" diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 9dde1fd13e..66f229abef 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -353,15 +353,13 @@ def _should_use_npu_attention_mask(args) -> bool: and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False)) -def prepare_batch(args, data, vp_stage=None, num_samples=None): +def prepare_batch(args, data, vp_stage=None): """Prepare a micro-batch for Megatron forward: PP slicing, packed_seq_params, CP slicing. Extracted from BaseMegatronTrainer._prepare_batch for reuse in ray workers. """ batch = get_batch_on_this_pp_rank(args, data, vp_stage=vp_stage) - if num_samples is None: - num_samples = batch.pop('num_samples') - seq_lens = batch.pop('seq_lens', None) + seq_lens = batch.pop('seq_lens') text_position_ids = batch.pop('text_position_ids', None) if text_position_ids is None: text_position_ids = batch.get('position_ids') @@ -373,7 +371,6 @@ def prepare_batch(args, data, vp_stage=None, num_samples=None): batch.pop('attention_mask_2d', None) if args.padding_free and text_position_ids is not None: batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples if seq_lens is not None: batch['packed_seq_params'].seq_lens = torch.tensor(seq_lens, device=text_position_ids.device) batch = get_batch_on_this_cp_rank(args, batch) @@ -426,7 +423,7 @@ def compute_per_token_logps_fn(model, args, data_iterator, temperature=1.0, no_g packed_seq_params = data.get('packed_seq_params') if packed_seq_params is not None: - num_samples = packed_seq_params.num_samples + num_samples = len(packed_seq_params.seq_lens) else: input_ids = data.get('input_ids') num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] diff --git a/swift/template/base.py b/swift/template/base.py index cfa1e07207..c156f129d7 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -1630,7 +1630,6 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int from swift.dataset import RowPreprocessor if self.packing and isinstance(batch[0], list): batch = sum(batch, start=[]) - num_samples = len(batch) if self.task_type == 'causal_lm': if self.mode in {'transformers', 'train'}: res = self._data_collator(batch, padding_to=padding_to) @@ -1655,10 +1654,6 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int extra_kwargs = [b['_extra_kwargs'] for b in batch if b.get('_extra_kwargs') is not None] extra_kwargs = RowPreprocessor.rows_to_batched(extra_kwargs) res.update({k: v for k, v in extra_kwargs.items() if k not in res}) - if 'num_samples' in res: - num_samples = res.pop('num_samples') - if self.use_megatron: - res['num_samples'] = num_samples return res @staticmethod @@ -1765,9 +1760,7 @@ def _embedding_data_collator(self, for prefix in indexes: new_batch += self._fetch_inputs_startswith([b], prefix) labels.extend(b.get('labels', [])) - num_samples = len(new_batch) res = self._data_collator(new_batch, padding_to=padding_to) - res['num_samples'] = num_samples if labels: res['labels'] = torch.tensor(labels, dtype=torch.float32) return res @@ -1801,9 +1794,7 @@ def _reranker_data_collator(self, for key in b.keys() if isinstance(b[key], list) and b[key][j + positive_num] is not None }) labels_list.append(0) - num_samples = len(new_batch) res = self._data_collator(new_batch, padding_to=padding_to) - res['num_samples'] = num_samples if labels_list: res['labels'] = torch.tensor(labels_list, dtype=torch.long) else: From 2cd141bf9dd93169257aace2c17f95a2a39bb748 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 9 Jun 2026 17:28:38 +0800 Subject: [PATCH 2/9] update --- swift/megatron/trainers/base.py | 9 +++------ swift/megatron/trainers/dpo_trainer.py | 22 +++++++++++----------- swift/megatron/trainers/kto_trainer.py | 10 +++------- swift/megatron/trainers/rlhf_mixin.py | 5 +++-- swift/megatron/trainers/utils.py | 2 +- 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index cf211b028e..4bd98997ad 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1012,7 +1012,7 @@ def _collect_config_info(self) -> Dict[str, str]: } return {} - def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask=None, num_samples=None): + def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask=None): if self.args.context_parallel_size > 1: output_tensor = reconstruct_tensor_cp(output_tensor, packed_seq_params, dim=1) if packed_seq_params is None: @@ -1022,10 +1022,7 @@ def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask= last_token_idx = get_last_valid_indices(attention_mask.long()) last_tokens = output_tensor[torch.arange(output_tensor.shape[0]), last_token_idx] else: - num_samples = num_samples or packed_seq_params.num_samples - if self.args.context_parallel_size > 1: - last_token_idx = packed_seq_params.cu_seqlens_q[:num_samples] + packed_seq_params.seq_lens - 1 - else: - last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1 + num_samples = packed_seq_params.seq_lens.shape[0] + last_token_idx = packed_seq_params.cu_seqlens_q[:num_samples] + packed_seq_params.seq_lens - 1 last_tokens = output_tensor[0, last_token_idx] return last_tokens diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 56b5222b6f..98584b3c16 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -35,23 +35,23 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach() output_tensor = output_tensor[output_tensor.shape[0] // 2:] args = self.args - num_samples = labels.shape[0] // 2 if packed_seq_params is None else packed_seq_params.num_samples + num_samples = labels.shape[0] if packed_seq_params is None else packed_seq_params.seq_lens.shape[0] - logps = self.get_logps(output_tensor, labels, packed_seq_params, num_samples * 2) - ref_logps = self.get_logps(ref_output_tensor, labels, packed_seq_params, num_samples * 2) + logps = self.get_logps(output_tensor, labels, packed_seq_params) + ref_logps = self.get_logps(ref_output_tensor, labels, packed_seq_params) loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( - logps[:num_samples], - logps[num_samples:], - ref_logps[:num_samples], - ref_logps[num_samples:], + logps[:num_samples // 2], + logps[num_samples // 2:], + ref_logps[:num_samples // 2], + ref_logps[num_samples // 2:], ) if args.rpo_alpha: loss_mask = labels != -100 if args.padding_free: - num_tokens = packed_seq_params.cu_seqlens_q[num_samples] // args.context_parallel_size + num_tokens = packed_seq_params.cu_seqlens_q[num_samples // 2] // args.context_parallel_size loss_mask[:, num_tokens:] = 0 else: - loss_mask[num_samples:] = 0 + loss_mask[num_samples // 2:] = 0 nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) if args.context_parallel_size > 1: nll_loss = all_reduce(nll_loss, group=mpu.get_context_parallel_group()) @@ -60,8 +60,8 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed loss = loss.mean() metric = { 'loss': loss.detach().clone(), - 'logps/chosen': logps[:num_samples].mean(), - 'logps/rejected': logps[num_samples:].mean(), + 'logps/chosen': logps[:num_samples // 2].mean(), + 'logps/rejected': logps[num_samples // 2:].mean(), 'rewards/chosen': chosen_rewards.mean(), 'rewards/rejected': rejected_rewards.mean(), 'rewards/accuracies': (chosen_rewards > rejected_rewards).float().mean(), diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d9799dd0de..f7275af6ff 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -41,11 +41,8 @@ def __init__(self, args, template): self.dummy_kto_trainer = DummyKTOTrainer(args) def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool, length: int): - labels = data['labels'] - packed_seq_params = data.get('packed_seq_params') - num_samples = output_tensor.shape[0] if packed_seq_params is None else packed_seq_params.num_samples output = self._get_input_tensor(output_tensor, is_KL, is_ref, length, dim=1) - return self.get_logps(output, labels, packed_seq_params, num_samples) + return self.get_logps(output, data['labels'], data.get('packed_seq_params')) def _get_kto_length(self, data: Dict[str, Any]) -> int: if 'packed_seq_params' in data: @@ -150,15 +147,14 @@ def forward_step(self, data_iterator, model): res = torch.concat([output_tensor, ref_output_tensor], dim=dim) return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) - def _prepare_batch(self, data, vp_stage=None, num_samples=None): + def _prepare_batch(self, data, vp_stage=None): res = [] - num_samples = data.pop('num_samples') for key in ['completion_', 'KL_completion_']: _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} if not self.args.calculate_KL and key == 'KL_completion_': _data = {} else: - _data = super()._prepare_batch(_data, vp_stage, num_samples) + _data = super()._prepare_batch(_data, vp_stage) res.append(_data) res[0]['label'] = data['label'] return res diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 0388a15c7f..0d0e4f0212 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -70,15 +70,16 @@ def null_ref_context(self): for m in self.peft_models: m.set_adapter('default') - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_token=False): + def get_logps(self, output_tensor, labels, packed_seq_params, per_token=False): args = self.args per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask + num_samples = packed_seq_params.seq_lens.shape[0] if per_token: if args.context_parallel_size > 1: per_token_logps = reconstruct_tensor_cp(args.context_parallel_size, per_token_logps, packed_seq_params, - num_samples or packed_seq_params.num_samples) + num_samples) return per_token_logps if args.padding_free: diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 66f229abef..6415f8ea1a 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -423,7 +423,7 @@ def compute_per_token_logps_fn(model, args, data_iterator, temperature=1.0, no_g packed_seq_params = data.get('packed_seq_params') if packed_seq_params is not None: - num_samples = len(packed_seq_params.seq_lens) + num_samples = packed_seq_params.seq_lens.shape[0] else: input_ids = data.get('input_ids') num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] From 3dc00cd8dfe8e94072b3dcfb69b45fdac1bbbfaf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 15:52:13 +0800 Subject: [PATCH 3/9] update --- swift/megatron/trainers/grpo_trainer.py | 2 +- swift/megatron/trainers/reward_trainer.py | 4 ++-- swift/megatron/trainers/trainer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 3ea5478fb3..13ea7a81c9 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1186,7 +1186,7 @@ def forward_step(self, data_iterator, model): logits_packed, labels, compute_entropy=self.compute_entropy) if args.context_parallel_size > 1: - num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size + num_samples = packed_seq_params.seq_lens.shape[0] if args.padding_free else micro_batch_size cp_size = args.context_parallel_size per_token_logps_packed = reconstruct_tensor_cp(cp_size, per_token_logps_packed, packed_seq_params, num_samples) diff --git a/swift/megatron/trainers/reward_trainer.py b/swift/megatron/trainers/reward_trainer.py index 97c38f972a..4398ec84fc 100644 --- a/swift/megatron/trainers/reward_trainer.py +++ b/swift/megatron/trainers/reward_trainer.py @@ -14,8 +14,8 @@ class MegatronRewardTrainer(MegatronRLHFTrainer): def loss_func(self, output_tensor, *, data): packed_seq_params = data.get('packed_seq_params') margin = data.pop('margin', None) - num_samples = output_tensor.shape[0] // 2 if packed_seq_params is None else packed_seq_params.num_samples - rewards = self.get_last_tokens(output_tensor, packed_seq_params, data.get('attention_mask'), 2 * num_samples) + num_samples = output_tensor.shape[0] // 2 if packed_seq_params is None else packed_seq_params.seq_lens.shape[0] // 2 + rewards = self.get_last_tokens(output_tensor, packed_seq_params, data.get('attention_mask')) rewards_chosen, rewards_rejected = torch.split(rewards, num_samples, dim=0) if margin is not None: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean() diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index e044798220..e610e9225d 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -78,7 +78,7 @@ def _compute_channel_loss(self, losses, loss_mask, channels, packed_seq_params=N args = self.args metrics = defaultdict(lambda: torch.tensor([0.0, 0.0], dtype=torch.float32, device=torch.cuda.current_device())) if args.padding_free: - num_samples = packed_seq_params.num_samples + num_samples = packed_seq_params.seq_lens.shape[0] cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size for i in range(cu_seqlens.shape[0] - 1): channel = None if channels is None else channels[i] From d27c40b4c3bc73d55eba8d9dcaabd397b2c9af53 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 15:58:22 +0800 Subject: [PATCH 4/9] update --- swift/megatron/trainers/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 6415f8ea1a..d0ea1178ba 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -359,7 +359,13 @@ def prepare_batch(args, data, vp_stage=None): Extracted from BaseMegatronTrainer._prepare_batch for reuse in ray workers. """ batch = get_batch_on_this_pp_rank(args, data, vp_stage=vp_stage) - seq_lens = batch.pop('seq_lens') + seq_lens = batch.pop('seq_lens', None) + num_samples = batch.pop('num_samples', None) + if seq_lens is not None: + if num_samples is not None: + assert num_samples == len(seq_lens), ( + f"'num_samples' ({num_samples}) is inconsistent with len(seq_lens) ({len(seq_lens)}).") + num_samples = len(seq_lens) text_position_ids = batch.pop('text_position_ids', None) if text_position_ids is None: text_position_ids = batch.get('position_ids') @@ -373,6 +379,8 @@ def prepare_batch(args, data, vp_stage=None): batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) if seq_lens is not None: batch['packed_seq_params'].seq_lens = torch.tensor(seq_lens, device=text_position_ids.device) + if num_samples is not None: + batch['packed_seq_params'].num_samples = num_samples batch = get_batch_on_this_cp_rank(args, batch) return batch From f6c74dd07d140be00cfcb095bb91cc1dc6fddffd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 16:12:37 +0800 Subject: [PATCH 5/9] update --- swift/megatron/trainers/gkd_trainer.py | 1 - swift/megatron/trainers/reward_trainer.py | 4 +- swift/megatron/utils/convert_utils.py | 2 +- tests/megatron/test_opsd.py | 45 +++++++++++++++++++++++ 4 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 tests/megatron/test_opsd.py diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 6ee5a4885c..81c11f8da3 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -171,7 +171,6 @@ def _encode_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: padding_to = get_padding_to(args) encoded_batch = to_device(template.data_collator(encoded_list, padding_to=padding_to), self.device) - encoded_batch['num_samples'] = len(batch) return encoded_batch def _get_random_num(self) -> float: diff --git a/swift/megatron/trainers/reward_trainer.py b/swift/megatron/trainers/reward_trainer.py index 4398ec84fc..1f7d47d438 100644 --- a/swift/megatron/trainers/reward_trainer.py +++ b/swift/megatron/trainers/reward_trainer.py @@ -14,9 +14,9 @@ class MegatronRewardTrainer(MegatronRLHFTrainer): def loss_func(self, output_tensor, *, data): packed_seq_params = data.get('packed_seq_params') margin = data.pop('margin', None) - num_samples = output_tensor.shape[0] // 2 if packed_seq_params is None else packed_seq_params.seq_lens.shape[0] // 2 + num_samples = output_tensor.shape[0] if packed_seq_params is None else packed_seq_params.seq_lens.shape[0] rewards = self.get_last_tokens(output_tensor, packed_seq_params, data.get('attention_mask')) - rewards_chosen, rewards_rejected = torch.split(rewards, num_samples, dim=0) + rewards_chosen, rewards_rejected = torch.split(rewards, num_samples // 2, dim=0) if margin is not None: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean() else: diff --git a/swift/megatron/utils/convert_utils.py b/swift/megatron/utils/convert_utils.py index 19fa5e0c53..c8d98b0b7b 100644 --- a/swift/megatron/utils/convert_utils.py +++ b/swift/megatron/utils/convert_utils.py @@ -244,7 +244,7 @@ def test_convert_precision(args, hf_model, mg_model, template, test_convert_dtyp mg_inputs['packed_seq_params'] = get_packed_seq_params(text_position_ids) mg_language_model.config.fp8 = None # compat fp8 mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) - for key in ['labels', 'num_samples', 'attention_mask_2d']: + for key in ['labels', 'seq_lens', 'attention_mask_2d']: mg_inputs.pop(key, None) mg_inputs = get_batch_on_this_cp_rank(args, mg_inputs) _param = next(mg_language_model.parameters()) diff --git a/tests/megatron/test_opsd.py b/tests/megatron/test_opsd.py new file mode 100644 index 0000000000..40932f8451 --- /dev/null +++ b/tests/megatron/test_opsd.py @@ -0,0 +1,45 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' +os.environ['ASCEND_RT_VISIBLE_DEVICES'] = '0,1' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + +if __name__ == '__main__': + from swift.megatron import MegatronRLHFArguments, megatron_rlhf_main + megatron_rlhf_main( + MegatronRLHFArguments( + rlhf_type='gkd', + model='Qwen/Qwen3-4B', + teacher_model='Qwen/Qwen3-4B', + external_plugins=['examples/train/rlhf/opsd/opsd_plugin.py'], + dataset=['open-r1/OpenThoughts-114k-math'], + use_vllm=True, + vllm_mode='colocate', + vllm_gpu_memory_utilization=0.6, + vllm_max_model_len=10240, + tuner_type='lora', + lora_rank=64, + lora_alpha=128, + sleep_level=1, + lmbda=1.0, + beta=0.5, + temperature=1.2, + sft_alpha=0, + torch_dtype='bfloat16', + micro_batch_size=2, + global_batch_size=32, + train_iters=1000, + lr=2e-5, + save_steps=100, + save_total_limit=10, + logging_steps=1, + max_length=8192, + max_completion_length=2048, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + attention_backend='flash', + recompute_granularity='selective', + finetune=True, + no_save_optim=True, + no_save_rng=True, + )) From 7e6570d60707fc30d9329e3f9154a29ab0f846ed Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 16:23:27 +0800 Subject: [PATCH 6/9] update --- swift/megatron/trainers/utils.py | 1 + tests/megatron/test_grpo.py | 2 +- tests/megatron/test_kto.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index d0ea1178ba..e17f115f9f 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -360,6 +360,7 @@ def prepare_batch(args, data, vp_stage=None): """ batch = get_batch_on_this_pp_rank(args, data, vp_stage=vp_stage) seq_lens = batch.pop('seq_lens', None) + # Consider compatibility and security. num_samples = batch.pop('num_samples', None) if seq_lens is not None: if num_samples is not None: diff --git a/tests/megatron/test_grpo.py b/tests/megatron/test_grpo.py index e965990112..8a4fd94dbf 100644 --- a/tests/megatron/test_grpo.py +++ b/tests/megatron/test_grpo.py @@ -9,7 +9,7 @@ megatron_rlhf_main( MegatronRLHFArguments( rlhf_type='grpo', - model='Qwen/Qwen2.5-VL-3B-Instruct', + model='Qwen/Qwen3.5-4B', save_safetensors=True, context_parallel_size=1, tuner_type='lora', diff --git a/tests/megatron/test_kto.py b/tests/megatron/test_kto.py index b2508831dc..669cfb488f 100644 --- a/tests/megatron/test_kto.py +++ b/tests/megatron/test_kto.py @@ -8,7 +8,7 @@ def test_kto(): from swift.megatron import MegatronRLHFArguments, megatron_rlhf_main megatron_rlhf_main( MegatronRLHFArguments( - mcore_model='Qwen2.5-7B-Instruct-mcore', + model='Qwen/Qwen2.5-7B-Instruct', rlhf_type='kto', tuner_type='lora', load_from_cache_file=True, From e3f9b877f4f192fd5bb07eab3788360a065656b6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 16:47:34 +0800 Subject: [PATCH 7/9] fix --- tests/megatron/test_grpo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/megatron/test_grpo.py b/tests/megatron/test_grpo.py index 8a4fd94dbf..4da3f68df7 100644 --- a/tests/megatron/test_grpo.py +++ b/tests/megatron/test_grpo.py @@ -17,6 +17,7 @@ dataset=['AI-ModelScope/clevr_cogen_a_train#10000'], num_train_epochs=1, global_batch_size=128, + vllm_mm_processor_cache_gb=0, micro_batch_size=4, steps_per_generation=4, num_generations=8, From 4bc24744c5c3573c751f3502222af9ce71c33a84 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 17:30:22 +0800 Subject: [PATCH 8/9] fix --- swift/megatron/trainers/rlhf_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 0d0e4f0212..5cbd83a36f 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -75,7 +75,7 @@ def get_logps(self, output_tensor, labels, packed_seq_params, per_token=False): per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.seq_lens.shape[0] + num_samples = packed_seq_params.seq_lens.shape[0] if packed_seq_params is not None else labels.shape[0] if per_token: if args.context_parallel_size > 1: per_token_logps = reconstruct_tensor_cp(args.context_parallel_size, per_token_logps, packed_seq_params, From a6ac8a92d4d1906470d30c18a33903d3d87d0ccc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 15 Jun 2026 19:27:24 +0800 Subject: [PATCH 9/9] fix --- swift/megatron/trainers/gkd_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 81c11f8da3..6ee5a4885c 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -171,6 +171,7 @@ def _encode_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: padding_to = get_padding_to(args) encoded_batch = to_device(template.data_collator(encoded_list, padding_to=padding_to), self.device) + encoded_batch['num_samples'] = len(batch) return encoded_batch def _get_random_num(self) -> float: