Skip to content

Commit 1d579f6

Browse files
authored
Pass mp context to strategy (#651)
## Summary Fixes spawn and forkserver multi-process contexts. ## Details I was hoping that after #647 we could switch to `forkserver` by default. However it turns out that `forkserver` and `spawn` will import the calling processes entrypoint (E.g. `__main__.py`) so we run into the same blocker as #641. However, I was able to confirm that striping every heavy import out of `__main__.py` solves the issue. So we should be good to switch in v0.7.0. On my machine there is about a ~10s overhead for `forkserver` and slightly more for `spawn`, which is not the worst for a default. However, the overhead may be more on other systems: ### `time guidellm benchmark run --profile poisson --rate 5 --data prompt_tokens=128,output_tokens=128 --max-seconds 30 --outputs json` | Context | real | user | sys | | ---------- | --------- | --------- | -------- | | Fork | 0m37.874s | 0m17.356s | 0m1.883s | | Forkserver | 0m47.344s | 0m14.862s | 0m0.860s | | Spawn | 0m49.515s | 1m51.230s | 0m8.915s | ### `time guidellm benchmark run --profile concurrent --rate 400 --data prompt_tokens=128,output_tokens=128 --max-seconds 30 --outputs json` | Context | real | user | sys | | ---------- | --------- | --------- | --------- | | Fork | 0m39.324s | 0m37.602s | 0m5.623s | | Forkserver | 0m49.609s | 0m19.710s | 0m1.311s | | Spawn | 0m50.399s | 2m9.724s | 0m11.374s | ### `time guidellm benchmark run --profile concurrent --rate 400 --data prompt_tokens=128,output_tokens=128 --max-seconds 120 --outputs json` | Context | real | user | sys | | ---------- | --------- | --------- | --------- | | Fork | 2m15.309s | 1m42.911s | 0m15.957s | | Forkserver | 2m25.964s | 0m38.891s | 0m2.802s | | Spawn | 2m27.454s | 3m24.325s | 0m22.531s | ## Test Plan Set `GUIDELLM__MP_CONTEXT_TYPE=forkserver` and confirm benchmarks run. --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents 5e1f06b + 3acacc5 commit 1d579f6

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

src/guidellm/scheduler/strategies.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import math
2020
import random
2121
from abc import abstractmethod
22-
from multiprocessing import Event, Value, synchronize
22+
from multiprocessing import synchronize
23+
from multiprocessing.context import BaseContext
2324
from multiprocessing.sharedctypes import Synchronized
2425
from typing import Annotated, ClassVar, Literal, TypeVar
2526

@@ -103,7 +104,10 @@ def requests_limit(self) -> PositiveInt | None:
103104
return None
104105

105106
def init_processes_timings(
106-
self, worker_count: PositiveInt, max_concurrency: PositiveInt
107+
self,
108+
worker_count: PositiveInt,
109+
max_concurrency: PositiveInt,
110+
mp_context: BaseContext,
107111
):
108112
"""
109113
Initialize shared timing state for multi-process coordination.
@@ -117,9 +121,9 @@ def init_processes_timings(
117121
self.worker_count = worker_count
118122
self.max_concurrency = max_concurrency
119123

120-
self._processes_init_event = Event()
121-
self._processes_request_index = Value("i", 0)
122-
self._processes_start_time = Value("d", -1.0)
124+
self._processes_init_event = mp_context.Event()
125+
self._processes_request_index = mp_context.Value("i", 0)
126+
self._processes_start_time = mp_context.Value("d", -1.0)
123127

124128
def init_processes_start(self, start_time: float):
125129
"""
@@ -593,7 +597,12 @@ def requests_limit(self) -> PositiveInt | None:
593597
"""
594598
return self.max_concurrency
595599

596-
def init_processes_timings(self, worker_count: int, max_concurrency: int):
600+
def init_processes_timings(
601+
self,
602+
worker_count: PositiveInt,
603+
max_concurrency: PositiveInt,
604+
mp_context: BaseContext,
605+
):
597606
"""
598607
Initialize Poisson-specific timing state.
599608
@@ -603,10 +612,10 @@ def init_processes_timings(self, worker_count: int, max_concurrency: int):
603612
:param worker_count: Number of worker processes to coordinate
604613
:param max_concurrency: Maximum number of concurrent requests allowed
605614
"""
606-
self._offset = Value("d", -1.0)
615+
self._offset = mp_context.Value("d", -1.0)
607616
# Call base implementation last to avoid
608617
# setting Event before offset is ready
609-
super().init_processes_timings(worker_count, max_concurrency)
618+
super().init_processes_timings(worker_count, max_concurrency, mp_context)
610619

611620
def init_processes_start(self, start_time: float):
612621
"""

src/guidellm/scheduler/worker_group.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ async def create_processes(self):
221221
# Initialize worker processes
222222
self.processes = []
223223
self.strategy.init_processes_timings(
224-
worker_count=num_processes, max_concurrency=max_conc
224+
worker_count=num_processes,
225+
max_concurrency=max_conc,
226+
mp_context=self.mp_context,
225227
)
226228
for rank in range(num_processes):
227229
# Distribute any remainder across the first N ranks

tests/unit/scheduler/test_strategies.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import math
44
import time
5+
from multiprocessing import get_context
56
from typing import Literal, TypeVar
67

78
import pytest
@@ -502,7 +503,9 @@ async def test_timing_without_rampup(self):
502503
### WRITTEN BY AI ###
503504
"""
504505
strategy = AsyncConstantStrategy(rate=10.0, rampup_duration=0.0)
505-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
506+
strategy.init_processes_timings(
507+
worker_count=1, max_concurrency=100, mp_context=get_context()
508+
)
506509
start_time = 1000.0
507510
strategy.init_processes_start(start_time)
508511

@@ -525,7 +528,9 @@ async def test_timing_with_rampup(self):
525528
rate = 10.0
526529
rampup_duration = 2.0
527530
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
528-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
531+
strategy.init_processes_timings(
532+
worker_count=1, max_concurrency=100, mp_context=get_context()
533+
)
529534
start_time = 1000.0
530535
strategy.init_processes_start(start_time)
531536

@@ -574,7 +579,9 @@ async def test_timing_with_rampup_edge_cases(self):
574579

575580
# Test with very short rampup_duration
576581
strategy = AsyncConstantStrategy(rate=100.0, rampup_duration=0.01)
577-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
582+
strategy.init_processes_timings(
583+
worker_count=1, max_concurrency=100, mp_context=get_context()
584+
)
578585
start_time = 2000.0
579586
strategy.init_processes_start(start_time)
580587

@@ -584,7 +591,9 @@ async def test_timing_with_rampup_edge_cases(self):
584591

585592
# Test with very long rampup_duration
586593
strategy2 = AsyncConstantStrategy(rate=1.0, rampup_duration=100.0)
587-
strategy2.init_processes_timings(worker_count=1, max_concurrency=100)
594+
strategy2.init_processes_timings(
595+
worker_count=1, max_concurrency=100, mp_context=get_context()
596+
)
588597
start_time2 = 3000.0
589598
strategy2.init_processes_start(start_time2)
590599

@@ -613,7 +622,9 @@ async def test_timing_rampup_transition(self):
613622
rate = 10.0
614623
rampup_duration = 2.0
615624
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
616-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
625+
strategy.init_processes_timings(
626+
worker_count=1, max_concurrency=100, mp_context=get_context()
627+
)
617628
start_time = 5000.0
618629
strategy.init_processes_start(start_time)
619630

0 commit comments

Comments
 (0)