Skip to content

Commit cbfc744

Browse files
authored
fix: Fix compare script and add tool to create hf toy (#1689)
Signed-off-by: yaoyu-33 <[email protected]>
1 parent fb85f9d commit cbfc744

File tree

2 files changed

+181
-3
lines changed

2 files changed

+181
-3
lines changed

examples/conversion/compare_hf_and_megatron/compare.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,6 @@ def _load_megatron_model(args):
607607
model_provider.expert_tensor_parallel_size = etp
608608
model_provider.pipeline_dtype = torch.bfloat16
609609
model_provider.finalize()
610-
model_provider.initialize_model_parallel(seed=0)
611610
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)
612611

613612
model_components = [m.eval() for m in megatron_model]
@@ -716,8 +715,8 @@ def compare_models_one_step(args) -> None:
716715
)
717716

718717
del hf_model
719-
# Load Megatron model
720-
megatron_model = _load_megatron_model(args)
718+
# Reload Megatron model to ensure a fresh instance before comparison
719+
megatron_model, _ = _load_megatron_model(args)
721720

722721
# Broadcast HF results to all ranks after Megatron initialization
723722
# (following the pattern from generate_from_hf.py)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Utility script that materializes a significantly smaller HuggingFace checkpoint
4+
from an existing model configuration. It is primarily intended to help bridge
5+
functional / quantization tests (e.g., the Qwen3 MoE conversion suites) avoid
6+
downloading extremely large public checkpoints.
7+
8+
Example:
9+
```bash
10+
uv run python examples/conversion/create_hf_toy_model.py \
11+
--hf-model-id Qwen/Qwen3-30B-A3B \
12+
--output-dir /tmp/qwen3_toy \
13+
--num-hidden-layers 2 \
14+
--num-experts 4
15+
```
16+
17+
The script works by:
18+
1. Loading the original configuration via `AutoConfig` so that all model-specific
19+
attributes (e.g., gating settings, rotary params) stay in sync with the
20+
upstream release.
21+
2. Overriding a handful of size-related knobs (hidden layers, number of experts,
22+
etc.) so that the instantiated model is tiny but structurally compatible.
23+
3. Saving the resulting random-weight checkpoint alongside a tokenizer so tests
24+
can treat it like any other HF model directory.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import argparse
30+
from pathlib import Path
31+
from typing import Optional
32+
33+
import torch
34+
from transformers import (
35+
AutoConfig,
36+
AutoModelForCausalLM,
37+
AutoTokenizer,
38+
)
39+
40+
41+
def _parse_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser(description="Create a reduced HuggingFace Causal LM checkpoint for tests.")
43+
parser.add_argument(
44+
"--hf-model-id",
45+
default="Qwen/Qwen3-30B-A3B",
46+
help="Source HuggingFace model id to pull the base config from.",
47+
)
48+
parser.add_argument(
49+
"--tokenizer-id",
50+
default=None,
51+
help="Optional tokenizer model id. Defaults to --hf-model-id.",
52+
)
53+
parser.add_argument(
54+
"--output-dir",
55+
required=True,
56+
help="Directory where the toy checkpoint will be saved.",
57+
)
58+
parser.add_argument(
59+
"--num-hidden-layers",
60+
type=int,
61+
default=2,
62+
help="Number of transformer layers to keep in the toy model.",
63+
)
64+
parser.add_argument(
65+
"--num-experts",
66+
type=int,
67+
default=4,
68+
help="Total MoE experts per layer for the toy model.",
69+
)
70+
parser.add_argument(
71+
"--num-experts-per-tok",
72+
type=int,
73+
default=None,
74+
help="Experts routed per token. Defaults to --num-experts.",
75+
)
76+
parser.add_argument(
77+
"--moe-intermediate-size",
78+
type=int,
79+
default=None,
80+
help="Optional override for the MoE FFN size.",
81+
)
82+
parser.add_argument(
83+
"--seed",
84+
type=int,
85+
default=1234,
86+
help="Torch seed applied before checkpoint creation.",
87+
)
88+
parser.add_argument(
89+
"--disable-remote-code-trust",
90+
action="store_false",
91+
dest="trust_remote_code",
92+
help="Disable trust_remote_code when loading from HuggingFace.",
93+
)
94+
parser.set_defaults(trust_remote_code=True)
95+
return parser.parse_args()
96+
97+
98+
def _adjust_config(
99+
config,
100+
*,
101+
num_hidden_layers: int,
102+
num_experts: int,
103+
num_experts_per_tok: Optional[int],
104+
moe_intermediate_size: Optional[int],
105+
) -> None:
106+
"""Mutate the config in-place so it matches the requested toy topology."""
107+
108+
config.num_hidden_layers = num_hidden_layers
109+
110+
if hasattr(config, "max_window_layers"):
111+
config.max_window_layers = min(config.max_window_layers, num_hidden_layers)
112+
113+
if hasattr(config, "layer_types"):
114+
config.layer_types = config.layer_types[:num_hidden_layers]
115+
116+
mlp_only_layers = getattr(config, "mlp_only_layers", [])
117+
if isinstance(mlp_only_layers, (list, tuple)):
118+
config.mlp_only_layers = [layer for layer in mlp_only_layers if layer < num_hidden_layers]
119+
120+
config.num_experts = num_experts
121+
config.num_experts_per_tok = (
122+
num_experts_per_tok
123+
if num_experts_per_tok is not None
124+
else min(num_experts, getattr(config, "num_experts_per_tok", num_experts))
125+
)
126+
127+
if hasattr(config, "router_top_k"):
128+
config.router_top_k = min(config.num_experts, config.num_experts_per_tok)
129+
130+
if moe_intermediate_size is not None:
131+
config.moe_intermediate_size = moe_intermediate_size
132+
133+
134+
def _save_tokenizer(output_dir: Path, tokenizer_id: str, *, trust_remote_code: bool) -> None:
135+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=trust_remote_code)
136+
tokenizer.save_pretrained(output_dir)
137+
138+
139+
def main() -> None:
140+
"""Main entry point."""
141+
args = _parse_args()
142+
143+
output_dir = Path(args.output_dir).expanduser().resolve()
144+
output_dir.mkdir(parents=True, exist_ok=True)
145+
146+
tokenizer_id = args.tokenizer_id or args.hf_model_id
147+
trust_remote_code = bool(args.trust_remote_code)
148+
149+
torch.manual_seed(args.seed)
150+
151+
config = AutoConfig.from_pretrained(
152+
args.hf_model_id,
153+
trust_remote_code=trust_remote_code,
154+
)
155+
config.torch_dtype = torch.bfloat16
156+
157+
_adjust_config(
158+
config,
159+
num_hidden_layers=args.num_hidden_layers,
160+
num_experts=args.num_experts,
161+
num_experts_per_tok=args.num_experts_per_tok,
162+
moe_intermediate_size=args.moe_intermediate_size,
163+
)
164+
165+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code)
166+
model = model.bfloat16()
167+
model.save_pretrained(output_dir, safe_serialization=True)
168+
169+
_save_tokenizer(output_dir, tokenizer_id, trust_remote_code=trust_remote_code)
170+
171+
print(f"Toy HuggingFace checkpoint saved to: {output_dir}")
172+
print(f" hidden_layers={args.num_hidden_layers}")
173+
print(f" num_experts={args.num_experts}")
174+
print(f" num_experts_per_tok={config.num_experts_per_tok}")
175+
print(f" tokenizer_source={tokenizer_id}")
176+
177+
178+
if __name__ == "__main__":
179+
main()

0 commit comments

Comments
 (0)