Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class RLHFMegatronArgumentsMixin:
reference_free: bool = False
label_smoothing: float = 0.
f_divergence_type: str = 'reverse_kl'
loss_type: Optional[str] = None

# kto
desirable_weight: float = 1.
Expand Down Expand Up @@ -321,6 +320,8 @@ def __post_init__(self):

@dataclass
class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
loss_type: Optional[str] = None # rlhf / plugins

check_model: bool = True
padded_vocab_size: Optional[int] = None
initialize_embedding: bool = False
Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class LLMMegatronModelType:
glm4 = 'glm4'
minimax_m2 = 'minimax_m2'

qwen3_emb = 'qwen3_emb'


class MLLMMegatronModelType:
qwen2_vl = 'qwen2_vl'
Expand Down
14 changes: 11 additions & 3 deletions swift/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,16 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
elif key in {'linear_fc1'}:
return 1

def _split_tp(self, hf_weight, tp_dim, is_expert):
def _split_tp(self, hf_weight, tp_dim, is_expert, is_embedding: bool):
tp_size = self.etp_size if is_expert else self.tp_size
tp_rank = self.etp_rank if is_expert else self.tp_rank
if is_embedding:
padding_size = math.ceil(hf_weight.shape[0] / tp_size) * tp_size - hf_weight.shape[0]
if padding_size > 0:
new_size = hf_weight.shape[0] + padding_size
logger.warning(
f'Padding embedding from {hf_weight.shape[0]} to {new_size} (padding size: {padding_size})')
hf_weight = F.pad(hf_weight, (0, 0, 0, padding_size))
if tp_dim is not None and tp_size > 1:
tensor = hf_weight.chunk(tp_size, dim=tp_dim)[tp_rank]
else:
Expand All @@ -171,12 +178,13 @@ def _set_weight(
):
# tp/etp
tp_dim = self._get_tp_split_dim(mg_key)
tensor = self._split_tp(hf_weight, tp_dim, is_expert)
is_embedding = mg_key in {'embedding.word_embeddings.weight', 'output_layer.weight'}
tensor = self._split_tp(hf_weight, tp_dim, is_expert, is_embedding=is_embedding)
del hf_weight
if not isinstance(mg_param, (list, tuple)):
mg_param = [mg_param]
if hf_scale_inv is not None:
hf_scale_inv = self._split_tp(hf_scale_inv, tp_dim, is_expert)
hf_scale_inv = self._split_tp(hf_scale_inv, tp_dim, is_expert, is_embedding=is_embedding)
hf_scale_inv = hf_scale_inv.chunk(len(mg_param), dim=0)
if offset:
assert hf_scale_inv is None, f'mg_key: {mg_key}'
Expand Down
11 changes: 9 additions & 2 deletions swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(
parallel_mode=None,
skip_weight_param_allocation=False,
)
elif args.task_type == 'embedding' and self.post_process:
self.output_layer = None

if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion:
config.apply_rope_fusion = False
Expand Down Expand Up @@ -447,8 +449,13 @@ def _postprocess(
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)

logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
if args.task_type == 'embedding':
logits = hidden_states
if args.sequence_parallel and args.tensor_model_parallel_size > 1:
logits = gather_from_sequence_parallel_region(logits)
else:
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)

# Restore sequence parallel execution to the output layer if necessary.
if sequence_parallel_override:
Expand Down
3 changes: 2 additions & 1 deletion swift/megatron/model/gpts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from swift.model import ModelType
from ..constant import MegatronModelType
from ..register import MegatronModelMeta, register_megatron_model
from . import glm4, minimax_m2, qwen3_next
from . import glm4, minimax_m2, qwen3_emb, qwen3_next

