Skip to content

Commit ed49406

Browse files
Merge pull request #589 from Butanium/feat/vllm-token-input-compat
feat(vllm): accept token lists and HuggingFace tokenizer results
2 parents c00f98d + 0bc6c4b commit ed49406

File tree

3 files changed

+154
-6
lines changed

3 files changed

+154
-6
lines changed

src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.distributed.parallel_state import get_pp_group
44
from vllm.outputs import RequestOutput
55
from vllm.sequence import IntermediateTensors
6-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
6+
from vllm.tokenizers import cached_tokenizer_from_config
77
from vllm.v1.outputs import ModelRunnerOutput
88
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
99

@@ -186,7 +186,7 @@ def load_model(self, *args, **kwargs) -> None:
186186

187187
self.nnsight_model = VLLM(self.model)
188188

189-
self.nnsight_model.tokenizer = init_tokenizer_from_configs(self.model_config)
189+
self.nnsight_model.tokenizer = cached_tokenizer_from_config(self.model_config)
190190

191191
self.nnsight_model._interleaver.mediators = []
192192

src/nnsight/modeling/vllm/vllm.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
77
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
8-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
8+
from vllm.tokenizers import cached_tokenizer_from_config
9+
from vllm.inputs import TokensPrompt
910

1011
from vllm import LLM, envs
1112
from vllm.distributed import (
@@ -118,7 +119,9 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
118119

119120
_ROPE_DICT.clear()
120121

121-
self.tokenizer = init_tokenizer_from_configs(vllm_config.model_config)
122+
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
123+
if getattr(self.tokenizer, "pad_token", None) is None:
124+
self.tokenizer.pad_token = self.tokenizer.eos_token
122125

123126
return model
124127

@@ -150,8 +153,42 @@ def _prepare_input(
150153
params = []
151154

152155
for arg in args:
153-
154-
if not type(arg) is list:
156+
if arg == []:
157+
raise ValueError("Empty list of prompts is not allowed")
158+
159+
if type(arg) is dict:
160+
keys = set(arg.keys())
161+
if "input_ids" in keys and keys.issubset(
162+
{"input_ids", "attention_mask"}
163+
):
164+
# is hf tokenizer result
165+
batch_input_ids = arg["input_ids"]
166+
batch_attention_mask = arg.get("attention_mask", None)
167+
if isinstance(batch_input_ids, torch.Tensor):
168+
batch_input_ids = batch_input_ids.tolist()
169+
if isinstance(batch_attention_mask, torch.Tensor):
170+
batch_attention_mask = batch_attention_mask.tolist()
171+
if batch_input_ids == []:
172+
raise ValueError("Empty list of token ids is not allowed")
173+
if isinstance(batch_input_ids[0], int):
174+
# list of token ids
175+
batch_input_ids = [batch_input_ids]
176+
batch_attention_mask = [batch_attention_mask]
177+
178+
for input_ids, attention_mask in zip(
179+
batch_input_ids, batch_attention_mask
180+
):
181+
prompt = TokensPrompt(
182+
prompt_token_ids=[
183+
t for t, m in zip(input_ids, attention_mask) if m != 0
184+
]
185+
)
186+
prompts.append(prompt)
187+
params.append(NNsightSamplingParams(**kwargs))
188+
continue
189+
190+
if type(arg) is not list or isinstance(arg[0], int):
191+
# if arg is a list of ints (token ids), we also need to wrap it in a list
155192
arg = [arg]
156193

157194
for i, prompt in enumerate(arg):
@@ -163,6 +200,9 @@ def _prepare_input(
163200
if kwargs != {}:
164201
param.is_default_param = False
165202

203+
if type(prompt) is list and isinstance(prompt[0], int):
204+
prompt = TokensPrompt(prompt_token_ids=prompt)
205+
166206
prompts.append(prompt)
167207
params.append(param)
168208

tests/test_vllm.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,111 @@ def test_tensor_parallelism(self, tp, vllm_gpt2, ET_prompt: str):
393393
assert next_token != " Paris"
394394
assert hs.shape == torch.Size([11, 3072])
395395
assert torch.all(hs[:, 2000:] == 0)
396+
397+
398+
# =============================================================================
399+
# Token Input Compatibility
400+
# =============================================================================
401+
402+
403+
class TestTokenInputs:
404+
"""Tests for token ID and HuggingFace tokenizer input compatibility."""
405+
406+
@torch.no_grad()
407+
def test_single_token_list(self, vllm_gpt2, ET_prompt: str):
408+
"""Test passing a single list of token IDs."""
409+
token_ids = vllm_gpt2.tokenizer.encode(ET_prompt)
410+
411+
with vllm_gpt2.trace(token_ids, temperature=0.0, top_p=1):
412+
logits = vllm_gpt2.logits.output.save()
413+
414+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
415+
assert next_token == " Paris"
416+
417+
@torch.no_grad()
418+
def test_batched_token_lists(self, vllm_gpt2, ET_prompt: str, MSG_prompt: str):
419+
"""Test passing multiple lists of token IDs."""
420+
et_tokens = vllm_gpt2.tokenizer.encode(ET_prompt)
421+
msg_tokens = vllm_gpt2.tokenizer.encode(MSG_prompt)
422+
423+
with vllm_gpt2.trace([et_tokens, msg_tokens], temperature=0.0, top_p=1):
424+
logits = vllm_gpt2.logits.output.save()
425+
426+
assert logits.shape[0] == 2
427+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
428+
assert tokens == [" Paris", " New"]
429+
430+
@torch.no_grad()
431+
def test_hf_tokenizer_dict_single(self, vllm_gpt2, ET_prompt: str):
432+
"""Test passing HuggingFace tokenizer output dict for single prompt."""
433+
hf_output = vllm_gpt2.tokenizer(ET_prompt, return_tensors="pt")
434+
435+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
436+
logits = vllm_gpt2.logits.output.save()
437+
438+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
439+
assert next_token == " Paris"
440+
441+
@torch.no_grad()
442+
def test_hf_tokenizer_dict_batched(
443+
self, vllm_gpt2, ET_prompt: str, MSG_prompt: str
444+
):
445+
"""Test passing HuggingFace tokenizer output dict for batched prompts."""
446+
hf_output = vllm_gpt2.tokenizer(
447+
[ET_prompt, MSG_prompt], return_tensors="pt", padding=True
448+
)
449+
450+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
451+
logits = vllm_gpt2.logits.output.save()
452+
453+
assert logits.shape[0] == 2
454+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
455+
assert tokens == [" Paris", " New"]
456+
457+
@torch.no_grad()
458+
def test_hf_tokenizer_with_padding_mask(self, vllm_gpt2):
459+
"""Test that padding tokens are correctly filtered via attention_mask."""
460+
short_prompt = "Hello"
461+
long_prompt = "The Eiffel Tower is located in the city of"
462+
463+
hf_output = vllm_gpt2.tokenizer(
464+
[short_prompt, long_prompt], return_tensors="pt", padding=True
465+
)
466+
467+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
468+
logits = vllm_gpt2.logits.output.save()
469+
470+
assert logits.shape[0] == 2
471+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
472+
assert tokens[1] == " Paris"
473+
474+
@torch.no_grad()
475+
def test_token_list_in_invoker(self, vllm_gpt2, ET_prompt: str):
476+
"""Test token list input within an invoker."""
477+
token_ids = vllm_gpt2.tokenizer.encode(ET_prompt)
478+
479+
with vllm_gpt2.trace(temperature=0.0, top_p=1) as tracer:
480+
with tracer.invoke(token_ids):
481+
logits = vllm_gpt2.logits.output.save()
482+
483+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
484+
assert next_token == " Paris"
485+
486+
@torch.no_grad()
487+
def test_mixed_string_and_token_invokers(
488+
self, vllm_gpt2, ET_prompt: str, MSG_prompt: str
489+
):
490+
"""Test mixing string and token list inputs across invokers."""
491+
et_tokens = vllm_gpt2.tokenizer.encode(ET_prompt)
492+
493+
with vllm_gpt2.trace(temperature=0.0, top_p=1) as tracer:
494+
with tracer.invoke(et_tokens):
495+
et_logits = vllm_gpt2.logits.output.save()
496+
497+
with tracer.invoke(MSG_prompt):
498+
msg_logits = vllm_gpt2.logits.output.save()
499+
500+
et_token = vllm_gpt2.tokenizer.decode(et_logits.argmax(dim=-1))
501+
msg_token = vllm_gpt2.tokenizer.decode(msg_logits.argmax(dim=-1))
502+
assert et_token == " Paris"
503+
assert msg_token == " New"

0 commit comments

Comments
 (0)