Skip to content

Commit a64d08e

Browse files
committed
perf(migration): cache parent information with timestamps to reduce io latency
1 parent 6c90cc1 commit a64d08e

File tree

7 files changed

+60
-31
lines changed

7 files changed

+60
-31
lines changed

pychunkedgraph/graph/edges/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorstore as ts
1313
import zstandard as zstd
1414
from graph_tool import Graph
15+
from cachetools import LRUCache
1516

1617
from pychunkedgraph.graph import types
1718
from pychunkedgraph.graph.chunks.utils import (
@@ -21,6 +22,7 @@
2122
from pychunkedgraph.graph.utils import basetypes
2223

2324
from ..utils import basetypes
25+
from ..utils.generic import get_parents_at_timestamp
2426

2527

2628
_edge_type_fileds = ("in_chunk", "between_chunk", "cross_chunk")
@@ -39,6 +41,7 @@
3941
]
4042
)
4143
ZSTD_EDGE_COMPRESSION = 17
44+
PARENTS_CACHE = LRUCache(64 * 1024)
4245

4346

4447
class Edges:
@@ -341,6 +344,23 @@ def _filter(node):
341344
chunks_map[node_b] = np.concatenate(chunks_map[node_b])
342345
return int(mlayer), _filter(node_a), _filter(node_b)
343346

347+
def _populate_parents_cache(children: np.ndarray):
348+
global PARENTS_CACHE
349+
350+
not_cached = []
351+
for child in children:
352+
try:
353+
# reset lru index, these will be needed soon
354+
_ = PARENTS_CACHE[child]
355+
except KeyError:
356+
not_cached.append(child)
357+
358+
all_parents = cg.get_parents(not_cached, current=False)
359+
for child, parents in zip(not_cached, all_parents):
360+
PARENTS_CACHE[child] = {}
361+
for parent, ts in parents:
362+
PARENTS_CACHE[child][ts] = parent
363+
344364
def _get_new_edge(edge, parent_ts, padding):
345365
"""
346366
Attempts to find new edge(s) for the stale `edge`.
@@ -371,7 +391,13 @@ def _get_new_edge(edge, parent_ts, padding):
371391
if np.any(mask):
372392
parents_a = _edges[mask][:, 0]
373393
children_b = cg.get_children(_edges[mask][:, 1], flatten=True)
374-
parents_b = np.unique(cg.get_parents(children_b, time_stamp=parent_ts))
394+
# parents_b = np.unique(cg.get_parents(children_b, time_stamp=parent_ts))
395+
_populate_parents_cache(children_b)
396+
_parents_b, missing = get_parents_at_timestamp(
397+
children_b, PARENTS_CACHE, time_stamp=parent_ts, unique=True
398+
)
399+
_parents_b_missing = np.unique(cg.get_parents(missing, time_stamp=parent_ts))
400+
parents_b = np.concatenate([_parents_b, _parents_b_missing])
375401
_cx_edges_d = cg.get_cross_chunk_edges(parents_b, time_stamp=parent_ts)
376402
parents_b = []
377403
for _node, _edges_d in _cx_edges_d.items():

pychunkedgraph/graph/utils/generic.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
TODO categorize properly
44
"""
55

6-
76
import datetime
87
from typing import Dict
98
from typing import Iterable
@@ -173,14 +172,30 @@ def mask_nodes_by_bounding_box(
173172
adapt_layers = layers - 2
174173
adapt_layers[adapt_layers < 0] = 0
175174
fanout = meta.graph_config.FANOUT
176-
bounding_box_layer = (
177-
bounding_box[None] / (fanout ** adapt_layers)[:, None, None]
178-
)
175+
bounding_box_layer = bounding_box[None] / (fanout**adapt_layers)[:, None, None]
179176
bound_check = np.array(
180177
[
181178
np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1),
182179
np.all(chunk_coordinates + 1 > bounding_box_layer[:, 0], axis=1),
183180
]
184181
).T
185182

186-
return np.all(bound_check, axis=1)
183+
return np.all(bound_check, axis=1)
184+
185+
186+
def get_parents_at_timestamp(nodes, parents_ts_map, time_stamp, unique: bool = False):
187+
"""
188+
Search for the first parent with ts <= `time_stamp`.
189+
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
190+
"""
191+
skipped_nodes = []
192+
parents = set() if unique else []
193+
for node in nodes:
194+
try:
195+
for ts, parent in parents_ts_map[node].items():
196+
if time_stamp >= ts:
197+
parents.add(parent) if unique else parents.append(parent)
198+
break
199+
except KeyError:
200+
skipped_nodes.append(node)
201+
return list(parents), skipped_nodes

pychunkedgraph/ingest/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Callable, Dict, Iterable, Tuple, Sequence
1111

1212
import numpy as np
13-
from rq import Queue as RQueue
13+
from rq import Queue as RQueue, Retry
1414

1515

1616
from .utils import chunk_id_str, get_chunks_not_done, randomize_grid_points
@@ -209,6 +209,7 @@ def _queue_tasks(imanager: IngestionManager, chunk_fn: Callable, coords: Iterabl
209209
timeout=environ.get("L2JOB_TIMEOUT", "3m"),
210210
result_ttl=0,
211211
job_id=chunk_id_str(2, chunk_coord),
212+
retry=Retry(2, 10),
212213
)
213214
)
214215
q.enqueue_many(job_datas)

