Skip to content

Commit ad1d814

Browse files
committed
perf(upgrade): reduce latency for atomic layer chunks
1 parent 01a3646 commit ad1d814

File tree

1 file changed

+59
-25
lines changed

1 file changed

+59
-25
lines changed

pychunkedgraph/ingest/upgrade/atomic_layer.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22

33
from collections import defaultdict
4-
from concurrent.futures import ThreadPoolExecutor, as_completed
54
from datetime import datetime, timedelta, timezone
5+
import multiprocessing as mp
66
import logging, math, time
77
from copy import copy
88

@@ -19,13 +19,28 @@
1919
CHILDREN = {}
2020

2121

22+
def _get_parents_at_timestamp(nodes, parents_ts_map, time_stamp):
23+
"""
24+
Search for the first parent with ts <= `time_stamp`.
25+
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
26+
"""
27+
parents = []
28+
for node in nodes:
29+
for ts, parent in parents_ts_map[node].items():
30+
if time_stamp >= ts:
31+
parents.append(parent)
32+
break
33+
return parents
34+
35+
2236
def update_cross_edges(
2337
cg: ChunkedGraph,
2438
node,
2539
cx_edges_d: dict,
2640
node_ts,
2741
node_end_ts,
28-
timestamps_d: defaultdict[int, set],
42+
timestamps_map: defaultdict[int, set],
43+
parents_ts_map: defaultdict[int, dict],
2944
) -> list:
3045
"""
3146
Helper function to update a single L2 ID.
@@ -35,9 +50,9 @@ def update_cross_edges(
3550
edges = np.concatenate(list(cx_edges_d.values()))
3651
partners = np.unique(edges[:, 1])
3752

38-
timestamps = copy(timestamps_d[node])
53+
timestamps = copy(timestamps_map[node])
3954
for partner in partners:
40-
timestamps.update(timestamps_d[partner])
55+
timestamps.update(timestamps_map[partner])
4156

4257
node_end_ts = node_end_ts or datetime.now(timezone.utc)
4358
for ts in sorted(timestamps):
@@ -47,7 +62,7 @@ def update_cross_edges(
4762
break
4863

4964
val_dict = {}
50-
parents = cg.get_parents(partners, time_stamp=ts)
65+
parents = _get_parents_at_timestamp(partners, parents_ts_map, ts)
5166
edge_parents_d = dict(zip(partners, parents))
5267
for layer, layer_edges in cx_edges_d.items():
5368
layer_edges = fastremap.remap(
@@ -75,28 +90,42 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
7590
all_partners = np.unique(np.concatenate(all_cx_edges)[:, 1])
7691
timestamps_d = get_parent_timestamps(cg, np.concatenate([nodes, all_partners]))
7792

93+
parents_ts_map = defaultdict(dict)
94+
all_parents = cg.get_parents(all_partners, current=False)
95+
for partner, parents in zip(all_partners, all_parents):
96+
for parent, ts in parents:
97+
parents_ts_map[partner][ts] = parent
98+
7899
rows = []
100+
skipped = []
79101
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
80102
is_stale = end_ts is not None
81103
_cx_edges_d = cx_edges_d.get(node, {})
82104
if not _cx_edges_d:
105+
skipped.append(node)
83106
continue
84107
if is_stale:
85108
end_ts -= timedelta(milliseconds=1)
86109

87-
_rows = update_cross_edges(cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d)
110+
_rows = update_cross_edges(
111+
cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d, parents_ts_map
112+
)
88113
if is_stale:
89114
row_id = serializers.serialize_uint64(node)
90115
val_dict = {Hierarchy.StaleTimeStamp: 0}
91116
_rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
92117
rows.extend(_rows)
93-
118+
parents = cg.get_roots(skipped)
119+
layers = cg.get_chunk_layers(parents)
120+
assert np.all(layers == cg.meta.layer_count)
94121
return rows
95122

96123

97124
def _update_nodes_helper(args):
98-
cg, nodes, nodes_ts = args
99-
return update_nodes(cg, nodes, nodes_ts)
125+
cg_info, nodes, nodes_ts = args
126+
cg = ChunkedGraph(**cg_info)
127+
rows = update_nodes(cg, nodes, nodes_ts)
128+
cg.client.write(rows)
100129

101130

102131
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False):
@@ -134,21 +163,26 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False)
134163

135164
if debug:
136165
rows = update_nodes(cg, nodes, nodes_ts)
137-
else:
138-
task_size = int(math.ceil(len(nodes) / 16))
139-
chunked_nodes = chunked(nodes, task_size)
140-
chunked_nodes_ts = chunked(nodes_ts, task_size)
141-
tasks = []
142-
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
143-
args = (cg, chunk, ts_chunk)
144-
tasks.append(args)
145-
logging.info(f"task size {task_size}, count {len(tasks)}.")
146-
147-
rows = []
148-
with ThreadPoolExecutor(max_workers=8) as executor:
149-
futures = [executor.submit(_update_nodes_helper, task) for task in tasks]
150-
for future in tqdm(as_completed(futures), total=len(futures)):
151-
rows.extend(future.result())
166+
cg.client.write(rows)
167+
return
152168

153-
cg.client.write(rows)
169+
task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
170+
chunked_nodes = chunked(nodes, task_size)
171+
chunked_nodes_ts = chunked(nodes_ts, task_size)
172+
cg_info = cg.get_serialized_info()
173+
174+
tasks = []
175+
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
176+
args = (cg_info, chunk, ts_chunk)
177+
tasks.append(args)
178+
179+
processes = min(mp.cpu_count() * 2, len(tasks))
180+
logging.info(f"processing {len(nodes)} nodes with {processes} workers.")
181+
with mp.Pool(processes) as pool:
182+
_ = list(
183+
tqdm(
184+
pool.imap_unordered(_update_nodes_helper, tasks),
185+
total=len(tasks),
186+
)
187+
)
154188
logging.info(f"total elaspsed time: {time.time() - start}")

0 commit comments

Comments
 (0)