diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 136e647..c33a572 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -5,13 +5,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import transformers from megatron.core import mpu from packaging import version from peft import PeftModel from peft.utils import ModulesToSaveWrapper from tqdm import tqdm -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union from mcore_bridge.tuners import LoraParallelLinear from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, @@ -66,7 +65,6 @@ def __init__(self, config): self.pp_group = mpu.get_pipeline_model_parallel_group() self.etp_group = mpu.get_expert_tensor_parallel_group() self.ep_group = mpu.get_expert_model_parallel_group() - self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev') self.tp_rank = mpu.get_tensor_model_parallel_rank() self.pp_rank = mpu.get_pipeline_model_parallel_rank() self.etp_rank = mpu.get_expert_tensor_parallel_rank() @@ -1615,7 +1613,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict.update(origin_hf_state_dict) return hf_state_dict - def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, adapter_name: str = 'default'): + def load_weights( + self, + mg_models, + hf_model_dir: str, + peft_format: bool = False, + adapter_name: str = 'default', + converter: Optional[Callable] = None, + ): """Load weights from safetensors (HuggingFace) format into Megatron model. Args: @@ -1624,6 +1629,7 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False. If True, loads LoRA delta weights. If False, loads the full model weights. adapter_name: Name of the adapter for PEFT models. Defaults to 'default'. + converter: Used to perform key-value conversion on the newly loaded state_dict. """ self._peft_format = peft_format self._adapter_name = adapter_name @@ -1631,17 +1637,30 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, self._disable_tqdm = False with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, peft_format=peft_format) as loader: state_dict = loader.get_state_dict() + if converter: + new_state_dict = {} + for k, v in state_dict.items(): + kv = converter(k, v) + if kv is None: + continue + k, v = kv + new_state_dict[k] = v + state_dict = new_state_dict hf_prefix = 'base_model.model.' if peft_format else '' for mg_model in mg_models: list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: ')) - def export_weights(self, - mg_models, - target_device=None, - only_master_rank: bool = False, - peft_format: bool = False, - tqdm_desc: str = 'Exporting: ', - disable_tqdm: bool = True): + def export_weights( + self, + mg_models, + target_device=None, + only_master_rank: bool = False, + peft_format: bool = False, + adapter_name: str = 'default', + converter: Optional[Callable] = None, + tqdm_desc: str = 'Exporting: ', + disable_tqdm: bool = True, + ): """Export Megatron model weights to safetensors (HuggingFace) format as a generator. This method yields weight tensors one by one for streaming save operations or RL weight synchronization, @@ -1654,6 +1673,8 @@ def export_weights(self, peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False. - If True, exports only LoRA delta weights. If False, exports the complete model weights (e.g., after merge-lora or full-parameter fine-tuning). + adapter_name: Name of the adapter for PEFT models. Defaults to 'default'. + converter: Used to perform key-value conversion on the newly exported state_dict. tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '. disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True. @@ -1663,8 +1684,8 @@ def export_weights(self, self._target_device = target_device self._only_master_rank = only_master_rank self._peft_format = peft_format + self._adapter_name = adapter_name self._disable_tqdm = disable_tqdm - self._adapter_name = 'default' self._peft_target_modules = set() self._peft_modules_to_save = set() hf_prefix = 'base_model.model.' if peft_format else '' @@ -1674,13 +1695,21 @@ def export_weights(self, mg_models[i] = mg_model.model self.config = mg_models[0].config with torch.no_grad(): - yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc) + for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc): + if converter: + kv = converter(k, v) + if kv is None: + continue + k, v = kv + yield k, v def save_weights( self, mg_models, output_dir: str, peft_format: bool = False, + adapter_name: str = 'default', + converter: Optional[Callable] = None, max_shard_size: str = '5GB', ) -> None: """Save Megatron model checkpoint in safetensors (HuggingFace) format. @@ -1695,6 +1724,8 @@ def save_weights( peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False. If True, saves LoRA delta weights. If False, saves the complete model weights (e.g., after merge-lora or full-parameter fine-tuning). + adapter_name: Name of the adapter for PEFT models. Defaults to 'default'. + converter: Used to perform key-value conversion on the newly exported state_dict. max_shard_size: Maximum size of a single storage file, default is '5GB'. """ gc_collect() @@ -1705,6 +1736,8 @@ def save_weights( target_device='cpu', only_master_rank=True, peft_format=peft_format, + adapter_name=adapter_name, + converter=converter, tqdm_desc='Saving: ', disable_tqdm=False): saver.add_tensor(k, v)