Skip to content

[BUG: MistralTokenizer] AttributeError: 'MistralTokenizer' object has no attribute 'init_kwargs' #109

@mratsim

Description

@mratsim

Python -VV

from vllm collect_env.py

<details>
<summary>The output of <code>python collect_env.py</code></summary>


==============================
        System Info
==============================
OS                           : Ubuntu 24.04.1 LTS (x86_64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version                : Could not collect
CMake version                : version 3.28.3
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.7.0+cu128
Is debug build               : False
CUDA used to build PyTorch   : 12.8
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-6.15.2-arch1-1-x86_64-with-glibc2.39

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : False
CUDA runtime version         : 12.8.93
CUDA_MODULE_LOADING set to   : N/A
GPU models and configuration : Could not collect
Nvidia driver version        : Could not collect
cuDNN version                : Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  32
On-line CPU(s) list:                     0-31
Vendor ID:                               AuthenticAMD
Model name:                              AMD Ryzen 9 9950X 16-Core Processor
CPU family:                              26
Model:                                   68
Thread(s) per core:                      2
Core(s) per socket:                      16
Socket(s):                               1
Stepping:                                0
Frequency boost:                         enabled
CPU(s) scaling MHz:                      71%
CPU max MHz:                             5756.4521
CPU min MHz:                             624.1940
BogoMIPS:                                8583.31
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               768 KiB (16 instances)
L1i cache:                               512 KiB (16 instances)
L2 cache:                                16 MiB (16 instances)
L3 cache:                                64 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-31
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; IBPB on VMEXIT only
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsx async abort:           Not affected

==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-cufile-cu12==1.13.0.11
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-ml-py==12.575.51
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvshmem-cu12==3.2.5
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pynvml==12.0.0
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.0+cu128
[pip3] torchaudio==2.7.0+cu128
[pip3] torchvision==0.22.0+cu128
[pip3] transformers==4.52.4
[pip3] triton==3.3.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
Neuron SDK Version           : N/A
vLLM Version                 : 0.9.2.dev247+g0f9e7354f.d20250625 (git sha: 0f9e7354f, date: 20250625)
vLLM Build Flags:
  CUDA Archs: 12.0; ROCm: Disabled; Neuron: Disabled
GPU Topology:
  Could not collect

==============================
     Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=all
NVIDIA_REQUIRE_CUDA=cuda>=12.8 brand=unknown,driver>=470,driver<471 brand=grid,driver>=470,driver<471 brand=tesla,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=vapps,driver>=470,driver<471 brand=vpc,driver>=470,driver<471 brand=vcs,driver>=470,driver<471 brand=vws,driver>=470,driver<471 brand=cloudgaming,driver>=470,driver<471 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=560,driver<561 brand=grid,driver>=560,driver<561 brand=tesla,driver>=560,driver<561 brand=nvidia,driver>=560,driver<561 brand=quadro,driver>=560,driver<561 brand=quadrortx,driver>=560,driver<561 brand=nvidiartx,driver>=560,driver<561 brand=vapps,driver>=560,driver<561 brand=vpc,driver>=560,driver<561 brand=vcs,driver>=560,driver<561 brand=vws,driver>=560,driver<561 brand=cloudgaming,driver>=560,driver<561 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566
TORCH_CUDA_ARCH_LIST=12.0
NCCL_VERSION=2.25.1-1
CMAKE_BUILD_TYPE=Release
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVCC_THREADS=8
NVIDIA_PRODUCT_NAME=CUDA
NVIDIA_CPU_ONLY=1
CUDA_VERSION=12.8.1
MAX_JOBS=8
VLLM_FLASH_ATTN_VERSION=2
LD_LIBRARY_PATH=/usr/local/cuda/lib64
CUDA_HOME=/usr/local/cuda
CUDA_HOME=/usr/local/cuda
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1


</details>

Pip Freeze

.

Reproduction Steps

        vllm serve "${MODEL}" \
            --port "${VLLM_PORT}" \
            --trust-remote-code \
            --gpu-memory-utilization 0.95 \
            --served-model-name "${MODELNAME}" \
            --enable-prefix-caching \
            --enable-chunked-prefill \
            --max-model-len "${MODEL_LEN}" \
            --max_num_seqs 64 \
            --tokenizer_mode mistral \
            --generation-config "${MODEL}" \
            --enable-auto-tool-choice \
            --tool-call-parser mistral

Expected Behavior

The command is crashing with a incompatibility with huggingface/transformers

AttributeError: 'MistralTokenizer' object has no attribute 'init_kwargs'

The property is expected here: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/pixtral/processing_pixtral.py#L156

This is with mistral_common==1.6.2 and vllm commit 0f9e7354f508af3fe314cfb709babaaa668f1b04 built from source on 2025-06-25.

Tekkenizer v11 is used but it seems like the image handling has a different tokenizer.

Full backtrace:

.2-24B.w4a16-gptq', 'tokenizer_mode': 'mistral', 'trust_remote_code': True, 'max_model_len': 92500, 'served_model_name': ['mistral3.2-24b'], 'generation_config': '/workspace/local_models/Mistral-3.2-24B.w4a16-gptq', 'gpu_memory_utilization': 0.95, 'enable_prefix_caching': True, 'max_num_seqs': 64, 'enable_chunked_prefill': True}
INFO 06-25 15:47:53 [config.py:839] This model supports multiple tasks: {'reward', 'score', 'generate', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 06-25 15:47:53 [config.py:1453] Using max model len 92500
INFO 06-25 15:47:54 [config.py:2197] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 06-25 15:47:56 [__init__.py:244] Automatically detected platform cuda.
INFO 06-25 15:47:57 [core.py:459] Waiting for init message from front-end.
INFO 06-25 15:47:57 [core.py:69] Initializing a V1 LLM engine (v0.9.2.dev247+g0f9e7354f.d20250625) with config: model='/workspace/local_models/Mistral-3.2-24B.w4a16-gptq', speculative_config=None, tokenizer='/workspace/local_models/Mistral-3.2-24B.w4a16-gptq', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=92500, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=mistral3.2-24b, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
WARNING 06-25 15:47:57 [utils.py:2753] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7fdf7dde3c50>
INFO 06-25 15:47:58 [parallel_state.py:1072] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
ERROR 06-25 15:47:59 [core.py:519] EngineCore failed to start.
ERROR 06-25 15:47:59 [core.py:519] Traceback (most recent call last):
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/inputs/registry.py", line 169, in call_hf_processor
ERROR 06-25 15:47:59 [core.py:519]     output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
ERROR 06-25 15:47:59 [core.py:519]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/transformers/models/pixtral/processing_pixtral.py", line 156, in __call__
ERROR 06-25 15:47:59 [core.py:519]     tokenizer_init_kwargs=self.tokenizer.init_kwargs,
ERROR 06-25 15:47:59 [core.py:519]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519] AttributeError: 'MistralTokenizer' object has no attribute 'init_kwargs'
ERROR 06-25 15:47:59 [core.py:519] 
ERROR 06-25 15:47:59 [core.py:519] The above exception was the direct cause of the following exception:
ERROR 06-25 15:47:59 [core.py:519] 
ERROR 06-25 15:47:59 [core.py:519] Traceback (most recent call last):
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
ERROR 06-25 15:47:59 [core.py:519]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-25 15:47:59 [core.py:519]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/engine/core.py", line 394, in __init__
ERROR 06-25 15:47:59 [core.py:519]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/engine/core.py", line 75, in __init__
ERROR 06-25 15:47:59 [core.py:519]     self.model_executor = executor_class(vllm_config)
ERROR 06-25 15:47:59 [core.py:519]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__
ERROR 06-25 15:47:59 [core.py:519]     self._init_executor()
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor
ERROR 06-25 15:47:59 [core.py:519]     self.collective_rpc("init_device")
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-25 15:47:59 [core.py:519]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-25 15:47:59 [core.py:519]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/utils.py", line 2687, in run_method
ERROR 06-25 15:47:59 [core.py:519]     return func(*args, **kwargs)
ERROR 06-25 15:47:59 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device
ERROR 06-25 15:47:59 [core.py:519]     self.worker.init_device()  # type: ignore
ERROR 06-25 15:47:59 [core.py:519]     ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 165, in init_device
ERROR 06-25 15:47:59 [core.py:519]     self.model_runner: GPUModelRunner = GPUModelRunner(
ERROR 06-25 15:47:59 [core.py:519]                                         ^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/worker/gpu_model_runner.py", line 142, in __init__
ERROR 06-25 15:47:59 [core.py:519]     encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
ERROR 06-25 15:47:59 [core.py:519]                                                  ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 199, in compute_encoder_budget
ERROR 06-25 15:47:59 [core.py:519]     ) = _compute_encoder_budget_multimodal(
ERROR 06-25 15:47:59 [core.py:519]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 229, in _compute_encoder_budget_multimodal
ERROR 06-25 15:47:59 [core.py:519]     .get_max_tokens_per_item_by_nonzero_modality(model_config)
ERROR 06-25 15:47:59 [core.py:519]      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/registry.py", line 158, in get_max_tokens_per_item_by_nonzero_modality
ERROR 06-25 15:47:59 [core.py:519]     self.get_max_tokens_per_item_by_modality(model_config).items()
ERROR 06-25 15:47:59 [core.py:519]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/registry.py", line 132, in get_max_tokens_per_item_by_modality
ERROR 06-25 15:47:59 [core.py:519]     return profiler.get_mm_max_tokens(
ERROR 06-25 15:47:59 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/profiling.py", line 277, in get_mm_max_tokens
ERROR 06-25 15:47:59 [core.py:519]     mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
ERROR 06-25 15:47:59 [core.py:519]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/profiling.py", line 169, in _get_dummy_mm_inputs
ERROR 06-25 15:47:59 [core.py:519]     return self.processor.apply(
ERROR 06-25 15:47:59 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1808, in apply
ERROR 06-25 15:47:59 [core.py:519]     ) = self._cached_apply_hf_processor(
ERROR 06-25 15:47:59 [core.py:519]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1574, in _cached_apply_hf_processor
ERROR 06-25 15:47:59 [core.py:519]     ) = self._apply_hf_processor_main(
ERROR 06-25 15:47:59 [core.py:519]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1417, in _apply_hf_processor_main
ERROR 06-25 15:47:59 [core.py:519]     prompt_ids = self._apply_hf_processor_text_only(prompt)
ERROR 06-25 15:47:59 [core.py:519]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1342, in _apply_hf_processor_text_only
ERROR 06-25 15:47:59 [core.py:519]     prompt_ids, _, _ = self._apply_hf_processor_text_mm(
ERROR 06-25 15:47:59 [core.py:519]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1312, in _apply_hf_processor_text_mm
ERROR 06-25 15:47:59 [core.py:519]     processed_data = self._call_hf_processor(
ERROR 06-25 15:47:59 [core.py:519]                      ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/model_executor/models/mistral3.py", line 232, in _call_hf_processor
ERROR 06-25 15:47:59 [core.py:519]     processed_outputs = super()._call_hf_processor(
ERROR 06-25 15:47:59 [core.py:519]                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/multimodal/processing.py", line 1275, in _call_hf_processor
ERROR 06-25 15:47:59 [core.py:519]     return self.info.ctx.call_hf_processor(
ERROR 06-25 15:47:59 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-25 15:47:59 [core.py:519]   File "/workspace/vllm/vllm/inputs/registry.py", line 187, in call_hf_processor
ERROR 06-25 15:47:59 [core.py:519]     raise ValueError(msg) from exc
ERROR 06-25 15:47:59 [core.py:519] ValueError: Failed to apply PixtralProcessor on data={'text': '[IMG]'} with kwargs={}
Process EngineCore_0:
Traceback (most recent call last):
  File "/workspace/vllm/vllm/inputs/registry.py", line 169, in call_hf_processor
    output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/pixtral/processing_pixtral.py", line 156, in __call__
    tokenizer_init_kwargs=self.tokenizer.init_kwargs,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MistralTokenizer' object has no attribute 'init_kwargs'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 523, in run_engine_core
    raise e
  File "/workspace/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core.py", line 394, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/workspace/vllm/vllm/v1/engine/core.py", line 75, in __init__
    self.model_executor = executor_class(vllm_config)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__
    self._init_executor()
  File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor
    self.collective_rpc("init_device")
  File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/utils.py", line 2687, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device
    self.worker.init_device()  # type: ignore
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 165, in init_device
    self.model_runner: GPUModelRunner = GPUModelRunner(
                                        ^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/worker/gpu_model_runner.py", line 142, in __init__
    encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
                                                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 199, in compute_encoder_budget
    ) = _compute_encoder_budget_multimodal(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 229, in _compute_encoder_budget_multimodal
    .get_max_tokens_per_item_by_nonzero_modality(model_config)
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/registry.py", line 158, in get_max_tokens_per_item_by_nonzero_modality
    self.get_max_tokens_per_item_by_modality(model_config).items()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/registry.py", line 132, in get_max_tokens_per_item_by_modality
    return profiler.get_mm_max_tokens(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/profiling.py", line 277, in get_mm_max_tokens
    mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/profiling.py", line 169, in _get_dummy_mm_inputs
    return self.processor.apply(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1808, in apply
    ) = self._cached_apply_hf_processor(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1574, in _cached_apply_hf_processor
    ) = self._apply_hf_processor_main(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1417, in _apply_hf_processor_main
    prompt_ids = self._apply_hf_processor_text_only(prompt)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1342, in _apply_hf_processor_text_only
    prompt_ids, _, _ = self._apply_hf_processor_text_mm(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1312, in _apply_hf_processor_text_mm
    processed_data = self._call_hf_processor(
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/model_executor/models/mistral3.py", line 232, in _call_hf_processor
    processed_outputs = super()._call_hf_processor(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/multimodal/processing.py", line 1275, in _call_hf_processor
    return self.info.ctx.call_hf_processor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/inputs/registry.py", line 187, in call_hf_processor
    raise ValueError(msg) from exc
ValueError: Failed to apply PixtralProcessor on data={'text': '[IMG]'} with kwargs={}
[rank0]:[W625 15:47:59.697775105 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Traceback (most recent call last):
  File "/usr/local/bin/vllm", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/workspace/vllm/vllm/entrypoints/cli/main.py", line 65, in main
    args.dispatch_function(args)
  File "/workspace/vllm/vllm/entrypoints/cli/serve.py", line 55, in cmd
    uvloop.run(run_server(args))
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1325, in run_server
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1345, in run_server_worker
    async with build_async_engine_client(args, client_config) as engine_client:
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 155, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 191, in build_async_engine_client_from_engine_args
    async_llm = AsyncLLM.from_vllm_config(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/async_llm.py", line 162, in from_vllm_config
    return cls(
           ^^^^
  File "/workspace/vllm/vllm/v1/engine/async_llm.py", line 124, in __init__
    self.engine_core = EngineCoreClient.make_async_mp_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 93, in make_async_mp_client
    return AsyncMPClient(vllm_config, executor_class, log_stats,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 735, in __init__
    super().__init__(
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 433, in __init__
    self._init_engines_direct(vllm_config, local_only,
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 502, in _init_engines_direct
    self._wait_for_engine_startup(handshake_socket, input_address,
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 522, in _wait_for_engine_startup
    wait_for_engine_startup(
  File "/workspace/vllm/vllm/v1/utils.py", line 494, in wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

Additional Context

No response

Suggested Solutions

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions