Skip to content

Commit 79bb1d4

Browse files
committed
Better support for subclasses and fix issues with >1 exception depth
1 parent dfce3b7 commit 79bb1d4

File tree

9 files changed

+289
-20
lines changed

9 files changed

+289
-20
lines changed

metaflow/plugins/env_escape/client.py

+55-14
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def decode(self, json_obj):
357357
# this connection will be converted to a local stub.
358358
return self._datatransferer.load(json_obj)
359359

360-
def get_local_class(self, name, obj_id=None, is_returned_exception=False):
360+
def get_local_class(
361+
self, name, obj_id=None, is_returned_exception=False, is_parent=False
362+
):
361363
# Gets (and creates if needed), the class mapping to the remote
362364
# class of name 'name'.
363365

@@ -367,6 +369,15 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
367369
# - classes that are proxied regular classes AND NOT proxied exceptions
368370
# - clases that are NOT proxied regular classes AND are proxied exceptions
369371
name = get_canonical_name(name, self._aliases)
372+
373+
def name_to_parent_name(name):
374+
return "parent:%s" % name
375+
376+
if is_parent:
377+
lookup_name = name_to_parent_name(name)
378+
else:
379+
lookup_name = name
380+
370381
if name == "function":
371382
# Special handling of pickled functions. We create a new class that
372383
# simply has a __call__ method that will forward things back to
@@ -378,15 +389,15 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
378389
self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, []
379390
)
380391
return self._proxied_standalone_functions[obj_id]
381-
local_class = self._proxied_classes.get(name, None)
392+
local_class = self._proxied_classes.get(lookup_name, None)
382393
if local_class is not None:
383394
return local_class
384395

385396
is_proxied_exception = name in self._exception_hierarchy
386397
is_proxied_non_exception = name in self._proxied_classnames
387398

388399
if not is_proxied_exception and not is_proxied_non_exception:
389-
if is_returned_exception:
400+
if is_returned_exception or is_parent:
390401
# In this case, it may be a local exception that we need to
391402
# recreate
392403
try:
@@ -405,7 +416,7 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
405416
dict(getattr(local_exception, "__dict__", {})),
406417
)
407418
wrapped_exception.__module__ = ex_module
408-
self._proxied_classes[name] = wrapped_exception
419+
self._proxied_classes[lookup_name] = wrapped_exception
409420
return wrapped_exception
410421

411422
raise ValueError("Class '%s' is not known" % name)
@@ -422,31 +433,61 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
422433
for parent in ex_parents:
423434
# We always consider it to be an exception so that we wrap even non
424435
# proxied builtins exceptions
425-
parents.append(self.get_local_class(parent, is_returned_exception=True))
436+
parents.append(self.get_local_class(parent, is_parent=True))
426437
# For regular classes, we get what it exposes from the server
427438
if is_proxied_non_exception:
428439
remote_methods = self.stub_request(None, OP_GETMETHODS, name)
429440
else:
430441
remote_methods = {}
431442

432-
if is_proxied_exception and not is_proxied_non_exception:
433-
# This is a pure exception
443+
parent_local_class = None
444+
local_class = None
445+
if is_proxied_exception:
446+
# If we are a proxied exception AND a proxied class, we create two classes:
447+
# actually:
448+
# - the class itself (which is a stub)
449+
# - the class in the capacity of a parent class (to another exception
450+
# presumably). The reason for this is that if we have a exception/proxied
451+
# class A and another B and B inherits from A, the MRO order would be all
452+
# wrong since both A and B would also inherit from `Stub`. Here what we
453+
# do is:
454+
# - A_parent inherits from the actual parents of A (let's assume a
455+
# builtin exception)
456+
# - A inherits from (Stub, A_parent)
457+
# - B_parent inherints from A_parent and the builtin Exception
458+
# - B inherits from (Stub, B_parent)
434459
ex_module, ex_name = name.rsplit(".", 1)
435-
local_class = ExceptionMetaClass(ex_name, (*parents,), {})
436-
local_class.__module__ = ex_module
437-
else:
438-
# This method creates either a pure stub or a stub that is also an exception
460+
parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {})
461+
parent_local_class.__module__ = ex_module
462+
463+
if is_proxied_non_exception:
439464
local_class = create_class(
440465
self,
441466
name,
442467
self._overrides.get(name, {}),
443468
self._getattr_overrides.get(name, {}),
444469
self._setattr_overrides.get(name, {}),
445470
remote_methods,
446-
parents,
471+
(parent_local_class,) if parent_local_class else None,
447472
)
448-
self._proxied_classes[name] = local_class
449-
return local_class
473+
if parent_local_class:
474+
self._proxied_classes[name_to_parent_name(name)] = parent_local_class
475+
if local_class:
476+
self._proxied_classes[name] = local_class
477+
else:
478+
# This is for the case of pure proxied exceptions -- we want the lookup of
479+
# foo.MyException to be the same class as looking of foo.MyException as a parent
480+
# of another exception so `isinstance` works properly
481+
self._proxied_classes[name] = parent_local_class
482+
483+
if is_parent:
484+
# This should never happen but making sure
485+
if not parent_local_class:
486+
raise RuntimeError(
487+
"Exception parent class %s is not a proxied exception" % name
488+
)
489+
return parent_local_class
490+
return self._proxied_classes[name]
450491

