-
Notifications
You must be signed in to change notification settings - Fork 384
Expand file tree
/
Copy pathauto_bridge.py
More file actions
1797 lines (1527 loc) · 78.8 KB
/
Copy pathauto_bridge.py
File metadata and controls
1797 lines (1527 loc) · 78.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import logging
from collections.abc import Callable
from contextlib import nullcontext
from functools import cached_property, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Iterable, List, Literal, Optional, Tuple, Type, TypeVar, Union
import torch
import torch.distributed as dist
import transformers
if TYPE_CHECKING:
from megatron.bridge.peft.base import PEFT
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig
from modelopt.torch.quantization.utils import is_quantized
from safetensors.torch import save_file
from transformers.configuration_utils import PretrainedConfig
from typing_extensions import Unpack
from megatron.bridge.models.conversion import model_bridge
from megatron.bridge.models.conversion.model_bridge import (
HFWeightTuple,
MegatronModelBridge,
WeightConversionTask,
)
from megatron.bridge.models.conversion.utils import get_causal_lm_class_name_via_auto_map
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.base import PreTrainedBase
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM, _ConfigOnlyPretrainedShim
from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource
from megatron.bridge.models.model_provider import GetModelKwargs, ModelParallelKwargs, ModelProviderMixin
logger = logging.getLogger(__name__)
MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule)
DataclassT = TypeVar("DataclassT")
# Supported HuggingFace architecture suffixes for causal generation models
SUPPORTED_HF_ARCHITECTURES: tuple[str, ...] = (
"ForCausalLM",
"ForConditionalGeneration",
"NemotronH_Nano_VL_V2",
"NemotronH_Nano_Omni_Reasoning_V3",
"Qwen2_5OmniModel",
"NemotronLabsDiffusionModel",
)
# Mapping from non-standard HF architecture names to their actual transformers class names.
# Some HF model configs report architecture names that don't follow the standard
# 'ForCausalLM'/'ForConditionalGeneration' convention and don't directly map to a
# transformers class. This dict resolves those aliases.
HF_ARCHITECTURE_ALIASES: dict[str, str] = {
"Qwen2_5OmniModel": "Qwen2_5OmniForConditionalGeneration",
}
MTP_CONFIG_FIELDS: tuple[str, ...] = ("num_nextn_predict_layers", "mtp_num_hidden_layers", "mtp_num_layers")
_MISSING = object()
def _get_config_field(config: Any, field: str) -> Any:
if isinstance(config, dict):
return config.get(field, _MISSING)
return getattr(config, "__dict__", {}).get(field, _MISSING)
def _config_disables_mtp(config: Any) -> bool:
"""Return True when a config object or dict explicitly disables MTP layers."""
if config is None:
return False
for field in MTP_CONFIG_FIELDS:
value = _get_config_field(config, field)
if value is _MISSING or value is None:
continue
return int(value) == 0
text_config = _get_config_field(config, "text_config")
return text_config is not _MISSING and _config_disables_mtp(text_config)
def _saved_config_disables_mtp(path: str | Path) -> bool:
"""Check the final config.json written for an HF export."""
import json
config_path = Path(path) / "config.json"
if not config_path.exists():
return False
with open(config_path) as f:
return _config_disables_mtp(json.load(f))
# Preformatted display string for error/help messages
SUPPORTED_HF_ARCHITECTURES_DISPLAY = " or ".join(f"'{s}'" for s in SUPPORTED_HF_ARCHITECTURES)
def _drop_readonly_config_properties(
config_dict: dict[str, object], config_type: Type[PretrainedConfig]
) -> dict[str, object]:
"""Remove read-only property names before constructing a HuggingFace config."""
readonly_properties = {
name
for cls in config_type.mro()
for name, attr in vars(cls).items()
if isinstance(attr, property) and attr.fset is None
}
if not readonly_properties:
return config_dict
return {key: value for key, value in config_dict.items() if key not in readonly_properties}
class AutoBridge(Generic[MegatronModelT]):
"""
Automatically select and instantiate the appropriate bridge for a model.
This unified bridge class combines automatic model detection with full bridge
functionality for converting models between HuggingFace and Megatron formats.
It handles the conversion of causal language models (e.g., GPT, Llama, Phi)
between HuggingFace's transformers library format and Megatron-Core's distributed
training format. It manages weight mapping, tensor parallelism distribution, and
configuration translation.
The bridge supports both directions of conversion:
- HuggingFace → Megatron: For training or inference with Megatron
- Megatron → HuggingFace: For saving trained models in HF format
Args:
hf_pretrained: Either a PreTrainedCausalLM instance with loaded model,
or a PretrainedConfig for configuration-only operations
Example:
>>> # Load and convert a model to Megatron format
>>> bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B")
>>> provider = bridge.to_megatron_provider()
>>> megatron_model = provider.provide_distributed_model(wrap_with_ddp=False)
>>> # Export a Megatron model back to HuggingFace format
>>> bridge.save_hf_pretrained(megatron_model, "./exported_model")
>>> # Convert weights with custom settings
>>> for name, weight in bridge.export_hf_weights(
... megatron_model,
... cpu=True
... ):
... print(f"Exported {name}: {weight.shape}")
>>> # Check if a model is supported before loading
>>> if AutoBridge.can_handle("microsoft/phi-2"):
... bridge = AutoBridge.from_hf_pretrained("microsoft/phi-2")
Note:
The bridge automatically detects the model architecture and applies
the appropriate weight mappings. Custom architectures require implementing
a MegatronModelBridge subclass.
"""
def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig):
if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)):
raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance")
self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained
# Data type for exporting weights
self.export_weight_dtype: Literal["bf16", "fp16", "fp8"] = "bf16"
self.hf_model_id: Optional[str] = None
trust_remote_code = getattr(hf_pretrained, "trust_remote_code", False)
self.trust_remote_code = trust_remote_code if isinstance(trust_remote_code, bool) else False
@classmethod
def list_supported_models(cls) -> list[str]:
"""
List all model architectures currently supported by the bridge system.
Returns:
List of supported HuggingFace model architecture names
"""
# Get all registered implementations from the dispatch system
supported = []
# Access the dispatch registry to find all registered types
if hasattr(model_bridge.get_model_bridge, "_exact_types"):
for arch_type in model_bridge.get_model_bridge._exact_types.keys():
# Support both type and string registrations
if isinstance(arch_type, str):
supported.append(arch_type)
elif hasattr(arch_type, "__name__"):
supported.append(arch_type.__name__)
return sorted(supported)
@classmethod
def supports(cls, config: Any) -> bool:
"""
Check if this bridge supports the given model configuration.
A model is supported if it has at least one architecture ending with one of the
suffixes listed in SUPPORTED_HF_ARCHITECTURES.
Args:
config: HuggingFace model config object
Returns:
True if this bridge can handle the model, False otherwise
"""
architectures = getattr(config, "architectures", [])
if not architectures:
return False
return any(arch.endswith(SUPPORTED_HF_ARCHITECTURES) for arch in architectures)
@classmethod
def from_auto_config(cls, megatron_path: str, hf_model_id: str, trust_remote_code: bool = False) -> "AutoBridge":
"""
Create a config-only AutoBridge by synthesizing an HF config from a Megatron checkpoint.
This method creates a bridge instace from a Megatron checkpoint and reference hf_model_id,
without loading any weights. This enables exporting of:
- Custom small models of popular architectures
- Models pruned from a larger teacher model
Args:
megatron_path: Directory path where the Megatron checkpoint is stored
hf_model_id: HuggingFace model ID or path to model directory
Examples: "meta-llama/Meta-Llama-3-8B", "./my_model"
trust_remote_code: Whether to trust remote code when loading config.
Defaults to False for security. Set to True only for models that
require custom modeling code from the repository.
Returns:
AutoBridge: Bridge instance configured for the architecture
Raises:
FileNotFoundError: If run_config.yaml is not found in the Megatron path
"""
from transformers import AutoConfig
from megatron.bridge.models.conversion.utils import conform_config_to_reference
from megatron.bridge.training.model_load_save import load_model_config
checkpoint_path = Path(megatron_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Megatron checkpoint not found: {checkpoint_path}")
# Look for configuration files to determine the model type
run_config = checkpoint_path / "run_config.yaml"
if not run_config.exists():
iter_dirs = [d for d in checkpoint_path.iterdir() if d.is_dir() and d.name.startswith("iter_")]
if iter_dirs:
latest_iter = max(iter_dirs, key=lambda d: int(d.name.replace("iter_", "")))
run_config = latest_iter / "run_config.yaml"
if not run_config.exists():
raise FileNotFoundError(
f"Could not find run_config.yaml in {checkpoint_path}. Ensure this is a valid Megatron checkpoint."
)
# 1. Load config from both sides
megatron_cfg, _ = load_model_config(str(run_config.parent))
if trust_remote_code:
logger.warning(
"Loading a model with trust_remote_code=True allows arbitrary code execution "
"from the model repository. Only use this with models you trust."
)
hf_cfg = AutoConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code)
# 2. Translate Megatron config -> HF, conforming to reference config
bridge = cls.from_hf_config(hf_cfg)
megatron_hf_cfg_dict = bridge._model_bridge.megatron_to_hf_config(megatron_cfg)
megatron_hf_cfg_dict = conform_config_to_reference(megatron_hf_cfg_dict, hf_cfg.to_dict())
megatron_hf_cfg_dict = _drop_readonly_config_properties(megatron_hf_cfg_dict, type(hf_cfg))
# 3. Build final bridge from the synthesized config
synthesized_config = type(hf_cfg)(**megatron_hf_cfg_dict)
bridge = cls.from_hf_config(synthesized_config)
bridge.hf_model_id = hf_model_id
bridge.trust_remote_code = trust_remote_code
return bridge
@classmethod
def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge":
"""
Create an AutoBridge from a HuggingFace configuration.
This method creates a bridge instance from just a model configuration,
without loading any weights. This is useful for:
- Creating Megatron models with random initialization
- Working with model architectures without downloading weights
- Testing and development scenarios
Args:
config: HuggingFace PretrainedConfig instance containing model
architecture information
Returns:
AutoBridge: Bridge instance configured for the architecture
Raises:
ValueError: If the configuration is not for a supported CausalLM model
Example:
>>> from transformers import AutoConfig
>>>
>>> # Load just the configuration
>>> config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B")
>>>
>>> # Create bridge from config (no weights)
>>> bridge = AutoBridge.from_hf_config(config)
>>>
>>> # Create Megatron model with random initialization
>>> provider = bridge.to_megatron_provider(load_weights=False)
>>> model = provider.provide_distributed_model(wrap_with_ddp=False)
>>> # Or use for architecture exploration
>>> transformer_config = bridge.transformer_config
>>> print(f"Hidden size: {transformer_config.hidden_size}")
>>> print(f"Num layers: {transformer_config.num_layers}")
See Also:
from_hf_pretrained: Create bridge with loaded weights
transformer_config: Access the Megatron TransformerConfig
"""
cls._validate_config(config)
return cls(config)
@classmethod
def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge":
"""
Load an AutoBridge from a pretrained model, automatically detecting the model type.
This method loads a model from HuggingFace Hub or a local directory and
creates a bridge instance ready for conversion operations. The model
architecture is validated to ensure compatibility.
Args:
path: HuggingFace model ID or path to model directory
Examples: "meta-llama/Meta-Llama-3-8B", "./my_model"
**kwargs: Additional arguments passed to HuggingFace from_hf_pretrained
Common options include:
- torch_dtype: Model precision (torch.float16, torch.bfloat16)
- device_map: Device placement strategy ("auto", "cuda:0", etc.)
- trust_remote_code: Allow custom model code execution
- attn_implementation: Attention implementation ("flash_attention_2", etc.)
Returns:
AutoBridge: Bridge instance with loaded model
Raises:
ValueError: If the model architecture is not supported
Example:
>>> # Basic loading
>>> bridge = AutoBridge.from_hf_pretrained("gpt2")
>>> # Load with specific settings
>>> bridge = AutoBridge.from_hf_pretrained(
... "meta-llama/Meta-Llama-3-8B",
... torch_dtype=torch.float16,
... device_map="auto"
... )
>>> # Works with local paths too
>>> bridge = AutoBridge.from_hf_pretrained("/path/to/model")
"""
# First load just the config to check architecture support
# Use thread-safe config loading to prevent race conditions
config_kwargs = dict(kwargs)
trust_remote_code = bool(config_kwargs.pop("trust_remote_code", False))
if trust_remote_code:
logger.warning(
"Loading a model with trust_remote_code=True allows arbitrary code execution "
"from the model repository. Only use this with models you trust."
)
config = safe_load_config_with_retry(path, trust_remote_code=trust_remote_code, **config_kwargs)
cls._validate_config(config, str(path))
# Transformers 5.0+ changed `rope_scaling` to a property whose setter
# does `self.rope_parameters = value`, replacing the entire dict and
# dropping any fields (e.g. `rope_theta`) that were set during initial
# construction. When a `rope_scaling` override is passed as a kwarg,
# `PretrainedConfig.from_dict` applies it via `setattr` *after* the
# initial construction, so those fields are silently lost and Megatron
# falls back to defaults (e.g. `rotary_base=10000`). Pre-populate the
# override dict with all base-config rope fields so the setter
# preserves them.
if "rope_scaling" in kwargs and isinstance(kwargs["rope_scaling"], dict):
base_rope = getattr(config, "rope_scaling", None)
if isinstance(base_rope, dict):
for key, value in base_rope.items():
if key not in kwargs["rope_scaling"]:
kwargs["rope_scaling"][key] = value
try:
return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs))
except Exception as e:
raise ValueError(f"Failed to load model with AutoBridge: {e}") from e
@classmethod
def can_handle(cls, path: Union[str, Path], trust_remote_code: bool = False) -> bool:
"""
Check if the bridge can handle the model at the given path.
This method allows you to verify model compatibility before attempting
to load it, which can be useful for validation or UI feedback.
Args:
path: Path to model directory or HuggingFace model ID
Examples: "meta-llama/Meta-Llama-3-8B", "/models/my_model"
trust_remote_code: Whether to trust remote code when loading config.
Set to True for models that use custom modeling code.
Returns:
bool: True if the bridge supports the model, False otherwise
Example:
>>> # Check if a model is supported
>>> if AutoBridge.can_handle("meta-llama/Meta-Llama-3-8B"):
... print("Model is supported!")
... else:
... print("Model requires a custom bridge implementation")
"""
try:
config = safe_load_config_with_retry(path, trust_remote_code=trust_remote_code)
return cls.supports(config)
except Exception:
return False
def load_hf_weights(
self,
model: list[MegatronModelT],
hf_path: str | Path | None = None,
allowed_mismatched_params: list[str] | None = None,
) -> None:
"""
Load HuggingFace weights into a Megatron model.
This method handles the conversion and distribution of weights from
HuggingFace format to Megatron's distributed format, including proper
tensor parallel and pipeline parallel distribution.
Args:
model: List of Megatron model instances (one per virtual pipeline stage)
hf_path: Optional path to load weights from. If None, uses weights
from the bridge's hf_pretrained instance
allowed_mismatched_params: Optional list of parameter names or patterns
to allow mismatch (skip instead of raise error).
Returns:
The input model with loaded weights
Raises:
ValueError: If hf_path is None and bridge was created without weights
Example:
>>> # Load weights from bridge's pretrained model
>>> bridge = AutoBridge.from_hf_pretrained("gpt2")
>>> megatron_model = create_megatron_model() # Your model creation
>>> bridge.load_hf_weights(megatron_model)
>>> # Load weights from a different checkpoint
>>> bridge.load_hf_weights(megatron_model, "./finetuned_model")
>>> # Load weights with allowed mismatched parameters
>>> bridge.load_hf_weights(
... megatron_model,
... allowed_mismatched_params=["*.bias", "decoder.layers.0.*"]
... )
"""
if hf_path is None:
if not isinstance(self.hf_pretrained, PreTrainedCausalLM):
raise ValueError("hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance")
pre_trained = self.hf_pretrained
else:
# Preserve trust_remote_code setting from the original bridge instance
trust_remote_code = getattr(self.hf_pretrained, "trust_remote_code", False)
pre_trained = PreTrainedCausalLM.from_pretrained(hf_path, trust_remote_code=trust_remote_code)
bridge = self._model_bridge
bridge.load_weights_hf_to_megatron(pre_trained, model, allowed_mismatched_params=allowed_mismatched_params)
# Get unquantized_state_dict from the bridge instance that was used for optimizer reload
self.unquantized_state_dict = getattr(bridge, "unquantized_state_dict", None)
return model
def export_hf_weights(
self,
model: list[MegatronModelT],
cpu: bool = False,
show_progress: bool = True,
conversion_tasks: Optional[List[WeightConversionTask]] = None,
merge_adapter_weights: bool = True,
) -> Iterable["HFWeightTuple"]:
"""
Export Megatron model weights to HuggingFace format.
This method yields weight tensors in HuggingFace format, handling the
gathering of distributed tensors and format conversion. It's useful for
streaming weight export or custom processing. All ranks get full tensors.
If the model contains LoRA adapters, they will be automatically merged
into the base weights before export. This ensures the exported model
contains the full fine-tuned weights.
Args:
model: Megatron model instance or list of instances
cpu: Whether to move tensors to CPU before yielding
show_progress: Display progress bar during export
conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks.
If not provided, tasks will be built automatically from the models.
*Please note that this is an advanced feature and should be used with caution.
The tasks needs to be built with the `get_conversion_tasks` method first and
carefully adjust based on your needs.*
merge_adapter_weights: Whether to gather and merge LoRA adapter weights into the base
tensors during export (defaults to True). Set to False to export only the base tensors.
Yields:
HFWeightTuple: Named tuples of (param_name, weight_tensor)
Example:
>>> # Export and process weights
>>> for name, weight in bridge.export_hf_weights(model):
... print(f"{name}: {weight.shape}")
>>> # Export with specific settings
>>> weights = list(bridge.export_hf_weights(
... model,
... cpu=True
... ))
"""
# Build conversion tasks based on export_weight_dtype configuration
if conversion_tasks is None and self.export_weight_dtype == "fp8":
if not isinstance(model, list):
model = [model]
self._validate_fp8_export_config(model)
# Use FP8 export tasks for blockwise FP8 weights
conversion_tasks = self._model_bridge.build_export_fp8_tasks(self.hf_pretrained, model)
bridge = self._model_bridge
return bridge.stream_weights_megatron_to_hf(
model,
self.hf_pretrained,
cpu=cpu,
show_progress=show_progress,
conversion_tasks=conversion_tasks,
merge_adapter_weights=merge_adapter_weights,
)
def export_hf_weights_quant(
self,
model: list[MegatronModelT],
quantization_checker: Callable[[str], bool],
quant_fn: Callable[..., Tuple[torch.Tensor, torch.Tensor]],
quant_block_size: Optional[Tuple[int, int]] = None,
cpu: bool = False,
show_progress: bool = True,
conversion_tasks: Optional[List[WeightConversionTask]] = None,
merge_adapter_weights: bool = False,
) -> Iterable["HFWeightTuple"]:
"""
Export Megatron model weights to HuggingFace format with quantization.
"""
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
return model_bridge.stream_weights_megatron_to_hf_quant(
dispatch_instance,
model,
self.hf_pretrained,
quantization_checker,
quant_fn,
quant_block_size=quant_block_size,
cpu=cpu,
show_progress=show_progress,
conversion_tasks=conversion_tasks,
merge_adapter_weights=merge_adapter_weights,
)
def export_adapter_weights(
self,
model: list[MegatronModelT],
cpu: bool = True,
show_progress: bool = True,
) -> Iterable["HFWeightTuple"]:
"""
Export only adapter weights from a Megatron model without merging them into base tensors.
This is useful when you want to save or inspect LoRA adapters independently from the
underlying pretrained weights.
Args:
model: Megatron model instance or list of instances
cpu: Whether to move tensors to CPU before yielding
show_progress: Display progress bar during export
Yields:
HFWeightTuple: Named tuples of (param_name, weight_tensor) for adapter parameters
"""
bridge = self._model_bridge
return bridge.stream_adapter_weights_megatron_to_hf(model, cpu=cpu, show_progress=show_progress)
def save_hf_adapter(
self,
model: list[MegatronModelT],
path: str | Path,
peft_config: "PEFT",
base_model_name_or_path: Optional[str] = None,
show_progress: bool = True,
) -> None:
"""Save LoRA adapter weights as a HuggingFace PEFT-compatible directory.
The output directory contains ``adapter_config.json`` and
``adapter_model.safetensors`` and can be loaded directly with
``peft.PeftModel.from_pretrained(base_model, path)``.
Args:
model: Megatron model instance or list of instances.
path: Directory path where the adapter files will be saved.
peft_config: The LoRA / DoRA config used during training (provides
``dim``, ``alpha``, ``dropout``, etc.).
base_model_name_or_path: HuggingFace model identifier or local path
of the base model this adapter was trained on. If *None*, the
value is inferred from ``hf_pretrained.model_name_or_path``.
show_progress: Display progress bar during export.
Example:
>>> bridge.save_hf_adapter(
... megatron_model,
... "./my-lora-adapter",
... peft_config=lora,
... base_model_name_or_path="Qwen/Qwen3-4B",
... )
>>> # Load with HuggingFace PEFT
>>> from peft import PeftModel
>>> from transformers import AutoModelForCausalLM
>>> base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B")
>>> model = PeftModel.from_pretrained(base, "./my-lora-adapter")
Note:
This method is collective -- all ranks must call it. Only rank 0
writes files to disk; the other ranks participate in the generator
to gather distributed (TP/PP/EP) tensors.
"""
import json
from safetensors.torch import save_file
from megatron.bridge.models.conversion.peft_bridge import (
build_adapter_config_dict,
convert_adapter_weights_to_peft_state,
infer_rank_pattern_from_adapter_weights,
infer_target_modules_from_adapter_weights,
)
if dist.is_initialized():
dist.barrier()
raw_adapter_weights = [
HFWeightTuple(exported_weight.param_name, exported_weight.weight.clone().float())
for exported_weight in self.export_adapter_weights(model, cpu=True, show_progress=show_progress)
]
if not raw_adapter_weights:
raise RuntimeError(
"No adapter weights were found on the model. "
"Ensure the model has PEFT adapters applied before calling save_hf_adapter()."
)
adapter_state, module_adapter_keys, target_parameters = convert_adapter_weights_to_peft_state(
raw_adapter_weights,
)
rank_pattern = infer_rank_pattern_from_adapter_weights(
raw_adapter_weights,
default_rank=getattr(peft_config, "dim", 32),
)
is_rank0 = not dist.is_initialized() or dist.get_rank() == 0
if is_rank0:
save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True)
if base_model_name_or_path is None:
base_model_name_or_path = str(
getattr(self.hf_pretrained, "model_name_or_path", "")
or getattr(self.hf_pretrained, "name_or_path", "")
)
target_modules = infer_target_modules_from_adapter_weights(module_adapter_keys)
adapter_config = build_adapter_config_dict(
peft_config,
target_modules=target_modules,
target_parameters=target_parameters,
base_model_name_or_path=base_model_name_or_path,
rank_pattern=rank_pattern,
)
config_path = save_dir / "adapter_config.json"
with open(config_path, "w") as f:
json.dump(adapter_config, f, indent=2)
weights_path = save_dir / "adapter_model.safetensors"
save_file(adapter_state, str(weights_path))
if dist.is_initialized():
dist.barrier()
def save_hf_pretrained(
self,
model: list[MegatronModelT],
path: str | Path,
show_progress: bool = True,
source_path: Optional[Union[str, Path]] = None,
strict: bool = True,
merge_adapter_weights: bool = True,
distributed_save: bool = False,
save_every_n_ranks: int = 1,
) -> None:
"""
Save a Megatron model in HuggingFace format.
This method exports the complete model including configuration, tokenizer,
and weights to a directory that can be loaded with HuggingFace's
from_pretrained methods.
If the model contains LoRA adapters, they will be automatically merged
into the base weights before saving. This ensures the saved model
contains the full fine-tuned weights.
If the original model was loaded with trust_remote_code=True, any custom
modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved
to ensure the saved model can be loaded properly.
Config-only bridges are supported when created via the auto-config
flow in convert_checkpoints.py.
Args:
model: Megatron model instance or list of instances
path: Directory path to save the model
show_progress: Display progress bar during weight export
source_path: Path to the directory containing custom modeling files to be preserved.
This is useful when converting from Megatron checkpoints where the original
HuggingFace model with custom modeling files needs to be referenced. If not specified,
the path will be automatically determined from the HuggingFace configuration.
strict: Whether to perform strict validation during weight export
merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
distributed_save: Whether to enable distributed saving mode where each rank saves
part of weights independently. When False (default), only rank 0 performs
the save operation after gathering weights from all ranks.
save_every_n_ranks: Interval for saving weights across ranks in distributed mode.
For example, if set to 2, only ranks 0, 2, 4, ... will save weights.
This is useful for reducing I/O pressure when dealing with large-scale distributed
training. Only effective when distributed_save=True. Default is 1 (all ranks save).
Example:
>>> # Save model after training
>>> bridge.save_hf_pretrained(megatron_model, "./my_finetuned_model")
>>> # Load the saved model with HuggingFace
>>> from transformers import AutoModelForCausalLM
>>> hf_model = AutoModelForCausalLM.from_pretrained("./my_finetuned_model")
Note:
This method is collective - all ranks must call it. Only rank 0
saves the configuration files, while weight saving is coordinated
across all ranks.
"""
if not isinstance(self.hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)):
raise ValueError("save_hf_pretrained requires a pretrained HuggingFace model or config.")
is_config_only = isinstance(self.hf_pretrained, PretrainedConfig)
def _save_artifacts():
if is_config_only:
import json
# Config-only path: write config.json and optionally download modeling files from Hub.
Path(path).mkdir(parents=True, exist_ok=True)
config_dict = self.hf_pretrained.to_dict()
if not self.trust_remote_code:
config_dict.pop("auto_map", None)
with open(Path(path) / "config.json", "w") as _f:
json.dump(config_dict, _f, indent=2, sort_keys=True, allow_nan=True)
# Download custom modeling files only when the export is meant to preserve remote code.
hub_repo = self.hf_model_id
if hub_repo and self.trust_remote_code:
try:
from huggingface_hub import hf_hub_download, list_repo_files
repo_files = list_repo_files(hub_repo)
py_files = [f for f in repo_files if f.endswith(".py")]
for py_file in py_files:
hf_hub_download(
repo_id=hub_repo,
filename=py_file,
local_dir=path,
)
except Exception as exc:
logger.warning(
"Could not download modeling files from %s: %s. "
"This is expected for models that use standard transformers "
"modeling classes and do not define custom .py files.",
hub_repo,
exc,
)
finally:
PreTrainedBase._cleanup_hf_local_dir_cache(Path(path))
else:
# Get bridge-level ADDITIONAL_FILE_PATTERNS if configured
additional_files = getattr(self._model_bridge, "ADDITIONAL_FILE_PATTERNS", None) or None
self.hf_pretrained.save_artifacts(
path, original_source_path=source_path, additional_files=additional_files
)
if dist.is_initialized():
if dist.get_rank() == 0:
_save_artifacts()
else:
_save_artifacts()
self.save_hf_weights(
model,
path,
show_progress,
strict,
merge_adapter_weights=merge_adapter_weights,
distributed_save=distributed_save,
save_every_n_ranks=save_every_n_ranks,
)
def save_hf_weights(
self,
model: list[MegatronModelT],
path: str | Path,
show_progress: bool = True,
strict: bool = True,
merge_adapter_weights: bool = True,
distributed_save: bool = False,
save_every_n_ranks: int = 1,
) -> None:
"""
Save Megatron model weights in HuggingFace safetensors format.
This method exports only the model weights (not configuration or tokenizer)
to safetensors files compatible with HuggingFace. It uses streaming save
to handle large models efficiently without requiring all weights in memory
at once.
If the model contains LoRA adapters, they will be automatically merged
into the base weights before saving. This ensures the saved weights
contain the full fine-tuned parameters.
The weights are gathered from distributed ranks and saved in the standard
HuggingFace sharded format when the model is large.
Args:
model: Megatron model instance or list of instances
path: Directory path where weight files will be saved
show_progress: Display progress bar during export
merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
distributed_save: Whether to enable distributed saving mode where each rank saves
part of weights independently.
save_every_n_ranks: Interval for saving weights across ranks in distributed mode.
For example, if set to 2, only ranks 0, 2, 4, ... will save weights.
Raises:
ValueError: If the state source doesn't support streaming save
Example:
>>> # Save just the weights
>>> bridge.save_hf_weights(megatron_model, "./model_weights")
>>> # Save without progress bar (useful in scripts)
>>> bridge.save_hf_weights(megatron_model, "./weights", show_progress=False)
Note:
- This method is collective and must be called by all ranks
- Uses safetensors format for efficient loading and security
- Automatically handles model sharding for large models
- The saved weights can be loaded with HuggingFace's from_pretrained
"""
is_distributed = dist.is_initialized()
if is_distributed:
dist.barrier()
bridge = self._model_bridge
generator = bridge.stream_weights_megatron_to_hf(
model,
self.hf_pretrained,
cpu=True,
show_progress=show_progress,
merge_adapter_weights=merge_adapter_weights,
)
model_instance = self._get_model_instance(model)
quant_tensors = None
if is_quantized(model_instance):
quant_tensors = {}
def _filter_quant(gen):
for name, tensor in gen:
if "_quantizer." in name:
quant_tensors[name] = tensor
continue
yield name, tensor
generator = _filter_quant(generator)
# Check if the state source is SafeTensorsStateSource for streaming save.
if (
hasattr(self.hf_pretrained, "state")
and hasattr(self.hf_pretrained.state, "source")
and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource)
):
source = self.hf_pretrained.state.source
model_config = getattr(model_instance, "config", None)
hf_config = getattr(self.hf_pretrained, "config", self.hf_pretrained)
mtp_disabled = _saved_config_disables_mtp(path) or any(
_config_disables_mtp(config) for config in (hf_config, model_config)
)
ignored_source_key_prefixes = ("mtp.",) if mtp_disabled and source.has_glob("mtp.*") else None
source.save_generator(
generator,
path,
strict=strict,
distributed_save=distributed_save,
save_every_n_ranks=save_every_n_ranks,
ignored_source_key_prefixes=ignored_source_key_prefixes,
)
else:
# Config-only path: shard and write safetensors directly
import json
from huggingface_hub import split_torch_state_dict_into_shards
# NOTE: Collects the full state dict into CPU memory before sharding.
# For very large models (>70B), this may require significant host RAM.
rank = dist.get_rank() if is_distributed else 0
if rank == 0:
state_dict = {name: tensor.contiguous().cpu() for name, tensor in generator}
else:
for _ in generator:
pass
state_dict = None
if rank == 0:
plan = split_torch_state_dict_into_shards(state_dict)
safe_dir = Path(path)
safe_dir.mkdir(parents=True, exist_ok=True)
for filename, tensors in plan.filename_to_tensors.items():
shard = {k: state_dict[k] for k in tensors}
save_file(shard, safe_dir / filename)
if plan.is_sharded:
index = {"metadata": plan.metadata, "weight_map": plan.tensor_to_filename}
with open(safe_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
# Save quantizer/amax sidecar after the main generator is consumed (rank 0 only).
if quant_tensors:
rank = dist.get_rank() if is_distributed else 0
if rank == 0:
sidecar_path = Path(path) / "modelopt_weights.pt"
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(quant_tensors, sidecar_path)
if is_distributed:
dist.barrier()
def save_megatron_model(
self,
model: list[MegatronModule],
path: str | Path,
hf_tokenizer_path: Optional[str | Path] = None,
low_memory_save: bool = False,
hf_tokenizer_kwargs: Optional[dict] = None,
) -> None:
"""
Save a Megatron model in native Megatron checkpoint format without optimizer
state.
This method saves the model in Megatron's native checkpoint format, which
can be loaded directly by Megatron for training or inference. The checkpoint
includes the model configuration and weights, NO optimizer state or other
artifacts.
Args:
model: Megatron model instance or list of instances