4040
4141from absl import app
4242from absl import flags
43+ from flax .linen import partitioning as nn_partitioning
4344import jax
4445import transformers
4546
4647from MaxText import model_creation_utils
4748from MaxText import pyconfig
4849from MaxText .common_types import Config
50+ from MaxText .globals import MAXTEXT_PKG_DIR
4951from MaxText .integration .tunix .tunix_adapter import TunixMaxTextAdapter
5052from tunix .rl .rollout import base_rollout
5153from tunix .rl .rollout .vllm_rollout import VllmRollout
@@ -137,17 +139,17 @@ def decode_with_vllm(
137139 vllm_args ["hf_config_path" ] = hf_config_path
138140 vllm_args ["gpu_memory_utilization" ] = gpu_memory_utilization
139141
140- if load_parameters_path is None :
141- vllm_args ["load_format" ] = "dummy"
142-
143142 # Prepare MaxText and sharding configs (Parallelism is dynamic)
144143 vllm_args ["additional_config" ]["maxtext_config" ] = {
145144 "model_name" : model_name ,
146145 "max_target_length" : max_target_length ,
147146 "weight_dtype" : "bfloat16" ,
148147 "allow_split_physical_axes" : True ,
149- "load_parameters_path" : load_parameters_path ,
150148 }
149+ if load_parameters_path is not None :
150+ vllm_args ["additional_config" ]["maxtext_config" ]["load_parameters_path" ] = load_parameters_path
151+ else :
152+ vllm_args ["load_format" ] = "dummy"
151153
152154 vllm_args ["additional_config" ]["sharding" ] = {
153155 "sharding_strategy" : {
@@ -173,7 +175,13 @@ def decode_with_vllm(
173175 f"Initializing LLM with DP={ vllm_args ['data_parallel_size' ]} , TP={ vllm_args ['tensor_parallel_size' ]} "
174176 f"and EP={ ici_expert_parallelism if enable_expert_parallel else 0 } ..."
175177 )
176- llm = LLM (** vllm_args )
178+
179+ vllm_config_path = os .path .join (MAXTEXT_PKG_DIR , "configs" , "vllm.yml" )
180+ argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
181+ vllm_config = pyconfig .initialize (argv_list )
182+
183+ with nn_partitioning .axis_rules (vllm_config .logical_axis_rules ):
184+ llm = LLM (** vllm_args )
177185
178186 print ("Generating output..." )
179187 outputs = llm .generate ([prompt ], sampling_params )
0 commit comments