Skip to content

Commit eef1026

Browse files
Merge branch 'main' into dev
2 parents 995a319 + 7ab8470 commit eef1026

File tree

6 files changed

+202
-20
lines changed

6 files changed

+202
-20
lines changed

llms.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ with model.trace("Hello"):
10261026
# Layer 1's output now equals layer 0's output
10271027
layer1_out = model.transformer.h[1].output.save()
10281028

1029-
assert torch.equal(layer0_out.value[0], layer1_out[0])
1029+
assert torch.equal(layer0_out[0], layer1_out[0])
10301030
```
10311031

10321032
### Skipping Constraints

src/nnsight/modeling/language.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,26 +207,24 @@ def _tokenize(
207207

208208
def _prepare_input(
209209
self,
210-
*inputs: Tuple[
211-
Union[
212-
str,
213-
List[str],
214-
List[List[str]],
215-
List[int],
216-
List[List[int]],
217-
torch.Tensor,
218-
List[torch.Tensor],
219-
Dict[str, Any],
220-
BatchEncoding,
221-
]
210+
*inputs: Union[
211+
str,
212+
List[str],
213+
List[List[str]],
214+
List[int],
215+
List[List[int]],
216+
torch.Tensor,
217+
List[torch.Tensor],
218+
Dict[str, Any],
219+
BatchEncoding,
222220
],
223221
input_ids: Union[
224222
List[int], List[List[int]], torch.Tensor, List[torch.Tensor]
225223
] = None,
226224
labels: Any = None,
227225
attention_mask: Any = None,
228226
**kwargs,
229-
) -> Tuple[BatchEncoding, int]:
227+
) -> Tuple[Tuple[()], Dict[str, Any]]:
230228

231229
if input_ids is not None:
232230

src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.distributed.parallel_state import get_pp_group
44
from vllm.outputs import RequestOutput
55
from vllm.sequence import IntermediateTensors
6-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
6+
from vllm.tokenizers import cached_tokenizer_from_config
77
from vllm.v1.outputs import ModelRunnerOutput
88
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
99

@@ -186,7 +186,7 @@ def load_model(self, *args, **kwargs) -> None:
186186

187187
self.nnsight_model = VLLM(self.model)
188188

189-
self.nnsight_model.tokenizer = init_tokenizer_from_configs(self.model_config)
189+
self.nnsight_model.tokenizer = cached_tokenizer_from_config(self.model_config)
190190

191191
self.nnsight_model._interleaver.mediators = []
192192

src/nnsight/modeling/vllm/vllm.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
77
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
8-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
8+
from vllm.tokenizers import cached_tokenizer_from_config
9+
from vllm.inputs import TokensPrompt
910

1011
from vllm import LLM, envs
1112
from vllm.distributed import (
@@ -18,6 +19,7 @@
1819
from vllm.entrypoints.llm import LLM
1920

2021
from ...intervention.envoy import Envoy
22+
from ...intervention.tracing.tracer import ScanningTracer
2123
from ...intervention.tracing.util import push_variables
2224
from ...util import WrapperModule
2325
from ..mixins import RemoteableMixin
@@ -118,7 +120,9 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
118120

119121
_ROPE_DICT.clear()
120122

121-
self.tokenizer = init_tokenizer_from_configs(vllm_config.model_config)
123+
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
124+
if getattr(self.tokenizer, "pad_token", None) is None:
125+
self.tokenizer.pad_token = self.tokenizer.eos_token
122126

123127
return model
124128

@@ -150,8 +154,42 @@ def _prepare_input(
150154
params = []
151155

152156
for arg in args:
153-
154-
if not type(arg) is list:
157+
if arg == []:
158+
raise ValueError("Empty list of prompts is not allowed")
159+
160+
if type(arg) is dict:
161+
keys = set(arg.keys())
162+
if "input_ids" in keys and keys.issubset(
163+
{"input_ids", "attention_mask"}
164+
):
165+
# is hf tokenizer result
166+
batch_input_ids = arg["input_ids"]
167+
batch_attention_mask = arg.get("attention_mask", None)
168+
if isinstance(batch_input_ids, torch.Tensor):
169+
batch_input_ids = batch_input_ids.tolist()
170+
if isinstance(batch_attention_mask, torch.Tensor):
171+
batch_attention_mask = batch_attention_mask.tolist()
172+
if batch_input_ids == []:
173+
raise ValueError("Empty list of token ids is not allowed")
174+
if isinstance(batch_input_ids[0], int):
175+
# list of token ids
176+
batch_input_ids = [batch_input_ids]
177+
batch_attention_mask = [batch_attention_mask]
178+
179+
for input_ids, attention_mask in zip(
180+
batch_input_ids, batch_attention_mask
181+
):
182+
prompt = TokensPrompt(
183+
prompt_token_ids=[
184+
t for t, m in zip(input_ids, attention_mask) if m != 0
185+
]
186+
)
187+
prompts.append(prompt)
188+
params.append(NNsightSamplingParams(**kwargs))
189+
continue
190+
191+
if type(arg) is not list or isinstance(arg[0], int):
192+
# if arg is a list of ints (token ids), we also need to wrap it in a list
155193
arg = [arg]
156194

157195
for i, prompt in enumerate(arg):
@@ -163,6 +201,9 @@ def _prepare_input(
163201
if kwargs != {}:
164202
param.is_default_param = False
165203

204+
if type(prompt) is list and isinstance(prompt[0], int):
205+
prompt = TokensPrompt(prompt_token_ids=prompt)
206+
166207
prompts.append(prompt)
167208
params.append(param)
168209

@@ -248,6 +289,9 @@ def __call__(
248289
push_variables(self._interleaver.mediators[0].info.frame, saves)
249290

250291
def interleave(self, fn: Callable, *args, **kwargs):
292+
"""Execute the traced function with vLLM, dispatching the engine if needed."""
293+
if not self.dispatched and not isinstance(self._interleaver.tracer, ScanningTracer):
294+
self.dispatch()
251295

252296
try:
253297
fn(*args, **kwargs)

tests/test_vllm.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,111 @@ def test_tensor_parallelism(self, tp, vllm_gpt2, ET_prompt: str):
393393
assert next_token != " Paris"
394394
assert hs.shape == torch.Size([11, 3072])
395395
assert torch.all(hs[:, 2000:] == 0)
396+
397+
398+
# =============================================================================
399+
# Token Input Compatibility
400+
# =============================================================================
401+
402+
403+
class TestTokenInputs:
404+
"""Tests for token ID and HuggingFace tokenizer input compatibility."""
405+
406+
@torch.no_grad()
407+
def test_single_token_list(self, vllm_gpt2, ET_prompt: str):
408+
"""Test passing a single list of token IDs."""
409+
token_ids = vllm_gpt2.tokenizer.encode(ET_prompt)
410+
411+
with vllm_gpt2.trace(token_ids, temperature=0.0, top_p=1):
412+
logits = vllm_gpt2.logits.output.save()
413+
414+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
415+
assert next_token == " Paris"
416+
417+
@torch.no_grad()
418+
def test_batched_token_lists(self, vllm_gpt2, ET_prompt: str, MSG_prompt: str):
419+
"""Test passing multiple lists of token IDs."""
420+
et_tokens = vllm_gpt2.tokenizer.encode(ET_prompt)
421+
msg_tokens = vllm_gpt2.tokenizer.encode(MSG_prompt)
422+
423+
with vllm_gpt2.trace([et_tokens, msg_tokens], temperature=0.0, top_p=1):
424+
logits = vllm_gpt2.logits.output.save()
425+
426+
assert logits.shape[0] == 2
427+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
428+
assert tokens == [" Paris", " New"]
429+
430+
@torch.no_grad()
431+
def test_hf_tokenizer_dict_single(self, vllm_gpt2, ET_prompt: str):
432+
"""Test passing HuggingFace tokenizer output dict for single prompt."""
433+
hf_output = vllm_gpt2.tokenizer(ET_prompt, return_tensors="pt")
434+
435+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
436+
logits = vllm_gpt2.logits.output.save()
437+
438+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
439+
assert next_token == " Paris"
440+
441+
@torch.no_grad()
442+
def test_hf_tokenizer_dict_batched(
443+
self, vllm_gpt2, ET_prompt: str, MSG_prompt: str
444+
):
445+
"""Test passing HuggingFace tokenizer output dict for batched prompts."""
446+
hf_output = vllm_gpt2.tokenizer(
447+
[ET_prompt, MSG_prompt], return_tensors="pt", padding=True
448+
)
449+
450+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
451+
logits = vllm_gpt2.logits.output.save()
452+
453+
assert logits.shape[0] == 2
454+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
455+
assert tokens == [" Paris", " New"]
456+
457+
@torch.no_grad()
458+
def test_hf_tokenizer_with_padding_mask(self, vllm_gpt2):
459+
"""Test that padding tokens are correctly filtered via attention_mask."""
460+
short_prompt = "Hello"
461+
long_prompt = "The Eiffel Tower is located in the city of"
462+
463+
hf_output = vllm_gpt2.tokenizer(
464+
[short_prompt, long_prompt], return_tensors="pt", padding=True
465+
)
466+
467+
with vllm_gpt2.trace(dict(hf_output), temperature=0.0, top_p=1):
468+
logits = vllm_gpt2.logits.output.save()
469+
470+
assert logits.shape[0] == 2
471+
tokens = vllm_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1))
472+
assert tokens[1] == " Paris"
473+
474+
@torch.no_grad()
475+
def test_token_list_in_invoker(self, vllm_gpt2, ET_prompt: str):
476+
"""Test token list input within an invoker."""
477+
token_ids = vllm_gpt2.tokenizer.encode(ET_prompt)
478+
479+
with vllm_gpt2.trace(temperature=0.0, top_p=1) as tracer:
480+
with tracer.invoke(token_ids):
481+
logits = vllm_gpt2.logits.output.save()
482+
483+
next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
484+
assert next_token == " Paris"
485+
486+
@torch.no_grad()
487+
def test_mixed_string_and_token_invokers(
488+
self, vllm_gpt2, ET_prompt: str, MSG_prompt: str
489+
):
490+
"""Test mixing string and token list inputs across invokers."""
491+
et_tokens = vllm_gpt2.tokenizer.encode(ET_prompt)
492+
493+
with vllm_gpt2.trace(temperature=0.0, top_p=1) as tracer:
494+
with tracer.invoke(et_tokens):
495+
et_logits = vllm_gpt2.logits.output.save()
496+
497+
with tracer.invoke(MSG_prompt):
498+
msg_logits = vllm_gpt2.logits.output.save()
499+
500+
et_token = vllm_gpt2.tokenizer.decode(et_logits.argmax(dim=-1))
501+
msg_token = vllm_gpt2.tokenizer.decode(msg_logits.argmax(dim=-1))
502+
assert et_token == " Paris"
503+
assert msg_token == " New"

tests/test_vllm_dispatch_bug.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Test for VLLM dispatch=False tracing bug."""
2+
import pytest
3+
import torch
4+
5+
try:
6+
from nnsight.modeling.vllm import VLLM
7+
except Exception as e:
8+
pytest.skip(f"Skipping VLLM tests: \n{e}", allow_module_level=True)
9+
10+
11+
@pytest.fixture(scope="module")
12+
def vllm_gpt2_no_dispatch():
13+
"""VLLM model initialized without dispatch=True."""
14+
return VLLM("gpt2", tensor_parallel_size=1, gpu_memory_utilization=0.1)
15+
16+
17+
@torch.no_grad()
18+
def test_trace_without_dispatch(vllm_gpt2_no_dispatch):
19+
"""Tracing should work even when dispatch=False at init time."""
20+
model = vllm_gpt2_no_dispatch
21+
22+
assert not model.dispatched, "Model should not be dispatched initially"
23+
assert model.vllm_entrypoint is None, "vllm_entrypoint should be None initially"
24+
25+
with model.trace("The Eiffel Tower is located in the city of", temperature=0.0, top_p=1):
26+
logits = model.logits.output.save()
27+
28+
assert model.dispatched, "Model should be dispatched after trace"
29+
assert model.vllm_entrypoint is not None, "vllm_entrypoint should exist after trace"
30+
31+
next_token = model.tokenizer.decode(logits.argmax(dim=-1))
32+
assert next_token == " Paris"

0 commit comments

Comments
 (0)