34
34
from .communication .socket_bytestream import SocketByteStream
35
35
36
36
from .data_transferer import DataTransferer , ObjReference
37
- from .exception_transferer import load_exception
38
- from .override_decorators import LocalAttrOverride , LocalException , LocalOverride
37
+ from .exception_transferer import ExceptionMetaClass , load_exception
38
+ from .override_decorators import (
39
+ LocalAttrOverride ,
40
+ LocalExceptionDeserializer ,
41
+ LocalOverride ,
42
+ )
39
43
from .stub import create_class
40
44
from .utils import get_canonical_name
41
45
@@ -193,28 +197,41 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
193
197
self ._proxied_classes = {
194
198
k : None
195
199
for k in itertools .chain (
196
- response [FIELD_CONTENT ]["classes" ], response [FIELD_CONTENT ]["proxied" ]
200
+ response [FIELD_CONTENT ]["classes" ],
201
+ response [FIELD_CONTENT ]["proxied" ],
202
+ (e [0 ] for e in response [FIELD_CONTENT ]["exceptions" ]),
197
203
)
198
204
}
199
205
206
+ self ._exception_hierarchy = dict (response [FIELD_CONTENT ]["exceptions" ])
207
+ self ._proxied_classnames = set (response [FIELD_CONTENT ]["classes" ]).union (
208
+ response [FIELD_CONTENT ]["proxied" ]
209
+ )
210
+ self ._aliases = response [FIELD_CONTENT ]["aliases" ]
211
+
200
212
# Determine all overrides
201
213
self ._overrides = {}
202
214
self ._getattr_overrides = {}
203
215
self ._setattr_overrides = {}
204
- self ._exception_overrides = {}
216
+ self ._exception_deserializers = {}
205
217
for override in override_values :
206
218
if isinstance (override , (LocalOverride , LocalAttrOverride )):
207
219
for obj_name , obj_funcs in override .obj_mapping .items ():
208
- if obj_name not in self ._proxied_classes :
220
+ canonical_name = get_canonical_name (obj_name , self ._aliases )
221
+ if canonical_name not in self ._proxied_classes :
209
222
raise ValueError (
210
223
"%s does not refer to a proxied or override type" % obj_name
211
224
)
212
225
if isinstance (override , LocalOverride ):
213
- override_dict = self ._overrides .setdefault (obj_name , {})
226
+ override_dict = self ._overrides .setdefault (canonical_name , {})
214
227
elif override .is_setattr :
215
- override_dict = self ._setattr_overrides .setdefault (obj_name , {})
228
+ override_dict = self ._setattr_overrides .setdefault (
229
+ canonical_name , {}
230
+ )
216
231
else :
217
- override_dict = self ._getattr_overrides .setdefault (obj_name , {})
232
+ override_dict = self ._getattr_overrides .setdefault (
233
+ canonical_name , {}
234
+ )
218
235
if isinstance (obj_funcs , str ):
219
236
obj_funcs = (obj_funcs ,)
220
237
for name in obj_funcs :
@@ -223,11 +240,18 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
223
240
"%s was already overridden for %s" % (name , obj_name )
224
241
)
225
242
override_dict [name ] = override .func
226
- if isinstance (override , LocalException ):
227
- cur_ex = self ._exception_overrides .get (override .class_path , None )
228
- if cur_ex is not None :
229
- raise ValueError ("Exception %s redefined" % override .class_path )
230
- self ._exception_overrides [override .class_path ] = override .wrapped_class
243
+ if isinstance (override , LocalExceptionDeserializer ):
244
+ canonical_name = get_canonical_name (override .class_path , self ._aliases )
245
+ if canonical_name not in self ._exception_hierarchy :
246
+ raise ValueError (
247
+ "%s does not refer to an exception type" % override .class_path
248
+ )
249
+ cur_des = self ._exception_deserializers .get (canonical_name , None )
250
+ if cur_des is not None :
251
+ raise ValueError (
252
+ "Exception %s has multiple deserializers" % override .class_path
253
+ )
254
+ self ._exception_deserializers [canonical_name ] = override .deserializer
231
255
232
256
# Proxied standalone functions are functions that are proxied
233
257
# as part of other objects like defaultdict for which we create a
@@ -243,8 +267,6 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
243
267
"aliases" : response [FIELD_CONTENT ]["aliases" ],
244
268
}
245
269
246
- self ._aliases = response [FIELD_CONTENT ]["aliases" ]
247
-
248
270
def __del__ (self ):
249
271
self .cleanup ()
250
272
@@ -288,8 +310,9 @@ def name(self):
288
310
def get_exports (self ):
289
311
return self ._export_info
290
312
291
- def get_local_exception_overrides (self ):
292
- return self ._exception_overrides
313
+ def get_exception_deserializer (self , name ):
314
+ cannonical_name = get_canonical_name (name , self ._aliases )
315
+ return self ._exception_deserializers .get (cannonical_name )
293
316
294
317
def stub_request (self , stub , request_type , * args , ** kwargs ):
295
318
# Encode the operation to send over the wire and wait for the response
@@ -313,7 +336,7 @@ def stub_request(self, stub, request_type, *args, **kwargs):
313
336
if response_type == MSG_REPLY :
314
337
return self .decode (response [FIELD_CONTENT ])
315
338
elif response_type == MSG_EXCEPTION :
316
- raise load_exception (self . _datatransferer , response [FIELD_CONTENT ])
339
+ raise load_exception (self , response [FIELD_CONTENT ])
317
340
elif response_type == MSG_INTERNAL_ERROR :
318
341
raise RuntimeError (
319
342
"Error in the server runtime:\n \n ===== SERVER TRACEBACK =====\n %s"
@@ -334,10 +357,27 @@ def decode(self, json_obj):
334
357
# this connection will be converted to a local stub.
335
358
return self ._datatransferer .load (json_obj )
336
359
337
- def get_local_class (self , name , obj_id = None ):
360
+ def get_local_class (
361
+ self , name , obj_id = None , is_returned_exception = False , is_parent = False
362
+ ):
338
363
# Gets (and creates if needed), the class mapping to the remote
339
364
# class of name 'name'.
365
+
366
+ # We actually deal with four types of classes:
367
+ # - proxied functions
368
+ # - classes that are proxied regular classes AND proxied exceptions
369
+ # - classes that are proxied regular classes AND NOT proxied exceptions
370
+ # - classes that are NOT proxied regular classes AND are proxied exceptions
340
371
name = get_canonical_name (name , self ._aliases )
372
+
373
+ def name_to_parent_name (name ):
374
+ return "parent:%s" % name
375
+
376
+ if is_parent :
377
+ lookup_name = name_to_parent_name (name )
378
+ else :
379
+ lookup_name = name
380
+
341
381
if name == "function" :
342
382
# Special handling of pickled functions. We create a new class that
343
383
# simply has a __call__ method that will forward things back to
@@ -346,27 +386,108 @@ def get_local_class(self, name, obj_id=None):
346
386
raise RuntimeError ("Local function unpickling without an object ID" )
347
387
if obj_id not in self ._proxied_standalone_functions :
348
388
self ._proxied_standalone_functions [obj_id ] = create_class (
349
- self , "__function_%s" % obj_id , {}, {}, {}, {"__call__" : "" }
389
+ self , "__function_%s" % obj_id , {}, {}, {}, {"__call__" : "" }, []
350
390
)
351
391
return self ._proxied_standalone_functions [obj_id ]
392
+ local_class = self ._proxied_classes .get (lookup_name , None )
393
+ if local_class is not None :
394
+ return local_class
395
+
396
+ is_proxied_exception = name in self ._exception_hierarchy
397
+ is_proxied_non_exception = name in self ._proxied_classnames
398
+
399
+ if not is_proxied_exception and not is_proxied_non_exception :
400
+ if is_returned_exception or is_parent :
401
+ # In this case, it may be a local exception that we need to
402
+ # recreate
403
+ try :
404
+ ex_module , ex_name = name .rsplit ("." , 1 )
405
+ __import__ (ex_module , None , None , "*" )
406
+ except Exception :
407
+ pass
408
+ if ex_module in sys .modules and issubclass (
409
+ getattr (sys .modules [ex_module ], ex_name ), BaseException
410
+ ):
411
+ # This is a local exception that we can recreate
412
+ local_exception = getattr (sys .modules [ex_module ], ex_name )
413
+ wrapped_exception = ExceptionMetaClass (
414
+ ex_name ,
415
+ (local_exception ,),
416
+ dict (getattr (local_exception , "__dict__" , {})),
417
+ )
418
+ wrapped_exception .__module__ = ex_module
419
+ self ._proxied_classes [lookup_name ] = wrapped_exception
420
+ return wrapped_exception
352
421
353
- if name not in self ._proxied_classes :
354
422
raise ValueError ("Class '%s' is not known" % name )
355
- local_class = self ._proxied_classes [name ]
356
- if local_class is None :
357
- # We need to build up this class. To do so, we take everything that the
358
- # remote class has and remove UNSUPPORTED things and overridden things
423
+
424
+ # At this stage:
425
+ # - we don't have a local_class for this
426
+ # - it is not an inbuilt exception so it is either a proxied exception, a
427
+ # proxied class or a proxied object that is both an exception and a class.
428
+
429
+ parents = []
430
+ if is_proxied_exception :
431
+ # If exception, we need to get the parents from the exception
432
+ ex_parents = self ._exception_hierarchy [name ]
433
+ for parent in ex_parents :
434
+ # We always consider it to be an exception so that we wrap even non
435
+ # proxied builtins exceptions
436
+ parents .append (self .get_local_class (parent , is_parent = True ))
437
+ # For regular classes, we get what it exposes from the server
438
+ if is_proxied_non_exception :
359
439
remote_methods = self .stub_request (None , OP_GETMETHODS , name )
440
+ else :
441
+ remote_methods = {}
442
+
443
+ parent_local_class = None
444
+ local_class = None
445
+ if is_proxied_exception :
446
+ # If we are a proxied exception AND a proxied class, we create two classes:
447
+ # actually:
448
+ # - the class itself (which is a stub)
449
+ # - the class in the capacity of a parent class (to another exception
450
+ # presumably). The reason for this is that if we have an exception/proxied
451
+ # class A and another B and B inherits from A, the MRO order would be all
452
+ # wrong since both A and B would also inherit from `Stub`. Here what we
453
+ # do is:
454
+ # - A_parent inherits from the actual parents of A (let's assume a
455
+ # builtin exception)
456
+ # - A inherits from (Stub, A_parent)
457
+ # - B_parent inherits from A_parent and the builtin Exception
458
+ # - B inherits from (Stub, B_parent)
459
+ ex_module , ex_name = name .rsplit ("." , 1 )
460
+ parent_local_class = ExceptionMetaClass (ex_name , (* parents ,), {})
461
+ parent_local_class .__module__ = ex_module
462
+
463
+ if is_proxied_non_exception :
360
464
local_class = create_class (
361
465
self ,
362
466
name ,
363
467
self ._overrides .get (name , {}),
364
468
self ._getattr_overrides .get (name , {}),
365
469
self ._setattr_overrides .get (name , {}),
366
470
remote_methods ,
471
+ (parent_local_class ,) if parent_local_class else None ,
367
472
)
473
+ if parent_local_class :
474
+ self ._proxied_classes [name_to_parent_name (name )] = parent_local_class
475
+ if local_class :
368
476
self ._proxied_classes [name ] = local_class
369
- return local_class
477
+ else :
478
+ # This is for the case of pure proxied exceptions -- we want the lookup of
479
+ # foo.MyException to be the same class as looking of foo.MyException as a parent
480
+ # of another exception so `isinstance` works properly
481
+ self ._proxied_classes [name ] = parent_local_class
482
+
483
+ if is_parent :
484
+ # This should never happen but making sure
485
+ if not parent_local_class :
486
+ raise RuntimeError (
487
+ "Exception parent class %s is not a proxied exception" % name
488
+ )
489
+ return parent_local_class
490
+ return self ._proxied_classes [name ]
370
491
371
492
def can_pickle (self , obj ):
372
493
return getattr (obj , "___connection___" , None ) == self
@@ -395,7 +516,7 @@ def unpickle_object(self, obj):
395
516
obj_id = obj .identifier
396
517
local_instance = self ._proxied_objects .get (obj_id )
397
518
if not local_instance :
398
- local_class = self .get_local_class (remote_class_name , obj_id )
519
+ local_class = self .get_local_class (remote_class_name , obj_id = obj_id )
399
520
local_instance = local_class (self , remote_class_name , obj_id )
400
521
return local_instance
401
522
0 commit comments