-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Labels
Description
Like this test in tests/test_vllm.py:
@torch.no_grad()
def test_invoker_group_batching(self, vllm_gpt2, ET_prompt: str, MSG_prompt: str):
"""Test complex invoker group batching."""
max_tokens_1 = 1
max_tokens_2 = 2
max_tokens_3 = 3
MSG_logits = list()
ET_logits = list()
two_prompts_logits = list()
all_logits = list()
with vllm_gpt2.trace() as tracer:
with tracer.invoke(MSG_prompt, max_tokens=max_tokens_1):
with tracer.iter[:]:
MSG_logits.append(vllm_gpt2.logits.output)
with tracer.invoke():
with tracer.all():
all_logits.append(vllm_gpt2.logits.output)
with tracer.invoke([ET_prompt, MSG_prompt], max_tokens=max_tokens_3):
with tracer.all():
two_prompts_logits.append(vllm_gpt2.logits.output)
with tracer.invoke(ET_prompt, max_tokens=max_tokens_2):
with tracer.iter[:]:
ET_logits.append(vllm_gpt2.logits.output)
# Each invoker has the correct number of logits
assert len(MSG_logits) == max_tokens_1
assert len(ET_logits) == max_tokens_2
assert len(two_prompts_logits) == max_tokens_3
assert len(all_logits) == max_tokens_3
# Check correctness of prompt-less invoker
assert (
all_logits[0].shape[0] == 4
and all_logits[1].shape[0] == 3
and all_logits[2].shape[0] == 2
)
# iter 0
assert torch.equal(all_logits[0][0], MSG_logits[0][0])
assert torch.equal(all_logits[0][1:3], two_prompts_logits[0][:2])
assert torch.equal(all_logits[0][3], ET_logits[0][0])
# iter 1
assert torch.equal(all_logits[1][0:2], two_prompts_logits[1])
assert torch.equal(all_logits[1][2], ET_logits[1][0])
# iter 2
assert torch.equal(all_logits[2], two_prompts_logits[2])