Skip to content

Commit 9e32bf4

Browse files
Merge pull request AI-Hypercomputer#2596 from AI-Hypercomputer:shuningjin-ckpt-ds3
PiperOrigin-RevId: 828207772
2 parents 8d9588a + d0ce195 commit 9e32bf4

File tree

7 files changed

+464
-81
lines changed

7 files changed

+464
-81
lines changed

src/MaxText/scratch_code/generate_hf_golden_logits.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):
6060

6161

6262
def save_golden_logits(
63-
model_id, output_path, prompt_texts, apply_chat_template, gcs_bucket, hf_model_path, image_paths, output_format
63+
model_id,
64+
output_path,
65+
prompt_texts,
66+
apply_chat_template,
67+
gcs_bucket,
68+
hf_model_path,
69+
hf_load_dtype,
70+
trust_remote_code,
71+
image_paths,
72+
output_format,
6473
):
6574
"""save golden logits"""
6675
if hf_model_path is None:
@@ -77,10 +86,18 @@ def save_golden_logits(
7786

7887
tokenizer = AutoTokenizer.from_pretrained(model_id)
7988
print(f"loading model from {hf_model_path}")
89+
90+
if hf_load_dtype == "float32":
91+
torch_dtype = torch.float32
92+
elif hf_load_dtype == "bfloat16":
93+
torch_dtype = torch.bfloat16
94+
else:
95+
raise ValueError
96+
8097
model = model_class.from_pretrained(
8198
hf_model_path,
82-
torch_dtype=torch.float32,
83-
trust_remote_code=True,
99+
dtype=torch_dtype,
100+
trust_remote_code=trust_remote_code,
84101
)
85102

86103
all_data_to_save = []
@@ -110,7 +127,7 @@ def save_golden_logits(
110127
# 2. Run inference
111128
with torch.no_grad():
112129
outputs = model(**inputs)
113-
logits = outputs.logits.cpu().numpy().astype("float32")
130+
logits = outputs.logits.cpu().to(torch.float32).numpy()
114131

115132
# 3. Populate final data dictionary with tensors from inputs and logits
116133
for key, value in inputs.items():
@@ -159,6 +176,21 @@ def main(raw_args=None) -> None:
159176
parser.add_argument(
160177
"--hf-model-path", type=str, required=False, default=None, help="local path to checkpoint if exists."
161178
)
179+
parser.add_argument(
180+
"--hf-load-dtype",
181+
type=str,
182+
required=False,
183+
choices=["float32", "bfloat16"],
184+
default="float32",
185+
help="model_class.from_pretrained: dtype",
186+
)
187+
# variable `args.trust_remote_code` is True by default, False only if with flag `--not-trust-remote-code`
188+
parser.add_argument(
189+
"--not-trust-remote-code",
190+
dest="trust_remote_code",
191+
action="store_false",
192+
help="model_class.from_pretrained: trust_remote_code",
193+
)
162194
parser.add_argument(
163195
"--image-paths", type=str, required=False, default=None, help="A semicolon-separated list of image_paths."
164196
)
@@ -185,6 +217,8 @@ def main(raw_args=None) -> None:
185217
args.apply_chat_template,
186218
args.gcs_bucket,
187219
args.hf_model_path,
220+
args.hf_load_dtype,
221+
args.trust_remote_code,
188222
image_paths,
189223
args.output_format,
190224
)

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
import jax
5555
import os
5656
from typing import Sequence, Any
57+
import time
58+
from tqdm import tqdm
5759

5860
from transformers import AutoTokenizer, AutoProcessor
5961

@@ -72,7 +74,8 @@
7274
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
7375

7476

75-
jax.config.update("jax_platform_name", "cpu")
77+
os.environ["JAX_PLATFORMS"] = "cpu"
78+
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
7679

7780

7881
def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
@@ -114,17 +117,22 @@ def main(argv: Sequence[str]) -> None:
114117
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
115118
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
116119

120+
# Initialize maxtext config
117121
config = pyconfig.initialize(argv)
118122
assert (
119123
config.load_full_state_path == ""
120124
), "This script expects parameters, not a full state. Use generate_param_only_checkpoint first if needed."
121125
max_utils.print_system_information()
122126

127+
# Load Maxtext checkpoint
128+
max_logging.log("\nLoading Orbax checkpoint...")
129+
start = time.time()
123130
engine = maxengine.MaxEngine(config)
124131
rng = jax.random.PRNGKey(1234)
125132
rng, rng_load_params = jax.random.split(rng)
126133
# load params from maxengine
127134
loaded_params_from_engine = engine.load_params(rng_load_params)
135+
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
128136

129137
if not config.base_output_directory:
130138
output_directory = f"tmp/{config.run_name}"
@@ -165,18 +173,22 @@ def main(argv: Sequence[str]) -> None:
165173
leaves_with_paths = jax.tree_util.tree_leaves_with_path(actual_weights_dict)
166174

167175
# traverse leavse to build: mt_param_key:mt_weights
176+
max_logging.log("\nProccessing weight...")
177+
start = time.time()
168178
processed_params_list = []
169-
for path_tuple_iter, leaf_value_iter in leaves_with_paths:
170-
processed_params_list.extend(
171-
process_leaf_param(path_tuple_iter, leaf_value_iter, param_map, shape_map, hook_fn_map, config)
172-
)
179+
for path_tuple_iter, leaf_value_iter in tqdm(leaves_with_paths, total=len(leaves_with_paths)):
180+
processed_params = process_leaf_param(path_tuple_iter, leaf_value_iter, param_map, shape_map, hook_fn_map, config)
181+
processed_params_list.extend(processed_params)
173182
transformed_hf_weights = dict(processed_params_list)
183+
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
174184

175185
# 5. Save in HuggingFace Format
176186
if not transformed_hf_weights:
177187
print("Error: No weights were transformed. Check mappings and parameter paths.")
178188
return
179189

190+
max_logging.log("\nSaving HuggingFace model...")
191+
start = time.time()
180192
save_model_files(
181193
weight_arrays=transformed_hf_weights,
182194
config=hf_config_obj,
@@ -185,6 +197,7 @@ def main(argv: Sequence[str]) -> None:
185197
output_dir=output_directory,
186198
)
187199
max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}")
200+
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
188201

