Skip to content

Commit f1d878d

Browse files
authored
Merge pull request #27 from SachinVarghese/vllm
Adding vLLM Client to inference perf runner
2 parents fc024d9 + f2c66e6 commit f1d878d

File tree

11 files changed

+131
-38
lines changed

11 files changed

+131
-38
lines changed

.github/workflows/format.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ jobs:
1313
steps:
1414
- name: Checkout Code
1515
uses: actions/checkout@v4
16-
- name: Set up Python
17-
uses: actions/setup-python@v5
18-
with:
19-
python-version: '3.13'
16+
- name: Set up Python PDM
17+
uses: pdm-project/setup-pdm@v4
2018
- name: Do Linting and Type Checks
2119
run: |
2220
make check

inference_perf/client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from .base import ModelServerClient
1515
from .mock_client import MockModelServerClient
16+
from .vllm_client import vLLMModelServerClient
1617

1718

18-
__all__ = ["ModelServerClient", "MockModelServerClient"]
19+
__all__ = ["ModelServerClient", "MockModelServerClient", "vLLMModelServerClient"]

inference_perf/client/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def set_report_generator(self, reportgen: ReportGenerator) -> None:
2727
self.reportgen = reportgen
2828

2929
@abstractmethod
30-
def process_request(self, data: InferenceData) -> None:
30+
async def process_request(self, data: InferenceData) -> None:
3131
raise NotImplementedError

inference_perf/client/mock_client.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from inference_perf.datagen import InferenceData
15-
from inference_perf.reportgen import ReportGenerator, Metric
15+
from inference_perf.reportgen import ReportGenerator, RequestMetric
1616
from .base import ModelServerClient
17+
import asyncio
1718

1819

1920
class MockModelServerClient(ModelServerClient):
20-
def __init__(self, uri: str) -> None:
21-
self.uri = uri
21+
def __init__(self) -> None:
22+
pass
2223

2324
def set_report_generator(self, reportgen: ReportGenerator) -> None:
2425
self.reportgen = reportgen
2526

