Skip to content

Commit f15acaf

Browse files
authored
[bridge] Support GPTBridge callback (#8)
1 parent 6f64a27 commit f15acaf

1 file changed

Lines changed: 46 additions & 13 deletions

File tree

src/mcore_bridge/bridge/gpt_bridge.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import torch
66
import torch.distributed as dist
77
import torch.nn.functional as F
8-
import transformers
98
from megatron.core import mpu
109
from packaging import version
1110
from peft import PeftModel
1211
from peft.utils import ModulesToSaveWrapper
1312
from tqdm import tqdm
14-
from typing import List, Optional, Union
13+
from typing import Callable, List, Optional, Union
1514

1615
from mcore_bridge.tuners import LoraParallelLinear
1716
from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr,
@@ -66,7 +65,6 @@ def __init__(self, config):
6665
self.pp_group = mpu.get_pipeline_model_parallel_group()
6766
self.etp_group = mpu.get_expert_tensor_parallel_group()
6867
self.ep_group = mpu.get_expert_model_parallel_group()
69-
self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev')
7068
self.tp_rank = mpu.get_tensor_model_parallel_rank()
7169
self.pp_rank = mpu.get_pipeline_model_parallel_rank()
7270
self.etp_rank = mpu.get_expert_tensor_parallel_rank()
@@ -1615,7 +1613,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
16151613
hf_state_dict.update(origin_hf_state_dict)
16161614
return hf_state_dict
16171615

1618-
def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, adapter_name: str = 'default'):
1616+
def load_weights(
1617+
self,
1618+
mg_models,
1619+
hf_model_dir: str,
1620+
peft_format: bool = False,
1621+
adapter_name: str = 'default',
1622+
converter: Optional[Callable] = None,
1623+
):
16191624
"""Load weights from safetensors (HuggingFace) format into Megatron model.
16201625
16211626
Args:
@@ -1624,24 +1629,38 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False,
16241629
peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False.
16251630
If True, loads LoRA delta weights. If False, loads the full model weights.
16261631
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1632+
converter: Used to perform key-value conversion on the newly loaded state_dict.
16271633
"""
16281634
self._peft_format = peft_format
16291635
self._adapter_name = adapter_name
16301636
mg_models = unwrap_model(mg_models)
16311637
self._disable_tqdm = False
16321638
with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, peft_format=peft_format) as loader:
16331639
state_dict = loader.get_state_dict()
1640+
if converter:
1641+
new_state_dict = {}
1642+
for k, v in state_dict.items():
1643+
kv = converter(k, v)
1644+
if kv is None:
1645+
continue
1646+
k, v = kv
1647+
new_state_dict[k] = v
1648+
state_dict = new_state_dict
16341649
hf_prefix = 'base_model.model.' if peft_format else ''
16351650
for mg_model in mg_models:
16361651
list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: '))
16371652

1638-
def export_weights(self,
1639-
mg_models,
1640-
target_device=None,
1641-
only_master_rank: bool = False,
1642-
peft_format: bool = False,
1643-
tqdm_desc: str = 'Exporting: ',
1644-
disable_tqdm: bool = True):
1653+
def export_weights(
1654+
self,
1655+
mg_models,
1656+
target_device=None,
1657+
only_master_rank: bool = False,
1658+
peft_format: bool = False,
1659+
adapter_name: str = 'default',
1660+
converter: Optional[Callable] = None,
1661+
tqdm_desc: str = 'Exporting: ',
1662+
disable_tqdm: bool = True,
1663+
):
16451664
"""Export Megatron model weights to safetensors (HuggingFace) format as a generator.
16461665
16471666
This method yields weight tensors one by one for streaming save operations or RL weight synchronization,
@@ -1654,6 +1673,8 @@ def export_weights(self,
16541673
peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False.
16551674
- If True, exports only LoRA delta weights. If False, exports the complete model weights
16561675
(e.g., after merge-lora or full-parameter fine-tuning).
1676+
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1677+
converter: Used to perform key-value conversion on the newly exported state_dict.
16571678
tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '.
16581679
disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True.
16591680
@@ -1663,8 +1684,8 @@ def export_weights(self,
16631684
self._target_device = target_device
16641685
self._only_master_rank = only_master_rank
16651686
self._peft_format = peft_format
1687+
self._adapter_name = adapter_name
16661688
self._disable_tqdm = disable_tqdm
1667-
self._adapter_name = 'default'
16681689
self._peft_target_modules = set()
16691690
self._peft_modules_to_save = set()
16701691
hf_prefix = 'base_model.model.' if peft_format else ''
@@ -1674,13 +1695,21 @@ def export_weights(self,
16741695
mg_models[i] = mg_model.model
16751696
self.config = mg_models[0].config
16761697
with torch.no_grad():
1677-
yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc)
1698+
for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc):
1699+
if converter:
1700+
kv = converter(k, v)
1701+
if kv is None:
1702+
continue
1703+
k, v = kv
1704+
yield k, v
16781705

16791706
def save_weights(
16801707
self,
16811708
mg_models,
16821709
output_dir: str,
16831710
peft_format: bool = False,
1711+
adapter_name: str = 'default',
1712+
converter: Optional[Callable] = None,
16841713
max_shard_size: str = '5GB',
16851714
) -> None:
16861715
"""Save Megatron model checkpoint in safetensors (HuggingFace) format.
@@ -1695,6 +1724,8 @@ def save_weights(
16951724
peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False.
16961725
If True, saves LoRA delta weights. If False, saves the complete model weights
16971726
(e.g., after merge-lora or full-parameter fine-tuning).
1727+
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1728+
converter: Used to perform key-value conversion on the newly exported state_dict.
16981729
max_shard_size: Maximum size of a single storage file, default is '5GB'.
16991730
"""
17001731
gc_collect()
@@ -1705,6 +1736,8 @@ def save_weights(
17051736
target_device='cpu',
17061737
only_master_rank=True,
17071738
peft_format=peft_format,
1739+
adapter_name=adapter_name,
1740+
converter=converter,
17081741
tqdm_desc='Saving: ',
17091742
disable_tqdm=False):
17101743
saver.add_tensor(k, v)

0 commit comments

Comments
 (0)