14
14
15
15
import sys
16
16
import threading
17
- from collections import namedtuple
17
+ import random
18
+ from collections import namedtuple , defaultdict
18
19
from enum import Enum
19
20
from typing import List
20
21
@@ -70,11 +71,12 @@ def h(*args, **kwargs):
70
71
# Meta relative
71
72
# ---------------
72
73
73
- def get_chunk_metas (self , chunk_keys ):
74
+ def get_chunk_metas (self , chunk_keys , filter_fields = None ):
74
75
"""
75
76
Get chunk metas according to the given chunk keys.
76
77
77
78
:param chunk_keys: chunk keys
79
+ :param filter_fields: filter the fields in meta
78
80
:return: List of chunk metas
79
81
"""
80
82
raise NotImplementedError
@@ -188,7 +190,9 @@ def get_local_address(self):
188
190
def get_ncores (self ):
189
191
return self ._ncores
190
192
191
- def get_chunk_metas (self , chunk_keys ):
193
+ def get_chunk_metas (self , chunk_keys , filter_fields = None ):
194
+ if filter_fields is not None : # pragma: no cover
195
+ raise NotImplementedError ("Local context doesn't support filter fields now" )
192
196
metas = []
193
197
for chunk_key in chunk_keys :
194
198
chunk_data = self .get (chunk_key )
@@ -219,17 +223,29 @@ def get_chunk_results(self, chunk_keys: List[str]) -> List:
219
223
220
224
221
225
class DistributedContext (ContextBase ):
222
- def __init__ (self , cluster_info , session_id , addr , chunk_meta_client ,
223
- resource_actor_ref , actor_ctx , ** kw ):
224
- self ._cluster_info = cluster_info
225
- is_distributed = cluster_info .is_distributed ()
226
+ def __init__ (self , scheduler_address , session_id , actor_ctx = None , ** kw ):
227
+ from .worker .api import WorkerAPI
228
+ from .scheduler .api import MetaAPI
229
+ from .scheduler .resource import ResourceActor
230
+ from .scheduler .utils import SchedulerClusterInfoActor
231
+ from .actors import new_client
232
+
233
+ self ._session_id = session_id
234
+ self ._scheduler_address = scheduler_address
235
+ self ._worker_api = WorkerAPI ()
236
+ self ._meta_api = MetaAPI (actor_ctx = actor_ctx , scheduler_endpoint = scheduler_address )
237
+
238
+ self ._running_mode = None
239
+ self ._actor_ctx = actor_ctx or new_client ()
240
+ self ._cluster_info = self ._actor_ctx .actor_ref (
241
+ SchedulerClusterInfoActor .default_uid (), address = scheduler_address )
242
+ is_distributed = self ._cluster_info .is_distributed ()
226
243
self ._running_mode = RunningMode .local_cluster \
227
244
if not is_distributed else RunningMode .distributed
228
- self ._session_id = session_id
229
- self ._address = addr
230
- self ._chunk_meta_client = chunk_meta_client
231
- self ._resource_actor_ref = resource_actor_ref
232
- self ._actor_ctx = actor_ctx
245
+ self ._address = kw .pop ('address' , None )
246
+ self ._resource_actor_ref = self ._actor_ctx .actor_ref (
247
+ ResourceActor .default_uid (), address = scheduler_address )
248
+
233
249
self ._extra_info = kw
234
250
235
251
@property
@@ -252,10 +268,6 @@ def get_local_address(self):
252
268
def get_ncores (self ):
253
269
return self ._extra_info .get ('n_cpu' )
254
270
255
- def get_chunk_metas (self , chunk_keys ):
256
- return self ._chunk_meta_client .batch_get_chunk_meta (
257
- self ._session_id , chunk_keys )
258
-
259
271
def get_chunk_results (self , chunk_keys : List [str ]) -> List :
260
272
from .serialize import dataserializer
261
273
from .worker .transfer import ResultSenderActor
@@ -269,6 +281,93 @@ def get_chunk_results(self, chunk_keys: List[str]) -> List:
269
281
dataserializer .loads (sender_ref .fetch_data (self ._session_id , chunk_key )))
270
282
return results
271
283
284
+ # Meta API
285
+ def get_tileable_metas (self , tileable_keys , filter_fields : List [str ]= None ) -> List :
286
+ return self ._meta_api .get_tileable_metas (self ._session_id , tileable_keys , filter_fields )
287
+
288
+ def get_chunk_metas (self , chunk_keys , filter_fields : List [str ] = None ) -> List :
289
+ return self ._meta_api .get_chunk_metas (self ._session_id , chunk_keys , filter_fields )
290
+
291
+ def get_tileable_key_by_name (self , name : str ):
292
+ return self ._meta_api .get_tileable_key_by_name (self ._session_id , name )
293
+
294
+ # Worker API
295
+ def get_chunks_data (self , worker : str , chunk_keys : List [str ], indexes : List = None ,
296
+ compression_types : List [str ]= None ):
297
+ return self ._worker_api .get_chunks_data (self ._session_id , worker , chunk_keys , indexes = indexes ,
298
+ compression_types = compression_types )
299
+
300
+ # Fetch tileable data by tileable keys and indexes.
301
+ def get_tileable_data (self , tileable_key : str , indexes : List = None ,
302
+ compression_types : List [str ]= None ):
303
+ from .serialize import dataserializer
304
+ from .utils import merge_chunks
305
+ from .tensor .core import TENSOR_TYPE
306
+ from .tensor .datasource import empty
307
+ from .tensor .indexing .getitem import TensorIndexTilesHandler
308
+
309
+ nsplits , chunk_keys , chunk_indexes = self .get_tileable_metas ([tileable_key ])[0 ]
310
+ chunk_idx_to_keys = dict (zip (chunk_indexes , chunk_keys ))
311
+ chunk_keys_to_idx = dict (zip (chunk_keys , chunk_indexes ))
312
+ endpoints = self .get_chunk_metas (chunk_keys , filter_fields = ['workers' ])
313
+ chunk_keys_to_worker = dict ((chunk_key , random .choice (es [0 ])) for es , chunk_key in zip (endpoints , chunk_keys ))
314
+
315
+ chunk_workers = defaultdict (list )
316
+ [chunk_workers [e ].append (chunk_key ) for chunk_key , e in chunk_keys_to_worker .items ()]
317
+
318
+ chunk_results = dict ()
319
+ if not indexes :
320
+ datas = []
321
+ for endpoint , chunks in chunk_workers .items ():
322
+ datas .append (self .get_chunks_data (endpoint , chunks , compression_types = compression_types ))
323
+ datas = [d .result () for d in datas ]
324
+ for (endpoint , chunks ), d in zip (chunk_workers .items (), datas ):
325
+ d = [dataserializer .loads (db ) for db in d ]
326
+ chunk_results .update (dict (zip ([chunk_keys_to_idx [k ] for k in chunks ], d )))
327
+ else :
328
+ # TODO: make a common util to handle indexes
329
+ if any (isinstance (ind , TENSOR_TYPE ) for ind in indexes ):
330
+ raise TypeError ("Doesn't support indexing by tensors" )
331
+ # Reuse the getitem logic to get each chunk's indexes
332
+ tileable_shape = tuple (sum (s ) for s in nsplits )
333
+ empty_tileable = empty (tileable_shape , chunk_size = nsplits )._inplace_tile ()
334
+ indexed = empty_tileable [tuple (indexes )]
335
+ index_handler = TensorIndexTilesHandler (indexed .op )
336
+ index_handler ._extract_indexes_info ()
337
+ index_handler ._preprocess_fancy_indexes ()
338
+ index_handler ._process_fancy_indexes ()
339
+ index_handler ._process_in_tensor ()
340
+
341
+ result_chunks = dict ()
342
+ for c in index_handler ._out_chunks :
343
+ result_chunks [chunk_idx_to_keys [c .inputs [0 ].index ]] = [c .index , c .op .indexes ]
344
+
345
+ chunk_datas = dict ()
346
+ for endpoint , chunks in chunk_workers .items ():
347
+ to_fetch_keys = []
348
+ to_fetch_indexes = []
349
+ to_fetch_idx = []
350
+ for r_chunk , (chunk_index , index_obj ) in result_chunks .items ():
351
+ if r_chunk in chunks :
352
+ to_fetch_keys .append (r_chunk )
353
+ to_fetch_indexes .append (index_obj )
354
+ to_fetch_idx .append (chunk_index )
355
+ if to_fetch_keys :
356
+ datas = self .get_chunks_data (endpoint , to_fetch_keys , indexes = to_fetch_indexes ,
357
+ compression_types = compression_types )
358
+ chunk_datas [tuple (to_fetch_idx )] = datas
359
+ chunk_datas = dict ((k , v .result ()) for k , v in chunk_datas .items ())
360
+ for idx , d in chunk_datas .items ():
361
+ d = [dataserializer .loads (db ) for db in d ]
362
+ chunk_results .update (dict (zip (idx , d )))
363
+
364
+ chunk_results = [(k , v ) for k , v in chunk_results .items ()]
365
+ if len (chunk_results ) == 1 :
366
+ ret = chunk_results [0 ][1 ]
367
+ else :
368
+ ret = merge_chunks (chunk_results )
369
+ return ret
370
+
272
371
273
372
class DistributedDictContext (DistributedContext , dict ):
274
373
def __init__ (self , * args , ** kwargs ):
0 commit comments