189202

190203
if __name__ == "__main__":

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,65 @@
469469
vocab_size=151936,
470470
)
471471

472+
deepseek3_671b_dict = {
473+
"architectures": ["DeepseekV3ForCausalLM"],
474+
"attention_bias": False,
475+
"attention_dropout": 0.0,
476+
"auto_map": {
477+
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
478+
"AutoModel": "modeling_deepseek.DeepseekV3Model",
479+
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM",
480+
},
481+
"bos_token_id": 0,
482+
"eos_token_id": 1,
483+
"ep_size": 1,
484+
"first_k_dense_replace": 3,
485+
"hidden_act": "silu",
486+
"hidden_size": 7168,
487+
"initializer_range": 0.02,
488+
"intermediate_size": 18432,
489+
"kv_lora_rank": 512,
490+
"max_position_embeddings": 163840,
491+
"model_type": "deepseek_v3",
492+
"moe_intermediate_size": 2048,
493+
"moe_layer_freq": 1,
494+
"n_group": 8,
495+
"n_routed_experts": 256,
496+
"n_shared_experts": 1,
497+
"norm_topk_prob": True,
498+
"num_attention_heads": 128,
499+
"num_experts_per_tok": 8,
500+
"num_hidden_layers": 61,
501+
"num_key_value_heads": 128,
502+
"num_nextn_predict_layers": 1,
503+
"q_lora_rank": 1536,
504+
"qk_nope_head_dim": 128,
505+
"qk_rope_head_dim": 64,
506+
"rms_norm_eps": 1e-06,
507+
"rope_scaling": {
508+
"beta_fast": 32,
509+
"beta_slow": 1,
510+
"factor": 40,
511+
"mscale": 1.0,
512+
"mscale_all_dim": 1.0,
513+
"original_max_position_embeddings": 4096,
514+
"type": "yarn",
515+
},
516+
"rope_theta": 10000,
517+
"routed_scaling_factor": 2.5,
518+
"scoring_func": "sigmoid",
519+
"tie_word_embeddings": False,
520+
"topk_group": 4,
521+
"topk_method": "noaux_tc",
522+
"torch_dtype": "bfloat16",
523+
"transformers_version": "4.33.1",
524+
"use_cache": True,
525+
"v_head_dim": 128,
526+
"vocab_size": 129280,
527+
}
528+
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)
529+
530+
# {maxtext model name: hf model config}
472531
qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
473532
# TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
474533
architectures=["Qwen3OmniMoeForConditionalGeneration"],
@@ -500,5 +559,6 @@
500559
"qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config,
501560
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
502561
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
562+
"deepseek3-671b": deepseek3_671b_config,
503563
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
504564
}

0 commit comments

Comments
 (0)