Skip to content

Commit 6646c0c

Browse files
authored
[Opt] Optimize deepstack buffer handling for multimodal Qwen3 models (vllm-project#40145)
Signed-off-by: xiaoming <1259730330@qq.com>
1 parent 95995bb commit 6646c0c

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

vllm/model_executor/models/qwen3_omni_moe_thinker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,6 +1753,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
17531753
)
17541754
for _ in range(self.deepstack_num_level)
17551755
]
1756+
# Tracks the valid token span currently stored in the buffer.
1757+
# Zero means there is no active deepstack payload to consume.
1758+
self.deepstack_input_embeds_num_tokens = 0
17561759

17571760
with self._mark_language_model(vllm_config):
17581761
self.language_model = Qwen3MoeLLMForCausalLM(
@@ -1773,6 +1776,13 @@ def _get_deepstack_input_embeds(
17731776
) -> IntermediateTensors | None:
17741777
if not getattr(self, "deepstack_input_embeds", None):
17751778
return None # If vision tower is skipped
1779+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1780+
return None
1781+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1782+
raise ValueError(
1783+
"Requested more deepstack tokens than available in buffer: "
1784+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1785+
)
17761786

17771787
# get deepstack_input_embeds from buffer, and clear the buffer
17781788
return IntermediateTensors(
@@ -1804,15 +1814,25 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
18041814
self.deepstack_input_embeds[idx][:num_tokens].copy_(
18051815
deepstack_input_embeds[idx]
18061816
)
1817+
self.deepstack_input_embeds_num_tokens = num_tokens
18071818

18081819
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
18091820
if not getattr(self, "deepstack_input_embeds", None):
18101821
return
1822+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1823+
return
18111824

18121825
# clear deepstack_input_embeds in buffer
18131826
if num_tokens > 0:
1827+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1828+
raise ValueError(
1829+
"Requested to clear more deepstack tokens than available in "
1830+
"buffer: "
1831+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1832+
)
18141833
for idx in range(self.deepstack_num_level):
18151834
self.deepstack_input_embeds[idx][:num_tokens].zero_()
1835+
self.deepstack_input_embeds_num_tokens = 0
18161836

18171837
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
18181838
mm_input_by_modality = {}

vllm/model_executor/models/qwen3_vl.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
16751675
)
16761676
for _ in range(self.deepstack_num_level)
16771677
]
1678+
# Tracks the valid token span currently stored in the buffer.
1679+
# Zero means there is no active deepstack payload to consume.
1680+
self.deepstack_input_embeds_num_tokens = 0
16781681

16791682
with self._mark_language_model(vllm_config):
16801683
self.language_model = Qwen3LLMForCausalLM(
@@ -1702,6 +1705,13 @@ def _get_deepstack_input_embeds(
17021705
) -> IntermediateTensors | None:
17031706
if not getattr(self, "deepstack_input_embeds", None):
17041707
return None # If vision tower is skipped
1708+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1709+
return None
1710+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1711+
raise ValueError(
1712+
"Requested more deepstack tokens than available in buffer: "
1713+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1714+
)
17051715

17061716
# get deepstack_input_embeds from buffer, and clear the buffer
17071717
return IntermediateTensors(
@@ -1733,15 +1743,25 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
17331743
self.deepstack_input_embeds[idx][:num_tokens].copy_(
17341744
deepstack_input_embeds[idx]
17351745
)
1746+
self.deepstack_input_embeds_num_tokens = num_tokens
17361747

17371748
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
17381749
if not getattr(self, "deepstack_input_embeds", None):
17391750
return
1751+
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
1752+
return
17401753

17411754
# clear deepstack_input_embeds in buffer
17421755
if num_tokens > 0:
1756+
if num_tokens > self.deepstack_input_embeds_num_tokens:
1757+
raise ValueError(
1758+
"Requested to clear more deepstack tokens than available in "
1759+
"buffer: "
1760+
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
1761+
)
17431762
for idx in range(self.deepstack_num_level):
17441763
self.deepstack_input_embeds[idx][:num_tokens].zero_()
1764+
self.deepstack_input_embeds_num_tokens = 0
17451765

17461766
# -- SupportsEncoderCudaGraph protocol methods --
17471767

0 commit comments

Comments
 (0)