55import torch
66import torch .distributed as dist
77import torch .nn .functional as F
8- import transformers
98from megatron .core import mpu
109from packaging import version
1110from peft import PeftModel
1211from peft .utils import ModulesToSaveWrapper
1312from tqdm import tqdm
14- from typing import List , Optional , Union
13+ from typing import Callable , List , Optional , Union
1514
1615from mcore_bridge .tuners import LoraParallelLinear
1716from mcore_bridge .utils import (MxFp4Dequantizer , SafetensorLazyLoader , StreamingSafetensorSaver , deep_getattr ,
@@ -66,7 +65,6 @@ def __init__(self, config):
6665 self .pp_group = mpu .get_pipeline_model_parallel_group ()
6766 self .etp_group = mpu .get_expert_tensor_parallel_group ()
6867 self .ep_group = mpu .get_expert_model_parallel_group ()
69- self .is_transformers_5 = version .parse (transformers .__version__ ) >= version .parse ('5.0.0.dev' )
7068 self .tp_rank = mpu .get_tensor_model_parallel_rank ()
7169 self .pp_rank = mpu .get_pipeline_model_parallel_rank ()
7270 self .etp_rank = mpu .get_expert_tensor_parallel_rank ()
@@ -1615,7 +1613,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
16151613 hf_state_dict .update (origin_hf_state_dict )
16161614 return hf_state_dict
16171615
1618- def load_weights (self , mg_models , hf_model_dir : str , peft_format : bool = False , adapter_name : str = 'default' ):
1616+ def load_weights (
1617+ self ,
1618+ mg_models ,
1619+ hf_model_dir : str ,
1620+ peft_format : bool = False ,
1621+ adapter_name : str = 'default' ,
1622+ converter : Optional [Callable ] = None ,
1623+ ):
16191624 """Load weights from safetensors (HuggingFace) format into Megatron model.
16201625
16211626 Args:
@@ -1624,24 +1629,38 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False,
16241629 peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False.
16251630 If True, loads LoRA delta weights. If False, loads the full model weights.
16261631 adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1632+ converter: Used to perform key-value conversion on the newly loaded state_dict.
16271633 """
16281634 self ._peft_format = peft_format
16291635 self ._adapter_name = adapter_name
16301636 mg_models = unwrap_model (mg_models )
16311637 self ._disable_tqdm = False
16321638 with torch .no_grad (), SafetensorLazyLoader (hf_model_dir , peft_format = peft_format ) as loader :
16331639 state_dict = loader .get_state_dict ()
1640+ if converter :
1641+ new_state_dict = {}
1642+ for k , v in state_dict .items ():
1643+ kv = converter (k , v )
1644+ if kv is None :
1645+ continue
1646+ k , v = kv
1647+ new_state_dict [k ] = v
1648+ state_dict = new_state_dict
16341649 hf_prefix = 'base_model.model.' if peft_format else ''
16351650 for mg_model in mg_models :
16361651 list (self ._convert ([mg_model ], state_dict , hf_prefix , True , 'Loading: ' ))
16371652
1638- def export_weights (self ,
1639- mg_models ,
1640- target_device = None ,
1641- only_master_rank : bool = False ,
1642- peft_format : bool = False ,
1643- tqdm_desc : str = 'Exporting: ' ,
1644- disable_tqdm : bool = True ):
1653+ def export_weights (
1654+ self ,
1655+ mg_models ,
1656+ target_device = None ,
1657+ only_master_rank : bool = False ,
1658+ peft_format : bool = False ,
1659+ adapter_name : str = 'default' ,
1660+ converter : Optional [Callable ] = None ,
1661+ tqdm_desc : str = 'Exporting: ' ,
1662+ disable_tqdm : bool = True ,
1663+ ):
16451664 """Export Megatron model weights to safetensors (HuggingFace) format as a generator.
16461665
16471666 This method yields weight tensors one by one for streaming save operations or RL weight synchronization,
@@ -1654,6 +1673,8 @@ def export_weights(self,
16541673 peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False.
16551674 - If True, exports only LoRA delta weights. If False, exports the complete model weights
16561675 (e.g., after merge-lora or full-parameter fine-tuning).
1676+ adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1677+ converter: Used to perform key-value conversion on the newly exported state_dict.
16571678 tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '.
16581679 disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True.
16591680
@@ -1663,8 +1684,8 @@ def export_weights(self,
16631684 self ._target_device = target_device
16641685 self ._only_master_rank = only_master_rank
16651686 self ._peft_format = peft_format
1687+ self ._adapter_name = adapter_name
16661688 self ._disable_tqdm = disable_tqdm
1667- self ._adapter_name = 'default'
16681689 self ._peft_target_modules = set ()
16691690 self ._peft_modules_to_save = set ()
16701691 hf_prefix = 'base_model.model.' if peft_format else ''
@@ -1674,13 +1695,21 @@ def export_weights(self,
16741695 mg_models [i ] = mg_model .model
16751696 self .config = mg_models [0 ].config
16761697 with torch .no_grad ():
1677- yield from self ._convert (mg_models , {}, hf_prefix , False , tqdm_desc = tqdm_desc )
1698+ for k , v in self ._convert (mg_models , {}, hf_prefix , False , tqdm_desc = tqdm_desc ):
1699+ if converter :
1700+ kv = converter (k , v )
1701+ if kv is None :
1702+ continue
1703+ k , v = kv
1704+ yield k , v
16781705
16791706 def save_weights (
16801707 self ,
16811708 mg_models ,
16821709 output_dir : str ,
16831710 peft_format : bool = False ,
1711+ adapter_name : str = 'default' ,
1712+ converter : Optional [Callable ] = None ,
16841713 max_shard_size : str = '5GB' ,
16851714 ) -> None :
16861715 """Save Megatron model checkpoint in safetensors (HuggingFace) format.
@@ -1695,6 +1724,8 @@ def save_weights(
16951724 peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False.
16961725 If True, saves LoRA delta weights. If False, saves the complete model weights
16971726 (e.g., after merge-lora or full-parameter fine-tuning).
1727+ adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1728+ converter: Used to perform key-value conversion on the newly exported state_dict.
16981729 max_shard_size: Maximum size of a single storage file, default is '5GB'.
16991730 """
17001731 gc_collect ()
@@ -1705,6 +1736,8 @@ def save_weights(
17051736 target_device = 'cpu' ,
17061737 only_master_rank = True ,
17071738 peft_format = peft_format ,
1739+ adapter_name = adapter_name ,
1740+ converter = converter ,
17081741 tqdm_desc = 'Saving: ' ,
17091742 disable_tqdm = False ):
17101743 saver .add_tensor (k , v )
0 commit comments