Skip to content

Commit 9d52020

Browse files
Merge pull request AI-Hypercomputer#2931 from AI-Hypercomputer:nicogrande/enable-dummy-weights-decode
PiperOrigin-RevId: 855441225
2 parents 5bc680b + 7649164 commit 9d52020

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/MaxText/vllm_decode.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@
4040

4141
from absl import app
4242
from absl import flags
43+
from flax.linen import partitioning as nn_partitioning
4344
import jax
4445
import transformers
4546

4647
from MaxText import model_creation_utils
4748
from MaxText import pyconfig
4849
from MaxText.common_types import Config
50+
from MaxText.globals import MAXTEXT_PKG_DIR
4951
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
5052
from tunix.rl.rollout import base_rollout
5153
from tunix.rl.rollout.vllm_rollout import VllmRollout
@@ -137,17 +139,17 @@ def decode_with_vllm(
137139
vllm_args["hf_config_path"] = hf_config_path
138140
vllm_args["gpu_memory_utilization"] = gpu_memory_utilization
139141

140-
if load_parameters_path is None:
141-
vllm_args["load_format"] = "dummy"
142-
143142
# Prepare MaxText and sharding configs (Parallelism is dynamic)
144143
vllm_args["additional_config"]["maxtext_config"] = {
145144
"model_name": model_name,
146145
"max_target_length": max_target_length,
147146
"weight_dtype": "bfloat16",
148147
"allow_split_physical_axes": True,
149-
"load_parameters_path": load_parameters_path,
150148
}
149+
if load_parameters_path is not None:
150+
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = load_parameters_path
151+
else:
152+
vllm_args["load_format"] = "dummy"
151153

152154
vllm_args["additional_config"]["sharding"] = {
153155
"sharding_strategy": {
@@ -173,7 +175,13 @@ def decode_with_vllm(
173175
f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} "
174176
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
175177
)
176-
llm = LLM(**vllm_args)
178+
179+
vllm_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
180+
argv_list = ["", str(vllm_config_path), "log_config=False"]
181+
vllm_config = pyconfig.initialize(argv_list)
182+
183+
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
184+
llm = LLM(**vllm_args)
177185

178186
print("Generating output...")
179187
outputs = llm.generate([prompt], sampling_params)

0 commit comments

Comments
 (0)