Skip to content

Empty invokes do not work in vLLM #590

@JadenFiotto-Kaufman

Description

@JadenFiotto-Kaufman

Like this test in tests/test_vllm.py:

 @torch.no_grad()
    def test_invoker_group_batching(self, vllm_gpt2, ET_prompt: str, MSG_prompt: str):
        """Test complex invoker group batching."""
        max_tokens_1 = 1
        max_tokens_2 = 2
        max_tokens_3 = 3

        MSG_logits = list()
        ET_logits = list()
        two_prompts_logits = list()
        all_logits = list()

        with vllm_gpt2.trace() as tracer:
            with tracer.invoke(MSG_prompt, max_tokens=max_tokens_1):
                with tracer.iter[:]:
                    MSG_logits.append(vllm_gpt2.logits.output)

            with tracer.invoke():
                with tracer.all():
                    all_logits.append(vllm_gpt2.logits.output)

            with tracer.invoke([ET_prompt, MSG_prompt], max_tokens=max_tokens_3):
                with tracer.all():
                    two_prompts_logits.append(vllm_gpt2.logits.output)

            with tracer.invoke(ET_prompt, max_tokens=max_tokens_2):
                with tracer.iter[:]:
                    ET_logits.append(vllm_gpt2.logits.output)

        # Each invoker has the correct number of logits
        assert len(MSG_logits) == max_tokens_1
        assert len(ET_logits) == max_tokens_2
        assert len(two_prompts_logits) == max_tokens_3
        assert len(all_logits) == max_tokens_3

        # Check correctness of prompt-less invoker
        assert (
            all_logits[0].shape[0] == 4
            and all_logits[1].shape[0] == 3
            and all_logits[2].shape[0] == 2
        )

        # iter 0
        assert torch.equal(all_logits[0][0], MSG_logits[0][0])
        assert torch.equal(all_logits[0][1:3], two_prompts_logits[0][:2])
        assert torch.equal(all_logits[0][3], ET_logits[0][0])

        # iter 1
        assert torch.equal(all_logits[1][0:2], two_prompts_logits[1])
        assert torch.equal(all_logits[1][2], ET_logits[1][0])

        # iter 2
        assert torch.equal(all_logits[2], two_prompts_logits[2])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions