Skip to content
Open
4 changes: 2 additions & 2 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
]


# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
# Per-prompt validation: mixed batched raw-string & tokenized prompts raise TypeError.
@pytest.mark.parametrize("invalid_input", [[["key", "cat"], [2, 3]]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use the original test case? Why does it need to be changed?

def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
Expand Down
21 changes: 11 additions & 10 deletions vllm/inputs/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ def parse_raw_prompts(

# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")

if len(first) == 0:
raise ValueError("Please provide at least one prompt")

# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)

prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
Expand Down