11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22
33from collections import defaultdict
4- from concurrent .futures import ThreadPoolExecutor , as_completed
54from datetime import datetime , timedelta , timezone
5+ import multiprocessing as mp
66import logging , math , time
77from copy import copy
88
1919CHILDREN = {}
2020
2121
22+ def _get_parents_at_timestamp (nodes , parents_ts_map , time_stamp ):
23+ """
24+ Search for the first parent with ts <= `time_stamp`.
25+ `parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
26+ """
27+ parents = []
28+ for node in nodes :
29+ for ts , parent in parents_ts_map [node ].items ():
30+ if time_stamp >= ts :
31+ parents .append (parent )
32+ break
33+ return parents
34+
35+
2236def update_cross_edges (
2337 cg : ChunkedGraph ,
2438 node ,
2539 cx_edges_d : dict ,
2640 node_ts ,
2741 node_end_ts ,
28- timestamps_d : defaultdict [int , set ],
42+ timestamps_map : defaultdict [int , set ],
43+ parents_ts_map : defaultdict [int , dict ],
2944) -> list :
3045 """
3146 Helper function to update a single L2 ID.
@@ -35,9 +50,9 @@ def update_cross_edges(
3550 edges = np .concatenate (list (cx_edges_d .values ()))
3651 partners = np .unique (edges [:, 1 ])
3752
38- timestamps = copy (timestamps_d [node ])
53+ timestamps = copy (timestamps_map [node ])
3954 for partner in partners :
40- timestamps .update (timestamps_d [partner ])
55+ timestamps .update (timestamps_map [partner ])
4156
4257 node_end_ts = node_end_ts or datetime .now (timezone .utc )
4358 for ts in sorted (timestamps ):
@@ -47,7 +62,7 @@ def update_cross_edges(
4762 break
4863
4964 val_dict = {}
50- parents = cg . get_parents (partners , time_stamp = ts )
65+ parents = _get_parents_at_timestamp (partners , parents_ts_map , ts )
5166 edge_parents_d = dict (zip (partners , parents ))
5267 for layer , layer_edges in cx_edges_d .items ():
5368 layer_edges = fastremap .remap (
@@ -75,28 +90,42 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
7590 all_partners = np .unique (np .concatenate (all_cx_edges )[:, 1 ])
7691 timestamps_d = get_parent_timestamps (cg , np .concatenate ([nodes , all_partners ]))
7792
93+ parents_ts_map = defaultdict (dict )
94+ all_parents = cg .get_parents (all_partners , current = False )
95+ for partner , parents in zip (all_partners , all_parents ):
96+ for parent , ts in parents :
97+ parents_ts_map [partner ][ts ] = parent
98+
7899 rows = []
100+ skipped = []
79101 for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
80102 is_stale = end_ts is not None
81103 _cx_edges_d = cx_edges_d .get (node , {})
82104 if not _cx_edges_d :
105+ skipped .append (node )
83106 continue
84107 if is_stale :
85108 end_ts -= timedelta (milliseconds = 1 )
86109
87- _rows = update_cross_edges (cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d )
110+ _rows = update_cross_edges (
111+ cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d , parents_ts_map
112+ )
88113 if is_stale :
89114 row_id = serializers .serialize_uint64 (node )
90115 val_dict = {Hierarchy .StaleTimeStamp : 0 }
91116 _rows .append (cg .client .mutate_row (row_id , val_dict , time_stamp = end_ts ))
92117 rows .extend (_rows )
93-
118+ parents = cg .get_roots (skipped )
119+ layers = cg .get_chunk_layers (parents )
120+ assert np .all (layers == cg .meta .layer_count )
94121 return rows
95122
96123
97124def _update_nodes_helper (args ):
98- cg , nodes , nodes_ts = args
99- return update_nodes (cg , nodes , nodes_ts )
125+ cg_info , nodes , nodes_ts = args
126+ cg = ChunkedGraph (** cg_info )
127+ rows = update_nodes (cg , nodes , nodes_ts )
128+ cg .client .write (rows )
100129
101130
102131def update_chunk (cg : ChunkedGraph , chunk_coords : list [int ], debug : bool = False ):
@@ -134,21 +163,26 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False)
134163
135164 if debug :
136165 rows = update_nodes (cg , nodes , nodes_ts )
137- else :
138- task_size = int (math .ceil (len (nodes ) / 16 ))
139- chunked_nodes = chunked (nodes , task_size )
140- chunked_nodes_ts = chunked (nodes_ts , task_size )
141- tasks = []
142- for chunk , ts_chunk in zip (chunked_nodes , chunked_nodes_ts ):
143- args = (cg , chunk , ts_chunk )
144- tasks .append (args )
145- logging .info (f"task size { task_size } , count { len (tasks )} ." )
146-
147- rows = []
148- with ThreadPoolExecutor (max_workers = 8 ) as executor :
149- futures = [executor .submit (_update_nodes_helper , task ) for task in tasks ]
150- for future in tqdm (as_completed (futures ), total = len (futures )):
151- rows .extend (future .result ())
166+ cg .client .write (rows )
167+ return
152168
153- cg .client .write (rows )
169+ task_size = int (math .ceil (len (nodes ) / mp .cpu_count () / 2 ))
170+ chunked_nodes = chunked (nodes , task_size )
171+ chunked_nodes_ts = chunked (nodes_ts , task_size )
172+ cg_info = cg .get_serialized_info ()
173+
174+ tasks = []
175+ for chunk , ts_chunk in zip (chunked_nodes , chunked_nodes_ts ):
176+ args = (cg_info , chunk , ts_chunk )
177+ tasks .append (args )
178+
179+ processes = min (mp .cpu_count () * 2 , len (tasks ))
180+ logging .info (f"processing { len (nodes )} nodes with { processes } workers." )
181+ with mp .Pool (processes ) as pool :
182+ _ = list (
183+ tqdm (
184+ pool .imap_unordered (_update_nodes_helper , tasks ),
185+ total = len (tasks ),
186+ )
187+ )
154188 logging .info (f"total elaspsed time: { time .time () - start } " )
0 commit comments