Skip to content

Commit cef44a0

Browse files
committed
added ramp down feature
Signed-off-by: Rishabh Singh <sngri@amazon.com>
1 parent 94f93db commit cef44a0

File tree

5 files changed

+366
-10
lines changed

5 files changed

+366
-10
lines changed

osbenchmark/worker_coordinator/worker_coordinator.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2836,11 +2836,17 @@ def schedule_for(task_allocation, parameter_source):
28362836

28372837
if requires_time_period_schedule(task, runner_for_op, params_for_op):
28382838
warmup_time_period = task.warmup_time_period if task.warmup_time_period else 0
2839+
ramp_down_time_period = task.ramp_down_time_period if task.ramp_down_time_period else 0
28392840
if client_index == 0:
28402841
logger.info("Creating time-period based schedule with [%s] distribution for [%s] with a warmup period of [%s] "
28412842
"seconds and a time period of [%s] seconds.", task.schedule, task.name,
28422843
str(warmup_time_period), str(task.time_period))
2843-
loop_control = TimePeriodBased(warmup_time_period, task.time_period)
2844+
loop_control = TimePeriodBased(warmup_time_period, task.time_period, ramp_down_time_period,
2845+
client_index, task.clients)
2846+
# Log individual client duration if ramp-down is enabled
2847+
if ramp_down_time_period > 0 and client_index == 0:
2848+
logger.info("Ramp-down enabled: clients will stop in reverse order over [%s] seconds",
2849+
str(ramp_down_time_period))
28442850
else:
28452851
warmup_iterations = task.warmup_iterations if task.warmup_iterations else 0
28462852
if task.iterations:
@@ -2950,13 +2956,32 @@ async def __call__(self):
29502956

29512957

29522958
class TimePeriodBased:
2953-
def __init__(self, warmup_time_period, time_period):
2959+
def __init__(self, warmup_time_period, time_period, ramp_down_time_period=None,
2960+
client_index=None, total_clients=None):
29542961
self._warmup_time_period = warmup_time_period
29552962
self._time_period = time_period
2963+
self._ramp_down_time_period = ramp_down_time_period or 0
2964+
self._client_index = client_index
2965+
self._total_clients = total_clients
2966+
self.logger = logging.getLogger(__name__)
2967+
29562968
if warmup_time_period is not None and time_period is not None:
2957-
self._duration = self._warmup_time_period + self._time_period
2969+
self._base_duration = self._warmup_time_period + self._time_period
2970+
2971+
# Calculate how early this client should stop during ramp-down
2972+
# Clients stop in REVERSE order: Client 0 stops first, Client (N-1) stops last
2973+
if self._ramp_down_time_period > 0 and client_index is not None and total_clients is not None:
2974+
reverse_index = (total_clients - 1) - client_index
2975+
client_early_stop = self._ramp_down_time_period * (reverse_index / total_clients)
2976+
self._duration = self._base_duration - client_early_stop
2977+
self.logger.info("Client [%d/%d] will run for %.2f seconds (base: %.2f, early stop: %.2f due to ramp-down)",
2978+
client_index, total_clients, self._duration, self._base_duration, client_early_stop)
2979+
else:
2980+
self._duration = self._base_duration
29582981
else:
29592982
self._duration = None
2983+
self._base_duration = None
2984+
29602985
self._start = None
29612986
self._now = None
29622987

osbenchmark/workload/loader.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,14 +1816,16 @@ def parse_parallel(self, ops_spec, ops, test_procedure_name):
18161816
default_warmup_time_period = self._r(ops_spec, "warmup-time-period", error_ctx="parallel", mandatory=False)
18171817
default_time_period = self._r(ops_spec, "time-period", error_ctx="parallel", mandatory=False)
18181818
default_ramp_up_time_period = self._r(ops_spec, "ramp-up-time-period", error_ctx="parallel", mandatory=False)
1819+
default_ramp_down_time_period = self._r(ops_spec, "ramp-down-time-period", error_ctx="parallel", mandatory=False)
18191820
clients = self._r(ops_spec, "clients", error_ctx="parallel", mandatory=False)
18201821
completed_by = self._r(ops_spec, "completed-by", error_ctx="parallel", mandatory=False)
18211822

