Skip to content

Commit 32dafa8

Browse files
committed
test(loadgen): add unit tests for worker concurrency distribution
Add a new test suite to validate how LoadGenerator splits concurrency across workers. The tests cover evenly divisible cases, remainder handling, and scenarios where concurrency is lower than the number of workers. This also fixes Python 3.9 compatibility by replacing Type | None with Optional[Type], and adds lightweight test-time mocks for asyncio.TaskGroup and typing.TypeAlias so the suite runs cleanly on Python versions below 3.11. Signed-off-by: Sathvik <Sathvik.S@ibm.com>
1 parent 02d3c1a commit 32dafa8

File tree

9 files changed

+125
-19
lines changed

9 files changed

+125
-19
lines changed

e2e/utils/llm_d_inference_sim.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import textwrap
66
import shutil
77
from contextlib import AsyncContextDecorator
8+
from typing import Optional
89

910

1011
logger = logging.getLogger(__name__)
@@ -24,7 +25,7 @@ def is_available(executable: str = "llm-d-inference-sim") -> bool:
2425

2526
_host = "127.0.0.1"
2627
_port: int
27-
_proc: asyncio.subprocess.Process | None = None
28+
_proc: "Optional[asyncio.subprocess.Process]" = None
2829
_wait_until_ready: bool
2930

