Skip to content

Commit ee36957

Browse files
committed
Refactor observations workflow based on pattern in temporalio/samples-python#113
1 parent 1035bd3 commit ee36957

File tree

6 files changed

+52
-40
lines changed

6 files changed

+52
-40
lines changed

oonipipeline/src/oonipipeline/temporal/activities/analysis.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from temporalio import workflow, activity
1010

11+
from .common import optimize_all_tables
12+
1113
with workflow.unsafe.imports_passed_through():
1214
import clickhouse_driver
1315

@@ -29,7 +31,6 @@
2931
get_prev_range,
3032
make_db_rows,
3133
maybe_delete_prev_range,
32-
optimize_all_tables,
3334
)
3435

3536
log = logging.getLogger("oonidata.processing")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
from oonipipeline.db.connections import ClickhouseConnection
3+
from oonipipeline.db.create_tables import create_queries
4+
5+
from temporalio import activity
6+
7+
8+
@dataclass
9+
class ClickhouseParams:
10+
clickhouse_url: str
11+
12+
13+
@activity.defn
14+
def optimize_all_tables(params: ClickhouseParams):
15+
with ClickhouseConnection(params.clickhouse_url) as db:
16+
for _, table_name in create_queries:
17+
db.execute(f"OPTIMIZE TABLE {table_name}")

oonipipeline/src/oonipipeline/temporal/activities/observations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def make_observation_in_day(params: MakeObservationsParams) -> dict:
141141
),
142142
)
143143
)
144+
log.info(f"prev_ranges: {prev_ranges}")
144145

145146
t = PerfTimer()
146147
total_t = PerfTimer()
@@ -175,6 +176,7 @@ def make_observation_in_day(params: MakeObservationsParams) -> dict:
175176
if len(prev_ranges) > 0:
176177
with ClickhouseConnection(params.clickhouse, row_buffer_size=10_000) as db:
177178
for table_name, pr in prev_ranges:
179+
log.info("deleting previous range of {pr}")
178180
maybe_delete_prev_range(db=db, prev_range=pr)
179181

180182
return {"size": total_size, "measurement_count": total_msmt_count}

oonipipeline/src/oonipipeline/temporal/common.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
MeasurementListProgress,
2222
)
2323
from ..db.connections import ClickhouseConnection
24-
from ..db.create_tables import create_queries
2524

2625
log = logging.getLogger("oonidata.processing")
2726

@@ -165,12 +164,6 @@ def get_prev_range(
165164
return prev_range
166165

167166

168-
def optimize_all_tables(clickhouse):
169-
with ClickhouseConnection(clickhouse) as db:
170-
for _, table_name in create_queries:
171-
db.execute(f"OPTIMIZE TABLE {table_name}")
172-
173-
174167
def get_obs_count_by_cc(
175168
db: ClickhouseConnection,
176169
test_name: List[str],

oonipipeline/src/oonipipeline/temporal/workflows.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import List
2+
from typing import List, Optional
33

44
import logging
55
import multiprocessing
@@ -8,6 +8,7 @@
88
from datetime import datetime, timedelta, timezone
99

1010
from temporalio import workflow
11+
from temporalio.common import SearchAttributeKey
1112
from temporalio.worker import Worker, SharedStateManager
1213
from temporalio.client import (
1314
Client as TemporalClient,
@@ -18,6 +19,11 @@
1819
ScheduleState,
1920
)
2021

22+
from oonipipeline.temporal.activities.common import (
23+
optimize_all_tables,
24+
ClickhouseParams,
25+
)
26+
2127

2228
# Handle temporal sandbox violations related to calls to self.processName =
2329
# mp.current_process().name in logger, see:
@@ -36,7 +42,7 @@
3642
make_analysis_in_a_day,
3743
make_cc_batches,
3844
)
39-
from oonipipeline.temporal.common import get_obs_count_by_cc, optimize_all_tables
45+
from oonipipeline.temporal.common import get_obs_count_by_cc
4046
from oonipipeline.temporal.activities.observations import (
4147
MakeObservationsParams,
4248
make_observation_in_day,
@@ -67,6 +73,7 @@ def make_worker(client: TemporalClient, parallelism: int) -> Worker:
6773
make_observation_in_day,
6874
make_ground_truths_in_day,
6975
make_analysis_in_a_day,
76+
optimize_all_tables,
7077
],
7178
activity_executor=concurrent.futures.ProcessPoolExecutor(parallelism + 2),
7279
max_concurrent_activities=parallelism,
@@ -76,6 +83,14 @@ def make_worker(client: TemporalClient, parallelism: int) -> Worker:
7683
)
7784

7885

