@@ -393,3 +393,111 @@ def test_tensor_parallelism(self, tp, vllm_gpt2, ET_prompt: str):
393393 assert next_token != " Paris"
394394 assert hs .shape == torch .Size ([11 , 3072 ])
395395 assert torch .all (hs [:, 2000 :] == 0 )
396+
397+
398+ # =============================================================================
399+ # Token Input Compatibility
400+ # =============================================================================
401+
402+
403+ class TestTokenInputs :
404+ """Tests for token ID and HuggingFace tokenizer input compatibility."""
405+
406+ @torch .no_grad ()
407+ def test_single_token_list (self , vllm_gpt2 , ET_prompt : str ):
408+ """Test passing a single list of token IDs."""
409+ token_ids = vllm_gpt2 .tokenizer .encode (ET_prompt )
410+
411+ with vllm_gpt2 .trace (token_ids , temperature = 0.0 , top_p = 1 ):
412+ logits = vllm_gpt2 .logits .output .save ()
413+
414+ next_token = vllm_gpt2 .tokenizer .decode (logits .argmax (dim = - 1 ))
415+ assert next_token == " Paris"
416+
417+ @torch .no_grad ()
418+ def test_batched_token_lists (self , vllm_gpt2 , ET_prompt : str , MSG_prompt : str ):
419+ """Test passing multiple lists of token IDs."""
420+ et_tokens = vllm_gpt2 .tokenizer .encode (ET_prompt )
421+ msg_tokens = vllm_gpt2 .tokenizer .encode (MSG_prompt )
422+
423+ with vllm_gpt2 .trace ([et_tokens , msg_tokens ], temperature = 0.0 , top_p = 1 ):
424+ logits = vllm_gpt2 .logits .output .save ()
425+
426+ assert logits .shape [0 ] == 2
427+ tokens = vllm_gpt2 .tokenizer .batch_decode (logits .argmax (dim = - 1 ))
428+ assert tokens == [" Paris" , " New" ]
429+
430+ @torch .no_grad ()
431+ def test_hf_tokenizer_dict_single (self , vllm_gpt2 , ET_prompt : str ):
432+ """Test passing HuggingFace tokenizer output dict for single prompt."""
433+ hf_output = vllm_gpt2 .tokenizer (ET_prompt , return_tensors = "pt" )
434+
435+ with vllm_gpt2 .trace (dict (hf_output ), temperature = 0.0 , top_p = 1 ):
436+ logits = vllm_gpt2 .logits .output .save ()
437+
438+ next_token = vllm_gpt2 .tokenizer .decode (logits .argmax (dim = - 1 ))
439+ assert next_token == " Paris"
440+
441+ @torch .no_grad ()
442+ def test_hf_tokenizer_dict_batched (
443+ self , vllm_gpt2 , ET_prompt : str , MSG_prompt : str
444+ ):
445+ """Test passing HuggingFace tokenizer output dict for batched prompts."""
446+ hf_output = vllm_gpt2 .tokenizer (
447+ [ET_prompt , MSG_prompt ], return_tensors = "pt" , padding = True
448+ )
449+
450+ with vllm_gpt2 .trace (dict (hf_output ), temperature = 0.0 , top_p = 1 ):
451+ logits = vllm_gpt2 .logits .output .save ()
452+
453+ assert logits .shape [0 ] == 2
454+ tokens = vllm_gpt2 .tokenizer .batch_decode (logits .argmax (dim = - 1 ))
455+ assert tokens == [" Paris" , " New" ]
456+
457+ @torch .no_grad ()
458+ def test_hf_tokenizer_with_padding_mask (self , vllm_gpt2 ):
459+ """Test that padding tokens are correctly filtered via attention_mask."""
460+ short_prompt = "Hello"
461+ long_prompt = "The Eiffel Tower is located in the city of"
462+
463+ hf_output = vllm_gpt2 .tokenizer (
464+ [short_prompt , long_prompt ], return_tensors = "pt" , padding = True
465+ )
466+
467+ with vllm_gpt2 .trace (dict (hf_output ), temperature = 0.0 , top_p = 1 ):
468+ logits = vllm_gpt2 .logits .output .save ()
469+
470+ assert logits .shape [0 ] == 2
471+ tokens = vllm_gpt2 .tokenizer .batch_decode (logits .argmax (dim = - 1 ))
472+ assert tokens [1 ] == " Paris"
473+
474+ @torch .no_grad ()
475+ def test_token_list_in_invoker (self , vllm_gpt2 , ET_prompt : str ):
476+ """Test token list input within an invoker."""
477+ token_ids = vllm_gpt2 .tokenizer .encode (ET_prompt )
478+
479+ with vllm_gpt2 .trace (temperature = 0.0 , top_p = 1 ) as tracer :
480+ with tracer .invoke (token_ids ):
481+ logits = vllm_gpt2 .logits .output .save ()
482+
483+ next_token = vllm_gpt2 .tokenizer .decode (logits .argmax (dim = - 1 ))
484+ assert next_token == " Paris"
485+
486+ @torch .no_grad ()
487+ def test_mixed_string_and_token_invokers (
488+ self , vllm_gpt2 , ET_prompt : str , MSG_prompt : str
489+ ):
490+ """Test mixing string and token list inputs across invokers."""
491+ et_tokens = vllm_gpt2 .tokenizer .encode (ET_prompt )
492+
493+ with vllm_gpt2 .trace (temperature = 0.0 , top_p = 1 ) as tracer :
494+ with tracer .invoke (et_tokens ):
495+ et_logits = vllm_gpt2 .logits .output .save ()
496+
497+ with tracer .invoke (MSG_prompt ):
498+ msg_logits = vllm_gpt2 .logits .output .save ()
499+
500+ et_token = vllm_gpt2 .tokenizer .decode (et_logits .argmax (dim = - 1 ))
501+ msg_token = vllm_gpt2 .tokenizer .decode (msg_logits .argmax (dim = - 1 ))
502+ assert et_token == " Paris"
503+ assert msg_token == " New"
0 commit comments