Skip to content
This repository was archived by the owner on Feb 13, 2026. It is now read-only.

Commit 8c2e1e5

Browse files
authored
Merge pull request #304 from punch-mission/speedster
Let's do things quickly, shall we?
2 parents f074d63 + 3fefe1a commit 8c2e1e5

4 files changed

Lines changed: 158 additions & 6 deletions

File tree

changelog/304.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Adds `speedster`, a `launcher` alternative that can achieve ~5x throughput for short-running flows

punchpipe/control/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def load_flow_data(pipeline_config):
237237
flow_enabled = dict()
238238
flow_batch_size = dict()
239239
for flow_type in pipeline_config["flows"]:
240-
flow_enabled[flow_type] = pipeline_config["flows"][flow_type].get("enabled", True)
240+
flow_enabled[flow_type] = pipeline_config["flows"][flow_type].get("enabled", True) is True
241241
flow_weights[flow_type] = pipeline_config["flows"][flow_type].get("launch_weight", 1)
242242
flow_batch_size[flow_type] = pipeline_config["flows"][flow_type].get("batch_size", 1)
243243
return flow_weights, flow_enabled, flow_batch_size

punchpipe/control/processor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from dateutil.parser import parse as parse_datetime_str
66
from prefect import get_run_logger, tags
7-
from prefect.context import get_run_context
7+
from prefect.context import MissingContextError, get_run_context
88

99
from punchpipe.control.db import File, Flow
1010
from punchpipe.control.util import (
@@ -16,7 +16,7 @@
1616

1717

1818
def generic_process_flow_logic(flow_id: int | list[int], core_flow_to_launch, pipeline_config_path: str, session=None,
19-
call_data_processor=None):
19+
call_data_processor=None, ):
2020
if session is None:
2121
session = get_database_session()
2222
if isinstance(flow_id, int):
@@ -44,9 +44,13 @@ def generic_process_flow_logic(flow_id: int | list[int], core_flow_to_launch, pi
4444
f"{flow_db_entry.creation_time} and launched at {flow_db_entry.launch_time}.")
4545

4646
# update the processing flow name with the flow run name from Prefect
47-
flow_run_context = get_run_context()
48-
flow_db_entry.flow_run_name = flow_run_context.flow_run.name
49-
flow_db_entry.flow_run_id = flow_run_context.flow_run.id
47+
try:
48+
flow_run_context = get_run_context()
49+
flow_db_entry.flow_run_name = flow_run_context.flow_run.name
50+
flow_db_entry.flow_run_id = flow_run_context.flow_run.id
51+
except MissingContextError:
52+
# We're not in a flow context---probably we're running under speedster
53+
pass
5054
flow_db_entry.state = "running"
5155
flow_db_entry.start_time = datetime.now()
5256