451492
def can_pickle(self, obj):
452493
return getattr(obj, "___connection___", None) == self

metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,30 @@ def unsupported_method(stub, func, *args, **kwargs):
5252

5353

5454
@local_exception_deserialize("test_lib.SomeException")
55-
def deserialize_user(ex, json_obj):
55+
def some_exception_deserialize(ex, json_obj):
5656
ex.user_value = json_obj
5757

5858

5959
@remote_exception_serialize("test_lib.SomeException")
6060
def some_exception_serialize(ex):
6161
return 42
62+
63+
64+
@local_exception_deserialize("test_lib.ExceptionAndClass")
65+
def exception_and_class_deserialize(ex, json_obj):
66+
ex.user_value = json_obj
67+
68+
69+
@remote_exception_serialize("test_lib.ExceptionAndClass")
70+
def exception_and_class_serialize(ex):
71+
return 43
72+
73+
74+
@local_exception_deserialize("test_lib.ExceptionAndClassChild")
75+
def exception_and_class_child_deserialize(ex, json_obj):
76+
ex.user_value = json_obj
77+
78+
79+
@remote_exception_serialize("test_lib.ExceptionAndClassChild")
80+
def exception_and_class_child_serialize(ex):
81+
return 44

metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
("test_lib", "test_lib.alias"): {
1414
"TestClass1": lib.TestClass1,
1515
"TestClass2": lib.TestClass2,
16+
"BaseClass": lib.BaseClass,
17+
"ChildClass": lib.ChildClass,
1618
"ExceptionAndClass": lib.ExceptionAndClass,
19+
"ExceptionAndClassChild": lib.ExceptionAndClassChild,
1720
}
1821
}
1922

@@ -22,6 +25,7 @@
2225
"SomeException": lib.SomeException,
2326
"MyBaseException": lib.MyBaseException,
2427
"ExceptionAndClass": lib.ExceptionAndClass,
28+
"ExceptionAndClassChild": lib.ExceptionAndClassChild,
2529
}
2630
}
2731

metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from html.parser import HTMLParser
23

34

45
class MyBaseException(Exception):
@@ -20,6 +21,36 @@ def __str__(self):
2021
return "ExceptionAndClass Str: %s" % super().__str__()
2122

2223

24+
class ExceptionAndClassChild(ExceptionAndClass):
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def method_on_child_exception(self):
29+
return "method_on_child_exception"
30+
31+
def __str__(self):
32+
return "ExceptionAndClassChild Str: %s" % super().__str__()
33+
34+
35+
class BaseClass(HTMLParser):
36+
def __init__(self, *args, **kwargs):
37+
super().__init__(*args, **kwargs)
38+
self._output = []
39+
40+
def handle_starttag(self, tag, attrs):
41+
self._output.append(tag)
42+
return super().handle_starttag(tag, attrs)
43+
44+
def get_output(self):
45+
return self._output
46+
47+
48+
class ChildClass(BaseClass):
49+
def handle_endtag(self, tag):
50+
self._output.append(tag)
51+
return super().handle_endtag(tag)
52+
53+
2354
class TestClass1(object):
2455
cls_object = 25
2556

@@ -72,6 +103,9 @@ def __hidden(self, name, value):
72103
def weird_indirection(self, name):
73104
return functools.partial(self.__hidden, name)
74105

