INFO 04-04 21:49:56 [__init__.py:59] TPU info: node_name=tpu-0 | tpu_type=v5litepod-16 | worker_id=3 | num_chips=4 | num_cores_per_chip=1
INFO 04-04 21:50:03 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 04-04 21:50:03 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 04-04 21:50:03 [interface.py:226] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
Check failed with unknown exit code: -6.
WARNING 04-04 21:50:05 [__init__.py:80] The quantization method 'awq' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.awq.VllmAWQConfig'>.
WARNING 04-04 21:50:05 [__init__.py:80] The quantization method 'compressed-tensors' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors.VllmCompressedTensorsConfig'>.
WARNING 04-04 21:50:05 [__init__.py:80] The quantization method 'fp8' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.fp8.VllmFp8Config'>.
WARNING 04-04 21:50:05 [__init__.py:80] The quantization method 'mxfp4' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.mxfp4.VllmMxfp4Config'>.
INFO 04-04 21:50:06 [__init__.py:31] Registering MaxTextForCausalLM model with tpu_inference and vllm.
INFO 04-04 21:50:06 [model_loader.py:588] Registered JAX model MaxTextForCausalLM with tpu_inference and vLLM registries.
INFO 04-04 21:50:06 [__init__.py:33] Successfully registered MaxTextForCausalLM model.
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292]
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292] █ █ █▄ ▄█
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292] ▄▄ ▄█ █ █ █ ▀▄▀ █ version 0.17.0rc1.dev136+gee8a29511
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292] █▄█▀ █ █ █ █ model google/gemma-4-26B-A4B-it
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292] ▀▀ ▀▀▀▀▀ ▀▀▀▀▀ ▀ ▀
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:292]
(APIServer pid=1355885) INFO 04-04 21:50:06 [utils.py:228] non-default args: {'model_tag': 'google/gemma-4-26B-A4B-it', 'model': 'google/gemma-4-26B-A4B-it', 'seed': 42, 'max_model_len': 5120, 'hf_overrides': {'architectures': ['MaxTextForCausalLM']}, 'tensor_parallel_size': 4, 'gpu_memory_utilization': 0.97, 'enable_prefix_caching': False, 'max_num_batched_tokens': 4096, 'max_num_seqs': 128, 'additional_config': {'maxtext_config': {'model_name': 'gemma4-26b', 'log_config': True, 'enable_dp_attention': True, 'load_parameters_path': '/dev/shm/gemma4-26b'}}}
(APIServer pid=1355885) INFO 04-04 21:50:07 [model.py:531] Resolved architecture: MaxTextForCausalLM
(APIServer pid=1355885) INFO 04-04 21:50:07 [model.py:1554] Using max model len 5120
(APIServer pid=1355885) INFO 04-04 21:50:07 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=4096.
(APIServer pid=1355885) INFO 04-04 21:50:07 [vllm.py:753] Asynchronous scheduling is enabled.
(APIServer pid=1355885) INFO 04-04 21:50:07 [tpu_platform.py:141] Initialized sharding configuration: ShardingConfigManager(total_devices=4, sharding_strategy=ShardingStrategy(tensor_parallelism=4, expert_parallelism=1, sequence_parallelism=1, data_parallelism=1, attention_data_parallelism=1, attention_data_expert_parallelism=1), device_indexes=None)
(APIServer pid=1355885) INFO 04-04 21:50:07 [__init__.py:108] Registered model loader `<class 'tpu_inference.models.vllm.vllm_model_loader.IncrementalModelLoader'>` with load format `tpu_streaming_loader`
(APIServer pid=1355885) WARNING 04-04 21:50:07 [__init__.py:97] Load format `runai_streamer` is already registered, and will be overwritten by the new loader class `<class 'tpu_inference.models.vllm.vllm_model_loader.RunaiIncrementalModelLoader'>`.
(APIServer pid=1355885) INFO 04-04 21:50:07 [__init__.py:108] Registered model loader `<class 'tpu_inference.models.vllm.vllm_model_loader.RunaiIncrementalModelLoader'>` with load format `runai_streamer`
(APIServer pid=1355885) INFO 04-04 21:50:07 [tpu_platform.py:182] Using KV cache block size: 256
(APIServer pid=1355885) INFO 04-04 21:50:07 [tpu_platform.py:193] Force using UniProcExecutor for JAX on single host without pipeline parallelism.
(APIServer pid=1355885) INFO 04-04 21:50:07 [compilation.py:286] Enabled custom fusions: norm_quant, act_quant
(APIServer pid=1355885) WARNING 04-04 21:50:10 [input_processor.py:80] The signature of Platform.validate_request has changed from `(cls, prompt, params, processed_inputs) -> None` to `(cls, processed_inputs, params) -> None`. The old signature will no longer be supported starting from v0.18.
(APIServer pid=1355885) WARNING 04-04 21:50:10 [tpu_platform.py:231] Pin memory is not supported on TPU.
INFO 04-04 21:50:14 [__init__.py:59] TPU info: node_name=tpu-0 | tpu_type=v5litepod-16 | worker_id=3 | num_chips=4 | num_cores_per_chip=1
INFO 04-04 21:50:22 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 04-04 21:50:22 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 04-04 21:50:22 [interface.py:226] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
Check failed with unknown exit code: -6.
(EngineCore_DP0 pid=1356450) WARNING 04-04 21:50:24 [__init__.py:80] The quantization method 'awq' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.awq.VllmAWQConfig'>.
(EngineCore_DP0 pid=1356450) WARNING 04-04 21:50:24 [__init__.py:80] The quantization method 'compressed-tensors' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors.VllmCompressedTensorsConfig'>.
(EngineCore_DP0 pid=1356450) WARNING 04-04 21:50:24 [__init__.py:80] The quantization method 'fp8' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.fp8.VllmFp8Config'>.
(EngineCore_DP0 pid=1356450) WARNING 04-04 21:50:24 [__init__.py:80] The quantization method 'mxfp4' already exists and will be overwritten by the quantization config <class 'tpu_inference.layers.vllm.quantization.mxfp4.VllmMxfp4Config'>.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:25 [__init__.py:31] Registering MaxTextForCausalLM model with tpu_inference and vllm.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:25 [model_loader.py:588] Registered JAX model MaxTextForCausalLM with tpu_inference and vLLM registries.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:25 [__init__.py:33] Successfully registered MaxTextForCausalLM model.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:25 [core.py:103] Initializing a V1 LLM engine (v0.17.0rc1.dev136+gee8a29511) with config: model='google/gemma-4-26B-A4B-it', speculative_config=None, tokenizer='google/gemma-4-26B-A4B-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=5120, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=42, served_model_name=google/gemma-4-26B-A4B-it, enable_prefix_caching=False, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': None, 'compile_ranges_split_points': [4096], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': None, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore_DP0 pid=1356450) WARNING 04-04 21:50:25 [tpu_platform.py:231] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:29 [parallel_state.py:1395] world_size=1 rank=0 local_rank=0 distributed_init_method=file:///tmp/tmpuk8sb45z backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:29 [parallel_state.py:1717] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank 0, EPLB rank N/A
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [tpu_runner.py:302] Init mesh | mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [utils.py:94] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [utils.py:60] Prepared request paddings: [8, 16, 32, 64, 128]
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [compilation_manager.py:52] Enabling JAX compile cache.
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [tpu_worker.py:279] Init worker | rank=0 | is_first_rank=True | is_last_rank=True | topology_order_id=0 | is_driver_worker=True | hbm=[(0.0, 15.75), (0.0, 15.75), (0.0, 15.75), (0.0, 15.75)]GiB |self.devices=[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] | total devices=[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] | local_devices=[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [model_loader.py:381] Loading model with MODEL_IMPL_TYPE=auto
(EngineCore_DP0 pid=1356450) INFO 04-04 21:50:30 [model_loader.py:384] Resolved MODEL_IMPL_TYPE 'auto' to 'flax_nnx'
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] EngineCore failed to start.
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] Traceback (most recent call last):
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 1085, in run_engine_core
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return func(*args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 843, in __init__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] super().__init__(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 112, in __init__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return func(*args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 103, in __init__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self._init_executor()
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 49, in _init_executor
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.driver_worker.load_model()
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/tpu_inference/worker/tpu_worker.py", line 391, in load_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.model_runner.load_model()
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/tpu_inference/runner/tpu_runner.py", line 536, in load_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.model_fn, self.compute_logits_fn, self.pooler_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/tpu_inference/models/common/model_loader.py", line 398, in get_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return get_flax_model(vllm_config, rng, mesh,
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/tpu_inference/models/common/model_loader.py", line 267, in get_flax_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/tpu_inference/models/common/model_loader.py", line 235, in _get_nnx_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] model.load_weights(rng)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/maxtext_vllm_adapter/adapter.py", line 254, in load_weights
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] model, _ = model_creation_utils.create_nnx_model(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/utils/model_creation_utils.py", line 246, in create_nnx_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] abstract_model = nnx.eval_shape(_create_model_partial)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 272, in eval_shape
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/nnx/transforms/transforms.py", line 269, in _eval_shape_fn
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] out = f_call(*args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/utils/model_creation_utils.py", line 241, in _create_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/utils/model_creation_utils.py", line 206, in from_config
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] model = create_model(config, mesh, model_mode=model_mode, rngs=rngs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/utils/model_creation_utils.py", line 224, in create_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] model = get_transformer_model(config, mesh, quant, model_mode=model_mode, rngs=rngs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/utils/model_creation_utils.py", line 215, in get_transformer_model
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return models.Transformer(config, mesh, quant=quant, rngs=rngs, model_mode=model_mode)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 400, in __call__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return _graph_node_meta_call(cls, *args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 411, in _graph_node_meta_call
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] cls._pytree_meta_construct(node, *args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.__init__(*args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/models/models.py", line 372, in __init__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] self.decoder.lazy_init(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/layers/nnx_wrappers.py", line 220, in lazy_init
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return lazy_init(self, *args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/layers/nnx_wrappers.py", line 162, in lazy_init
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] _set_initializing(module, False)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/layers/nnx_wrappers.py", line 253, in __call__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/layers/decoders.py", line 1061, in __call__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] y, returned_cache = layer(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/layers/nnx_wrappers.py", line 437, in __call__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] out = method_fn(module, *args, **kwargs)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/src/maxtext/models/gemma4.py", line 327, in __call__
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/linen/spmd.py", line 259, in with_logical_constraint
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return jax.tree_util.tree_map(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/linen/spmd.py", line 226, in _with_sharding_constraint_one_fallback
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return _with_sharding_constraint(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/flax/linen/spmd.py", line 203, in _with_sharding_constraint
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return lax.with_sharding_constraint(x, axis_resources)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/jax/_src/sharding_impls.py", line 1056, in cached_named_sharding
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] return NamedSharding(mesh, pspec, memory_kind=memory_kind)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/jax/_src/named_sharding.py", line 483, in check_pspec
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] _check_mesh_resource_axis(mesh, spec)
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] File "/home/nmilosev/maxtext/maxtext_venv/lib/python3.12/site-packages/jax/_src/named_sharding.py", line 538, in _check_mesh_resource_axis
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] raise ValueError(
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] ValueError: Resource axis: attn_dp of PartitionSpec('data', None, ('model', 'attn_dp')) is not found in mesh: ('data', 'model').
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] --------------------
(EngineCore_DP0 pid=1356450) ERROR 04-04 21:50:30 [core.py:1111] For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Bug report
Hi, I am having an issue running Gemma 4 on a TPU VM:
Following this guide: https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md
And trying to do online inference.
Converted weights with:
Startup script:
I realize that
enable_dp_attentionshould be set totrue, and I have done so, but VLLM doesn't seem to pick it up. I also tried adding--data-parallel-size 2, but no changes.MaxText was installed from source (with uv).
Any clues?
Thanks!
Logs/Output
Full log
Environment Information
TPU v5e (single host, 4 accelerators)
maxtext commit
9777a4cf9574f3d10c591e25450cea1b1dde7e01Full env
Additional Context
No response