register_megatron_model(
MegatronModelMeta(
Expand All @@ -14,6 +14,7 @@
ModelType.yi,
ModelType.openbuddy_llama,
ModelType.qwen3,
ModelType.qwen3_reranker,
ModelType.qwen2_moe,
ModelType.qwen3_moe,
ModelType.internlm3,
Expand Down
30 changes: 30 additions & 0 deletions swift/megatron/model/gpts/qwen3_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) ModelScope Contributors. All rights reserved.

import megatron.core
from packaging import version

from swift.model import ModelType
from ..constant import MegatronModelType
from ..gpt_bridge import GPTBridge
from ..register import MegatronModelMeta, register_megatron_model


class Qwen3EmbBridge(GPTBridge):

def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
res = super()._convert_hf_state_dict(hf_state_dict, to_mcore)
if to_mcore:
res = self._add_prefix(res, 'model.')
elif not to_mcore:
res = self._remove_prefix(res, 'model.')
return res


register_megatron_model(
MegatronModelMeta(
MegatronModelType.qwen3_emb,
[
ModelType.qwen3_emb,
],
bridge_cls=Qwen3EmbBridge,
))
4 changes: 3 additions & 1 deletion swift/megatron/model/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import math
from typing import TYPE_CHECKING, Optional, Union

import megatron.core
Expand Down Expand Up @@ -149,7 +150,8 @@ def oom_observer(device, alloc, device_alloc, device_free):
model = megatron_model_meta.model_cls(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
vocab_size=math.ceil(args.padded_vocab_size / args.tensor_model_parallel_size)
* args.tensor_model_parallel_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
Expand Down
10 changes: 8 additions & 2 deletions swift/megatron/pipelines/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.utils import is_torch_npu_available

from swift.megatron.arguments import MegatronSftArguments
from swift.megatron.trainers import MegatronTrainer
from swift.megatron.trainers import MegatronEmbeddingTrainer, MegatronRerankerTrainer, MegatronTrainer
from swift.megatron.utils import get_padding_to
from swift.pipelines import SwiftSft
from swift.template import TEMPLATE_MAPPING
Expand All @@ -29,7 +29,13 @@ class MegatronSft(SwiftSft):
args: args_class

def prepare_trainer(self):
return MegatronTrainer(self.args, self.template)
args = self.args
if args.task_type == 'embedding':
return MegatronEmbeddingTrainer(self.args, self.template)
elif args.task_type in {'reranker', 'generative_reranker'}:
return MegatronRerankerTrainer(self.args, self.template)
else:
return MegatronTrainer(self.args, self.template)

def __init__(self, args: Optional[Union[List[str], MegatronSftArguments]] = None) -> None:
self.train_msg = {}
Expand Down
4 changes: 4 additions & 0 deletions swift/megatron/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .kto_trainer import MegatronKTOTrainer
from .reward_trainer import MegatronRewardTrainer
from .rollout_mixin import MegatronRolloutMixin
from .embedding_trainer import MegatronEmbeddingTrainer
from .reranker_trainer import MegatronRerankerTrainer
from .trainer import MegatronTrainer
else:
_import_structure = {
Expand All @@ -19,6 +21,8 @@
'kto_trainer': ['MegatronKTOTrainer'],
'reward_trainer': ['MegatronRewardTrainer'],
'rollout_mixin': ['MegatronRolloutMixin'],
'embedding_trainer': ['MegatronEmbeddingTrainer'],
'reranker_trainer': ['MegatronRerankerTrainer'],
'trainer': ['MegatronTrainer'],
}
import sys
Expand Down
46 changes: 46 additions & 0 deletions swift/megatron/trainers/embedding_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from functools import partial

import torch.nn
from megatron.training import get_args, get_timers

from swift.loss import loss_map
from swift.utils import get_logger
from .base import BaseMegatronTrainer

logger = get_logger()


class MegatronEmbeddingTrainer(BaseMegatronTrainer):

def __init__(self, args, template):
super().__init__(args, template)
if not args.padding_free:
raise ValueError('Currently, task_type embedding only supports padding_free.')
self._loss_func = loss_map[self.args.loss_type](args, self)

def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None):
args = self.args
last_hidden_state = self.get_last_tokens(output_tensor, packed_seq_params)
loss = self._loss_func({'last_hidden_state': last_hidden_state}, labels)
metric = {'loss': loss.detach().clone()}
metric = self._all_reduce_metric(metric)
return loss, metric

def forward_step(self, data_iterator, model):
timers = get_timers()

# Get the batch.
vp_stage = model.module.module.vp_stage
timers('batch-generator', log_level=2).start()
with self.stimer(bdata=True):
data = self.get_batch(data_iterator, vp_stage)
timers('batch-generator').stop()
labels = data.get('labels')
if self.args.task_type == 'seq_cls':
data.pop('labels', None)
with self.stimer:
output_tensor = model(**data)
packed_seq_params = data.get('packed_seq_params')
loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params)
return output_tensor, loss_func
30 changes: 30 additions & 0 deletions swift/megatron/trainers/reranker_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch.nn
from megatron.training import get_args, get_timers

from swift.utils import get_logger
from .base import BaseMegatronTrainer

logger = get_logger()


class MegatronRerankerTrainer(BaseMegatronTrainer):

def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None):
pass

def forward_step(self, data_iterator, model):
timers = get_timers()

# Get the batch.
vp_stage = model.module.module.vp_stage
timers('batch-generator', log_level=2).start()
with self.stimer(bdata=True):
data = self.get_batch(data_iterator, vp_stage)
timers('batch-generator').stop()
labels = data.get('labels')
if self.args.task_type == 'seq_cls':
data.pop('labels', None)
with self.stimer:
output_tensor = model(**data)
packed_seq_params = data.get('packed_seq_params')
8 changes: 4 additions & 4 deletions swift/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def lm_head_forward(self, hidden_states):
return hidden_states

lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer']
llm_model = get_lm_head_model(module, model_meta=model_meta, lm_heads=lm_heads)
lm_head_model = get_lm_head_model(module, model_meta=model_meta, lm_heads=lm_heads)

