Skip to content

Commit 3acacc5

Browse files
committed
Fix unit tests
Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent 9972dbb commit 3acacc5

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

tests/unit/scheduler/test_strategies.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import math
44
import time
5+
from multiprocessing import get_context
56
from typing import Literal, TypeVar
67

78
import pytest
@@ -502,7 +503,9 @@ async def test_timing_without_rampup(self):
502503
### WRITTEN BY AI ###
503504
"""
504505
strategy = AsyncConstantStrategy(rate=10.0, rampup_duration=0.0)
505-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
506+
strategy.init_processes_timings(
507+
worker_count=1, max_concurrency=100, mp_context=get_context()
508+
)
506509
start_time = 1000.0
507510
strategy.init_processes_start(start_time)
508511

@@ -525,7 +528,9 @@ async def test_timing_with_rampup(self):
525528
rate = 10.0
526529
rampup_duration = 2.0
527530
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
528-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
531+
strategy.init_processes_timings(
532+
worker_count=1, max_concurrency=100, mp_context=get_context()
533+
)
529534
start_time = 1000.0
530535
strategy.init_processes_start(start_time)
531536

@@ -574,7 +579,9 @@ async def test_timing_with_rampup_edge_cases(self):
574579

575580
# Test with very short rampup_duration
576581
strategy = AsyncConstantStrategy(rate=100.0, rampup_duration=0.01)
577-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
582+
strategy.init_processes_timings(
583+
worker_count=1, max_concurrency=100, mp_context=get_context()
584+
)
578585
start_time = 2000.0
579586
strategy.init_processes_start(start_time)
580587

@@ -584,7 +591,9 @@ async def test_timing_with_rampup_edge_cases(self):
584591

585592
# Test with very long rampup_duration
586593
strategy2 = AsyncConstantStrategy(rate=1.0, rampup_duration=100.0)
587-
strategy2.init_processes_timings(worker_count=1, max_concurrency=100)
594+
strategy2.init_processes_timings(
595+
worker_count=1, max_concurrency=100, mp_context=get_context()
596+
)
588597
start_time2 = 3000.0
589598
strategy2.init_processes_start(start_time2)
590599

@@ -613,7 +622,9 @@ async def test_timing_rampup_transition(self):
613622
rate = 10.0
614623
rampup_duration = 2.0
615624
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
616-
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
625+
strategy.init_processes_timings(
626+
worker_count=1, max_concurrency=100, mp_context=get_context()
627+
)
617628
start_time = 5000.0
618629
strategy.init_processes_start(start_time)
619630

0 commit comments

Comments
 (0)