pychunkedgraph/ingest/upgrade/atomic_layer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,13 @@
1010
from pychunkedgraph.graph import ChunkedGraph, types
1111
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
1212
from pychunkedgraph.graph.utils import serializers
13+
from pychunkedgraph.graph.utils.generic import get_parents_at_timestamp
1314

1415
from .utils import fix_corrupt_nodes, get_end_timestamps, get_parent_timestamps
1516

1617
CHILDREN = {}
1718

1819

19-
def _get_parents_at_timestamp(nodes, parents_ts_map, time_stamp):
20-
"""
21-
Search for the first parent with ts <= `time_stamp`.
22-
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
23-
"""
24-
parents = []
25-
for node in nodes:
26-
for ts, parent in parents_ts_map[node].items():
27-
if time_stamp >= ts:
28-
parents.append(parent)
29-
break
30-
return parents
31-
32-
3320
def update_cross_edges(
3421
cg: ChunkedGraph,
3522
node,
@@ -59,7 +46,7 @@ def update_cross_edges(
5946
break
6047

6148
val_dict = {}
62-
parents = _get_parents_at_timestamp(partners, parents_ts_map, ts)
49+
parents, _ = get_parents_at_timestamp(partners, parents_ts_map, ts)
6350
edge_parents_d = dict(zip(partners, parents))
6451
for layer, layer_edges in cx_edges_d.items():
6552
layer_edges = fastremap.remap(

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import numpy as np
1111
from tqdm import tqdm
1212

13-
from pychunkedgraph.graph import ChunkedGraph
13+
from pychunkedgraph.graph import ChunkedGraph, edges
1414
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
15-
from pychunkedgraph.graph.edges import get_latest_edges_wrapper
1615
from pychunkedgraph.graph.utils import serializers
1716
from pychunkedgraph.graph.types import empty_2d
1817
from pychunkedgraph.utils.general import chunked
@@ -105,7 +104,6 @@ def _populate_cx_edges_with_timestamps(
105104
row_id = serializers.serialize_uint64(node)
106105
val_dict = {Hierarchy.StaleTimeStamp: 0}
107106
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts))
108-
109107
cg.client.write(rows)
110108

111109

@@ -119,7 +117,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts) -> list:
119117
for ts, cx_edges_d in CX_EDGES[node].items():
120118
if ts < node_ts:
121119
continue
122-
cx_edges_d, edge_nodes = get_latest_edges_wrapper(cg, cx_edges_d, parent_ts=ts)
120+
cx_edges_d, edge_nodes = edges.get_latest_edges_wrapper(cg, cx_edges_d, parent_ts=ts)
123121
if edge_nodes.size == 0:
124122
continue
125123

@@ -204,12 +202,13 @@ def update_chunk(
204202

205203
if debug:
206204
rows = []
205+
logging.info(f"processing {len(nodes)} nodes with 1 worker.")
207206
for node, node_ts in zip(nodes, nodes_ts):
208207
rows.extend(update_cross_edges(cg, layer, node, node_ts))
209208
logging.info(f"total elaspsed time: {time.time() - start}")
210209
return
211210

212-
task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
211+
task_size = int(math.ceil(len(nodes) / mp.cpu_count()))
213212
chunked_nodes = chunked(nodes, task_size)
214213
chunked_nodes_ts = chunked(nodes_ts, task_size)
215214
cg_info = cg.get_serialized_info()
@@ -219,7 +218,7 @@ def update_chunk(
219218
args = (cg_info, layer, chunk, ts_chunk)
220219
tasks.append(args)
221220

222-
processes = min(mp.cpu_count() * 2, len(tasks))
221+
processes = min(mp.cpu_count(), len(tasks))
223222
logging.info(f"processing {len(nodes)} nodes with {processes} workers.")
224223
with mp.Pool(processes) as pool:
225224
_ = list(

pychunkedgraph/ingest/upgrade/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def get_parent_timestamps(
108108

109109
def fix_corrupt_nodes(cg: ChunkedGraph, nodes: list, children_d: dict):
110110
"""
111-
Iteratively removes a node from parent column of its children.
112-
Then removes the node iteself, effectively erasing it.
111+
For each node: delete it from parent column of its children.
112+
Then deletes the node itself, effectively erasing it from hierarchy.
113113
"""
114114
table = cg.client._table
115115
batcher = table.mutations_batcher(flush_count=500)

pychunkedgraph/ingest/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
import tensorstore as ts
12-
from rq import Queue, Worker
12+
from rq import Queue, Retry, Worker
1313
from rq.worker import WorkerStatus
1414

1515
from . import IngestConfig
@@ -199,6 +199,7 @@ def queue_layer_helper(parent_layer: int, imanager: IngestionManager, fn):
199199
result_ttl=0,
200200
job_id=chunk_id_str(parent_layer, chunk_coord),
201201
timeout=f"{timeout_scale * int(parent_layer * parent_layer)}m",
202+
retry=Retry(2, 10),
202203
)
203204
)
204205
q.enqueue_many(job_datas)

0 commit comments

Comments
 (0)