3031
def __init__(
@@ -91,7 +92,7 @@ async def __aexit__(self, *exc):
9192
async def wait_until_ready(
9293
self,
9394
polling_sec: float = 0.5,
94-
timeout_sec: float | None = 10,
95+
timeout_sec: Optional[float] = 10,
9596
) -> None:
9697
"""Waits until the server is ready to serve requests."""
9798
assert self._proc

inference_perf/analysis/analyze.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import logging
1717
import operator
1818
from pathlib import Path
19-
from typing import Any, Dict, List, Tuple
19+
from typing import Any, Dict, List, Optional, Tuple
2020

2121
logger = logging.getLogger(__name__)
2222

2323

24-
def _extract_latency_metric(latency_data: Dict[str, Any], metric_name: str, convert_to_ms: bool = False) -> float | None:
24+
def _extract_latency_metric(latency_data: Dict[str, Any], metric_name: str, convert_to_ms: bool = False) -> Optional[float]:
2525
"""Helper to extract a metric's mean value from latency data."""
2626
metric_data = latency_data.get(metric_name)
2727
if isinstance(metric_data, dict):
@@ -31,7 +31,7 @@ def _extract_latency_metric(latency_data: Dict[str, Any], metric_name: str, conv
3131
return None
3232

3333

34-
def _extract_throughput_metric(throughput_data: Dict[str, Any], metric_name: str) -> float | None:
34+
def _extract_throughput_metric(throughput_data: Dict[str, Any], metric_name: str) -> Optional[float]:
3535
"""Helper to extract a throughput metric's value."""
3636
metric_value = throughput_data.get(metric_name)
3737
if isinstance(metric_value, (int, float)):

inference_perf/client/metricsclient/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ def __init__(self) -> None:
9797
pass
9898

9999
@abstractmethod
100-
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> ModelServerMetrics | None:
100+
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> Optional[ModelServerMetrics]:
101101
raise NotImplementedError
102102

103103
@abstractmethod
104-
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> ModelServerMetrics | None:
104+
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> Optional[ModelServerMetrics]:
105105
raise NotImplementedError
106106

107107
@abstractmethod

inference_perf/client/metricsclient/mock_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from .base import MetricsClient, PerfRuntimeParameters, ModelServerMetrics
15+
from typing import Optional
1516

1617

1718
class MockMetricsClient(MetricsClient):
1819
def __init__(self) -> None:
1920
pass
2021

21-
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> ModelServerMetrics | None:
22+
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> Optional[ModelServerMetrics]:
2223
return None
2324

24-
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> ModelServerMetrics | None:
25+
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> Optional[ModelServerMetrics]:
2526
return None
2627

2728
def wait(self) -> None:

inference_perf/client/metricsclient/prometheus_client/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from abc import abstractmethod
1515
import logging
1616
import time
17-
from typing import List, cast, Any
17+
from typing import List, cast, Any, Optional
1818
import requests
1919
from inference_perf.client.modelserver.base import ModelServerPrometheusMetric
2020
from inference_perf.config import PrometheusClientConfig
@@ -184,7 +184,7 @@ def wait(self) -> None:
184184
wait_time = self.scrape_interval + PROMETHEUS_SCRAPE_BUFFER_SEC
185185
time.sleep(wait_time)
186186

187-
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> ModelServerMetrics | None:
187+
def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) -> Optional[ModelServerMetrics]:
188188
"""
189189
Collects the summary metrics for the given Perf Benchmark run.
190190
@@ -204,7 +204,7 @@ def collect_metrics_summary(self, runtime_parameters: PerfRuntimeParameters) ->
204204

205205
return self.get_model_server_metrics(runtime_parameters.model_server_metrics, query_duration, query_eval_time)
206206

207-
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> ModelServerMetrics | None:
207+
def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, stage_id: int) -> Optional[ModelServerMetrics]:
208208
"""
209209
Collects the summary metrics for a specific stage.
210210
@@ -235,7 +235,7 @@ def collect_metrics_for_stage(self, runtime_parameters: PerfRuntimeParameters, s
235235

236236
def get_model_server_metrics(
237237
self, metrics_metadata: MetricsMetadata, query_duration: float, query_eval_time: float
238-
) -> ModelServerMetrics | None:
238+
) -> Optional[ModelServerMetrics]:
239239
"""
240240
Collects the summary metrics for the given Model Server Client and query duration.
241241

inference_perf/client/modelserver/openai_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class openAIModelServerClient(ModelServerClient):
34-
_session: "openAIModelServerClientSession | None" = None
34+
_session: "Optional[openAIModelServerClientSession]" = None
3535
_session_lock = asyncio.Lock()
3636

3737
def __init__(

inference_perf/client/requestdatacollector/multiprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from asyncio import get_event_loop, create_task
1818
from contextlib import asynccontextmanager
19-
from typing import AsyncIterator
19+
from typing import AsyncIterator, Optional
2020
from functools import partial
2121
import logging
2222
from inference_perf.client.requestdatacollector import RequestDataCollector
@@ -30,7 +30,7 @@ class MultiprocessRequestDataCollector(RequestDataCollector):
3030
"""Responsible for accumulating client request metrics"""
3131

3232
def __init__(self) -> None:
33-
self.queue: "mp.JoinableQueue[RequestLifecycleMetric | None]" = mp.JoinableQueue()
33+
self.queue: "mp.JoinableQueue[Optional[RequestLifecycleMetric]]" = mp.JoinableQueue()
3434

3535
def record_metric(self, metric: RequestLifecycleMetric) -> None:
3636
self.queue.put(metric)

inference_perf/loadgen/load_generator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,29 @@
3131
from asyncio import (
3232
CancelledError,
3333
Semaphore,
34-
TaskGroup,
3534
create_task,
3635
gather,
3736
run,
3837
sleep,
3938
set_event_loop_policy,
4039
get_event_loop,
4140
)
42-
from typing import List, Tuple, TypeAlias, Optional
41+
42+
try:
43+
from asyncio import TaskGroup
44+
except ImportError:
45+
# Python 3.9 compatibility: TaskGroup was added in 3.11
46+
# This is a dummy for import-time compatibility.
47+
# Runtime usage will still require Python 3.11+.
48+
TaskGroup = object
49+
50+
from typing import List, Tuple, Optional, Union
51+
try:
52+
from typing import TypeAlias
53+
except ImportError:
54+
# Python 3.9 compatibility: TypeAlias was added in 3.10
55+
from typing import Any
56+
TypeAlias = Any
4357
from types import FrameType
4458
import time
4559
import multiprocessing as mp
@@ -55,7 +69,7 @@
5569

5670
logger = logging.getLogger(__name__)
5771

58-
RequestQueueData: TypeAlias = Tuple[int, InferenceAPIData | int, float, Optional[str]]
72+
RequestQueueData: TypeAlias = Tuple[int, Union[InferenceAPIData, int], float, Optional[str]]
5973

6074

6175
class Worker(mp.Process):
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import unittest
2+
from unittest.mock import MagicMock
3+
import multiprocessing as mp
4+
import asyncio
5+
import typing
6+
7+
# Patch asyncio.TaskGroup for Python < 3.11
8+
if not hasattr(asyncio, 'TaskGroup'):
9+
class MockTaskGroup:
10+
async def __aenter__(self):
11+
return self
12+
async def __aexit__(self, exc_type, exc_val, exc_tb):
13+
pass
14+
def create_task(self, coro):
15+
return asyncio.create_task(coro)
16+
asyncio.TaskGroup = MockTaskGroup
17+
18+
# Patch typing.TypeAlias for Python < 3.10
19+
if not hasattr(typing, 'TypeAlias'):
20+
typing.TypeAlias = typing.Any
21+
22+
from inference_perf.loadgen.load_generator import LoadGenerator
23+
from inference_perf.config import LoadConfig, LoadType
24+
25+
class MockWorker:
26+
def __init__(self, id, shared_max_concurrency):
27+
self.id = id
28+
self.shared_max_concurrency = shared_max_concurrency
29+
30+
class TestLoadGeneratorConcurrency(unittest.TestCase):
31+
def setUp(self):
32+
self.mock_datagen = MagicMock()
33+
self.load_config = LoadConfig(
34+
type=LoadType.CONCURRENT,
35+
num_workers=4,
36+
worker_max_concurrency=100
37+
)
38+
# Mocking get_circuit_breaker since LoadGenerator init calls it
39+
with unittest.mock.patch('inference_perf.loadgen.load_generator.get_circuit_breaker'):
40+
self.load_generator = LoadGenerator(self.mock_datagen, self.load_config)
41+
42+
def test_set_worker_concurrency_divisible(self):
43+
# Setup workers
44+
self.load_generator.workers = []
45+
for i in range(4):
46+
shared_val = mp.Value('i', 0)
47+
self.load_generator.workers.append(MockWorker(i, shared_val))
48+
49+
# Test concurrency_level = 8 (8 / 4 = 2 per worker)
50+
self.load_generator._set_worker_concurrency(8)
51+
52+
for worker in self.load_generator.workers:
53+
self.assertEqual(worker.shared_max_concurrency.value, 2, f"Worker {worker.id} should have concurrency 2")
54+
55+
def test_set_worker_concurrency_remainder(self):
56+
# Setup workers
57+
self.load_generator.workers = []
58+
for i in range(4):
59+
shared_val = mp.Value('i', 0)
60+
self.load_generator.workers.append(MockWorker(i, shared_val))
61+
62+
# Test concurrency_level = 10 (10 // 4 = 2, 10 % 4 = 2)
63+
# Workers 0, 1 should have 3
64+
# Workers 2, 3 should have 2
65+
self.load_generator._set_worker_concurrency(10)
66+
67+
self.assertEqual(self.load_generator.workers[0].shared_max_concurrency.value, 3)
68+
self.assertEqual(self.load_generator.workers[1].shared_max_concurrency.value, 3)
69+
self.assertEqual(self.load_generator.workers[2].shared_max_concurrency.value, 2)
70+
self.assertEqual(self.load_generator.workers[3].shared_max_concurrency.value, 2)
71+
72+
def test_set_worker_concurrency_less_than_workers(self):
73+
# Setup workers
74+
self.load_generator.workers = []
75+
for i in range(4):
76+
shared_val = mp.Value('i', 0)
77+
self.load_generator.workers.append(MockWorker(i, shared_val))
78+
79+
# Test concurrency_level = 3
80+
# Workers 0, 1, 2 should have 1
81+
# Worker 3 should have 0
82+
self.load_generator._set_worker_concurrency(3)
83+
84+
self.assertEqual(self.load_generator.workers[0].shared_max_concurrency.value, 1)
85+
self.assertEqual(self.load_generator.workers[1].shared_max_concurrency.value, 1)
86+
self.assertEqual(self.load_generator.workers[2].shared_max_concurrency.value, 1)
87+
self.assertEqual(self.load_generator.workers[3].shared_max_concurrency.value, 0)
88+
89+
if __name__ == '__main__':
90+
unittest.main()

0 commit comments

Comments
 (0)