punchpipe/speedster.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os
2+
import time
3+
import argparse
4+
import warnings
5+
import traceback
6+
import multiprocessing
7+
from datetime import datetime
8+
from collections import defaultdict
9+
10+
import yaml
11+
from prefect.logging import disable_run_logger
12+
from sqlalchemy import update
13+
from tqdm.auto import tqdm
14+
from yaml.loader import FullLoader
15+
16+
from punchpipe.cli import find_flow
17+
from punchpipe.control.db import Flow
18+
from punchpipe.control.util import get_database_session
19+
20+
21+
def load_pipeline_configuration(path: str = None) -> dict:
22+
with open(path) as f:
23+
config = yaml.load(f, Loader=FullLoader)
24+
# TODO: add validation
25+
return config
26+
27+
28+
def load_enabled_flows(pipeline_config):
29+
enabled_flows = []
30+
for flow_type in pipeline_config["flows"]:
31+
if pipeline_config["flows"][flow_type].get("enabled", True) == "speedy":
32+
enabled_flows.append(flow_type)
33+
return enabled_flows
34+
35+
36+
def gather_planned_flows(session, enabled_flows, max_n=None):
37+
flows = (session.query(Flow)
38+
.where(Flow.state == "planned")
39+
.where(Flow.flow_type.in_(enabled_flows))
40+
.order_by(Flow.is_backprocessing.asc(), Flow.priority.desc(), Flow.creation_time.desc())
41+
.limit(max_n).all())
42+
count_per_type = defaultdict(lambda: 0)
43+
flow_ids = []
44+
types = []
45+
for flow in flows:
46+
types.append(flow.flow_type)
47+
count_per_type[flow.flow_type] += 1
48+
flow_ids.append(flow.flow_id)
49+
50+
return flow_ids, types, count_per_type
51+
52+
53+
def worker_init(config_path):
54+
global session, flow_type_to_runner, path_to_config
55+
with disable_run_logger(), warnings.catch_warnings():
56+
# Otherwise warning spam will hide any progress messages
57+
warnings.simplefilter('ignore')
58+
session = get_database_session()
59+
flow_type_to_runner = dict()
60+
path_to_config = config_path
61+
62+
63+
def worker_run_flow(inputs):
64+
flow_id, flow_type, delay = inputs
65+
global flow_type_to_runner, session, path_to_config
66+
if flow_type not in flow_type_to_runner:
67+
runner = find_flow(flow_type + "_process_flow").fn
68+
flow_type_to_runner[flow_type] = runner
69+
else:
70+
runner = flow_type_to_runner[flow_type]
71+
72+
session.execute(update(Flow).where(Flow.flow_id == flow_id).values(
73+
state='launched', flow_run_name='speedster', launch_time=datetime.now()))
74+
75+
with disable_run_logger(), warnings.catch_warnings():
76+
# Otherwise warning spam will hide any progress messages
77+
warnings.simplefilter('ignore')
78+
try:
79+
time.sleep(delay)
80+
runner(flow_id, path_to_config, session)
81+
except KeyboardInterrupt:
82+
session.execute(
83+
update(Flow).where(Flow.flow_id == flow_id).values(state='revivable'))
84+
session.commit()
85+
print(f"Keyboard interrupt in flow {flow_id}; marked as revivable")
86+
except: # noqa: E722
87+
print(f"Exception in flow {flow_id}")
88+
traceback.print_exc()
89+
90+
91+
if __name__ == "__main__":
92+
multiprocessing.set_start_method('forkserver')
93+
parser = argparse.ArgumentParser(prog='speedster')
94+
parser.add_argument("config", type=str, help="Path to config.")
95+
parser.add_argument("-f", "--flows-per-batch", type=int, help="Max number of flows per batch.")
96+
parser.add_argument("-b", "--n-batches", type=int, help="Number of batches.")
97+
parser.add_argument("-w", "--n-workers", type=int, help="Number of workers")
98+
args = parser.parse_args()
99+
config_path = args.config
100+
101+
pipeline_config = load_pipeline_configuration(config_path)
102+
enabled_flows = load_enabled_flows(pipeline_config)
103+
session = get_database_session(engine_kwargs=dict(isolation_level="READ COMMITTED"))
104+
105+
if args.n_workers is None:
106+
args.n_workers = os.cpu_count()
107+
108+
if args.flows_per_batch is None:
109+
n_cores = args.n_workers
110+
else:
111+
n_cores = min(args.n_workers, args.flows_per_batch)
112+
113+
n_batches_run = 0
114+
with multiprocessing.Pool(n_cores, initializer=worker_init, initargs=(config_path,)) as p:
115+
print("Beginning fetch-run loop; press Ctrl-C to exit and allow time for cleanup")
116+
if args.flows_per_batch:
117+
print(f"Will cap at {args.flows_per_batch} flows per batch")
118+
if args.n_batches:
119+
print(f"Will stop after {args.n_batches} batches")
120+
while True:
121+
batch_of_flows, batch_types, count_per_type = gather_planned_flows(
122+
session, enabled_flows, args.flows_per_batch)
123+
124+
if len(batch_of_flows) == 0:
125+
print("No pending flows found---will wait two minutes and try again")
126+
try:
127+
time.sleep(60*2)
128+
except KeyboardInterrupt:
129+
break
130+
else:
131+
print("Batch contents: ", end='')
132+
count_report = []
133+
for type in sorted(count_per_type.keys()):
134+
print(f"{count_per_type[type]} of {type}, ", end='')
135+
print()
136+
with tqdm(total=len(batch_of_flows)) as pbar:
137+
# Stagger the launches which may give less DB and IO contention
138+
delays = [i / 6 if i < n_cores else 0 for i in range(len(batch_of_flows))]
139+
try:
140+
for _ in p.imap_unordered(worker_run_flow, zip(batch_of_flows, batch_types, delays)):
141+
pbar.update()
142+
except KeyboardInterrupt:
143+
print("Halting")
144+
break
145+
n_batches_run += 1
146+
if args.n_batches and n_batches_run >= args.n_batches:
147+
break

0 commit comments

Comments
 (0)