Skip to content

Commit 3662540

Browse files
author
maxtext authors
committed
Merge pull request #1821 from AI-Hypercomputer:yixuannwang-test2
PiperOrigin-RevId: 771593926
2 parents 4d4b6b0 + db07d6b commit 3662540

File tree

8 files changed

+572
-51
lines changed

8 files changed

+572
-51
lines changed

MaxText/tests/hf_checkpoint_conversion_check.py

Lines changed: 177 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
17-
from typing import Sequence
16+
import os
1817
import torch
18+
import torch.nn.functional as F
19+
import argparse
1920
from transformers import AutoTokenizer, AutoModelForCausalLM
20-
import os
21-
from absl import app # Removed flags
21+
from tabulate import tabulate
2222

2323
from MaxText.utils.ckpt_conversion.utils.hf_utils import (
24-
check_predicted_tokens_match,
24+
# check_predicted_tokens_match,
2525
check_arrays_match,
2626
)
27+
from MaxText import max_logging
2728
# Read Hugging Face token from environment variable
2829
hf_token = os.environ.get("HF_AUTH_TOKEN")
2930

@@ -40,6 +41,7 @@
4041
huggingface_hub
4142
transformers
4243
accelerate
44+
tabulate
4345
"""
4446

4547

@@ -70,46 +72,180 @@ def get_logits(inputs, model, golden_model):
7072
return logits, golden_logits
7173

7274

73-
def main(argv: Sequence[str]) -> None:
74-
# Parse arguments from argv
75-
# Default values
76-
parsed_args = {"golden_model_id": "google/gemma-2-2b-it", "hf_checkpoint_path": os.path.expanduser("~/.hf_output/")}
77-
for arg in argv[1:]:
78-
if "=" in arg:
79-
key, value = arg.split("=", 1)
80-
if key in parsed_args:
81-
parsed_args[key] = value
82-
else:
83-
print(f"Warning: Unknown argument '{key}' found in argv. Ignoring.")
84-
85-
golden_model = AutoModelForCausalLM.from_pretrained(parsed_args["golden_model_id"], torch_dtype=torch.float32)
86-
87-
tokenizer = AutoTokenizer.from_pretrained(parsed_args["hf_checkpoint_path"])
88-
model = AutoModelForCausalLM.from_pretrained(parsed_args["hf_checkpoint_path"], torch_dtype=torch.float32)
89-
90-
# TODO: (@yixuannwang) use 3 prompts to verify
91-
input_text = "I love to"
92-
inputs = tokenizer(input_text, return_tensors="pt")
93-
# --- Generate Output ---
94-
with torch.no_grad():
95-
outputs = model.generate(**inputs, max_new_tokens=8)
96-
# --- Decode and Print ---
97-
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
98-
99-
# Check weights match
100-
print("########### check weights match ############### ")
101-
check_weights_match(model, golden_model)
75+
def get_top_k_tokens_scores(logits_tensor, tokenizer_instance, k=10, description=""):
76+
"""Get the top-k tokens and their scores from a given logits tensor."""
77+
max_logging.log(f"\n--- {description} top {k} tokens ---")
78+
collected_tokens = []
79+
tokens = []
80+
# Ensure logits_tensor is on CPU for operations like topk and item()
81+
logits_tensor = logits_tensor.cpu()
82+
topk_results = torch.topk(logits_tensor[0, -1], k=k)
83+
for i in range(k):
84+
tok_id = topk_results.indices[i].item()
85+
score = topk_results.values[i].item()
86+
tok = tokenizer_instance.decode(tok_id)
87+
collected_tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})
88+
tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})
89+
90+
# Prepare data for tabulate: a list of lists
91+
table_data = [[d["id"], d["token"], d["score"]] for d in collected_tokens]
92+
max_logging.log(tabulate(table_data, headers=["Token ID", "Token", "Score"], tablefmt="orgtbl"))
93+
return tokens
94+
95+
96+
def compare_top_tokens(converted_tokens, golden_tokens):
97+
"""
98+
Compares two lists of top tokens and calculates similarity metrics.
99+
100+
Args:
101+
converted_tokens: top tokens from the converted model.
102+
golden_tokens: top tokens from the golden model.
103+
"""
104+
# Extract the sets of token IDs for comparison
105+
converted_ids = {token["id"] for token in converted_tokens}
106+
golden_ids = {token["id"] for token in golden_tokens}
107+
108+
# --- Metric 1: Overlap Count & Jaccard Similarity ---
109+
intersection = converted_ids.intersection(golden_ids)
110+
union = converted_ids.union(golden_ids)
111+
112+
overlap_count = len(intersection)
113+
jaccard_similarity = overlap_count / len(union) if union else 0.0
114+
115+
# --- Metric 2: Rank Agreement ---
116+
rank_matches = 0
117+
min_len = min(len(converted_tokens), len(golden_tokens))
118+
for i in range(min_len):
119+
if converted_tokens[i]["id"] == golden_tokens[i]["id"]:
120+
rank_matches += 1
121+
122+
rank_agreement = (rank_matches / min_len) * 100 if min_len > 0 else 0.0
123+
124+
metrics = {
125+
"overlap_count": f"{overlap_count}/{min_len}",
126+
"jaccard_similarity": jaccard_similarity,
127+
"rank_agreement_percentage": rank_agreement,
128+
}
129+
130+
max_logging.log("\n--- Similarity Metrics of Top Tokens ---")
131+
table = [[key, value] for key, value in metrics.items()]
132+
max_logging.log(tabulate(table, headers=["Metric", "Value"], tablefmt="orgtbl"))
133+
134+
135+
def check_kl_divergence(model_logits, golden_logits, atol=0.02):
136+
"""
137+
Calculates KL divergence D_KL(P_golden || Q_model) over a batch of sequences.
138+
139+
Args:
140+
model_logits: Logits from the converted model (Batch, SeqLen, VocabSize).
141+
golden_logits: Logits from the golden model (Batch, SeqLen, VocabSize).
142+
token_size: The number of vocabulary entries to consider for the comparison.
143+
(Effectively vocab_size_to_compare).
144+
"""
145+
# 1. Select the relevant vocabulary slice from the logits.
146+
token_size = min(model_logits.shape[2], golden_logits.shape[2])
147+
model_logits_sliced = model_logits[..., :token_size]
148+
golden_logits_sliced = golden_logits[..., :token_size]
149+
150+
# 2. Reshape
151+
b, s, v = model_logits_sliced.shape
152+
model_logits_reshaped = model_logits_sliced.view(b * s, v)
153+
golden_logits_reshaped = golden_logits_sliced.view(b * s, v)
154+
155+
# 3. Get the probability distributions.
156+
golden_probabilities = F.softmax(golden_logits_reshaped, dim=-1)
157+
model_log_probabilities = F.log_softmax(model_logits_reshaped, dim=-1)
158+
159+
# 4. Calculate avg KL divergence for all token distributions.
160+
# use 'batchmean'; the sum of the KL divergences for each token in the batch
161+
# and then divides by the number of tokens (b * s)
162+
kl_div_value = F.kl_div(
163+
input=model_log_probabilities,
164+
target=golden_probabilities,
165+
reduction="batchmean", # Use 'batchmean' for the average KL per token.
166+
log_target=False,
167+
)
168+
169+
max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.6f}")
170+
171+
# To find the max KL divergence for any single token in the set
172+
# use reduction='none'.
173+
kl_divs_per_token = F.kl_div(
174+
input=model_log_probabilities, target=golden_probabilities, reduction="none", log_target=False
175+
).sum(
176+
dim=-1
177+
) # Sum over the vocab dim to get a single KL value per token
178+
179+
max_kl_div = kl_divs_per_token.max()
180+
max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.6f}")
181+
182+
assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}"
183+
184+
185+
def run_prompts(args: argparse.Namespace) -> None:
186+
"""
187+
Args:
188+
- golden_model_id (str): HF model ID for the golden model.
189+
- hf_checkpoint_path (str): Path to the converted HF checkpoint.
190+
- max_kl_div (float): Maximum allowed KL divergence.
191+
"""
192+
golden_model = AutoModelForCausalLM.from_pretrained(args.golden_model_id, torch_dtype=torch.bfloat16)
193+
golden_tokenizer = AutoTokenizer.from_pretrained(args.golden_model_id)
194+
195+
tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint_path)
196+
model, _ = AutoModelForCausalLM.from_pretrained(
197+
args.hf_checkpoint_path, trust_remote_code=True, torch_dtype=torch.bfloat16, output_loading_info=True
198+
)
199+
200+
# max_logging.log(loading_info)
201+
202+
prompts = ["I love to", "Today is a", "What is the"]
203+
for input_text in prompts:
204+
max_logging.log(f"\n--- Prompt: {input_text} ---")
205+
inputs = tokenizer(input_text, return_tensors="pt")
206+
# --- Generate Output ---
207+
with torch.no_grad():
208+
outputs = model.generate(**inputs, max_new_tokens=15, do_sample=False)
209+
# --- Decode and Print ---
210+
max_logging.log(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
211+
212+
# --- Compare tokens ---
213+
model_logits, golden_model_logits = get_logits(inputs, model, golden_model)
214+
tokens = get_top_k_tokens_scores(model_logits, tokenizer, k=10, description="converted model")
215+
golden_tokens = get_top_k_tokens_scores(golden_model_logits, golden_tokenizer, k=10, description="golden model")
216+
compare_top_tokens(converted_tokens=tokens, golden_tokens=golden_tokens)
217+
218+
check_kl_divergence(model_logits, golden_model_logits, atol=args.max_kl_div)
219+
220+
"""
221+
if the model's structure is exactly the same as the golden model (layers, vocab_size, etc.),
222+
you can check more weights details using the following steps:
102223
103-
# Run forward pass to get logits
104-
logits, golden_logits = get_logits(inputs, model, golden_model)
224+
check_weights_match(model, golden_model)
105225
106226
# Check logits from the first 5 tokens match
107-
print("########### check logits match ############### ")
108-
check_arrays_match(logits[0, :5, :], golden_logits[0, :5, :], atol=0.2)
227+
check_arrays_match(model_logits[0, :5, :], golden_model_logits[0, :5, :], atol=0.2)
109228
110-
print("########### check predicted token match ############### ")
111-
check_predicted_tokens_match(logits, golden_logits)
229+
check_predicted_tokens_match(model_logits, golden_model_logits)
230+
"""
112231

113232

114233
if __name__ == "__main__":
115-
app.run(main)
234+
parser = argparse.ArgumentParser(description="Verify HuggingFace checkpoints converted from MaxText.")
235+
parser.add_argument(
236+
"--golden_model_id",
237+
type=str,
238+
default="google/gemma-2-2b-it",
239+
help="The HuggingFace model ID for the golden/reference model.",
240+
)
241+
parser.add_argument(
242+
"--hf_checkpoint_path",
243+
type=str,
244+
default=os.path.expanduser("~/.hf_output/"),
245+
help="Path to the converted HuggingFace checkpoint directory.",
246+
)
247+
parser.add_argument("--max_kl_div", type=float, default=0.02, help="Maximum allowed KL divergence between model logits.")
248+
249+
parsed_args = parser.parse_args()
250+
251+
run_prompts(parsed_args)

MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma2-2b-it/2025-02-20-1
1212
LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma2-2b_output" # HF requires a local dir
1313
GOLDEN_MODEL_ID="google/gemma-2-2b-it"
1414

15-
CONVERT_MODULE="MaxText.ckpt_conversion.to_huggingface"
15+
CONVERT_MODULE="MaxText.utils.ckpt_conversion.to_huggingface"
1616
CONVERT_ARGS=(
1717
"MaxText/configs/base.yml"
1818
"model_name=gemma2-2b"
@@ -29,7 +29,7 @@ CONVERT_ARGS=(
2929
"base_output_directory=${HF_CHECKPOINT_GCS_PATH}"
3030
)
3131

32-
VERIFY_MODULE="MaxText.tests.huggingface_ckpt_conversion_check"
32+
VERIFY_MODULE="MaxText.tests.hf_ckpt_conversion_check"
3333

3434
VERIFY_ARGS=(
3535
"golden_model_id=${GOLDEN_MODEL_ID}"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/bin/bash
2+
3+
# Exit immediately if a command exits with a non-zero status.
4+
set -e
5+
6+
export HF_AUTH_TOKEN=""
7+
8+
DATE=$(date +%Y-%m-%d)
9+
# Define variables for paths and arguments
10+
HF_CHECKPOINT_GCS_PATH="gs://maxtext-model-checkpoints/HuggingFace/gemma3-4b/${DATE}" # (optional)GCS path for HF model
11+
MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/unscanned/checkpoints/0/items"
12+
LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma3-4b_output" # HF requires a local dir
13+
GOLDEN_MODEL_ID="google/gemma-3-4b-it"
14+
15+
CONVERT_MODULE="MaxText.utils.ckpt_conversion.to_huggingface"
16+
CONVERT_ARGS=(
17+
"MaxText/configs/base.yml",
18+
"model_name=gemma3-4b",
19+
"tokenizer_path=assets/tokenizer.gemma3",
20+
"load_parameters_path=${MAXTEXT_CHECKPOINT_DIR}",
21+
"per_device_batch_size=1",
22+
"run_name=ht_test",
23+
"max_prefill_predict_length=8",
24+
"max_target_length=16",
25+
"steps=1",
26+
"async_checkpointing=false",
27+
"prompt='I love to'",
28+
"scan_layers=false",
29+
"attention='dot_product'",
30+
"base_output_directory=${HF_CHECKPOINT_GCS_PATH}"
31+
)
32+
33+
VERIFY_MODULE="MaxText.tests.hf_ckpt_conversion_check"
34+
35+
VERIFY_ARGS=(
36+
"--golden_model_id=${GOLDEN_MODEL_ID}"
37+
"--hf_checkpoint_path=${LOCAL_HF_CHECKPOINT_DIR}" # Updated to local path
38+
)
39+
40+
41+
# --- Step 1: Run the Hugging Face Conversion ---
42+
echo "Starting Hugging Face model conversion for gemma2-2b..."
43+
cd "$MAXTEXT_PROJECT_DIR"
44+
45+
# Construct the command
46+
CONVERT_CMD=("python3" -m "$CONVERT_MODULE")
47+
for arg in "${CONVERT_ARGS[@]}"; do
48+
CONVERT_CMD+=("$arg")
49+
done
50+
51+
# Execute the command
52+
"${CONVERT_CMD[@]}"
53+
54+
echo "Hugging Face model conversion finished."
55+
56+
# --- Step 2: Run the Verification Script ---
57+
echo "Starting verification for the converted gemma2-2b model..."
58+
59+
# Create local directory for checkpoints and download from GCS
60+
echo "Creating local directory for HF checkpoints: ${LOCAL_HF_CHECKPOINT_DIR}"
61+
mkdir -p "${LOCAL_HF_CHECKPOINT_DIR}"
62+
echo "Downloading HF checkpoints from ${HF_CHECKPOINT_GCS_PATH} to ${LOCAL_HF_CHECKPOINT_DIR}..."
63+
gsutil -m cp -r "${HF_CHECKPOINT_GCS_PATH}/*" "${LOCAL_HF_CHECKPOINT_DIR}/"
64+
echo "Download complete."
65+
66+
# Construct the command
67+
VERIFY_CMD=("python3" -m "$VERIFY_MODULE")
68+
if [ ${#VERIFY_ARGS[@]} -ne 0 ]; then
69+
for arg in "${VERIFY_ARGS[@]}"; do
70+
VERIFY_CMD+=("$arg")
71+
done
72+
fi
73+
74+
# Execute the command
75+
"${VERIFY_CMD[@]}"
76+
77+
# Optional: Clean up the local checkpoint directory
78+
echo "Cleaning up local HF checkpoint directory: ${LOCAL_HF_CHECKPOINT_DIR}"
79+
rm -rf "${LOCAL_HF_CHECKPOINT_DIR}"
80+
echo "Cleanup complete."
81+
82+
echo "Verification script finished. Please check the above generated text"
83+
echo "All steps completed."

MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from MaxText.utils.ckpt_conversion.utils.shape_mapping import SHAPE_MAPPING
3636
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
37-
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, TOKENIZER_HF_IDS)
37+
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
3838

3939
"""Convert MaxText unscanned ckpt into HF format"""
4040

@@ -83,9 +83,9 @@ def main(argv: Sequence[str]) -> None:
8383
hf_config_obj = HF_MODEL_CONFIGS[model_key]
8484

8585
# 2. Load Tokenizer
86-
if model_key not in TOKENIZER_HF_IDS:
86+
if model_key not in HF_IDS:
8787
raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}")
88-
hf_tokenizer_id = TOKENIZER_HF_IDS[model_key]
88+
hf_tokenizer_id = HF_IDS[model_key]
8989
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_id, token=hf_token)
9090

9191
# 3. Get parameter mappings

0 commit comments

Comments
 (0)