18221823
# now descent to each operation
18231824
tasks = []
18241825
for task in self._r(ops_spec, "tasks", error_ctx="parallel"):
18251826
tasks.append(self.parse_task(task, ops, test_procedure_name, default_warmup_iterations, default_iterations,
1826-
default_warmup_time_period, default_time_period, default_ramp_up_time_period, completed_by))
1827+
default_warmup_time_period, default_time_period, default_ramp_up_time_period,
1828+
default_ramp_down_time_period, completed_by))
18271829

18281830
for task in tasks:
18291831
if task.ramp_up_time_period != default_ramp_up_time_period:
@@ -1833,6 +1835,13 @@ def parse_parallel(self, ops_spec, ops, test_procedure_name):
18331835
else:
18341836
self._error(f"task '{task.name}' specifies a different ramp-up-time-period than its enclosing "
18351837
f"'parallel' element in test-procedure '{test_procedure_name}'.")
1838+
if task.ramp_down_time_period != default_ramp_down_time_period:
1839+
if default_ramp_down_time_period is None:
1840+
self._error(f"task '{task.name}' in 'parallel' element of test-procedure '{test_procedure_name}' specifies "
1841+
f"a ramp-down-time-period but it is only allowed on the 'parallel' element.")
1842+
else:
1843+
self._error(f"task '{task.name}' specifies a different ramp-down-time-period than its enclosing "
1844+
f"'parallel' element in test-procedure '{test_procedure_name}'.")
18361845
if completed_by:
18371846
completion_task = None
18381847
for task in tasks:
@@ -1848,7 +1857,7 @@ def parse_parallel(self, ops_spec, ops, test_procedure_name):
18481857
return workload.Parallel(tasks, clients)
18491858

18501859
def parse_task(self, task_spec, ops, test_procedure_name, default_warmup_iterations=None, default_iterations=None,
1851-
default_warmup_time_period=None, default_time_period=None, default_ramp_up_time_period=None,
1860+
default_warmup_time_period=None, default_time_period=None, default_ramp_up_time_period=None, default_ramp_down_time_period=None,
18521861
completed_by_name=None):
18531862

18541863
op_spec = task_spec["operation"]
@@ -1874,6 +1883,8 @@ def parse_task(self, task_spec, ops, test_procedure_name, default_warmup_iterati
18741883
default_value=default_time_period),
18751884
ramp_up_time_period=self._r(task_spec, "ramp-up-time-period", error_ctx=op.name,
18761885
mandatory=False, default_value=default_ramp_up_time_period),
1886+
ramp_down_time_period=self._r(task_spec, "ramp-down-time-period", error_ctx=op.name,
1887+
mandatory=False, default_value=default_ramp_down_time_period),
18771888
clients=self._r(task_spec, "clients", error_ctx=op.name, mandatory=False, default_value=1),
18781889
completes_parent=(task_name == completed_by_name),
18791890
schedule=schedule,
@@ -1901,6 +1912,19 @@ def parse_task(self, task_spec, ops, test_procedure_name, default_warmup_iterati
19011912
self._error(f"The warmup-time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is "
19021913
f"{task.warmup_time_period} seconds but must be greater than or equal to the "
19031914
f"ramp-up-time-period of {task.ramp_up_time_period} seconds.")
1915+
if task.ramp_down_time_period is not None:
1916+
if task.time_period is None:
1917+
self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of "
1918+
f"{task.ramp_down_time_period} seconds but no time-period.")
1919+
elif task.time_period < task.ramp_down_time_period:
1920+
self._error(f"The time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is "
1921+
f"{task.time_period} seconds but must be greater than or equal to the "
1922+
f"ramp-down-time-period of {task.ramp_down_time_period} seconds.")
1923+
1924+
if (task.warmup_iterations is not None or task.iterations is not None) and task.ramp_down_time_period is not None:
1925+
self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of "
1926+
f"{task.ramp_down_time_period} seconds as well as {task.warmup_iterations} warmup iterations and "
1927+
f"{task.iterations} iterations but mixing time periods and iterations is not allowed.")
19041928

19051929
return task
19061930

osbenchmark/workload/workload.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,8 @@ class Task:
937937
IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST = ["non-fatal"]
938938

