From dfce3b723563e1b3faf8e0069542931bc53a0ec4 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Thu, 15 Feb 2024 00:56:29 -0800 Subject: [PATCH 1/5] Various fixes in the env escape code We now handle exceptions that are also proxied classes better We canonicalize more names to better deal with aliases in proxied libraries. Other cleanup --- metaflow/plugins/env_escape/client.py | 134 ++++++++++++---- metaflow/plugins/env_escape/client_modules.py | 62 ++----- .../emulate_test_lib/overrides.py | 57 ++----- .../emulate_test_lib/server_mappings.py | 7 +- .../configurations/test_lib_impl/test_lib.py | 57 ++++--- .../env_escape/exception_transferer.py | 151 +++++------------- .../plugins/env_escape/override_decorators.py | 16 +- metaflow/plugins/env_escape/server.py | 29 +++- metaflow/plugins/env_escape/stub.py | 105 ++++++++++-- metaflow/plugins/env_escape/utils.py | 6 +- test/env_escape/example.py | 124 +++++++++----- 11 files changed, 426 insertions(+), 322 deletions(-) diff --git a/metaflow/plugins/env_escape/client.py b/metaflow/plugins/env_escape/client.py index 60ffdf09b68..a5b23f892ad 100644 --- a/metaflow/plugins/env_escape/client.py +++ b/metaflow/plugins/env_escape/client.py @@ -34,8 +34,12 @@ from .communication.socket_bytestream import SocketByteStream from .data_transferer import DataTransferer, ObjReference -from .exception_transferer import load_exception -from .override_decorators import LocalAttrOverride, LocalException, LocalOverride +from .exception_transferer import ExceptionMetaClass, load_exception +from .override_decorators import ( + LocalAttrOverride, + LocalExceptionDeserializer, + LocalOverride, +) from .stub import create_class from .utils import get_canonical_name @@ -193,28 +197,41 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d self._proxied_classes = { k: None for k in itertools.chain( - response[FIELD_CONTENT]["classes"], response[FIELD_CONTENT]["proxied"] + response[FIELD_CONTENT]["classes"], + response[FIELD_CONTENT]["proxied"], + (e[0] for e in response[FIELD_CONTENT]["exceptions"]), ) } + self._exception_hierarchy = dict(response[FIELD_CONTENT]["exceptions"]) + self._proxied_classnames = set(response[FIELD_CONTENT]["classes"]).union( + response[FIELD_CONTENT]["proxied"] + ) + self._aliases = response[FIELD_CONTENT]["aliases"] + # Determine all overrides self._overrides = {} self._getattr_overrides = {} self._setattr_overrides = {} - self._exception_overrides = {} + self._exception_deserializers = {} for override in override_values: if isinstance(override, (LocalOverride, LocalAttrOverride)): for obj_name, obj_funcs in override.obj_mapping.items(): - if obj_name not in self._proxied_classes: + canonical_name = get_canonical_name(obj_name, self._aliases) + if canonical_name not in self._proxied_classes: raise ValueError( "%s does not refer to a proxied or override type" % obj_name ) if isinstance(override, LocalOverride): - override_dict = self._overrides.setdefault(obj_name, {}) + override_dict = self._overrides.setdefault(canonical_name, {}) elif override.is_setattr: - override_dict = self._setattr_overrides.setdefault(obj_name, {}) + override_dict = self._setattr_overrides.setdefault( + canonical_name, {} + ) else: - override_dict = self._getattr_overrides.setdefault(obj_name, {}) + override_dict = self._getattr_overrides.setdefault( + canonical_name, {} + ) if isinstance(obj_funcs, str): obj_funcs = (obj_funcs,) for name in obj_funcs: @@ -223,11 +240,18 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d "%s was already overridden for %s" % (name, obj_name) ) override_dict[name] = override.func - if isinstance(override, LocalException): - cur_ex = self._exception_overrides.get(override.class_path, None) - if cur_ex is not None: - raise ValueError("Exception %s redefined" % override.class_path) - self._exception_overrides[override.class_path] = override.wrapped_class + if isinstance(override, LocalExceptionDeserializer): + canonical_name = get_canonical_name(override.class_path, self._aliases) + if canonical_name not in self._exception_hierarchy: + raise ValueError( + "%s does not refer to an exception type" % override.class_path + ) + cur_des = self._exception_deserializers.get(canonical_name, None) + if cur_des is not None: + raise ValueError( + "Exception %s has multiple deserializers" % override.class_path + ) + self._exception_deserializers[canonical_name] = override.deserializer # Proxied standalone functions are functions that are proxied # as part of other objects like defaultdict for which we create a @@ -243,8 +267,6 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d "aliases": response[FIELD_CONTENT]["aliases"], } - self._aliases = response[FIELD_CONTENT]["aliases"] - def __del__(self): self.cleanup() @@ -288,8 +310,9 @@ def name(self): def get_exports(self): return self._export_info - def get_local_exception_overrides(self): - return self._exception_overrides + def get_exception_deserializer(self, name): + cannonical_name = get_canonical_name(name, self._aliases) + return self._exception_deserializers.get(cannonical_name) def stub_request(self, stub, request_type, *args, **kwargs): # Encode the operation to send over the wire and wait for the response @@ -313,7 +336,7 @@ def stub_request(self, stub, request_type, *args, **kwargs): if response_type == MSG_REPLY: return self.decode(response[FIELD_CONTENT]) elif response_type == MSG_EXCEPTION: - raise load_exception(self._datatransferer, response[FIELD_CONTENT]) + raise load_exception(self, response[FIELD_CONTENT]) elif response_type == MSG_INTERNAL_ERROR: raise RuntimeError( "Error in the server runtime:\n\n===== SERVER TRACEBACK =====\n%s" @@ -334,9 +357,15 @@ def decode(self, json_obj): # this connection will be converted to a local stub. return self._datatransferer.load(json_obj) - def get_local_class(self, name, obj_id=None): + def get_local_class(self, name, obj_id=None, is_returned_exception=False): # Gets (and creates if needed), the class mapping to the remote # class of name 'name'. + + # We actually deal with four types of classes: + # - proxied functions + # - classes that are proxied regular classes AND proxied exceptions + # - classes that are proxied regular classes AND NOT proxied exceptions + # - clases that are NOT proxied regular classes AND are proxied exceptions name = get_canonical_name(name, self._aliases) if name == "function": # Special handling of pickled functions. We create a new class that @@ -346,17 +375,67 @@ def get_local_class(self, name, obj_id=None): raise RuntimeError("Local function unpickling without an object ID") if obj_id not in self._proxied_standalone_functions: self._proxied_standalone_functions[obj_id] = create_class( - self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""} + self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, [] ) return self._proxied_standalone_functions[obj_id] + local_class = self._proxied_classes.get(name, None) + if local_class is not None: + return local_class + + is_proxied_exception = name in self._exception_hierarchy + is_proxied_non_exception = name in self._proxied_classnames + + if not is_proxied_exception and not is_proxied_non_exception: + if is_returned_exception: + # In this case, it may be a local exception that we need to + # recreate + try: + ex_module, ex_name = name.rsplit(".", 1) + __import__(ex_module, None, None, "*") + except Exception: + pass + if ex_module in sys.modules and issubclass( + getattr(sys.modules[ex_module], ex_name), BaseException + ): + # This is a local exception that we can recreate + local_exception = getattr(sys.modules[ex_module], ex_name) + wrapped_exception = ExceptionMetaClass( + ex_name, + (local_exception,), + dict(getattr(local_exception, "__dict__", {})), + ) + wrapped_exception.__module__ = ex_module + self._proxied_classes[name] = wrapped_exception + return wrapped_exception - if name not in self._proxied_classes: raise ValueError("Class '%s' is not known" % name) - local_class = self._proxied_classes[name] - if local_class is None: - # We need to build up this class. To do so, we take everything that the - # remote class has and remove UNSUPPORTED things and overridden things + + # At this stage: + # - we don't have a local_class for this + # - it is not an inbuilt exception so it is either a proxied exception, a + # proxied class or a proxied object that is both an exception and a class. + + parents = [] + if is_proxied_exception: + # If exception, we need to get the parents from the exception + ex_parents = self._exception_hierarchy[name] + for parent in ex_parents: + # We always consider it to be an exception so that we wrap even non + # proxied builtins exceptions + parents.append(self.get_local_class(parent, is_returned_exception=True)) + # For regular classes, we get what it exposes from the server + if is_proxied_non_exception: remote_methods = self.stub_request(None, OP_GETMETHODS, name) + else: + remote_methods = {} + + if is_proxied_exception and not is_proxied_non_exception: + # This is a pure exception + ex_module, ex_name = name.rsplit(".", 1) + local_class = ExceptionMetaClass(ex_name, (*parents,), {}) + local_class.__module__ = ex_module + else: + # This method creates either a pure stub or a stub that is also an exception local_class = create_class( self, name, @@ -364,8 +443,9 @@ def get_local_class(self, name, obj_id=None): self._getattr_overrides.get(name, {}), self._setattr_overrides.get(name, {}), remote_methods, + parents, ) - self._proxied_classes[name] = local_class + self._proxied_classes[name] = local_class return local_class def can_pickle(self, obj): @@ -395,7 +475,7 @@ def unpickle_object(self, obj): obj_id = obj.identifier local_instance = self._proxied_objects.get(obj_id) if not local_instance: - local_class = self.get_local_class(remote_class_name, obj_id) + local_class = self.get_local_class(remote_class_name, obj_id=obj_id) local_instance = local_class(self, remote_class_name, obj_id) return local_instance diff --git a/metaflow/plugins/env_escape/client_modules.py b/metaflow/plugins/env_escape/client_modules.py index 2896f2c5156..aa20766414a 100644 --- a/metaflow/plugins/env_escape/client_modules.py +++ b/metaflow/plugins/env_escape/client_modules.py @@ -7,7 +7,6 @@ from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL from .client import Client -from .override_decorators import LocalException from .utils import get_canonical_name @@ -16,7 +15,7 @@ def _clean_client(client): class _WrappedModule(object): - def __init__(self, loader, prefix, exports, exception_classes, client): + def __init__(self, loader, prefix, exports, client): self._loader = loader self._prefix = prefix self._client = client @@ -24,19 +23,20 @@ def __init__(self, loader, prefix, exports, exception_classes, client): r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605 ) self._exports = {} - self._aliases = exports["aliases"] + self._aliases = exports.get("aliases", []) for k in ("classes", "functions", "values"): result = [] - for item in exports[k]: + for item in exports.get(k, []): m = is_match.match(item) if m: result.append(m.group(1)) self._exports[k] = result - self._exception_classes = {} - for k, v in exception_classes.items(): - m = is_match.match(k) + result = [] + for item, _ in exports.get("exceptions", []): + m = is_match.match(item) if m: - self._exception_classes[m.group(1)] = v + result.append(m.group(1)) + self._exports["exceptions"] = result def __getattr__(self, name): if name == "__loader__": @@ -50,8 +50,8 @@ def __getattr__(self, name): name = get_canonical_name(self._prefix + "." + name, self._aliases)[ len(self._prefix) + 1 : ] - if name in self._exports["classes"]: - # We load classes lazily + if name in self._exports["classes"] or name in self._exports["exceptions"]: + # We load classes and exceptions lazily return self._client.get_local_class("%s.%s" % (self._prefix, name)) elif name in self._exports["functions"]: # TODO: Grab doc back from the remote side like in _make_method @@ -67,8 +67,6 @@ def func(*args, **kwargs): return self._client.stub_request( None, OP_GETVAL, "%s.%s" % (self._prefix, name) ) - elif name in self._exception_classes: - return self._exception_classes[name] else: # Try to see if this is a submodule that we can load m = None @@ -173,7 +171,6 @@ def load_module(self, fullname): # Get information about overrides and what the server knows about exports = self._client.get_exports() - ex_overrides = self._client.get_local_exception_overrides() prefixes = set() export_classes = exports.get("classes", []) @@ -182,42 +179,13 @@ def load_module(self, fullname): export_exceptions = exports.get("exceptions", []) self._aliases = exports.get("aliases", {}) for name in itertools.chain( - export_classes, export_functions, export_values + export_classes, + export_functions, + export_values, + (e[0] for e in export_exceptions), ): splits = name.rsplit(".", 1) prefixes.add(splits[0]) - - # Now look at the exceptions coming from the server - formed_exception_classes = {} - for ex_name, ex_parents in export_exceptions: - # Exception is a tuple (name, (parents,)) - # Exceptions are also given in order of instantiation (ie: the - # server already topologically sorted them) - ex_class_dict = ex_overrides.get(ex_name, None) - if ex_class_dict is None: - ex_class_dict = {} - else: - ex_class_dict = dict(ex_class_dict.__dict__) - parents = [] - for fake_base in ex_parents: - if fake_base.startswith("builtins."): - # This is something we know of here - parents.append(eval(fake_base[9:])) - else: - # It's in formed_classes - parents.append(formed_exception_classes[fake_base]) - splits = ex_name.rsplit(".", 1) - ex_class_dict["__user_defined__"] = set(ex_class_dict.keys()) - new_class = type(splits[1], tuple(parents), ex_class_dict) - new_class.__module__ = splits[0] - new_class.__name__ = splits[1] - formed_exception_classes[ex_name] = new_class - - # Now update prefixes as needed - for name in formed_exception_classes: - splits = name.rsplit(".", 1) - prefixes.add(splits[0]) - # We will make sure that we create modules even for "empty" prefixes # because packages are always loaded hierarchically so if we have # something in `a.b.c` but nothing directly in `a`, we still need to @@ -235,7 +203,7 @@ def load_module(self, fullname): self._handled_modules = {} for prefix in prefixes: self._handled_modules[prefix] = _WrappedModule( - self, prefix, exports, formed_exception_classes, self._client + self, prefix, exports, self._client ) canonical_fullname = get_canonical_name(fullname, self._aliases) # Modules are created canonically but we need to return something for any diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py index 7418644e8f3..092455660f5 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py @@ -5,56 +5,44 @@ remote_override, remote_getattr_override, remote_setattr_override, - local_exception, + local_exception_deserialize, remote_exception_serialize, ) @local_override({"test_lib.TestClass1": "print_value"}) def local_print_value(stub, func): - print("Encoding before sending to server") v = func() - print("Adding 5") return v + 5 @remote_override({"test_lib.TestClass1": "print_value"}) def remote_print_value(obj, func): - print("Decoding from client") v = func() - print("Encoding for client") - return v + return v + 3 @local_getattr_override({"test_lib.TestClass1": "override_value"}) def local_get_value2(stub, name, func): - print("In local getattr override for %s" % name) r = func() - print("In local getattr override, got %s" % r) - return r - - -@local_setattr_override({"test_lib.TestClass1": "override_value"}) -def local_set_value2(stub, name, func, v): - print("In local setattr override for %s" % name) - r = func(v) - print("In local setattr override, got %s" % r) - return r + return r + 5 @remote_getattr_override({"test_lib.TestClass1": "override_value"}) def remote_get_value2(obj, name): - print("In remote getattr override for %s" % name) r = getattr(obj, name) - print("In remote getattr override, got %s" % r) + return r + 3 + + +@local_setattr_override({"test_lib.TestClass1": "override_value"}) +def local_set_value2(stub, name, func, v): + r = func(v + 5) return r @remote_setattr_override({"test_lib.TestClass1": "override_value"}) def remote_set_value2(obj, name, v): - print("In remote setattr override for %s" % name) - r = setattr(obj, name, v) - print("In remote setattr override, got %s" % r) + r = setattr(obj, name, v + 3) return r @@ -63,28 +51,9 @@ def unsupported_method(stub, func, *args, **kwargs): return NotImplementedError("Just because") -@local_override({"test_lib.package.TestClass3": "thirdfunction"}) -def iamthelocalthird(stub, func, val): - print("Locally the Third") - v = func(val) - return v - - -@remote_override({"test_lib.package.TestClass3": "thirdfunction"}) -def iamtheremotethird(obj, func, val): - print("Remotely the Third") - v = func(val) - return v - - -@local_exception("test_lib.SomeException") -class SomeException: - def __str__(self): - parent_val = super(self.__realclass__, self).__str__() - return parent_val + " In SomeException str override: %s" % self.user_value - - def _deserialize_user(self, json_obj): - self.user_value = json_obj +@local_exception_deserialize("test_lib.SomeException") +def deserialize_user(ex, json_obj): + ex.user_value = json_obj @remote_exception_serialize("test_lib.SomeException") diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py index 63de6f41369..847c263954b 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py @@ -10,17 +10,18 @@ import test_lib as lib EXPORTED_CLASSES = { - "test_lib": { + ("test_lib", "test_lib.alias"): { "TestClass1": lib.TestClass1, "TestClass2": lib.TestClass2, - "package.TestClass3": lib.TestClass3, + "ExceptionAndClass": lib.ExceptionAndClass, } } EXPORTED_EXCEPTIONS = { - "test_lib": { + ("test_lib", "test_lib.alias"): { "SomeException": lib.SomeException, "MyBaseException": lib.MyBaseException, + "ExceptionAndClass": lib.ExceptionAndClass, } } diff --git a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py index 420d148bc46..ead203bbcbe 100644 --- a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py +++ b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py @@ -9,8 +9,18 @@ class SomeException(MyBaseException): pass -class TestClass1(object): +class ExceptionAndClass(MyBaseException): + def __init__(self, *args): + super().__init__(*args) + + def method_on_exception(self): + return "method_on_exception" + + def __str__(self): + return "ExceptionAndClass Str: %s" % super().__str__() + +class TestClass1(object): cls_object = 25 def __init__(self, value): @@ -41,11 +51,11 @@ def to_class2(self, count, stride=1): return TestClass2(self._value, stride, count) @staticmethod - def somethingstatic(val): + def static_method(val): return val + 42 @classmethod - def somethingclass(cls): + def class_method(cls): return cls.cls_object @property @@ -56,13 +66,34 @@ def override_value(self): def override_value(self, value): self._value2 = value + def __hidden(self, name, value): + setattr(self, name, value) + + def weird_indirection(self, name): + return functools.partial(self.__hidden, name) + + def raiseOrReturnValueError(self, doRaise=False): + if doRaise: + raise ValueError("I raised") + return ValueError("I returned") + + def raiseOrReturnSomeException(self, doRaise=False): + if doRaise: + raise SomeException("I raised") + return SomeException("I returned") + + def raiseOrReturnExceptionAndClass(self, doRaise=False): + if doRaise: + raise ExceptionAndClass("I raised") + return ExceptionAndClass("I returned") + class TestClass2(object): def __init__(self, value, stride, count): self._mylist = [value + stride * i for i in range(count)] def something(self, val): - return "In Test2 with %s" % val + return "Test2:Something:%s" % val def __iter__(self): self._pos = 0 @@ -75,24 +106,6 @@ def __next__(self): raise StopIteration -class TestClass3(object): - def __init__(self): - print("I am Class3") - - def thirdfunction(self, val): - print("Got value: %s" % val) - # raise AttributeError("Some weird error") - - def raiseSomething(self): - raise SomeException("Something went wrong") - - def __hidden(self, name, value): - setattr(self, name, value) - - def weird_indirection(self, name): - return functools.partial(self.__hidden, name) - - def test_func(*args, **kwargs): return "In test func" diff --git a/metaflow/plugins/env_escape/exception_transferer.py b/metaflow/plugins/env_escape/exception_transferer.py index c504fd13c7e..97479e9fe62 100644 --- a/metaflow/plugins/env_escape/exception_transferer.py +++ b/metaflow/plugins/env_escape/exception_transferer.py @@ -4,11 +4,9 @@ try: # Import from client from .data_transferer import DataTransferer - from .stub import Stub except ImportError: # Import from server from data_transferer import DataTransferer - from stub import Stub # This file is heavily inspired from the RPYC project @@ -39,7 +37,6 @@ FIELD_EXC_MODULE = "m" FIELD_EXC_NAME = "n" FIELD_EXC_ARGS = "arg" -FIELD_EXC_ATTR = "atr" FIELD_EXC_TB = "tb" FIELD_EXC_USER = "u" FIELD_EXC_SI = "si" @@ -54,7 +51,6 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data traceback.format_exception(exception_type, exception_val, tb) ) exception_args = [] - exception_attrs = [] str_repr = None repr_repr = None for name in dir(exception_val): @@ -72,20 +68,10 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data repr_repr = repr(exception_val) elif name.startswith("_") or name == "with_traceback": continue - else: - try: - attr = getattr(exception_val, name) - except AttributeError: - continue - if DataTransferer.can_simple_dump(attr): - exception_attrs.append((name, attr)) - else: - exception_attrs.append((name, repr(attr))) to_return = { FIELD_EXC_MODULE: exception_type.__module__, FIELD_EXC_NAME: exception_type.__name__, FIELD_EXC_ARGS: exception_args, - FIELD_EXC_ATTR: exception_attrs, FIELD_EXC_TB: local_formatted_exception, FIELD_EXC_STR: str_repr, FIELD_EXC_REPR: repr_repr, @@ -98,121 +84,62 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data return data_transferer.dump(to_return) -def load_exception(data_transferer, json_obj): - json_obj = data_transferer.load(json_obj) +def load_exception(client, json_obj): + json_obj = client.decode(json_obj) + if json_obj.get(FIELD_EXC_SI) is not None: return StopIteration exception_module = json_obj.get(FIELD_EXC_MODULE) exception_name = json_obj.get(FIELD_EXC_NAME) exception_class = None - if exception_module not in sys.modules: - # Try to import the module - try: - # Use __import__ so that the user can access this exception - __import__(exception_module, None, None, "*") - except Exception: - pass - # Try again (will succeed if the __import__ worked) - if exception_module in sys.modules: - exception_class = getattr(sys.modules[exception_module], exception_name, None) - if exception_class is None or issubclass(exception_class, Stub): - # Best effort to "recreate" an exception. Note that in some cases, exceptions - # may actually be both exceptions we can transfer as well as classes we - # can transfer (stubs) but for exceptions, we don't want to use the stub - # otherwise it will just ping pong. - name = "%s.%s" % (exception_module, exception_name) - exception_class = _remote_exceptions_class.setdefault( - name, - type( - name, - (RemoteInterpreterException,), - {"__module__": "%s/%s" % (__name__, exception_module)}, - ), - ) - exception_class = _wrap_exception(exception_class) - raised_exception = exception_class.__new__(exception_class) - raised_exception.args = json_obj.get(FIELD_EXC_ARGS) - for name, attr in json_obj.get(FIELD_EXC_ATTR): - try: - if name in raised_exception.__user_defined__: - setattr(raised_exception, "_original_%s" % name, attr) - else: - setattr(raised_exception, name, attr) - except AttributeError: - # In case some things are read only - pass - s = json_obj.get(FIELD_EXC_STR) - if s: - try: - if "__str__" in raised_exception.__user_defined__: - setattr(raised_exception, "_original___str__", s) - else: - setattr(raised_exception, "__str__", lambda x, s=s: s) - except AttributeError: - raised_exception._missing_str = True - s = json_obj.get(FIELD_EXC_REPR) - if s: - try: - if "__repr__" in raised_exception.__user_defined__: - setattr(raised_exception, "_original___repr__", s) - else: - setattr(raised_exception, "__repr__", lambda x, s=s: s) - except AttributeError: - raised_exception._missing_repr = True + full_name = "%s.%s" % (exception_module, exception_name) + + exception_class = client.get_local_class(full_name, is_returned_exception=True) + + raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS)) + raised_exception._exception_str = json_obj.get(FIELD_EXC_STR, None) + raised_exception._exception_repr = json_obj.get(FIELD_EXC_REPR, None) + raised_exception._exception_tb = json_obj.get(FIELD_EXC_TB, None) + user_args = json_obj.get(FIELD_EXC_USER) if user_args is not None: - try: - deserializer = getattr(raised_exception, "_deserialize_user") - except AttributeError: - raised_exception._missing_deserializer = True - else: - deserializer(user_args) - raised_exception._remote_tb = json_obj[FIELD_EXC_TB] + deserializer = client.get_exception_deserializer(full_name) + if deserializer is not None: + deserializer(raised_exception, user_args) return raised_exception -def _wrap_exception(exception_class): - to_return = _derived_exceptions.get(exception_class) - if to_return is not None: - return to_return - - class WithPrettyPrinting(exception_class): - def __str__(self): - try: - text = super(WithPrettyPrinting, self).__str__() - except: # noqa E722 - text = "" - # if getattr(self, "_missing_deserializer", False): - # text += ( - # "\n\n===== WARNING: User data from the exception was not deserialized " - # "-- possible missing information =====\n" - # ) - # if getattr(self, "_missing_str", False): - # text += "\n\n===== WARNING: Could not set class specific __str__ " - # "-- possible missing information =====\n" - # if getattr(self, "_missing_repr", False): - # text += "\n\n===== WARNING: Could not set class specific __repr__ " - # "-- possible missing information =====\n" - remote_tb = getattr(self, "_remote_tb", "No remote traceback available") +class ExceptionMetaClass(type): + def __init__(cls, class_name, base_classes, class_dict): + super(ExceptionMetaClass, cls).__init__(class_name, base_classes, class_dict) + cls.__orig_str__ = cls.__str__ + cls.__orig_repr__ = cls.__repr__ + for n in ("_exception_str", "_exception_repr", "_exception_tb"): + setattr( + cls, + n, + property( + lambda self, n=n: getattr(self, "%s_val" % n, ""), + lambda self, v, n=n: setattr(self, "%s_val" % n, v), + ), + ) + + def _do_str(self): + text = self._exception_str text += "\n\n===== Remote (on server) traceback =====\n" - text += remote_tb + text += self._exception_tb text += "========================================\n" return text - WithPrettyPrinting.__name__ = exception_class.__name__ - WithPrettyPrinting.__module__ = exception_class.__module__ - WithPrettyPrinting.__realclass__ = exception_class - _derived_exceptions[exception_class] = WithPrettyPrinting - return WithPrettyPrinting + cls.__str__ = _do_str + cls.__repr__ = lambda self: self._exception_repr class RemoteInterpreterException(Exception): - """A 'generic exception' that is raised when the exception the gotten from - the remote server cannot be instantiated locally""" + """ + A 'generic' exception that was raised on the server side for which we have no + equivalent exception on this side + """ pass - - -_remote_exceptions_class = {} # Exception name -> type of that exception -_derived_exceptions = {} diff --git a/metaflow/plugins/env_escape/override_decorators.py b/metaflow/plugins/env_escape/override_decorators.py index cb9fd9cc099..095e1eb1786 100644 --- a/metaflow/plugins/env_escape/override_decorators.py +++ b/metaflow/plugins/env_escape/override_decorators.py @@ -110,18 +110,18 @@ def _wrapped(func): return _wrapped -class LocalException(object): - def __init__(self, class_path, wrapped_class): +class LocalExceptionDeserializer(object): + def __init__(self, class_path, deserializer): self._class_path = class_path - self._class = wrapped_class + self._deserializer = deserializer @property def class_path(self): return self._class_path @property - def wrapped_class(self): - return self._class + def deserializer(self): + return self._deserializer class RemoteExceptionSerializer(object): @@ -138,9 +138,9 @@ def serializer(self): return self._serializer -def local_exception(class_path): - def _wrapped(cls): - return LocalException(class_path, cls) +def local_exception_deserialize(class_path): + def _wrapped(func): + return LocalExceptionDeserializer(class_path, func) return _wrapped diff --git a/metaflow/plugins/env_escape/server.py b/metaflow/plugins/env_escape/server.py index a78e870c883..e7cb64f341d 100644 --- a/metaflow/plugins/env_escape/server.py +++ b/metaflow/plugins/env_escape/server.py @@ -113,9 +113,21 @@ def __init__(self, config_dir, max_pickle_version): # this by listing aliases in the same order so we don't support # it for now. raise ValueError( - "%s is an alias to both %s and %s" % (alias, base_name, a) + "%s is an alias to both %s and %s -- make sure all aliases " + "are listed in the same order" % (alias, base_name, a) ) + # Detect circular aliaes. If a user lists ("a", "b") and then ("b", "a"), we + # will have an entry in aliases saying b is an alias for a and a is an alias + # for b which is a recipe for disaster since we no longer have a cannonical name + # for things. + for alias, base_name in self._aliases.items(): + if base_name in self._aliases: + raise ValueError( + "%s and %s are circular aliases -- make sure all aliases " + "are listed in the same order" % (alias, base_name) + ) + # Determine if we have any overrides self._overrides = {} self._getattr_overrides = {} @@ -124,8 +136,9 @@ def __init__(self, config_dir, max_pickle_version): for override in override_values: if isinstance(override, (RemoteAttrOverride, RemoteOverride)): for obj_name, obj_funcs in override.obj_mapping.items(): + canonical_name = get_canonical_name(obj_name, self._aliases) obj_type = self._known_classes.get( - obj_name, self._proxied_types.get(obj_name) + canonical_name, self._proxied_types.get(obj_name) ) if obj_type is None: raise ValueError( @@ -146,11 +159,17 @@ def __init__(self, config_dir, max_pickle_version): ) override_dict[name] = override.func elif isinstance(override, RemoteExceptionSerializer): + canonical_name = get_canonical_name(override.class_path, self._aliases) + if canonical_name not in self._known_exceptions: + raise ValueError( + "%s does not refer to an exported exception" + % override.class_path + ) if override.class_path in self._exception_serializers: raise ValueError( "%s exception serializer already defined" % override.class_path ) - self._exception_serializers[override.class_path] = override.serializer + self._exception_serializers[canonical_name] = override.serializer # Process the exceptions making sure we have all the ones we need and building a # topologically sorted list for the client to instantiate @@ -181,8 +200,8 @@ def __init__(self, config_dir, max_pickle_version): else: raise ValueError( "Exported exception %s has non exported and non builtin parent " - "exception: %s. Known exceptions: %s" - % (ex_name, fqn, str(self._known_exceptions)) + "exception: %s (%s). Known exceptions: %s." + % (ex_name, fqn, canonical_fqn, str(self._known_exceptions)) ) name_to_parent_count[ex_name_canonical] = len(parents) - 1 name_to_parents[ex_name_canonical] = parents diff --git a/metaflow/plugins/env_escape/stub.py b/metaflow/plugins/env_escape/stub.py index 1e341c52d4a..1ed322a8b5a 100644 --- a/metaflow/plugins/env_escape/stub.py +++ b/metaflow/plugins/env_escape/stub.py @@ -17,6 +17,8 @@ OP_INIT, ) +from .exception_transferer import ExceptionMetaClass + DELETED_ATTRS = frozenset(["__array_struct__", "__array_interface__"]) # These attributes are accessed directly on the stub (not directly forwarded) @@ -26,7 +28,10 @@ "___remote_class_name___", "___identifier___", "___connection___", - "___local_overrides___" "__class__", + "___local_overrides___", + "___is_returned_exception___", + "___exception_attributes___", + "__class__", "__init__", "__del__", "__delattr__", @@ -66,8 +71,6 @@ def fwd_request(stub, request_type, *args, **kwargs): class StubMetaClass(type): - __slots__ = () - def __repr__(self): if self.__module__: return "" % (self.__module__, self.__name__) @@ -94,24 +97,25 @@ class Stub(with_metaclass(StubMetaClass, object)): happen on the remote side (server). """ - __slots__ = [ - "___remote_class_name___", - "___identifier___", - "___connection___", - "__weakref__", - ] - + __slots__ = () # def __iter__(self): # FIXME: Keep debugger QUIET!! # raise AttributeError - def __init__(self, connection, remote_class_name, identifier): + def __init__( + self, connection, remote_class_name, identifier, is_returned_exception=False + ): self.___remote_class_name___ = remote_class_name self.___identifier___ = identifier self.___connection___ = connection + # If it is a returned exception (ie: it was raised by the server), it behaves + # a bit differently for methods like __str__ and __repr__ (we try not to get + # stuff from the server) + self.___is_returned_exception___ = is_returned_exception def __del__(self): try: - fwd_request(self, OP_DEL) + if not self.___is_returned_exception___: + fwd_request(self, OP_DEL) except Exception: # raised in a destructor, most likely on program termination, # when the connection might have already been closed. @@ -144,10 +148,16 @@ def __delattr__(self, name): if name in LOCAL_ATTRS: object.__delattr__(self, name) else: + if self.___is_returned_exception___: + raise AttributeError() return fwd_request(self, OP_DELATTR, name) def __setattr__(self, name, value): - if name in LOCAL_ATTRS or name in self.___local_overrides___: + if ( + name in LOCAL_ATTRS + or name in self.___local_overrides___ + or self.___is_returned_exception___ + ): object.__setattr__(self, name, value) else: fwd_request(self, OP_SETATTR, name, value) @@ -159,9 +169,13 @@ def __hash__(self): return fwd_request(self, OP_HASH) def __repr__(self): + if self.___is_returned_exception___: + return self.__exception_repr__() return fwd_request(self, OP_REPR) def __str__(self): + if self.___is_returned_exception___: + return self.__exception_str__() return fwd_request(self, OP_STR) def __exit__(self, exc, typ, tb): @@ -249,6 +263,33 @@ def __call__(cls, *args, **kwargs): ) +class MetaExceptionWithConnection(StubMetaClass, ExceptionMetaClass): + def __new__(cls, class_name, base_classes, class_dict, connection): + return type.__new__(cls, class_name, base_classes, class_dict) + + def __init__(cls, class_name, base_classes, class_dict, connection): + cls.___class_remote_class_name___ = class_name + cls.___class_connection___ = connection + + # We call the one on ExceptionMetaClass which does everything needed (StubMetaClass + # does not do anything special for init) + ExceptionMetaClass.__init__(cls, class_name, base_classes, class_dict) + + # Restore __str__ and __repr__ to the original ones because we need to determine + # if we call them depending on whether or not the object is a returned exception + # or not + cls.__str__ = cls.__orig_str__ + cls.__repr__ = cls.__orig_repr__ + + def __call__(cls, *args, **kwargs): + if len(args) > 0 and id(args[0]) == id(cls.___class_connection___): + return super(MetaExceptionWithConnection, cls).__call__(*args, **kwargs) + else: + return cls.___class_connection___.stub_request( + None, OP_INIT, cls.___class_remote_class_name___, *args, **kwargs + ) + + def create_class( connection, class_name, @@ -256,8 +297,16 @@ def create_class( getattr_overrides, setattr_overrides, class_methods, + parents, ): - class_dict = {"__slots__": ()} + class_dict = { + "__slots__": [ + "___remote_class_name___", + "___identifier___", + "___connection___", + "___is_returned_exception___", + ] + } for name, doc in class_methods.items(): method_type = NORMAL_METHOD if name.startswith("___s___"): @@ -318,5 +367,33 @@ def create_class( ) overriden_attrs.add(attr) class_dict[attr] = property(getter, setter) + if parents: + # This means this is also an exception so we add a few more things to it + # so that it + # This is copied from ExceptionMetaClass in exception_transferer.py + for n in ("_exception_str", "_exception_repr", "_exception_tb"): + class_dict[n] = property( + lambda self, n=n: getattr(self, "%s_val" % n, ""), + lambda self, v, n=n: setattr(self, "%s_val" % n, v), + ) + + def _do_str(self): + text = self._exception_str + text += "\n\n===== Remote (on server) traceback =====\n" + text += self._exception_tb + text += "========================================\n" + return text + + class_dict["__exception_str__"] = _do_str + class_dict["__exception_repr__"] = lambda self: self._exception_repr + else: + # If we are based on an exception, we already have __weakref__ so we don't add + # it but not the case if we are not. + class_dict["__slots__"].append("__weakref__") + class_dict["___local_overrides___"] = overriden_attrs + if parents: + return MetaExceptionWithConnection( + class_name, (Stub, *parents), class_dict, connection + ) return MetaWithConnection(class_name, (Stub,), class_dict, connection) diff --git a/metaflow/plugins/env_escape/utils.py b/metaflow/plugins/env_escape/utils.py index f9fe1aacf0f..c0e00c3a0f6 100644 --- a/metaflow/plugins/env_escape/utils.py +++ b/metaflow/plugins/env_escape/utils.py @@ -13,12 +13,12 @@ def get_methods(class_object): for base_class in mros: all_attributes.update(base_class.__dict__) for name, attribute in all_attributes.items(): - if hasattr(attribute, "__call__"): - all_methods[name] = inspect.getdoc(attribute) - elif isinstance(attribute, staticmethod): + if isinstance(attribute, staticmethod): all_methods["___s___%s" % name] = inspect.getdoc(attribute) elif isinstance(attribute, classmethod): all_methods["___c___%s" % name] = inspect.getdoc(attribute) + elif hasattr(attribute, "__call__"): + all_methods[name] = inspect.getdoc(attribute) return all_methods diff --git a/test/env_escape/example.py b/test/env_escape/example.py index bef9d0871f7..bfdf7279a60 100644 --- a/test/env_escape/example.py +++ b/test/env_escape/example.py @@ -28,41 +28,47 @@ def run_test(through_escape=False): import test_lib as test + print("-- Test aliasing --") if through_escape: # This tests package aliasing - from test_lib.package import TestClass3 - else: - from test_lib import TestClass3 + from test_lib.alias import TestClass1 - o1 = test.TestClass1(123) - print("-- Test print_value --") + o1 = test.TestClass1(10) + print("-- Test normal method with overrides --") if through_escape: - # The server_mapping should add 5 here - assert o1.print_value() == 128 + expected_value = 10 + 8 else: - assert o1.print_value() == 123 - print("-- Test property --") - assert o1.value == 123 - print("-- Test value override (get) --") - assert o1.override_value == 123 - print("-- Test value override (set) --") - o1.override_value = 456 - assert o1.override_value == 456 + expected_value = 10 + assert o1.print_value() == expected_value - print("-- Test static method --") - assert test.TestClass1.somethingstatic(5) == 47 - assert o1.somethingstatic(5) == 47 - print("-- Test class method --") - assert test.TestClass1.somethingclass() == 25 - assert o1.somethingclass() == 25 + print("-- Test property (no override) --") + assert o1.value == 10 + o1.value = 15 + assert o1.value == 15 + if through_escape: + expected_value = 15 + 8 + else: + expected_value = 15 + assert o1.print_value() == expected_value - print("-- Test set and get --") - o1.value = 2 + print("-- Test property (with override) --") if through_escape: - # The server_mapping should add 5 here - assert o1.print_value() == 7 + expected_value = 123 + 8 + expected_value2 = 200 + 16 else: - assert o1.print_value() == 2 + expected_value = 123 + expected_value2 = 200 + assert o1.override_value == expected_value + o1.override_value = 200 + assert o1.override_value == expected_value2 + + print("-- Test static method --") + assert test.TestClass1.static_method(5) == 47 + assert o1.static_method(5) == 47 + + print("-- Test class method --") + assert test.TestClass1.class_method() == 25 + assert o1.class_method() == 25 print("-- Test function --") assert test.test_func() == "In test func" @@ -74,21 +80,65 @@ def run_test(through_escape=False): print("-- Test chaining of exported classes --") o2 = o1.to_class2(5) - assert o2.something("foo") == "In Test2 with foo" + assert o2.something("foo") == "Test2:Something:foo" + print("-- Test Iterating --") - for i in o2: - print("Got %d" % i) + for idx, i in enumerate(o2): + assert idx == i - 15 + assert i == 19 + + print("-- Test weird indirection --") + o1.weird_indirection("foo")(10) + assert o1.foo == 10 + o1.weird_indirection("_value")(20) + assert o1.value == 20 - print("-- Test exception --") - o3 = TestClass3() + print("-- Test exceptions --") + + # Non proxied exceptions can't be returned as objects try: - o3.raiseSomething() - except test.SomeException as e: - print("Caught the local exception: %s" % str(e)) + vexc = o1.raiseOrReturnValueError() + assert not through_escape, "Should have raised through escape" + assert isinstance(vexc, ValueError) + except RuntimeError as e: + assert ( + through_escape + and "Cannot proxy value of type " in str(e) + ) + + try: + excclass = o1.raiseOrReturnSomeException() + assert not through_escape, "Should have raised through escape" + assert isinstance(excclass, test.SomeException) + except RuntimeError as e: + assert ( + through_escape + and "Cannot proxy value of type " in str(e) + ) - print("-- Test returning proxied object --") - o3.weird_indirection("foo")(10) - assert o3.foo == 10 + exception_and_class = o1.raiseOrReturnExceptionAndClass() + assert isinstance(exception_and_class, test.ExceptionAndClass) + assert exception_and_class.method_on_exception() == "method_on_exception" + assert str(exception_and_class).startswith("ExceptionAndClass Str:") + + try: + o1.raiseOrReturnValueError(True) + assert False, "Should have raised" + except ValueError as e: + assert True + except Exception as e: + assert False, "Should have been ValueError" + + try: + o1.raiseOrReturnSomeException(True) + assert False, "Should have raised" + except test.SomeException as e: + assert True + if through_escape: + assert e.user_value == 42 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been SomeException" class EscapeTest(FlowSpec): From 79bb1d455561772b590e25c942f90e81cfee89d0 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Fri, 16 Feb 2024 11:34:25 -0800 Subject: [PATCH 2/5] Better support for subclasses and fix issues with >1 exception depth --- metaflow/plugins/env_escape/client.py | 69 +++++++++++++---- .../emulate_test_lib/overrides.py | 22 +++++- .../emulate_test_lib/server_mappings.py | 4 + .../configurations/test_lib_impl/test_lib.py | 39 ++++++++++ metaflow/plugins/env_escape/consts.py | 1 + .../env_escape/exception_transferer.py | 9 ++- metaflow/plugins/env_escape/server.py | 18 +++++ metaflow/plugins/env_escape/stub.py | 70 ++++++++++++++++- test/env_escape/example.py | 77 ++++++++++++++++++- 9 files changed, 289 insertions(+), 20 deletions(-) diff --git a/metaflow/plugins/env_escape/client.py b/metaflow/plugins/env_escape/client.py index a5b23f892ad..65637ecbf72 100644 --- a/metaflow/plugins/env_escape/client.py +++ b/metaflow/plugins/env_escape/client.py @@ -357,7 +357,9 @@ def decode(self, json_obj): # this connection will be converted to a local stub. return self._datatransferer.load(json_obj) - def get_local_class(self, name, obj_id=None, is_returned_exception=False): + def get_local_class( + self, name, obj_id=None, is_returned_exception=False, is_parent=False + ): # Gets (and creates if needed), the class mapping to the remote # class of name 'name'. @@ -367,6 +369,15 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): # - classes that are proxied regular classes AND NOT proxied exceptions # - clases that are NOT proxied regular classes AND are proxied exceptions name = get_canonical_name(name, self._aliases) + + def name_to_parent_name(name): + return "parent:%s" % name + + if is_parent: + lookup_name = name_to_parent_name(name) + else: + lookup_name = name + if name == "function": # Special handling of pickled functions. We create a new class that # simply has a __call__ method that will forward things back to @@ -378,7 +389,7 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, [] ) return self._proxied_standalone_functions[obj_id] - local_class = self._proxied_classes.get(name, None) + local_class = self._proxied_classes.get(lookup_name, None) if local_class is not None: return local_class @@ -386,7 +397,7 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): is_proxied_non_exception = name in self._proxied_classnames if not is_proxied_exception and not is_proxied_non_exception: - if is_returned_exception: + if is_returned_exception or is_parent: # In this case, it may be a local exception that we need to # recreate try: @@ -405,7 +416,7 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): dict(getattr(local_exception, "__dict__", {})), ) wrapped_exception.__module__ = ex_module - self._proxied_classes[name] = wrapped_exception + self._proxied_classes[lookup_name] = wrapped_exception return wrapped_exception raise ValueError("Class '%s' is not known" % name) @@ -422,20 +433,34 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): for parent in ex_parents: # We always consider it to be an exception so that we wrap even non # proxied builtins exceptions - parents.append(self.get_local_class(parent, is_returned_exception=True)) + parents.append(self.get_local_class(parent, is_parent=True)) # For regular classes, we get what it exposes from the server if is_proxied_non_exception: remote_methods = self.stub_request(None, OP_GETMETHODS, name) else: remote_methods = {} - if is_proxied_exception and not is_proxied_non_exception: - # This is a pure exception + parent_local_class = None + local_class = None + if is_proxied_exception: + # If we are a proxied exception AND a proxied class, we create two classes: + # actually: + # - the class itself (which is a stub) + # - the class in the capacity of a parent class (to another exception + # presumably). The reason for this is that if we have a exception/proxied + # class A and another B and B inherits from A, the MRO order would be all + # wrong since both A and B would also inherit from `Stub`. Here what we + # do is: + # - A_parent inherits from the actual parents of A (let's assume a + # builtin exception) + # - A inherits from (Stub, A_parent) + # - B_parent inherints from A_parent and the builtin Exception + # - B inherits from (Stub, B_parent) ex_module, ex_name = name.rsplit(".", 1) - local_class = ExceptionMetaClass(ex_name, (*parents,), {}) - local_class.__module__ = ex_module - else: - # This method creates either a pure stub or a stub that is also an exception + parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {}) + parent_local_class.__module__ = ex_module + + if is_proxied_non_exception: local_class = create_class( self, name, @@ -443,10 +468,26 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False): self._getattr_overrides.get(name, {}), self._setattr_overrides.get(name, {}), remote_methods, - parents, + (parent_local_class,) if parent_local_class else None, ) - self._proxied_classes[name] = local_class - return local_class + if parent_local_class: + self._proxied_classes[name_to_parent_name(name)] = parent_local_class + if local_class: + self._proxied_classes[name] = local_class + else: + # This is for the case of pure proxied exceptions -- we want the lookup of + # foo.MyException to be the same class as looking of foo.MyException as a parent + # of another exception so `isinstance` works properly + self._proxied_classes[name] = parent_local_class + + if is_parent: + # This should never happen but making sure + if not parent_local_class: + raise RuntimeError( + "Exception parent class %s is not a proxied exception" % name + ) + return parent_local_class + return self._proxied_classes[name] def can_pickle(self, obj): return getattr(obj, "___connection___", None) == self diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py index 092455660f5..120b32e31df 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py @@ -52,10 +52,30 @@ def unsupported_method(stub, func, *args, **kwargs): @local_exception_deserialize("test_lib.SomeException") -def deserialize_user(ex, json_obj): +def some_exception_deserialize(ex, json_obj): ex.user_value = json_obj @remote_exception_serialize("test_lib.SomeException") def some_exception_serialize(ex): return 42 + + +@local_exception_deserialize("test_lib.ExceptionAndClass") +def exception_and_class_deserialize(ex, json_obj): + ex.user_value = json_obj + + +@remote_exception_serialize("test_lib.ExceptionAndClass") +def exception_and_class_serialize(ex): + return 43 + + +@local_exception_deserialize("test_lib.ExceptionAndClassChild") +def exception_and_class_child_deserialize(ex, json_obj): + ex.user_value = json_obj + + +@remote_exception_serialize("test_lib.ExceptionAndClassChild") +def exception_and_class_child_serialize(ex): + return 44 diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py index 847c263954b..b849a7f9146 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py @@ -13,7 +13,10 @@ ("test_lib", "test_lib.alias"): { "TestClass1": lib.TestClass1, "TestClass2": lib.TestClass2, + "BaseClass": lib.BaseClass, + "ChildClass": lib.ChildClass, "ExceptionAndClass": lib.ExceptionAndClass, + "ExceptionAndClassChild": lib.ExceptionAndClassChild, } } @@ -22,6 +25,7 @@ "SomeException": lib.SomeException, "MyBaseException": lib.MyBaseException, "ExceptionAndClass": lib.ExceptionAndClass, + "ExceptionAndClassChild": lib.ExceptionAndClassChild, } } diff --git a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py index ead203bbcbe..b15256b4011 100644 --- a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py +++ b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py @@ -1,4 +1,5 @@ import functools +from html.parser import HTMLParser class MyBaseException(Exception): @@ -20,6 +21,36 @@ def __str__(self): return "ExceptionAndClass Str: %s" % super().__str__() +class ExceptionAndClassChild(ExceptionAndClass): + def __init__(self, *args): + super().__init__(*args) + + def method_on_child_exception(self): + return "method_on_child_exception" + + def __str__(self): + return "ExceptionAndClassChild Str: %s" % super().__str__() + + +class BaseClass(HTMLParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._output = [] + + def handle_starttag(self, tag, attrs): + self._output.append(tag) + return super().handle_starttag(tag, attrs) + + def get_output(self): + return self._output + + +class ChildClass(BaseClass): + def handle_endtag(self, tag): + self._output.append(tag) + return super().handle_endtag(tag) + + class TestClass1(object): cls_object = 25 @@ -72,6 +103,9 @@ def __hidden(self, name, value): def weird_indirection(self, name): return functools.partial(self.__hidden, name) + def returnChild(self): + return ChildClass() + def raiseOrReturnValueError(self, doRaise=False): if doRaise: raise ValueError("I raised") @@ -87,6 +121,11 @@ def raiseOrReturnExceptionAndClass(self, doRaise=False): raise ExceptionAndClass("I raised") return ExceptionAndClass("I returned") + def raiseOrReturnExceptionAndClassChild(self, doRaise=False): + if doRaise: + raise ExceptionAndClassChild("I raised") + return ExceptionAndClassChild("I returned") + class TestClass2(object): def __init__(self, value, stride, count): diff --git a/metaflow/plugins/env_escape/consts.py b/metaflow/plugins/env_escape/consts.py index bc07abc81b4..dd770ac9e7c 100644 --- a/metaflow/plugins/env_escape/consts.py +++ b/metaflow/plugins/env_escape/consts.py @@ -39,6 +39,7 @@ OP_SETVAL = 16 OP_INIT = 17 OP_CALLONCLASS = 18 +OP_SUBCLASSCHECK = 19 # Control messages CONTROL_SHUTDOWN = 1 diff --git a/metaflow/plugins/env_escape/exception_transferer.py b/metaflow/plugins/env_escape/exception_transferer.py index 97479e9fe62..7ea692f7fd5 100644 --- a/metaflow/plugins/env_escape/exception_transferer.py +++ b/metaflow/plugins/env_escape/exception_transferer.py @@ -85,6 +85,8 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data def load_exception(client, json_obj): + from .stub import Stub + json_obj = client.decode(json_obj) if json_obj.get(FIELD_EXC_SI) is not None: @@ -93,11 +95,16 @@ def load_exception(client, json_obj): exception_module = json_obj.get(FIELD_EXC_MODULE) exception_name = json_obj.get(FIELD_EXC_NAME) exception_class = None + # This name is already cannonical since we cannonicalize it on the server side full_name = "%s.%s" % (exception_module, exception_name) exception_class = client.get_local_class(full_name, is_returned_exception=True) - raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS)) + if issubclass(exception_class, Stub): + raised_exception = exception_class(_is_returned_exception=True) + raised_exception.args = tuple(json_obj.get(FIELD_EXC_ARGS)) + else: + raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS)) raised_exception._exception_str = json_obj.get(FIELD_EXC_STR, None) raised_exception._exception_repr = json_obj.get(FIELD_EXC_REPR, None) raised_exception._exception_tb = json_obj.get(FIELD_EXC_TB, None) diff --git a/metaflow/plugins/env_escape/server.py b/metaflow/plugins/env_escape/server.py index e7cb64f341d..ed35e56040f 100644 --- a/metaflow/plugins/env_escape/server.py +++ b/metaflow/plugins/env_escape/server.py @@ -36,6 +36,7 @@ OP_GETVAL, OP_SETVAL, OP_INIT, + OP_SUBCLASSCHECK, VALUE_LOCAL, VALUE_REMOTE, CONTROL_GETEXPORTS, @@ -255,6 +256,7 @@ def __init__(self, config_dir, max_pickle_version): OP_GETVAL: self._handle_getval, OP_SETVAL: self._handle_setval, OP_INIT: self._handle_init, + OP_SUBCLASSCHECK: self._handle_subclasscheck, } self._local_objects = {} @@ -292,6 +294,7 @@ def encode(self, obj): def encode_exception(self, ex_type, ex, trace_back): try: full_name = "%s.%s" % (ex_type.__module__, ex_type.__name__) + get_canonical_name(full_name, self._aliases) serializer = self._exception_serializers.get(full_name) except AttributeError: # Ignore if no __module__ for example -- definitely not something we built @@ -502,6 +505,21 @@ def _handle_init(self, target, class_name, *args, **kwargs): raise ValueError("Unknown class %s" % class_name) return class_type(*args, **kwargs) + def _handle_subclasscheck(self, target, class_name, otherclass_name, reverse=False): + class_type = self._known_classes.get(class_name) + if class_type is None: + raise ValueError("Unknown class %s" % class_name) + try: + sub_module, sub_name = otherclass_name.rsplit(".", 1) + __import__(sub_module, None, None, "*") + except Exception: + sub_module = None + if sub_module is None: + return False + if reverse: + return issubclass(class_type, getattr(sys.modules[sub_module], sub_name)) + return issubclass(getattr(sys.modules[sub_module], sub_name), class_type) + if __name__ == "__main__": max_pickle_version = int(sys.argv[1]) diff --git a/metaflow/plugins/env_escape/stub.py b/metaflow/plugins/env_escape/stub.py index 1ed322a8b5a..12131694eed 100644 --- a/metaflow/plugins/env_escape/stub.py +++ b/metaflow/plugins/env_escape/stub.py @@ -1,5 +1,6 @@ import functools import pickle +from typing import Any from .consts import ( OP_GETATTR, @@ -15,6 +16,7 @@ OP_PICKLE, OP_DIR, OP_INIT, + OP_SUBCLASSCHECK, ) from .exception_transferer import ExceptionMetaClass @@ -41,6 +43,7 @@ "__getattribute__", "__hash__", "__instancecheck__", + "__subclasscheck__", "__init__", "__metaclass__", "__module__", @@ -67,7 +70,11 @@ def fwd_request(stub, request_type, *args, **kwargs): connection = object.__getattribute__(stub, "___connection___") - return connection.stub_request(stub, request_type, *args, **kwargs) + if connection: + return connection.stub_request(stub, request_type, *args, **kwargs) + raise RuntimeError( + "Returned exception stub cannot be used to make further remote requests" + ) class StubMetaClass(type): @@ -102,7 +109,7 @@ class Stub(with_metaclass(StubMetaClass, object)): # raise AttributeError def __init__( - self, connection, remote_class_name, identifier, is_returned_exception=False + self, connection, remote_class_name, identifier, _is_returned_exception=False ): self.___remote_class_name___ = remote_class_name self.___identifier___ = identifier @@ -110,7 +117,7 @@ def __init__( # If it is a returned exception (ie: it was raised by the server), it behaves # a bit differently for methods like __str__ and __repr__ (we try not to get # stuff from the server) - self.___is_returned_exception___ = is_returned_exception + self.___is_returned_exception___ = _is_returned_exception def __del__(self): try: @@ -187,6 +194,16 @@ def __reduce_ex__(self, proto): # support for pickling return pickle.loads, (fwd_request(self, OP_PICKLE, proto),) + @classmethod + def __subclasshook__(cls, parent): + if parent.__bases__[0] == Stub: + raise NotImplementedError # Follow the usual mechanism + # If this is not a stub, we go over to the other side + parent_name = "%s.%s" % (parent.__module__, parent.__name__) + return cls.___class_connection___.stub_request( + None, OP_SUBCLASSCHECK, cls.___class_remote_class_name___, parent_name, True + ) + def _make_method(method_type, connection, class_name, name, doc): if name == "__call__": @@ -262,6 +279,24 @@ def __call__(cls, *args, **kwargs): None, OP_INIT, cls.___class_remote_class_name___, *args, **kwargs ) + def __subclasscheck__(cls, subclass): + subclass_name = "%s.%s" % (subclass.__module__, subclass.__name__) + if subclass.__bases__[0] == Stub: + subclass_name = subclass.___class_remote_class_name___ + return cls.___class_connection___.stub_request( + None, + OP_SUBCLASSCHECK, + cls.___class_remote_class_name___, + subclass_name, + ) + + def __instancecheck__(cls, instance): + if type(instance) == cls: + # Fast path if it's just an object of this class + return True + # Goes to __subclasscheck__ above + return cls.__subclasscheck__(type(instance)) + class MetaExceptionWithConnection(StubMetaClass, ExceptionMetaClass): def __new__(cls, class_name, base_classes, class_dict, connection): @@ -278,17 +313,46 @@ def __init__(cls, class_name, base_classes, class_dict, connection): # Restore __str__ and __repr__ to the original ones because we need to determine # if we call them depending on whether or not the object is a returned exception # or not + cls.__exception_str__ = cls.__str__ + cls.__exception_repr__ = cls.__repr__ cls.__str__ = cls.__orig_str__ cls.__repr__ = cls.__orig_repr__ def __call__(cls, *args, **kwargs): + # Very similar to the other case but we also need to be able to detect + # local instantiation of an exception so that we can set the __is_returned_exception__ if len(args) > 0 and id(args[0]) == id(cls.___class_connection___): return super(MetaExceptionWithConnection, cls).__call__(*args, **kwargs) + elif kwargs and kwargs.get("_is_returned_exception", False): + return super(MetaExceptionWithConnection, cls).__call__( + None, None, None, _is_returned_exception=True + ) else: return cls.___class_connection___.stub_request( None, OP_INIT, cls.___class_remote_class_name___, *args, **kwargs ) + # The issue is that for a proxied object that is also an exception, we now have + # two classes representing it, one that includes the Stub class and one that doesn't + # Concretely: + # - test.MyException would return a class that derives from Stub + # - test.MySubException would return a class that derives from Stub and test.MyException + # but WITHOUT the Stub portion (see get_local_class). + # - we want issubclass(test.MySubException, test.MyException) to return True and + # the same with instance checks. + def __instancecheck__(cls, instance): + return cls.__subclasscheck__(type(instance)) + + def __subclasscheck__(cls, subclass): + # __mro__[0] is this class itself + # __mro__[1] is the stub so we start checking at 2 + return any( + [ + subclass.__mro__[i] in cls.__mro__[2:] + for i in range(2, len(subclass.__mro__)) + ] + ) + def create_class( connection, diff --git a/test/env_escape/example.py b/test/env_escape/example.py index bfdf7279a60..50d6c75bad3 100644 --- a/test/env_escape/example.py +++ b/test/env_escape/example.py @@ -1,6 +1,8 @@ import os import sys +from html.parser import HTMLParser + from metaflow import FlowSpec, step, conda @@ -93,8 +95,45 @@ def run_test(through_escape=False): o1.weird_indirection("_value")(20) assert o1.value == 20 - print("-- Test exceptions --") + print("-- Test subclasses --") + child_obj = test.ChildClass() + child_obj_returned = o1.returnChild() + for o in (child_obj, child_obj_returned): + o.feed("Hello