106+
def returnChild(self):
107+
return ChildClass()
108+
75109
def raiseOrReturnValueError(self, doRaise=False):
76110
if doRaise:
77111
raise ValueError("I raised")
@@ -87,6 +121,11 @@ def raiseOrReturnExceptionAndClass(self, doRaise=False):
87121
raise ExceptionAndClass("I raised")
88122
return ExceptionAndClass("I returned")
89123

124+
def raiseOrReturnExceptionAndClassChild(self, doRaise=False):
125+
if doRaise:
126+
raise ExceptionAndClassChild("I raised")
127+
return ExceptionAndClassChild("I returned")
128+
90129

91130
class TestClass2(object):
92131
def __init__(self, value, stride, count):

metaflow/plugins/env_escape/consts.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
OP_SETVAL = 16
4040
OP_INIT = 17
4141
OP_CALLONCLASS = 18
42+
OP_SUBCLASSCHECK = 19
4243

4344
# Control messages
4445
CONTROL_SHUTDOWN = 1

metaflow/plugins/env_escape/exception_transferer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data
8585

8686

8787
def load_exception(client, json_obj):
88+
from .stub import Stub
89+
8890
json_obj = client.decode(json_obj)
8991

9092
if json_obj.get(FIELD_EXC_SI) is not None:
@@ -93,11 +95,16 @@ def load_exception(client, json_obj):
9395
exception_module = json_obj.get(FIELD_EXC_MODULE)
9496
exception_name = json_obj.get(FIELD_EXC_NAME)
9597
exception_class = None
98+
# This name is already cannonical since we cannonicalize it on the server side
9699
full_name = "%s.%s" % (exception_module, exception_name)
97100

98101
exception_class = client.get_local_class(full_name, is_returned_exception=True)
99102

100-
raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS))
103+
if issubclass(exception_class, Stub):
104+
raised_exception = exception_class(_is_returned_exception=True)
105+
raised_exception.args = tuple(json_obj.get(FIELD_EXC_ARGS))
106+
else:
107+
raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS))
101108
raised_exception._exception_str = json_obj.get(FIELD_EXC_STR, None)
102109
raised_exception._exception_repr = json_obj.get(FIELD_EXC_REPR, None)
103110
raised_exception._exception_tb = json_obj.get(FIELD_EXC_TB, None)

metaflow/plugins/env_escape/server.py

+18
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
OP_GETVAL,
3737
OP_SETVAL,
3838
OP_INIT,
39+
OP_SUBCLASSCHECK,
3940
VALUE_LOCAL,
4041
VALUE_REMOTE,
4142
CONTROL_GETEXPORTS,
@@ -255,6 +256,7 @@ def __init__(self, config_dir, max_pickle_version):
255256
OP_GETVAL: self._handle_getval,
256257
OP_SETVAL: self._handle_setval,
257258
OP_INIT: self._handle_init,
259+
OP_SUBCLASSCHECK: self._handle_subclasscheck,
258260
}
259261

260262
self._local_objects = {}
@@ -292,6 +294,7 @@ def encode(self, obj):
292294
def encode_exception(self, ex_type, ex, trace_back):
293295
try:
294296
full_name = "%s.%s" % (ex_type.__module__, ex_type.__name__)
297+
get_canonical_name(full_name, self._aliases)
295298
serializer = self._exception_serializers.get(full_name)
296299
except AttributeError:
297300
# Ignore if no __module__ for example -- definitely not something we built
@@ -502,6 +505,21 @@ def _handle_init(self, target, class_name, *args, **kwargs):
502505
raise ValueError("Unknown class %s" % class_name)
503506
return class_type(*args, **kwargs)
504507

508+
def _handle_subclasscheck(self, target, class_name, otherclass_name, reverse=False):
509+
class_type = self._known_classes.get(class_name)
510+
if class_type is None:
511+
raise ValueError("Unknown class %s" % class_name)
512+
try:
513+
sub_module, sub_name = otherclass_name.rsplit(".", 1)
514+
__import__(sub_module, None, None, "*")
515+
except Exception:
516+
sub_module = None
517+
if sub_module is None:
518+
return False
519+
if reverse:
520+
return issubclass(class_type, getattr(sys.modules[sub_module], sub_name))
521+
return issubclass(getattr(sys.modules[sub_module], sub_name), class_type)
522+
505523

506524
if __name__ == "__main__":
507525
max_pickle_version = int(sys.argv[1])

0 commit comments

Comments
 (0)