Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
import transformers
from megatron.core import mpu
from packaging import version
from peft import PeftModel
from peft.utils import ModulesToSaveWrapper
from tqdm import tqdm
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

from mcore_bridge.tuners import LoraParallelLinear
from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr,
Expand Down Expand Up @@ -66,7 +65,6 @@ def __init__(self, config):
self.pp_group = mpu.get_pipeline_model_parallel_group()
self.etp_group = mpu.get_expert_tensor_parallel_group()
self.ep_group = mpu.get_expert_model_parallel_group()
self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev')
self.tp_rank = mpu.get_tensor_model_parallel_rank()
self.pp_rank = mpu.get_pipeline_model_parallel_rank()
self.etp_rank = mpu.get_expert_tensor_parallel_rank()
Expand Down Expand Up @@ -1615,7 +1613,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
hf_state_dict.update(origin_hf_state_dict)
return hf_state_dict

def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, adapter_name: str = 'default'):
def load_weights(
self,
mg_models,
hf_model_dir: str,
peft_format: bool = False,
adapter_name: str = 'default',
converter: Optional[Callable] = None,
):
"""Load weights from safetensors (HuggingFace) format into Megatron model.

Args:
Expand All @@ -1624,24 +1629,38 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False,
peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False.
If True, loads LoRA delta weights. If False, loads the full model weights.
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
converter: Used to perform key-value conversion on the newly loaded state_dict.
"""
self._peft_format = peft_format
self._adapter_name = adapter_name
mg_models = unwrap_model(mg_models)
self._disable_tqdm = False
with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, peft_format=peft_format) as loader:
state_dict = loader.get_state_dict()
if converter:
new_state_dict = {}
for k, v in state_dict.items():
kv = converter(k, v)
if kv is None:
continue
k, v = kv
new_state_dict[k] = v
state_dict = new_state_dict
Comment on lines +1640 to +1648

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The bridge's internal conversion logic (e.g., in _set_state_dict) explicitly calls .load() on every value in the state dict. If the user-provided converter returns a raw torch.Tensor instead of a LazyTensor (which happens if they call .load() inside the converter to modify the weight), the subsequent code will fail with an AttributeError. You should ensure that any value returned by the converter that doesn't have a .load() method is wrapped appropriately to maintain compatibility with the rest of the bridge.

Suggested change
if converter:
new_state_dict = {}
for k, v in state_dict.items():
kv = converter(k, v)
if kv is None:
continue
k, v = kv
new_state_dict[k] = v
state_dict = new_state_dict
if converter:
from mcore_bridge.utils.safetensors import LazyTensor
new_state_dict = {}
for k, v in state_dict.items():
kv = converter(k, v)
if kv is None:
continue
k, v = kv
if not hasattr(v, 'load'):
v = LazyTensor(tensor=v)
new_state_dict[k] = v
state_dict = new_state_dict

hf_prefix = 'base_model.model.' if peft_format else ''
for mg_model in mg_models:
list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: '))

def export_weights(self,
mg_models,
target_device=None,
only_master_rank: bool = False,
peft_format: bool = False,
tqdm_desc: str = 'Exporting: ',
disable_tqdm: bool = True):
def export_weights(
self,
mg_models,
target_device=None,
only_master_rank: bool = False,
peft_format: bool = False,
adapter_name: str = 'default',
converter: Optional[Callable] = None,
tqdm_desc: str = 'Exporting: ',
disable_tqdm: bool = True,
):
"""Export Megatron model weights to safetensors (HuggingFace) format as a generator.

This method yields weight tensors one by one for streaming save operations or RL weight synchronization,
Expand All @@ -1654,6 +1673,8 @@ def export_weights(self,
peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False.
- If True, exports only LoRA delta weights. If False, exports the complete model weights
(e.g., after merge-lora or full-parameter fine-tuning).
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
converter: Used to perform key-value conversion on the newly exported state_dict.
tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '.
disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True.

Expand All @@ -1663,8 +1684,8 @@ def export_weights(self,
self._target_device = target_device
self._only_master_rank = only_master_rank
self._peft_format = peft_format
self._adapter_name = adapter_name
self._disable_tqdm = disable_tqdm
self._adapter_name = 'default'
self._peft_target_modules = set()
self._peft_modules_to_save = set()
hf_prefix = 'base_model.model.' if peft_format else ''
Expand All @@ -1674,13 +1695,21 @@ def export_weights(self,
mg_models[i] = mg_model.model
self.config = mg_models[0].config
with torch.no_grad():
yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc)
for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc):
if converter:
kv = converter(k, v)
if kv is None:
continue
k, v = kv
yield k, v

def save_weights(
self,
mg_models,
output_dir: str,
peft_format: bool = False,
adapter_name: str = 'default',
converter: Optional[Callable] = None,
max_shard_size: str = '5GB',
) -> None:
"""Save Megatron model checkpoint in safetensors (HuggingFace) format.
Expand All @@ -1695,6 +1724,8 @@ def save_weights(
peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False.
If True, saves LoRA delta weights. If False, saves the complete model weights
(e.g., after merge-lora or full-parameter fine-tuning).
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
converter: Used to perform key-value conversion on the newly exported state_dict.
max_shard_size: Maximum size of a single storage file, default is '5GB'.
"""
gc_collect()
Expand All @@ -1705,6 +1736,8 @@ def save_weights(
target_device='cpu',
only_master_rank=True,
peft_format=peft_format,
adapter_name=adapter_name,
converter=converter,
tqdm_desc='Saving: ',
disable_tqdm=False):
saver.add_tensor(k, v)
Expand Down
Loading