World!

") + assert o.get_output() == ["html", "p", "p", "html"] + + print("-- Test isinstance/issubclass --") + ex_child = test.ExceptionAndClassChild("I am a child") + assert isinstance(ex_child, test.ExceptionAndClassChild) + assert isinstance(ex_child, test.ExceptionAndClass) + assert isinstance(ex_child, Exception) + assert isinstance(ex_child, object) + + assert issubclass(type(ex_child), test.ExceptionAndClass) + assert issubclass(test.ExceptionAndClassChild, test.ExceptionAndClass) + assert issubclass(type(ex_child), Exception) + assert issubclass(test.ExceptionAndClassChild, Exception) + assert issubclass(type(ex_child), object) + assert issubclass(test.ExceptionAndClassChild, object) + + child_obj = test.ChildClass() + child_obj_returned = o1.returnChild() + + # I can't find an easy way (yet) to test support for subclasses based on non + # proxied types. It seems more minor for now so ignoring. + for o in (child_obj, child_obj_returned): + assert isinstance(o, test.ChildClass) + assert isinstance(o, test.BaseClass) + # assert isinstance(o, HTMLParser) + assert isinstance(o, object) + assert issubclass(type(o), test.BaseClass) + # assert issubclass(type(o), HTMLParser) + assert issubclass(type(o), object) + assert issubclass(test.ChildClass, test.BaseClass) + # assert issubclass(test.ChildClass, HTMLParser) + assert issubclass(test.ChildClass, object) + print("-- Test exceptions --") # Non proxied exceptions can't be returned as objects try: vexc = o1.raiseOrReturnValueError() @@ -118,9 +157,23 @@ def run_test(through_escape=False): exception_and_class = o1.raiseOrReturnExceptionAndClass() assert isinstance(exception_and_class, test.ExceptionAndClass) + assert isinstance(exception_and_class, test.MyBaseException) + assert isinstance(exception_and_class, Exception) assert exception_and_class.method_on_exception() == "method_on_exception" assert str(exception_and_class).startswith("ExceptionAndClass Str:") + exception_and_class_child = o1.raiseOrReturnExceptionAndClassChild() + assert isinstance(exception_and_class_child, test.ExceptionAndClassChild) + assert isinstance(exception_and_class_child, test.ExceptionAndClass) + assert isinstance(exception_and_class_child, test.MyBaseException) + assert isinstance(exception_and_class_child, Exception) + assert exception_and_class_child.method_on_exception() == "method_on_exception" + assert ( + exception_and_class_child.method_on_child_exception() + == "method_on_child_exception" + ) + assert str(exception_and_class_child).startswith("ExceptionAndClassChild Str:") + try: o1.raiseOrReturnValueError(True) assert False, "Should have raised" @@ -140,6 +193,28 @@ def run_test(through_escape=False): except Exception as e: assert False, "Should have been SomeException" + try: + o1.raiseOrReturnExceptionAndClass(True) + assert False, "Should have raised" + except test.ExceptionAndClass as e: + assert True + if through_escape: + assert e.user_value == 43 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been ExceptionAndClass" + + try: + o1.raiseOrReturnExceptionAndClassChild(True) + assert False, "Should have raised" + except test.ExceptionAndClassChild as e: + assert True + if through_escape: + assert e.user_value == 44 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been ExceptionAndClassChild" + class EscapeTest(FlowSpec): @conda(disabled=True) From 674f9862b051dae1ac893b2dd0ca3eca50cc6435 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Sat, 17 Feb 2024 00:51:18 -0800 Subject: [PATCH 3/5] Add support for __class__ in escape hatch --- metaflow/plugins/env_escape/stub.py | 22 +++++++++++++--------- test/env_escape/example.py | 6 ++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/metaflow/plugins/env_escape/stub.py b/metaflow/plugins/env_escape/stub.py index 12131694eed..bf101633f22 100644 --- a/metaflow/plugins/env_escape/stub.py +++ b/metaflow/plugins/env_escape/stub.py @@ -47,6 +47,7 @@ "__init__", "__metaclass__", "__module__", + "__name__", "__new__", "__reduce__", "__reduce_ex__", @@ -78,11 +79,11 @@ def fwd_request(stub, request_type, *args, **kwargs): class StubMetaClass(type): - def __repr__(self): - if self.__module__: - return "" % (self.__module__, self.__name__) + def __repr__(cls): + if cls.__module__: + return "" % (cls.__module__, cls.__name__) else: - return "" % (self.__name__,) + return "" % (cls.__name__,) def with_metaclass(meta, *bases): @@ -131,9 +132,7 @@ def __del__(self): def __getattribute__(self, name): if name in LOCAL_ATTRS: - if name == "__class__": - return None - elif name == "__doc__": + if name == "__doc__": return self.__getattr__("__doc__") elif name in DELETED_ATTRS: raise AttributeError() @@ -455,9 +454,14 @@ def _do_str(self): # it but not the case if we are not. class_dict["__slots__"].append("__weakref__") + class_module, class_name_only = class_name.rsplit(".", 1) class_dict["___local_overrides___"] = overriden_attrs + class_dict["__module__"] = class_module if parents: - return MetaExceptionWithConnection( + to_return = MetaExceptionWithConnection( class_name, (Stub, *parents), class_dict, connection ) - return MetaWithConnection(class_name, (Stub,), class_dict, connection) + else: + to_return = MetaWithConnection(class_name, (Stub,), class_dict, connection) + to_return.__name__ = class_name_only + return to_return diff --git a/test/env_escape/example.py b/test/env_escape/example.py index 50d6c75bad3..128556de4b8 100644 --- a/test/env_escape/example.py +++ b/test/env_escape/example.py @@ -83,6 +83,8 @@ def run_test(through_escape=False): print("-- Test chaining of exported classes --") o2 = o1.to_class2(5) assert o2.something("foo") == "Test2:Something:foo" + assert o2.__class__.__name__ == "TestClass2" + assert o2.__class__.__module__ == "test_lib" print("-- Test Iterating --") for idx, i in enumerate(o2): @@ -108,6 +110,8 @@ def run_test(through_escape=False): assert isinstance(ex_child, test.ExceptionAndClass) assert isinstance(ex_child, Exception) assert isinstance(ex_child, object) + assert ex_child.__class__.__name__ == "ExceptionAndClassChild" + assert ex_child.__class__.__module__ == "test_lib" assert issubclass(type(ex_child), test.ExceptionAndClass) assert issubclass(test.ExceptionAndClassChild, test.ExceptionAndClass) @@ -149,6 +153,8 @@ def run_test(through_escape=False): excclass = o1.raiseOrReturnSomeException() assert not through_escape, "Should have raised through escape" assert isinstance(excclass, test.SomeException) + assert excclass.__class__.__name__ == "SomeException" + assert excclass.__class__.__module__ == "test_lib" except RuntimeError as e: assert ( through_escape From c13b33d19bdd8d084c38241b71cf4b930516de9e Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Tue, 12 Mar 2024 01:27:49 -0700 Subject: [PATCH 4/5] Fix mypy issue --- .github/workflows/test-stubs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-stubs.yml b/.github/workflows/test-stubs.yml index 5775c588357..fa5a6f5012f 100644 --- a/.github/workflows/test-stubs.yml +++ b/.github/workflows/test-stubs.yml @@ -31,7 +31,7 @@ jobs: - name: Install Python ${{ matrix.ver }} dependencies run: | python3 -m pip install --upgrade pip setuptools - python3 -m pip install pytest build mypy pytest-mypy-plugins + python3 -m pip install pytest build "mypy<1.9" pytest-mypy-plugins - name: Install metaflow run: pip install . From 2d00a2ff68b00229a6edf8f3d732a330f76f031b Mon Sep 17 00:00:00 2001 From: Chaoying Wang Date: Wed, 13 Mar 2024 17:23:27 -0700 Subject: [PATCH 5/5] fix typos --- metaflow/plugins/env_escape/client.py | 6 +++--- metaflow/plugins/env_escape/server.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/env_escape/client.py b/metaflow/plugins/env_escape/client.py index 65637ecbf72..cb004fdf2f8 100644 --- a/metaflow/plugins/env_escape/client.py +++ b/metaflow/plugins/env_escape/client.py @@ -367,7 +367,7 @@ def get_local_class( # - proxied functions # - classes that are proxied regular classes AND proxied exceptions # - classes that are proxied regular classes AND NOT proxied exceptions - # - clases that are NOT proxied regular classes AND are proxied exceptions + # - classes that are NOT proxied regular classes AND are proxied exceptions name = get_canonical_name(name, self._aliases) def name_to_parent_name(name): @@ -447,14 +447,14 @@ def name_to_parent_name(name): # actually: # - the class itself (which is a stub) # - the class in the capacity of a parent class (to another exception - # presumably). The reason for this is that if we have a exception/proxied + # presumably). The reason for this is that if we have an exception/proxied # class A and another B and B inherits from A, the MRO order would be all # wrong since both A and B would also inherit from `Stub`. Here what we # do is: # - A_parent inherits from the actual parents of A (let's assume a # builtin exception) # - A inherits from (Stub, A_parent) - # - B_parent inherints from A_parent and the builtin Exception + # - B_parent inherits from A_parent and the builtin Exception # - B inherits from (Stub, B_parent) ex_module, ex_name = name.rsplit(".", 1) parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {}) diff --git a/metaflow/plugins/env_escape/server.py b/metaflow/plugins/env_escape/server.py index ed35e56040f..dd960bca3e8 100644 --- a/metaflow/plugins/env_escape/server.py +++ b/metaflow/plugins/env_escape/server.py @@ -118,7 +118,7 @@ def __init__(self, config_dir, max_pickle_version): "are listed in the same order" % (alias, base_name, a) ) - # Detect circular aliaes. If a user lists ("a", "b") and then ("b", "a"), we + # Detect circular aliases. If a user lists ("a", "b") and then ("b", "a"), we # will have an entry in aliases saying b is an alias for a and a is an alias # for b which is a recipe for disaster since we no longer have a cannonical name # for things.