found = False
for lm_head in lm_heads:
if hasattr(llm_model, lm_head):
getattr(llm_model, lm_head).forward = MethodType(lm_head_forward, getattr(llm_model, lm_head))
if hasattr(lm_head_model, lm_head):
getattr(lm_head_model, lm_head).forward = MethodType(lm_head_forward, getattr(lm_head_model, lm_head))
found = True
break

Expand All @@ -99,7 +99,7 @@ def _output_embedding_hook(module, args, kwargs, output):
'last_hidden_state': embeddings.contiguous(),
}

llm_model.register_forward_hook(_output_embedding_hook, with_kwargs=True)
lm_head_model.register_forward_hook(_output_embedding_hook, with_kwargs=True)


def patch_output_to_input_device(module: torch.nn.Module):
Expand Down
3 changes: 3 additions & 0 deletions swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
res = self._seq_cls_data_collator(batch, padding_to=padding_to)
elif self.task_type == 'embedding':
res = self._embedding_data_collator(batch, padding_to=padding_to)
num_samples = res.pop('num_samples')
elif self.task_type in {'reranker', 'generative_reranker'}:
res = self._reranker_data_collator(batch, padding_to=padding_to)
else:
Expand Down Expand Up @@ -1650,7 +1651,9 @@ def _embedding_data_collator(self,
for prefix in indexes:
new_batch += self._fetch_inputs_startswith([b], prefix)
labels.extend(b.get('labels', []))
num_samples = len(new_batch)
res = self._data_collator(new_batch, padding_to=padding_to)
res['num_samples'] = num_samples
if labels:
res['labels'] = torch.tensor(labels, dtype=torch.float32)
return res
Expand Down
8 changes: 4 additions & 4 deletions swift/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

if TYPE_CHECKING:
from .arguments import TrainArgumentsMixin, Seq2SeqTrainingArguments, TrainingArguments
from .embedding import EmbeddingTrainer
from .embedding_trainer import EmbeddingTrainer
from .mixin import DataLoaderMixin, SwiftMixin
from .reranker import RerankerTrainer
from .reranker_trainer import RerankerTrainer
from .seq2seq_trainer import Seq2SeqTrainer
from .trainer import Trainer
from .trainer_factory import TrainerFactory
Expand All @@ -16,9 +16,9 @@
else:
_import_structure = {
'arguments': ['TrainArgumentsMixin', 'Seq2SeqTrainingArguments', 'TrainingArguments'],
'embedding': ['EmbeddingTrainer'],
'embedding_trainer': ['EmbeddingTrainer'],
'mixin': ['DataLoaderMixin', 'SwiftMixin'],
'reranker': ['RerankerTrainer'],
'reranker_trainer': ['RerankerTrainer'],
'seq2seq_trainer': ['Seq2SeqTrainer'],
'trainer': ['Trainer'],
'trainer_factory': ['TrainerFactory'],
Expand Down
File renamed without changes.
File renamed without changes.
56 changes: 56 additions & 0 deletions tests/megatron/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'


def test_embedding():
from swift.megatron import megatron_sft_main, MegatronSftArguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import statement from swift.megatron import megatron_sft_main, MegatronSftArguments is repeated in both test_embedding (line 7) and test_reranker (line 31). To follow the DRY (Don't Repeat Yourself) principle and improve code clarity, it's recommended to move this import to the top of the file, outside of the test functions.

megatron_sft_main(
MegatronSftArguments(
model='Qwen/Qwen3-Embedding-0.6B',
task_type='embedding',
dataset=['sentence-transformers/stsb:positive'],
split_dataset_ratio=0.01,
tensor_model_parallel_size=2,
tuner_type='lora',
recompute_granularity='full',
recompute_method='uniform',
recompute_num_layers=1,
loss_type='infonce',
attn_impl='flash_attn',
max_length=2048,
eval_iters=5,
save_interval=5,
no_save_optim=True,
no_save_rng=True,
sequence_parallel=True,
finetune=True))


def test_reranker():
from swift.megatron import megatron_sft_main, MegatronSftArguments
megatron_sft_main(
MegatronSftArguments(
model='Qwen/Qwen3-Reranker-4B',
tuner_type='lora',
load_from_cache_file=True,
task_type='generative_reranker',
dataset=['MTEB/scidocs-reranking#10000'],
loss_type='pointwise_reranker',
split_dataset_ratio=0.01,
tensor_model_parallel_size=2,
recompute_granularity='full',
recompute_method='uniform',
recompute_num_layers=1,
train_iters=100,
eval_iters=5,
save_interval=5,
no_save_optim=True,
no_save_rng=True,
sequence_parallel=True,
finetune=True))


if __name__ == '__main__':
test_embedding()
Comment on lines +54 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_reranker function is defined but not called when the script is executed directly. To ensure both tests are run, you should add a call to test_reranker().

Suggested change
if __name__ == '__main__':
test_embedding()
if __name__ == '__main__':
test_embedding()
test_reranker()

# test_reranker()
Loading