|
13 | 13 | See the License for the specific language governing permissions and |
14 | 14 | limitations under the License. |
15 | 15 | """ |
16 | | - |
17 | | -from typing import Sequence |
| 16 | +import os |
18 | 17 | import torch |
| 18 | +import torch.nn.functional as F |
| 19 | +import argparse |
19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM |
20 | | -import os |
21 | | -from absl import app # Removed flags |
| 21 | +from tabulate import tabulate |
22 | 22 |
|
23 | 23 | from MaxText.utils.ckpt_conversion.utils.hf_utils import ( |
24 | | - check_predicted_tokens_match, |
| 24 | + # check_predicted_tokens_match, |
25 | 25 | check_arrays_match, |
26 | 26 | ) |
| 27 | +from MaxText import max_logging |
27 | 28 | # Read Hugging Face token from environment variable |
28 | 29 | hf_token = os.environ.get("HF_AUTH_TOKEN") |
29 | 30 |
|
|
40 | 41 | huggingface_hub |
41 | 42 | transformers |
42 | 43 | accelerate |
| 44 | + tabulate |
43 | 45 | """ |
44 | 46 |
|
45 | 47 |
|
@@ -70,46 +72,180 @@ def get_logits(inputs, model, golden_model): |
70 | 72 | return logits, golden_logits |
71 | 73 |
|
72 | 74 |
|
73 | | -def main(argv: Sequence[str]) -> None: |
74 | | - # Parse arguments from argv |
75 | | - # Default values |
76 | | - parsed_args = {"golden_model_id": "google/gemma-2-2b-it", "hf_checkpoint_path": os.path.expanduser("~/.hf_output/")} |
77 | | - for arg in argv[1:]: |
78 | | - if "=" in arg: |
79 | | - key, value = arg.split("=", 1) |
80 | | - if key in parsed_args: |
81 | | - parsed_args[key] = value |
82 | | - else: |
83 | | - print(f"Warning: Unknown argument '{key}' found in argv. Ignoring.") |
84 | | - |
85 | | - golden_model = AutoModelForCausalLM.from_pretrained(parsed_args["golden_model_id"], torch_dtype=torch.float32) |
86 | | - |
87 | | - tokenizer = AutoTokenizer.from_pretrained(parsed_args["hf_checkpoint_path"]) |
88 | | - model = AutoModelForCausalLM.from_pretrained(parsed_args["hf_checkpoint_path"], torch_dtype=torch.float32) |
89 | | - |
90 | | - # TODO: (@yixuannwang) use 3 prompts to verify |
91 | | - input_text = "I love to" |
92 | | - inputs = tokenizer(input_text, return_tensors="pt") |
93 | | - # --- Generate Output --- |
94 | | - with torch.no_grad(): |
95 | | - outputs = model.generate(**inputs, max_new_tokens=8) |
96 | | - # --- Decode and Print --- |
97 | | - print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
98 | | - |
99 | | - # Check weights match |
100 | | - print("########### check weights match ############### ") |
101 | | - check_weights_match(model, golden_model) |
| 75 | +def get_top_k_tokens_scores(logits_tensor, tokenizer_instance, k=10, description=""): |
| 76 | + """Get the top-k tokens and their scores from a given logits tensor.""" |
| 77 | + max_logging.log(f"\n--- {description} top {k} tokens ---") |
| 78 | + collected_tokens = [] |
| 79 | + tokens = [] |
| 80 | + # Ensure logits_tensor is on CPU for operations like topk and item() |
| 81 | + logits_tensor = logits_tensor.cpu() |
| 82 | + topk_results = torch.topk(logits_tensor[0, -1], k=k) |
| 83 | + for i in range(k): |
| 84 | + tok_id = topk_results.indices[i].item() |
| 85 | + score = topk_results.values[i].item() |
| 86 | + tok = tokenizer_instance.decode(tok_id) |
| 87 | + collected_tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)}) |
| 88 | + tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)}) |
| 89 | + |
| 90 | + # Prepare data for tabulate: a list of lists |
| 91 | + table_data = [[d["id"], d["token"], d["score"]] for d in collected_tokens] |
| 92 | + max_logging.log(tabulate(table_data, headers=["Token ID", "Token", "Score"], tablefmt="orgtbl")) |
| 93 | + return tokens |
| 94 | + |
| 95 | + |
| 96 | +def compare_top_tokens(converted_tokens, golden_tokens): |
| 97 | + """ |
| 98 | + Compares two lists of top tokens and calculates similarity metrics. |
| 99 | +
|
| 100 | + Args: |
| 101 | + converted_tokens: top tokens from the converted model. |
| 102 | + golden_tokens: top tokens from the golden model. |
| 103 | + """ |
| 104 | + # Extract the sets of token IDs for comparison |
| 105 | + converted_ids = {token["id"] for token in converted_tokens} |
| 106 | + golden_ids = {token["id"] for token in golden_tokens} |
| 107 | + |
| 108 | + # --- Metric 1: Overlap Count & Jaccard Similarity --- |
| 109 | + intersection = converted_ids.intersection(golden_ids) |
| 110 | + union = converted_ids.union(golden_ids) |
| 111 | + |
| 112 | + overlap_count = len(intersection) |
| 113 | + jaccard_similarity = overlap_count / len(union) if union else 0.0 |
| 114 | + |
| 115 | + # --- Metric 2: Rank Agreement --- |
| 116 | + rank_matches = 0 |
| 117 | + min_len = min(len(converted_tokens), len(golden_tokens)) |
| 118 | + for i in range(min_len): |
| 119 | + if converted_tokens[i]["id"] == golden_tokens[i]["id"]: |
| 120 | + rank_matches += 1 |
| 121 | + |
| 122 | + rank_agreement = (rank_matches / min_len) * 100 if min_len > 0 else 0.0 |
| 123 | + |
| 124 | + metrics = { |
| 125 | + "overlap_count": f"{overlap_count}/{min_len}", |
| 126 | + "jaccard_similarity": jaccard_similarity, |
| 127 | + "rank_agreement_percentage": rank_agreement, |
| 128 | + } |
| 129 | + |
| 130 | + max_logging.log("\n--- Similarity Metrics of Top Tokens ---") |
| 131 | + table = [[key, value] for key, value in metrics.items()] |
| 132 | + max_logging.log(tabulate(table, headers=["Metric", "Value"], tablefmt="orgtbl")) |
| 133 | + |
| 134 | + |
| 135 | +def check_kl_divergence(model_logits, golden_logits, atol=0.02): |
| 136 | + """ |
| 137 | + Calculates KL divergence D_KL(P_golden || Q_model) over a batch of sequences. |
| 138 | +
|
| 139 | + Args: |
| 140 | + model_logits: Logits from the converted model (Batch, SeqLen, VocabSize). |
| 141 | + golden_logits: Logits from the golden model (Batch, SeqLen, VocabSize). |
| 142 | + token_size: The number of vocabulary entries to consider for the comparison. |
| 143 | + (Effectively vocab_size_to_compare). |
| 144 | + """ |
| 145 | + # 1. Select the relevant vocabulary slice from the logits. |
| 146 | + token_size = min(model_logits.shape[2], golden_logits.shape[2]) |
| 147 | + model_logits_sliced = model_logits[..., :token_size] |
| 148 | + golden_logits_sliced = golden_logits[..., :token_size] |
| 149 | + |
| 150 | + # 2. Reshape |
| 151 | + b, s, v = model_logits_sliced.shape |
| 152 | + model_logits_reshaped = model_logits_sliced.view(b * s, v) |
| 153 | + golden_logits_reshaped = golden_logits_sliced.view(b * s, v) |
| 154 | + |
| 155 | + # 3. Get the probability distributions. |
| 156 | + golden_probabilities = F.softmax(golden_logits_reshaped, dim=-1) |
| 157 | + model_log_probabilities = F.log_softmax(model_logits_reshaped, dim=-1) |
| 158 | + |
| 159 | + # 4. Calculate avg KL divergence for all token distributions. |
| 160 | + # use 'batchmean'; the sum of the KL divergences for each token in the batch |
| 161 | + # and then divides by the number of tokens (b * s) |
| 162 | + kl_div_value = F.kl_div( |
| 163 | + input=model_log_probabilities, |
| 164 | + target=golden_probabilities, |
| 165 | + reduction="batchmean", # Use 'batchmean' for the average KL per token. |
| 166 | + log_target=False, |
| 167 | + ) |
| 168 | + |
| 169 | + max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.6f}") |
| 170 | + |
| 171 | + # To find the max KL divergence for any single token in the set |
| 172 | + # use reduction='none'. |
| 173 | + kl_divs_per_token = F.kl_div( |
| 174 | + input=model_log_probabilities, target=golden_probabilities, reduction="none", log_target=False |
| 175 | + ).sum( |
| 176 | + dim=-1 |
| 177 | + ) # Sum over the vocab dim to get a single KL value per token |
| 178 | + |
| 179 | + max_kl_div = kl_divs_per_token.max() |
| 180 | + max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.6f}") |
| 181 | + |
| 182 | + assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}" |
| 183 | + |
| 184 | + |
| 185 | +def run_prompts(args: argparse.Namespace) -> None: |
| 186 | + """ |
| 187 | + Args: |
| 188 | + - golden_model_id (str): HF model ID for the golden model. |
| 189 | + - hf_checkpoint_path (str): Path to the converted HF checkpoint. |
| 190 | + - max_kl_div (float): Maximum allowed KL divergence. |
| 191 | + """ |
| 192 | + golden_model = AutoModelForCausalLM.from_pretrained(args.golden_model_id, torch_dtype=torch.bfloat16) |
| 193 | + golden_tokenizer = AutoTokenizer.from_pretrained(args.golden_model_id) |
| 194 | + |
| 195 | + tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint_path) |
| 196 | + model, _ = AutoModelForCausalLM.from_pretrained( |
| 197 | + args.hf_checkpoint_path, trust_remote_code=True, torch_dtype=torch.bfloat16, output_loading_info=True |
| 198 | + ) |
| 199 | + |
| 200 | + # max_logging.log(loading_info) |
| 201 | + |
| 202 | + prompts = ["I love to", "Today is a", "What is the"] |
| 203 | + for input_text in prompts: |
| 204 | + max_logging.log(f"\n--- Prompt: {input_text} ---") |
| 205 | + inputs = tokenizer(input_text, return_tensors="pt") |
| 206 | + # --- Generate Output --- |
| 207 | + with torch.no_grad(): |
| 208 | + outputs = model.generate(**inputs, max_new_tokens=15, do_sample=False) |
| 209 | + # --- Decode and Print --- |
| 210 | + max_logging.log(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") |
| 211 | + |
| 212 | + # --- Compare tokens --- |
| 213 | + model_logits, golden_model_logits = get_logits(inputs, model, golden_model) |
| 214 | + tokens = get_top_k_tokens_scores(model_logits, tokenizer, k=10, description="converted model") |
| 215 | + golden_tokens = get_top_k_tokens_scores(golden_model_logits, golden_tokenizer, k=10, description="golden model") |
| 216 | + compare_top_tokens(converted_tokens=tokens, golden_tokens=golden_tokens) |
| 217 | + |
| 218 | + check_kl_divergence(model_logits, golden_model_logits, atol=args.max_kl_div) |
| 219 | + |
| 220 | + """ |
| 221 | + if the model's structure is exactly the same as the golden model (layers, vocab_size, etc.), |
| 222 | + you can check more weights details using the following steps: |
102 | 223 |
|
103 | | - # Run forward pass to get logits |
104 | | - logits, golden_logits = get_logits(inputs, model, golden_model) |
| 224 | + check_weights_match(model, golden_model) |
105 | 225 |
|
106 | 226 | # Check logits from the first 5 tokens match |
107 | | - print("########### check logits match ############### ") |
108 | | - check_arrays_match(logits[0, :5, :], golden_logits[0, :5, :], atol=0.2) |
| 227 | + check_arrays_match(model_logits[0, :5, :], golden_model_logits[0, :5, :], atol=0.2) |
109 | 228 |
|
110 | | - print("########### check predicted token match ############### ") |
111 | | - check_predicted_tokens_match(logits, golden_logits) |
| 229 | + check_predicted_tokens_match(model_logits, golden_model_logits) |
| 230 | + """ |
112 | 231 |
|
113 | 232 |
|
114 | 233 | if __name__ == "__main__": |
115 | | - app.run(main) |
| 234 | + parser = argparse.ArgumentParser(description="Verify HuggingFace checkpoints converted from MaxText.") |
| 235 | + parser.add_argument( |
| 236 | + "--golden_model_id", |
| 237 | + type=str, |
| 238 | + default="google/gemma-2-2b-it", |
| 239 | + help="The HuggingFace model ID for the golden/reference model.", |
| 240 | + ) |
| 241 | + parser.add_argument( |
| 242 | + "--hf_checkpoint_path", |
| 243 | + type=str, |
| 244 | + default=os.path.expanduser("~/.hf_output/"), |
| 245 | + help="Path to the converted HuggingFace checkpoint directory.", |
| 246 | + ) |
| 247 | + parser.add_argument("--max_kl_div", type=float, default=0.02, help="Maximum allowed KL divergence between model logits.") |
| 248 | + |
| 249 | + parsed_args = parser.parse_args() |
| 250 | + |
| 251 | + run_prompts(parsed_args) |
0 commit comments