26-
def process_request(self, data: InferenceData) -> None:
27+
async def process_request(self, data: InferenceData) -> None:
2728
print("Processing request - " + data.system_prompt)
28-
self.reportgen.collect_metrics(Metric(name=data.system_prompt))
29+
await asyncio.sleep(3)
30+
self.reportgen.collect_request_metrics(
31+
RequestMetric(
32+
prompt_tokens=0,
33+
output_tokens=0,
34+
time_per_request=3,
35+
)
36+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 The Kubernetes Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from inference_perf.datagen import InferenceData
15+
from inference_perf.reportgen import ReportGenerator, RequestMetric
16+
from .base import ModelServerClient
17+
from typing import Any
18+
import aiohttp
19+
import json
20+
import time
21+
22+
23+
class vLLMModelServerClient(ModelServerClient):
24+
def __init__(self, uri: str, model_name: str) -> None:
25+
self.model_name = model_name
26+
self.uri = uri + "/v1/completions"
27+
self.max_completion_tokens = 30
28+
29+
def set_report_generator(self, reportgen: ReportGenerator) -> None:
30+
self.reportgen = reportgen
31+
32+
def _createPayload(self, data: InferenceData) -> dict[str, Any]:
33+
return {"model": self.model_name, "prompt": data.system_prompt, "max_tokens": self.max_completion_tokens}
34+
35+
async def process_request(self, data: InferenceData) -> None:
36+
payload = self._createPayload(data)
37+
headers = {"Content-Type": "application/json"}
38+
async with aiohttp.ClientSession() as session:
39+
start = time.monotonic()
40+
try:
41+
async with session.post(self.uri, headers=headers, data=json.dumps(payload)) as response:
42+
if response.status == 200:
43+
content = await response.json()
44+
end = time.monotonic()
45+
usage = content["usage"]
46+
self.reportgen.collect_request_metrics(
47+
RequestMetric(
48+
prompt_tokens=usage["prompt_tokens"],
49+
output_tokens=usage["completion_tokens"],
50+
time_per_request=end - start,
51+
)
52+
)
53+
else:
54+
print(await response.text())
55+
except aiohttp.ClientConnectorError as e:
56+
print("vLLM Server connection error:\n", str(e))

inference_perf/loadgen/load_generator.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .load_timer import LoadTimer, ConstantLoadTimer, PoissonLoadTimer
1616
from inference_perf.datagen import DataGenerator
1717
from inference_perf.client import ModelServerClient
18+
from asyncio import TaskGroup, sleep
1819
import time
1920

2021

@@ -35,13 +36,20 @@ def __init__(self, datagen: DataGenerator, load_type: LoadType, rate: float, dur
3536
else:
3637
raise
3738

38-
def run(self, client: ModelServerClient) -> None:
39-
print("Run started")
39+
async def run(self, client: ModelServerClient) -> None:
4040
start_time = time.time()
4141
end_time = start_time + self.duration
42-
for _, (data, time_index) in enumerate(zip(self.datagen.get_data(), self.timer.start_timer(start_time), strict=True)):
43-
if time_index < end_time:
44-
client.process_request(data)
45-
else:
46-
print("Run complete")
47-
break
42+
print("Run started")
43+
async with TaskGroup() as tg:
44+
for _, (data, time_index) in enumerate(
45+
zip(self.datagen.get_data(), self.timer.start_timer(start_time), strict=True)
46+
):
47+
now = time.time()
48+
if time_index < end_time and now < end_time:
49+
if time_index > now:
50+
await sleep(time_index - time.time())
51+
tg.create_task(client.process_request(data))
52+
continue
53+
else:
54+
break
55+
print("Run completed")

inference_perf/loadgen/load_timer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def start_timer(self, initial: Optional[float] = None) -> Generator[float, None,
4848

4949
# Given a rate, yield a time to wait before the next request
5050
while True:
51-
next_time += self._rand.exponential(1 / self._rate)
51+
next_time += self._rand.uniform(0, 1 / self._rate)
5252
yield next_time
5353

5454

@@ -73,7 +73,6 @@ def start_timer(self, initial: Optional[float] = None) -> Generator[float, None,
7373

7474
# Schedule the requests over the next second
7575
timer = ConstantLoadTimer(req_count)
76-
times = timer.start_timer(next_time)
7776
for _ in range(req_count):
78-
next_time = next(times)
77+
next_time = next(timer.start_timer(next_time))
7978
yield next_time

inference_perf/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414
from inference_perf.loadgen import LoadGenerator, LoadType
1515
from inference_perf.datagen import MockDataGenerator
16-
from inference_perf.client import ModelServerClient, MockModelServerClient
16+
from inference_perf.client import ModelServerClient, vLLMModelServerClient
1717
from inference_perf.reportgen import ReportGenerator, MockReportGenerator
18+
import asyncio
1819

1920

2021
class InferencePerfRunner:
@@ -25,15 +26,15 @@ def __init__(self, client: ModelServerClient, loadgen: LoadGenerator, reportgen:
2526
self.client.set_report_generator(self.reportgen)
2627

2728
def run(self) -> None:
28-
self.loadgen.run(self.client)
29+
asyncio.run(self.loadgen.run(self.client))
2930

3031
def generate_report(self) -> None:
31-
self.reportgen.generate_report()
32+
asyncio.run(self.reportgen.generate_report())
3233

3334

3435
def main_cli() -> None:
3536
# Define Model Server Client
36-
client = MockModelServerClient(uri="0.0.0.0:0")
37+
client = vLLMModelServerClient(uri="http://0.0.0.0:8000", model_name="openai-community/gpt2")
3738

3839
# Define LoadGenerator
3940
loadgen = LoadGenerator(MockDataGenerator(), LoadType.CONSTANT, rate=2, duration=5)

inference_perf/reportgen/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import ReportGenerator, Metric
14+
from .base import ReportGenerator, RequestMetric
1515
from .mock_reportgen import MockReportGenerator
1616

1717

18-
__all__ = ["ReportGenerator", "Metric", "MockReportGenerator"]
18+
__all__ = ["ReportGenerator", "RequestMetric", "MockReportGenerator"]

inference_perf/reportgen/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@
1616
from typing import Tuple
1717

1818

19-
class Metric(BaseModel):
20-
name: str
19+
class MetricsSummary(BaseModel):
20+
total_requests: int
21+
avg_prompt_tokens: float
22+
avg_output_tokens: float
23+
avg_time_per_request: float
24+
25+
26+
class RequestMetric(BaseModel):
27+
prompt_tokens: int
28+
output_tokens: int
29+
time_per_request: float
2130

2231

2332
class ReportGenerator(ABC):
@@ -26,9 +35,9 @@ def __init__(self, *args: Tuple[int, ...]) -> None:
2635
pass
2736

2837
@abstractmethod
29-
def collect_metrics(self, metric: Metric) -> None:
38+
def collect_request_metrics(self, metric: RequestMetric) -> None:
3039
raise NotImplementedError
3140

3241
@abstractmethod
33-
def generate_report(self) -> None:
42+
async def generate_report(self) -> None:
3443
raise NotImplementedError

0 commit comments

Comments
 (0)