939939
def __init__(self, name, operation, tags=None, meta_data=None, warmup_iterations=None, iterations=None,
940-
warmup_time_period=None, time_period=None, ramp_up_time_period=None, clients=1, completes_parent=False,
940+
warmup_time_period=None, time_period=None, ramp_up_time_period=None, ramp_down_time_period=None,
941+
clients=1, completes_parent=False,
941942
schedule=None, params=None):
942943
self.name = name
943944
self.operation = operation
@@ -953,6 +954,7 @@ def __init__(self, name, operation, tags=None, meta_data=None, warmup_iterations
953954
self.warmup_time_period = warmup_time_period
954955
self.time_period = time_period
955956
self.ramp_up_time_period = ramp_up_time_period
957+
self.ramp_down_time_period = ramp_down_time_period
956958
self.clients = clients
957959
self.completes_parent = completes_parent
958960
self.schedule = schedule
@@ -1034,16 +1036,17 @@ def __hash__(self):
10341036
# Note that we do not include `params` in __hash__ and __eq__ (the other attributes suffice to uniquely define a task)
10351037
return hash(self.name) ^ hash(self.operation) ^ hash(self.warmup_iterations) ^ hash(self.iterations) ^ \
10361038
hash(self.warmup_time_period) ^ hash(self.time_period) ^ hash(self.ramp_up_time_period) ^ \
1037-
hash(self.clients) ^ hash(self.schedule) ^ hash(self.completes_parent)
1039+
hash(self.ramp_down_time_period) ^ hash(self.clients) ^ hash(self.schedule) ^ hash(self.completes_parent)
10381040

10391041
def __eq__(self, other):
10401042
# Note that we do not include `params` in __hash__ and __eq__ (the other attributes suffice to uniquely define a task)
10411043
return isinstance(other, type(self)) and (self.name, self.operation, self.warmup_iterations, self.iterations,
10421044
self.warmup_time_period, self.time_period, self.ramp_up_time_period,
1043-
self.clients, self.schedule,self.completes_parent) == (other.name,
1044-
other.operation, other.warmup_iterations,
1045+
self.ramp_down_time_period, self.clients, self.schedule,
1046+
self.completes_parent) == (other.name, other.operation, other.warmup_iterations,
10451047
other.iterations, other.warmup_time_period, other.time_period,
1046-
self.ramp_up_time_period, other.clients, other.schedule,
1048+
other.ramp_up_time_period, other.ramp_down_time_period,
1049+
other.clients, other.schedule,
10471050
other.completes_parent)
10481051

10491052
def __iter__(self):

tests/worker_coordinator/worker_coordinator_test.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,3 +2313,166 @@ def test_check_cpu_usage_drops_error_when_queue_full(self):
23132313
# Should not raise
23142314
self.actor._check_cpu_usage() # pylint: disable=protected-access
23152315
assert full_queue.qsize() == 1 # Still full; error dropped
2316+
2317+
class TimePeriodBasedTests(TestCase):
2318+
# pylint: disable=protected-access
2319+
def test_time_period_based_without_ramp_down(self):
2320+
# Test basic time-period based schedule without ramp-down
2321+
loop_control = worker_coordinator.TimePeriodBased(
2322+
warmup_time_period=10,
2323+
time_period=100,
2324+
ramp_down_time_period=None,
2325+
client_index=0,
2326+
total_clients=4
2327+
)
2328+
2329+
# Verify duration calculation
2330+
self.assertEqual(110, loop_control._duration)
2331+
self.assertEqual(110, loop_control._base_duration)
2332+
self.assertEqual(10, loop_control._warmup_time_period)
2333+
self.assertEqual(100, loop_control._time_period)
2334+
self.assertFalse(loop_control.infinite)
2335+
2336+
def test_time_period_based_with_ramp_down_client_0(self):
2337+
# Test ramp-down for client 0 (first client, stops earliest)
2338+
loop_control = worker_coordinator.TimePeriodBased(
2339+
warmup_time_period=10,
2340+
time_period=100,
2341+
ramp_down_time_period=20,
2342+
client_index=0,
2343+
total_clients=4
2344+
)
2345+
2346+
# Client 0: reverse_index = 3, early_stop = 20 * (3/4) = 15
2347+
# duration = 110 - 15 = 95
2348+
self.assertEqual(110, loop_control._base_duration)
2349+
self.assertEqual(95, loop_control._duration)
2350+
self.assertEqual(20, loop_control._ramp_down_time_period)
2351+
2352+
def test_time_period_based_with_ramp_down_client_3(self):
2353+
# Test ramp-down for client 3 (last client, runs full duration)
2354+
loop_control = worker_coordinator.TimePeriodBased(
2355+
warmup_time_period=10,
2356+
time_period=100,
2357+
ramp_down_time_period=20,
2358+
client_index=3,
2359+
total_clients=4
2360+
)
2361+
2362+
# Client 3: reverse_index = 0, early_stop = 20 * (0/4) = 0
2363+
# duration = 110 - 0 = 110
2364+
self.assertEqual(110, loop_control._base_duration)
2365+
self.assertEqual(110, loop_control._duration)
2366+
2367+
def test_time_period_based_with_ramp_down_all_clients(self):
2368+
# Test that clients stop in reverse order with correct spacing
2369+
warmup = 10
2370+
time_period = 100
2371+
ramp_down = 20
2372+
total_clients = 4
2373+
2374+
durations = []
2375+
for client_index in range(total_clients):
2376+
loop_control = worker_coordinator.TimePeriodBased(
2377+
warmup_time_period=warmup,
2378+
time_period=time_period,
2379+
ramp_down_time_period=ramp_down,
2380+
client_index=client_index,
2381+
total_clients=total_clients
2382+
)
2383+
durations.append(loop_control._duration)
2384+
2385+
# Expected: [95, 100, 105, 110] - clients stop 5s apart
2386+
self.assertEqual([95, 100, 105, 110], durations)
2387+
2388+
# Verify spacing is correct
2389+
for i in range(1, len(durations)):
2390+
spacing = durations[i] - durations[i-1]
2391+
expected_spacing = ramp_down / total_clients
2392+
self.assertAlmostEqual(expected_spacing, spacing, places=2)
2393+
2394+
class TaskRampDownTests(TestCase):
2395+
def test_task_with_ramp_down_time_period(self):
2396+
# Test that Task accepts ramp_down_time_period parameter
2397+
op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {})
2398+
task = workload.Task(
2399+
name="test-task",
2400+
operation=op,
2401+
warmup_time_period=10,
2402+
time_period=100,
2403+
ramp_up_time_period=20,
2404+
ramp_down_time_period=30,
2405+
clients=4
2406+
)
2407+
2408+
self.assertEqual(10, task.warmup_time_period)
2409+
self.assertEqual(100, task.time_period)
2410+
self.assertEqual(20, task.ramp_up_time_period)
2411+
self.assertEqual(30, task.ramp_down_time_period)
2412+
self.assertEqual(4, task.clients)
2413+
2414+
def test_task_without_ramp_down_defaults_to_none(self):
2415+
# Test that ramp_down_time_period defaults to None
2416+
op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {})
2417+
task = workload.Task(
2418+
name="test-task",
2419+
operation=op,
2420+
time_period=100,
2421+
clients=4
2422+
)
2423+
2424+
self.assertIsNone(task.ramp_down_time_period)
2425+
2426+
def test_task_equality_with_ramp_down(self):
2427+
# Test that tasks with different ramp_down_time_period are not equal
2428+
op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {})
2429+
2430+
task1 = workload.Task(
2431+
name="test-task",
2432+
operation=op,
2433+
time_period=100,
2434+
ramp_down_time_period=20,
2435+
clients=4
2436+
)
2437+
2438+
task2 = workload.Task(
2439+
name="test-task",
2440+
operation=op,
2441+
time_period=100,
2442+
ramp_down_time_period=30,
2443+
clients=4
2444+
)
2445+
2446+
task3 = workload.Task(
2447+
name="test-task",
2448+
operation=op,
2449+
time_period=100,
2450+
ramp_down_time_period=20,
2451+
clients=4
2452+
)
2453+
2454+
self.assertNotEqual(task1, task2)
2455+
self.assertEqual(task1, task3)
2456+
2457+
def test_task_hash_includes_ramp_down(self):
2458+
# Test that hash includes ramp_down_time_period
2459+
op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {})
2460+
2461+
task1 = workload.Task(
2462+
name="test-task",
2463+
operation=op,
2464+
time_period=100,
2465+
ramp_down_time_period=20,
2466+
clients=4
2467+
)
2468+
2469+
task2 = workload.Task(
2470+
name="test-task",
2471+
operation=op,
2472+
time_period=100,
2473+
ramp_down_time_period=30,
2474+
clients=4
2475+
)
2476+
2477+
# Different ramp_down should produce different hashes
2478+
self.assertNotEqual(hash(task1), hash(task2))

0 commit comments

Comments
 (0)