86+
def get_workflow_start_time() -> datetime:
87+
workflow_start_time = workflow.info().typed_search_attributes.get(
88+
SearchAttributeKey.for_datetime("TemporalScheduledStartTime")
89+
)
90+
assert workflow_start_time is not None, "TemporalScheduledStartTime not set"
91+
return workflow_start_time
92+
93+
7994
@dataclass
8095
class ObservationsWorkflowParams:
8196
probe_cc: List[str]
@@ -84,37 +99,27 @@ class ObservationsWorkflowParams:
8499
data_dir: str
85100
fast_fail: bool
86101
log_level: int = logging.INFO
102+
bucket_date: Optional[str] = None
87103

88104

89105
@workflow.defn
90106
class ObservationsWorkflow:
91107
@workflow.run
92108
async def run(self, params: ObservationsWorkflowParams) -> dict:
93-
# TODO(art): wrap this a coroutine call
94-
optimize_all_tables(params.clickhouse)
95-
96-
workflow_id = workflow.info().workflow_id
97-
98-
# TODO(art): this is quite sketchy. Waiting on temporal slack question:
99-
# https://temporalio.slack.com/archives/CTT84RS0P/p1714040382186429
100-
run_ts = datetime.strptime(
101-
"-".join(workflow_id.split("-")[-3:]),
102-
"%Y-%m-%dT%H:%M:%SZ",
109+
if params.bucket_date is None:
110+
params.bucket_date = (
111+
get_workflow_start_time() - timedelta(days=1)
112+
).strftime("%Y-%m-%d")
113+
114+
await workflow.execute_activity(
115+
optimize_all_tables,
116+
ClickhouseParams(clickhouse_url=params.clickhouse),
117+
start_to_close_timeout=timedelta(minutes=5),
103118
)
104-
bucket_date = (run_ts - timedelta(days=1)).strftime("%Y-%m-%d")
105-
106-
# read_time = workflow_info.start_time - timedelta(days=1)
107-
# log.info(f"workflow.info().start_time={workflow.info().start_time} ")
108-
# log.info(f"workflow.info().cron_schedule={workflow.info().cron_schedule} ")
109-
# log.info(f"workflow_info.workflow_id={workflow_info.workflow_id} ")
110-
# log.info(f"workflow_info.run_id={workflow_info.run_id} ")
111-
# log.info(f"workflow.now()={workflow.now()}")
112-
# print(workflow)
113-
# bucket_date = f"{read_time.year}-{read_time.month:02}-{read_time.day:02}"
114119

115120
t = PerfTimer()
116121
log.info(
117-
f"Starting observation making with probe_cc={params.probe_cc},test_name={params.test_name} bucket_date={bucket_date}"
122+
f"Starting observation making with probe_cc={params.probe_cc},test_name={params.test_name} bucket_date={params.bucket_date}"
118123
)
119124

120125
res = await workflow.execute_activity(
@@ -125,19 +130,17 @@ async def run(self, params: ObservationsWorkflowParams) -> dict:
125130
clickhouse=params.clickhouse,
126131
data_dir=params.data_dir,
127132
fast_fail=params.fast_fail,
128-
bucket_date=bucket_date,
133+
bucket_date=params.bucket_date,
129134
),
130135
start_to_close_timeout=timedelta(minutes=30),
131136
)
132137

133138
total_size = res["size"]
134139
total_measurement_count = res["measurement_count"]
135-
136-
# This needs to be adjusted once we get the the per entry concurrency working
137140
mb_per_sec = round(total_size / t.s / 10**6, 1)
138141
msmt_per_sec = round(total_measurement_count / t.s)
139142
log.info(
140-
f"finished processing {bucket_date} speed: {mb_per_sec}MB/s ({msmt_per_sec}msmt/s)"
143+
f"finished processing {params.bucket_date} speed: {mb_per_sec}MB/s ({msmt_per_sec}msmt/s)"
141144
)
142145

143146
# with ClickhouseConnection(params.clickhouse) as db:

oonipipeline/tests/test_cli.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,7 @@ def test_full_workflow(
5454
bucket_dict = dict(res)
5555
assert "2022-10-20" in bucket_dict, bucket_dict
5656
assert bucket_dict["2022-10-20"] == 200, bucket_dict
57-
58-
res = db.execute(
59-
"SELECT COUNT() FROM obs_web WHERE bucket_date = '2022-10-20' AND probe_cc = 'BA'"
60-
)
61-
obs_count = res[0][0] # type: ignore
57+
obs_count = bucket_dict["2022-10-20"]
6258

6359
result = cli_runner.invoke(
6460
cli,

0 commit comments

Comments
 (0)