diff --git a/.github/workflows/test-stubs.yml b/.github/workflows/test-stubs.yml index 5775c588357..fa5a6f5012f 100644 --- a/.github/workflows/test-stubs.yml +++ b/.github/workflows/test-stubs.yml @@ -31,7 +31,7 @@ jobs: - name: Install Python ${{ matrix.ver }} dependencies run: | python3 -m pip install --upgrade pip setuptools - python3 -m pip install pytest build mypy pytest-mypy-plugins + python3 -m pip install pytest build "mypy<1.9" pytest-mypy-plugins - name: Install metaflow run: pip install . diff --git a/metaflow/plugins/env_escape/client.py b/metaflow/plugins/env_escape/client.py index 60ffdf09b68..cb004fdf2f8 100644 --- a/metaflow/plugins/env_escape/client.py +++ b/metaflow/plugins/env_escape/client.py @@ -34,8 +34,12 @@ from .communication.socket_bytestream import SocketByteStream from .data_transferer import DataTransferer, ObjReference -from .exception_transferer import load_exception -from .override_decorators import LocalAttrOverride, LocalException, LocalOverride +from .exception_transferer import ExceptionMetaClass, load_exception +from .override_decorators import ( + LocalAttrOverride, + LocalExceptionDeserializer, + LocalOverride, +) from .stub import create_class from .utils import get_canonical_name @@ -193,28 +197,41 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d self._proxied_classes = { k: None for k in itertools.chain( - response[FIELD_CONTENT]["classes"], response[FIELD_CONTENT]["proxied"] + response[FIELD_CONTENT]["classes"], + response[FIELD_CONTENT]["proxied"], + (e[0] for e in response[FIELD_CONTENT]["exceptions"]), ) } + self._exception_hierarchy = dict(response[FIELD_CONTENT]["exceptions"]) + self._proxied_classnames = set(response[FIELD_CONTENT]["classes"]).union( + response[FIELD_CONTENT]["proxied"] + ) + self._aliases = response[FIELD_CONTENT]["aliases"] + # Determine all overrides self._overrides = {} self._getattr_overrides = {} self._setattr_overrides = {} - self._exception_overrides = {} + self._exception_deserializers = {} for override in override_values: if isinstance(override, (LocalOverride, LocalAttrOverride)): for obj_name, obj_funcs in override.obj_mapping.items(): - if obj_name not in self._proxied_classes: + canonical_name = get_canonical_name(obj_name, self._aliases) + if canonical_name not in self._proxied_classes: raise ValueError( "%s does not refer to a proxied or override type" % obj_name ) if isinstance(override, LocalOverride): - override_dict = self._overrides.setdefault(obj_name, {}) + override_dict = self._overrides.setdefault(canonical_name, {}) elif override.is_setattr: - override_dict = self._setattr_overrides.setdefault(obj_name, {}) + override_dict = self._setattr_overrides.setdefault( + canonical_name, {} + ) else: - override_dict = self._getattr_overrides.setdefault(obj_name, {}) + override_dict = self._getattr_overrides.setdefault( + canonical_name, {} + ) if isinstance(obj_funcs, str): obj_funcs = (obj_funcs,) for name in obj_funcs: @@ -223,11 +240,18 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d "%s was already overridden for %s" % (name, obj_name) ) override_dict[name] = override.func - if isinstance(override, LocalException): - cur_ex = self._exception_overrides.get(override.class_path, None) - if cur_ex is not None: - raise ValueError("Exception %s redefined" % override.class_path) - self._exception_overrides[override.class_path] = override.wrapped_class + if isinstance(override, LocalExceptionDeserializer): + canonical_name = get_canonical_name(override.class_path, self._aliases) + if canonical_name not in self._exception_hierarchy: + raise ValueError( + "%s does not refer to an exception type" % override.class_path + ) + cur_des = self._exception_deserializers.get(canonical_name, None) + if cur_des is not None: + raise ValueError( + "Exception %s has multiple deserializers" % override.class_path + ) + self._exception_deserializers[canonical_name] = override.deserializer # Proxied standalone functions are functions that are proxied # as part of other objects like defaultdict for which we create a @@ -243,8 +267,6 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d "aliases": response[FIELD_CONTENT]["aliases"], } - self._aliases = response[FIELD_CONTENT]["aliases"] - def __del__(self): self.cleanup() @@ -288,8 +310,9 @@ def name(self): def get_exports(self): return self._export_info - def get_local_exception_overrides(self): - return self._exception_overrides + def get_exception_deserializer(self, name): + cannonical_name = get_canonical_name(name, self._aliases) + return self._exception_deserializers.get(cannonical_name) def stub_request(self, stub, request_type, *args, **kwargs): # Encode the operation to send over the wire and wait for the response @@ -313,7 +336,7 @@ def stub_request(self, stub, request_type, *args, **kwargs): if response_type == MSG_REPLY: return self.decode(response[FIELD_CONTENT]) elif response_type == MSG_EXCEPTION: - raise load_exception(self._datatransferer, response[FIELD_CONTENT]) + raise load_exception(self, response[FIELD_CONTENT]) elif response_type == MSG_INTERNAL_ERROR: raise RuntimeError( "Error in the server runtime:\n\n===== SERVER TRACEBACK =====\n%s" @@ -334,10 +357,27 @@ def decode(self, json_obj): # this connection will be converted to a local stub. return self._datatransferer.load(json_obj) - def get_local_class(self, name, obj_id=None): + def get_local_class( + self, name, obj_id=None, is_returned_exception=False, is_parent=False + ): # Gets (and creates if needed), the class mapping to the remote # class of name 'name'. + + # We actually deal with four types of classes: + # - proxied functions + # - classes that are proxied regular classes AND proxied exceptions + # - classes that are proxied regular classes AND NOT proxied exceptions + # - classes that are NOT proxied regular classes AND are proxied exceptions name = get_canonical_name(name, self._aliases) + + def name_to_parent_name(name): + return "parent:%s" % name + + if is_parent: + lookup_name = name_to_parent_name(name) + else: + lookup_name = name + if name == "function": # Special handling of pickled functions. We create a new class that # simply has a __call__ method that will forward things back to @@ -346,17 +386,81 @@ def get_local_class(self, name, obj_id=None): raise RuntimeError("Local function unpickling without an object ID") if obj_id not in self._proxied_standalone_functions: self._proxied_standalone_functions[obj_id] = create_class( - self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""} + self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, [] ) return self._proxied_standalone_functions[obj_id] + local_class = self._proxied_classes.get(lookup_name, None) + if local_class is not None: + return local_class + + is_proxied_exception = name in self._exception_hierarchy + is_proxied_non_exception = name in self._proxied_classnames + + if not is_proxied_exception and not is_proxied_non_exception: + if is_returned_exception or is_parent: + # In this case, it may be a local exception that we need to + # recreate + try: + ex_module, ex_name = name.rsplit(".", 1) + __import__(ex_module, None, None, "*") + except Exception: + pass + if ex_module in sys.modules and issubclass( + getattr(sys.modules[ex_module], ex_name), BaseException + ): + # This is a local exception that we can recreate + local_exception = getattr(sys.modules[ex_module], ex_name) + wrapped_exception = ExceptionMetaClass( + ex_name, + (local_exception,), + dict(getattr(local_exception, "__dict__", {})), + ) + wrapped_exception.__module__ = ex_module + self._proxied_classes[lookup_name] = wrapped_exception + return wrapped_exception - if name not in self._proxied_classes: raise ValueError("Class '%s' is not known" % name) - local_class = self._proxied_classes[name] - if local_class is None: - # We need to build up this class. To do so, we take everything that the - # remote class has and remove UNSUPPORTED things and overridden things + + # At this stage: + # - we don't have a local_class for this + # - it is not an inbuilt exception so it is either a proxied exception, a + # proxied class or a proxied object that is both an exception and a class. + + parents = [] + if is_proxied_exception: + # If exception, we need to get the parents from the exception + ex_parents = self._exception_hierarchy[name] + for parent in ex_parents: + # We always consider it to be an exception so that we wrap even non + # proxied builtins exceptions + parents.append(self.get_local_class(parent, is_parent=True)) + # For regular classes, we get what it exposes from the server + if is_proxied_non_exception: remote_methods = self.stub_request(None, OP_GETMETHODS, name) + else: + remote_methods = {} + + parent_local_class = None + local_class = None + if is_proxied_exception: + # If we are a proxied exception AND a proxied class, we create two classes: + # actually: + # - the class itself (which is a stub) + # - the class in the capacity of a parent class (to another exception + # presumably). The reason for this is that if we have an exception/proxied + # class A and another B and B inherits from A, the MRO order would be all + # wrong since both A and B would also inherit from `Stub`. Here what we + # do is: + # - A_parent inherits from the actual parents of A (let's assume a + # builtin exception) + # - A inherits from (Stub, A_parent) + # - B_parent inherits from A_parent and the builtin Exception + # - B inherits from (Stub, B_parent) + ex_module, ex_name = name.rsplit(".", 1) + parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {}) + parent_local_class.__module__ = ex_module + + if is_proxied_non_exception: local_class = create_class( self, name, @@ -364,9 +468,26 @@ def get_local_class(self, name, obj_id=None): self._getattr_overrides.get(name, {}), self._setattr_overrides.get(name, {}), remote_methods, + (parent_local_class,) if parent_local_class else None, ) + if parent_local_class: + self._proxied_classes[name_to_parent_name(name)] = parent_local_class + if local_class: self._proxied_classes[name] = local_class - return local_class + else: + # This is for the case of pure proxied exceptions -- we want the lookup of + # foo.MyException to be the same class as looking of foo.MyException as a parent + # of another exception so `isinstance` works properly + self._proxied_classes[name] = parent_local_class + + if is_parent: + # This should never happen but making sure + if not parent_local_class: + raise RuntimeError( + "Exception parent class %s is not a proxied exception" % name + ) + return parent_local_class + return self._proxied_classes[name] def can_pickle(self, obj): return getattr(obj, "___connection___", None) == self @@ -395,7 +516,7 @@ def unpickle_object(self, obj): obj_id = obj.identifier local_instance = self._proxied_objects.get(obj_id) if not local_instance: - local_class = self.get_local_class(remote_class_name, obj_id) + local_class = self.get_local_class(remote_class_name, obj_id=obj_id) local_instance = local_class(self, remote_class_name, obj_id) return local_instance diff --git a/metaflow/plugins/env_escape/client_modules.py b/metaflow/plugins/env_escape/client_modules.py index 2896f2c5156..aa20766414a 100644 --- a/metaflow/plugins/env_escape/client_modules.py +++ b/metaflow/plugins/env_escape/client_modules.py @@ -7,7 +7,6 @@ from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL from .client import Client -from .override_decorators import LocalException from .utils import get_canonical_name @@ -16,7 +15,7 @@ def _clean_client(client): class _WrappedModule(object): - def __init__(self, loader, prefix, exports, exception_classes, client): + def __init__(self, loader, prefix, exports, client): self._loader = loader self._prefix = prefix self._client = client @@ -24,19 +23,20 @@ def __init__(self, loader, prefix, exports, exception_classes, client): r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605 ) self._exports = {} - self._aliases = exports["aliases"] + self._aliases = exports.get("aliases", []) for k in ("classes", "functions", "values"): result = [] - for item in exports[k]: + for item in exports.get(k, []): m = is_match.match(item) if m: result.append(m.group(1)) self._exports[k] = result - self._exception_classes = {} - for k, v in exception_classes.items(): - m = is_match.match(k) + result = [] + for item, _ in exports.get("exceptions", []): + m = is_match.match(item) if m: - self._exception_classes[m.group(1)] = v + result.append(m.group(1)) + self._exports["exceptions"] = result def __getattr__(self, name): if name == "__loader__": @@ -50,8 +50,8 @@ def __getattr__(self, name): name = get_canonical_name(self._prefix + "." + name, self._aliases)[ len(self._prefix) + 1 : ] - if name in self._exports["classes"]: - # We load classes lazily + if name in self._exports["classes"] or name in self._exports["exceptions"]: + # We load classes and exceptions lazily return self._client.get_local_class("%s.%s" % (self._prefix, name)) elif name in self._exports["functions"]: # TODO: Grab doc back from the remote side like in _make_method @@ -67,8 +67,6 @@ def func(*args, **kwargs): return self._client.stub_request( None, OP_GETVAL, "%s.%s" % (self._prefix, name) ) - elif name in self._exception_classes: - return self._exception_classes[name] else: # Try to see if this is a submodule that we can load m = None @@ -173,7 +171,6 @@ def load_module(self, fullname): # Get information about overrides and what the server knows about exports = self._client.get_exports() - ex_overrides = self._client.get_local_exception_overrides() prefixes = set() export_classes = exports.get("classes", []) @@ -182,42 +179,13 @@ def load_module(self, fullname): export_exceptions = exports.get("exceptions", []) self._aliases = exports.get("aliases", {}) for name in itertools.chain( - export_classes, export_functions, export_values + export_classes, + export_functions, + export_values, + (e[0] for e in export_exceptions), ): splits = name.rsplit(".", 1) prefixes.add(splits[0]) - - # Now look at the exceptions coming from the server - formed_exception_classes = {} - for ex_name, ex_parents in export_exceptions: - # Exception is a tuple (name, (parents,)) - # Exceptions are also given in order of instantiation (ie: the - # server already topologically sorted them) - ex_class_dict = ex_overrides.get(ex_name, None) - if ex_class_dict is None: - ex_class_dict = {} - else: - ex_class_dict = dict(ex_class_dict.__dict__) - parents = [] - for fake_base in ex_parents: - if fake_base.startswith("builtins."): - # This is something we know of here - parents.append(eval(fake_base[9:])) - else: - # It's in formed_classes - parents.append(formed_exception_classes[fake_base]) - splits = ex_name.rsplit(".", 1) - ex_class_dict["__user_defined__"] = set(ex_class_dict.keys()) - new_class = type(splits[1], tuple(parents), ex_class_dict) - new_class.__module__ = splits[0] - new_class.__name__ = splits[1] - formed_exception_classes[ex_name] = new_class - - # Now update prefixes as needed - for name in formed_exception_classes: - splits = name.rsplit(".", 1) - prefixes.add(splits[0]) - # We will make sure that we create modules even for "empty" prefixes # because packages are always loaded hierarchically so if we have # something in `a.b.c` but nothing directly in `a`, we still need to @@ -235,7 +203,7 @@ def load_module(self, fullname): self._handled_modules = {} for prefix in prefixes: self._handled_modules[prefix] = _WrappedModule( - self, prefix, exports, formed_exception_classes, self._client + self, prefix, exports, self._client ) canonical_fullname = get_canonical_name(fullname, self._aliases) # Modules are created canonically but we need to return something for any diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py index 7418644e8f3..120b32e31df 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/overrides.py @@ -5,56 +5,44 @@ remote_override, remote_getattr_override, remote_setattr_override, - local_exception, + local_exception_deserialize, remote_exception_serialize, ) @local_override({"test_lib.TestClass1": "print_value"}) def local_print_value(stub, func): - print("Encoding before sending to server") v = func() - print("Adding 5") return v + 5 @remote_override({"test_lib.TestClass1": "print_value"}) def remote_print_value(obj, func): - print("Decoding from client") v = func() - print("Encoding for client") - return v + return v + 3 @local_getattr_override({"test_lib.TestClass1": "override_value"}) def local_get_value2(stub, name, func): - print("In local getattr override for %s" % name) r = func() - print("In local getattr override, got %s" % r) - return r - - -@local_setattr_override({"test_lib.TestClass1": "override_value"}) -def local_set_value2(stub, name, func, v): - print("In local setattr override for %s" % name) - r = func(v) - print("In local setattr override, got %s" % r) - return r + return r + 5 @remote_getattr_override({"test_lib.TestClass1": "override_value"}) def remote_get_value2(obj, name): - print("In remote getattr override for %s" % name) r = getattr(obj, name) - print("In remote getattr override, got %s" % r) + return r + 3 + + +@local_setattr_override({"test_lib.TestClass1": "override_value"}) +def local_set_value2(stub, name, func, v): + r = func(v + 5) return r @remote_setattr_override({"test_lib.TestClass1": "override_value"}) def remote_set_value2(obj, name, v): - print("In remote setattr override for %s" % name) - r = setattr(obj, name, v) - print("In remote setattr override, got %s" % r) + r = setattr(obj, name, v + 3) return r @@ -63,30 +51,31 @@ def unsupported_method(stub, func, *args, **kwargs): return NotImplementedError("Just because") -@local_override({"test_lib.package.TestClass3": "thirdfunction"}) -def iamthelocalthird(stub, func, val): - print("Locally the Third") - v = func(val) - return v +@local_exception_deserialize("test_lib.SomeException") +def some_exception_deserialize(ex, json_obj): + ex.user_value = json_obj -@remote_override({"test_lib.package.TestClass3": "thirdfunction"}) -def iamtheremotethird(obj, func, val): - print("Remotely the Third") - v = func(val) - return v +@remote_exception_serialize("test_lib.SomeException") +def some_exception_serialize(ex): + return 42 -@local_exception("test_lib.SomeException") -class SomeException: - def __str__(self): - parent_val = super(self.__realclass__, self).__str__() - return parent_val + " In SomeException str override: %s" % self.user_value +@local_exception_deserialize("test_lib.ExceptionAndClass") +def exception_and_class_deserialize(ex, json_obj): + ex.user_value = json_obj - def _deserialize_user(self, json_obj): - self.user_value = json_obj +@remote_exception_serialize("test_lib.ExceptionAndClass") +def exception_and_class_serialize(ex): + return 43 -@remote_exception_serialize("test_lib.SomeException") -def some_exception_serialize(ex): - return 42 + +@local_exception_deserialize("test_lib.ExceptionAndClassChild") +def exception_and_class_child_deserialize(ex, json_obj): + ex.user_value = json_obj + + +@remote_exception_serialize("test_lib.ExceptionAndClassChild") +def exception_and_class_child_serialize(ex): + return 44 diff --git a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py index 63de6f41369..b849a7f9146 100644 --- a/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py +++ b/metaflow/plugins/env_escape/configurations/emulate_test_lib/server_mappings.py @@ -10,17 +10,22 @@ import test_lib as lib EXPORTED_CLASSES = { - "test_lib": { + ("test_lib", "test_lib.alias"): { "TestClass1": lib.TestClass1, "TestClass2": lib.TestClass2, - "package.TestClass3": lib.TestClass3, + "BaseClass": lib.BaseClass, + "ChildClass": lib.ChildClass, + "ExceptionAndClass": lib.ExceptionAndClass, + "ExceptionAndClassChild": lib.ExceptionAndClassChild, } } EXPORTED_EXCEPTIONS = { - "test_lib": { + ("test_lib", "test_lib.alias"): { "SomeException": lib.SomeException, "MyBaseException": lib.MyBaseException, + "ExceptionAndClass": lib.ExceptionAndClass, + "ExceptionAndClassChild": lib.ExceptionAndClassChild, } } diff --git a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py index 420d148bc46..b15256b4011 100644 --- a/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py +++ b/metaflow/plugins/env_escape/configurations/test_lib_impl/test_lib.py @@ -1,4 +1,5 @@ import functools +from html.parser import HTMLParser class MyBaseException(Exception): @@ -9,8 +10,48 @@ class SomeException(MyBaseException): pass -class TestClass1(object): +class ExceptionAndClass(MyBaseException): + def __init__(self, *args): + super().__init__(*args) + + def method_on_exception(self): + return "method_on_exception" + + def __str__(self): + return "ExceptionAndClass Str: %s" % super().__str__() + + +class ExceptionAndClassChild(ExceptionAndClass): + def __init__(self, *args): + super().__init__(*args) + + def method_on_child_exception(self): + return "method_on_child_exception" + + def __str__(self): + return "ExceptionAndClassChild Str: %s" % super().__str__() + +class BaseClass(HTMLParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._output = [] + + def handle_starttag(self, tag, attrs): + self._output.append(tag) + return super().handle_starttag(tag, attrs) + + def get_output(self): + return self._output + + +class ChildClass(BaseClass): + def handle_endtag(self, tag): + self._output.append(tag) + return super().handle_endtag(tag) + + +class TestClass1(object): cls_object = 25 def __init__(self, value): @@ -41,11 +82,11 @@ def to_class2(self, count, stride=1): return TestClass2(self._value, stride, count) @staticmethod - def somethingstatic(val): + def static_method(val): return val + 42 @classmethod - def somethingclass(cls): + def class_method(cls): return cls.cls_object @property @@ -56,13 +97,42 @@ def override_value(self): def override_value(self, value): self._value2 = value + def __hidden(self, name, value): + setattr(self, name, value) + + def weird_indirection(self, name): + return functools.partial(self.__hidden, name) + + def returnChild(self): + return ChildClass() + + def raiseOrReturnValueError(self, doRaise=False): + if doRaise: + raise ValueError("I raised") + return ValueError("I returned") + + def raiseOrReturnSomeException(self, doRaise=False): + if doRaise: + raise SomeException("I raised") + return SomeException("I returned") + + def raiseOrReturnExceptionAndClass(self, doRaise=False): + if doRaise: + raise ExceptionAndClass("I raised") + return ExceptionAndClass("I returned") + + def raiseOrReturnExceptionAndClassChild(self, doRaise=False): + if doRaise: + raise ExceptionAndClassChild("I raised") + return ExceptionAndClassChild("I returned") + class TestClass2(object): def __init__(self, value, stride, count): self._mylist = [value + stride * i for i in range(count)] def something(self, val): - return "In Test2 with %s" % val + return "Test2:Something:%s" % val def __iter__(self): self._pos = 0 @@ -75,24 +145,6 @@ def __next__(self): raise StopIteration -class TestClass3(object): - def __init__(self): - print("I am Class3") - - def thirdfunction(self, val): - print("Got value: %s" % val) - # raise AttributeError("Some weird error") - - def raiseSomething(self): - raise SomeException("Something went wrong") - - def __hidden(self, name, value): - setattr(self, name, value) - - def weird_indirection(self, name): - return functools.partial(self.__hidden, name) - - def test_func(*args, **kwargs): return "In test func" diff --git a/metaflow/plugins/env_escape/consts.py b/metaflow/plugins/env_escape/consts.py index bc07abc81b4..dd770ac9e7c 100644 --- a/metaflow/plugins/env_escape/consts.py +++ b/metaflow/plugins/env_escape/consts.py @@ -39,6 +39,7 @@ OP_SETVAL = 16 OP_INIT = 17 OP_CALLONCLASS = 18 +OP_SUBCLASSCHECK = 19 # Control messages CONTROL_SHUTDOWN = 1 diff --git a/metaflow/plugins/env_escape/exception_transferer.py b/metaflow/plugins/env_escape/exception_transferer.py index c504fd13c7e..7ea692f7fd5 100644 --- a/metaflow/plugins/env_escape/exception_transferer.py +++ b/metaflow/plugins/env_escape/exception_transferer.py @@ -4,11 +4,9 @@ try: # Import from client from .data_transferer import DataTransferer - from .stub import Stub except ImportError: # Import from server from data_transferer import DataTransferer - from stub import Stub # This file is heavily inspired from the RPYC project @@ -39,7 +37,6 @@ FIELD_EXC_MODULE = "m" FIELD_EXC_NAME = "n" FIELD_EXC_ARGS = "arg" -FIELD_EXC_ATTR = "atr" FIELD_EXC_TB = "tb" FIELD_EXC_USER = "u" FIELD_EXC_SI = "si" @@ -54,7 +51,6 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data traceback.format_exception(exception_type, exception_val, tb) ) exception_args = [] - exception_attrs = [] str_repr = None repr_repr = None for name in dir(exception_val): @@ -72,20 +68,10 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data repr_repr = repr(exception_val) elif name.startswith("_") or name == "with_traceback": continue - else: - try: - attr = getattr(exception_val, name) - except AttributeError: - continue - if DataTransferer.can_simple_dump(attr): - exception_attrs.append((name, attr)) - else: - exception_attrs.append((name, repr(attr))) to_return = { FIELD_EXC_MODULE: exception_type.__module__, FIELD_EXC_NAME: exception_type.__name__, FIELD_EXC_ARGS: exception_args, - FIELD_EXC_ATTR: exception_attrs, FIELD_EXC_TB: local_formatted_exception, FIELD_EXC_STR: str_repr, FIELD_EXC_REPR: repr_repr, @@ -98,121 +84,69 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data return data_transferer.dump(to_return) -def load_exception(data_transferer, json_obj): - json_obj = data_transferer.load(json_obj) +def load_exception(client, json_obj): + from .stub import Stub + + json_obj = client.decode(json_obj) + if json_obj.get(FIELD_EXC_SI) is not None: return StopIteration exception_module = json_obj.get(FIELD_EXC_MODULE) exception_name = json_obj.get(FIELD_EXC_NAME) exception_class = None - if exception_module not in sys.modules: - # Try to import the module - try: - # Use __import__ so that the user can access this exception - __import__(exception_module, None, None, "*") - except Exception: - pass - # Try again (will succeed if the __import__ worked) - if exception_module in sys.modules: - exception_class = getattr(sys.modules[exception_module], exception_name, None) - if exception_class is None or issubclass(exception_class, Stub): - # Best effort to "recreate" an exception. Note that in some cases, exceptions - # may actually be both exceptions we can transfer as well as classes we - # can transfer (stubs) but for exceptions, we don't want to use the stub - # otherwise it will just ping pong. - name = "%s.%s" % (exception_module, exception_name) - exception_class = _remote_exceptions_class.setdefault( - name, - type( - name, - (RemoteInterpreterException,), - {"__module__": "%s/%s" % (__name__, exception_module)}, - ), - ) - exception_class = _wrap_exception(exception_class) - raised_exception = exception_class.__new__(exception_class) - raised_exception.args = json_obj.get(FIELD_EXC_ARGS) - for name, attr in json_obj.get(FIELD_EXC_ATTR): - try: - if name in raised_exception.__user_defined__: - setattr(raised_exception, "_original_%s" % name, attr) - else: - setattr(raised_exception, name, attr) - except AttributeError: - # In case some things are read only - pass - s = json_obj.get(FIELD_EXC_STR) - if s: - try: - if "__str__" in raised_exception.__user_defined__: - setattr(raised_exception, "_original___str__", s) - else: - setattr(raised_exception, "__str__", lambda x, s=s: s) - except AttributeError: - raised_exception._missing_str = True - s = json_obj.get(FIELD_EXC_REPR) - if s: - try: - if "__repr__" in raised_exception.__user_defined__: - setattr(raised_exception, "_original___repr__", s) - else: - setattr(raised_exception, "__repr__", lambda x, s=s: s) - except AttributeError: - raised_exception._missing_repr = True + # This name is already cannonical since we cannonicalize it on the server side + full_name = "%s.%s" % (exception_module, exception_name) + + exception_class = client.get_local_class(full_name, is_returned_exception=True) + + if issubclass(exception_class, Stub): + raised_exception = exception_class(_is_returned_exception=True) + raised_exception.args = tuple(json_obj.get(FIELD_EXC_ARGS)) + else: + raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS)) + raised_exception._exception_str = json_obj.get(FIELD_EXC_STR, None) + raised_exception._exception_repr = json_obj.get(FIELD_EXC_REPR, None) + raised_exception._exception_tb = json_obj.get(FIELD_EXC_TB, None) + user_args = json_obj.get(FIELD_EXC_USER) if user_args is not None: - try: - deserializer = getattr(raised_exception, "_deserialize_user") - except AttributeError: - raised_exception._missing_deserializer = True - else: - deserializer(user_args) - raised_exception._remote_tb = json_obj[FIELD_EXC_TB] + deserializer = client.get_exception_deserializer(full_name) + if deserializer is not None: + deserializer(raised_exception, user_args) return raised_exception -def _wrap_exception(exception_class): - to_return = _derived_exceptions.get(exception_class) - if to_return is not None: - return to_return - - class WithPrettyPrinting(exception_class): - def __str__(self): - try: - text = super(WithPrettyPrinting, self).__str__() - except: # noqa E722 - text = "" - # if getattr(self, "_missing_deserializer", False): - # text += ( - # "\n\n===== WARNING: User data from the exception was not deserialized " - # "-- possible missing information =====\n" - # ) - # if getattr(self, "_missing_str", False): - # text += "\n\n===== WARNING: Could not set class specific __str__ " - # "-- possible missing information =====\n" - # if getattr(self, "_missing_repr", False): - # text += "\n\n===== WARNING: Could not set class specific __repr__ " - # "-- possible missing information =====\n" - remote_tb = getattr(self, "_remote_tb", "No remote traceback available") +class ExceptionMetaClass(type): + def __init__(cls, class_name, base_classes, class_dict): + super(ExceptionMetaClass, cls).__init__(class_name, base_classes, class_dict) + cls.__orig_str__ = cls.__str__ + cls.__orig_repr__ = cls.__repr__ + for n in ("_exception_str", "_exception_repr", "_exception_tb"): + setattr( + cls, + n, + property( + lambda self, n=n: getattr(self, "%s_val" % n, ""), + lambda self, v, n=n: setattr(self, "%s_val" % n, v), + ), + ) + + def _do_str(self): + text = self._exception_str text += "\n\n===== Remote (on server) traceback =====\n" - text += remote_tb + text += self._exception_tb text += "========================================\n" return text - WithPrettyPrinting.__name__ = exception_class.__name__ - WithPrettyPrinting.__module__ = exception_class.__module__ - WithPrettyPrinting.__realclass__ = exception_class - _derived_exceptions[exception_class] = WithPrettyPrinting - return WithPrettyPrinting + cls.__str__ = _do_str + cls.__repr__ = lambda self: self._exception_repr class RemoteInterpreterException(Exception): - """A 'generic exception' that is raised when the exception the gotten from - the remote server cannot be instantiated locally""" + """ + A 'generic' exception that was raised on the server side for which we have no + equivalent exception on this side + """ pass - - -_remote_exceptions_class = {} # Exception name -> type of that exception -_derived_exceptions = {} diff --git a/metaflow/plugins/env_escape/override_decorators.py b/metaflow/plugins/env_escape/override_decorators.py index cb9fd9cc099..095e1eb1786 100644 --- a/metaflow/plugins/env_escape/override_decorators.py +++ b/metaflow/plugins/env_escape/override_decorators.py @@ -110,18 +110,18 @@ def _wrapped(func): return _wrapped -class LocalException(object): - def __init__(self, class_path, wrapped_class): +class LocalExceptionDeserializer(object): + def __init__(self, class_path, deserializer): self._class_path = class_path - self._class = wrapped_class + self._deserializer = deserializer @property def class_path(self): return self._class_path @property - def wrapped_class(self): - return self._class + def deserializer(self): + return self._deserializer class RemoteExceptionSerializer(object): @@ -138,9 +138,9 @@ def serializer(self): return self._serializer -def local_exception(class_path): - def _wrapped(cls): - return LocalException(class_path, cls) +def local_exception_deserialize(class_path): + def _wrapped(func): + return LocalExceptionDeserializer(class_path, func) return _wrapped diff --git a/metaflow/plugins/env_escape/server.py b/metaflow/plugins/env_escape/server.py index a78e870c883..dd960bca3e8 100644 --- a/metaflow/plugins/env_escape/server.py +++ b/metaflow/plugins/env_escape/server.py @@ -36,6 +36,7 @@ OP_GETVAL, OP_SETVAL, OP_INIT, + OP_SUBCLASSCHECK, VALUE_LOCAL, VALUE_REMOTE, CONTROL_GETEXPORTS, @@ -113,9 +114,21 @@ def __init__(self, config_dir, max_pickle_version): # this by listing aliases in the same order so we don't support # it for now. raise ValueError( - "%s is an alias to both %s and %s" % (alias, base_name, a) + "%s is an alias to both %s and %s -- make sure all aliases " + "are listed in the same order" % (alias, base_name, a) ) + # Detect circular aliases. If a user lists ("a", "b") and then ("b", "a"), we + # will have an entry in aliases saying b is an alias for a and a is an alias + # for b which is a recipe for disaster since we no longer have a cannonical name + # for things. + for alias, base_name in self._aliases.items(): + if base_name in self._aliases: + raise ValueError( + "%s and %s are circular aliases -- make sure all aliases " + "are listed in the same order" % (alias, base_name) + ) + # Determine if we have any overrides self._overrides = {} self._getattr_overrides = {} @@ -124,8 +137,9 @@ def __init__(self, config_dir, max_pickle_version): for override in override_values: if isinstance(override, (RemoteAttrOverride, RemoteOverride)): for obj_name, obj_funcs in override.obj_mapping.items(): + canonical_name = get_canonical_name(obj_name, self._aliases) obj_type = self._known_classes.get( - obj_name, self._proxied_types.get(obj_name) + canonical_name, self._proxied_types.get(obj_name) ) if obj_type is None: raise ValueError( @@ -146,11 +160,17 @@ def __init__(self, config_dir, max_pickle_version): ) override_dict[name] = override.func elif isinstance(override, RemoteExceptionSerializer): + canonical_name = get_canonical_name(override.class_path, self._aliases) + if canonical_name not in self._known_exceptions: + raise ValueError( + "%s does not refer to an exported exception" + % override.class_path + ) if override.class_path in self._exception_serializers: raise ValueError( "%s exception serializer already defined" % override.class_path ) - self._exception_serializers[override.class_path] = override.serializer + self._exception_serializers[canonical_name] = override.serializer # Process the exceptions making sure we have all the ones we need and building a # topologically sorted list for the client to instantiate @@ -181,8 +201,8 @@ def __init__(self, config_dir, max_pickle_version): else: raise ValueError( "Exported exception %s has non exported and non builtin parent " - "exception: %s. Known exceptions: %s" - % (ex_name, fqn, str(self._known_exceptions)) + "exception: %s (%s). Known exceptions: %s." + % (ex_name, fqn, canonical_fqn, str(self._known_exceptions)) ) name_to_parent_count[ex_name_canonical] = len(parents) - 1 name_to_parents[ex_name_canonical] = parents @@ -236,6 +256,7 @@ def __init__(self, config_dir, max_pickle_version): OP_GETVAL: self._handle_getval, OP_SETVAL: self._handle_setval, OP_INIT: self._handle_init, + OP_SUBCLASSCHECK: self._handle_subclasscheck, } self._local_objects = {} @@ -273,6 +294,7 @@ def encode(self, obj): def encode_exception(self, ex_type, ex, trace_back): try: full_name = "%s.%s" % (ex_type.__module__, ex_type.__name__) + get_canonical_name(full_name, self._aliases) serializer = self._exception_serializers.get(full_name) except AttributeError: # Ignore if no __module__ for example -- definitely not something we built @@ -483,6 +505,21 @@ def _handle_init(self, target, class_name, *args, **kwargs): raise ValueError("Unknown class %s" % class_name) return class_type(*args, **kwargs) + def _handle_subclasscheck(self, target, class_name, otherclass_name, reverse=False): + class_type = self._known_classes.get(class_name) + if class_type is None: + raise ValueError("Unknown class %s" % class_name) + try: + sub_module, sub_name = otherclass_name.rsplit(".", 1) + __import__(sub_module, None, None, "*") + except Exception: + sub_module = None + if sub_module is None: + return False + if reverse: + return issubclass(class_type, getattr(sys.modules[sub_module], sub_name)) + return issubclass(getattr(sys.modules[sub_module], sub_name), class_type) + if __name__ == "__main__": max_pickle_version = int(sys.argv[1]) diff --git a/metaflow/plugins/env_escape/stub.py b/metaflow/plugins/env_escape/stub.py index 1e341c52d4a..bf101633f22 100644 --- a/metaflow/plugins/env_escape/stub.py +++ b/metaflow/plugins/env_escape/stub.py @@ -1,5 +1,6 @@ import functools import pickle +from typing import Any from .consts import ( OP_GETATTR, @@ -15,8 +16,11 @@ OP_PICKLE, OP_DIR, OP_INIT, + OP_SUBCLASSCHECK, ) +from .exception_transferer import ExceptionMetaClass + DELETED_ATTRS = frozenset(["__array_struct__", "__array_interface__"]) # These attributes are accessed directly on the stub (not directly forwarded) @@ -26,7 +30,10 @@ "___remote_class_name___", "___identifier___", "___connection___", - "___local_overrides___" "__class__", + "___local_overrides___", + "___is_returned_exception___", + "___exception_attributes___", + "__class__", "__init__", "__del__", "__delattr__", @@ -36,9 +43,11 @@ "__getattribute__", "__hash__", "__instancecheck__", + "__subclasscheck__", "__init__", "__metaclass__", "__module__", + "__name__", "__new__", "__reduce__", "__reduce_ex__", @@ -62,17 +71,19 @@ def fwd_request(stub, request_type, *args, **kwargs): connection = object.__getattribute__(stub, "___connection___") - return connection.stub_request(stub, request_type, *args, **kwargs) + if connection: + return connection.stub_request(stub, request_type, *args, **kwargs) + raise RuntimeError( + "Returned exception stub cannot be used to make further remote requests" + ) class StubMetaClass(type): - __slots__ = () - - def __repr__(self): - if self.__module__: - return "" % (self.__module__, self.__name__) + def __repr__(cls): + if cls.__module__: + return "" % (cls.__module__, cls.__name__) else: - return "" % (self.__name__,) + return "" % (cls.__name__,) def with_metaclass(meta, *bases): @@ -94,24 +105,25 @@ class Stub(with_metaclass(StubMetaClass, object)): happen on the remote side (server). """ - __slots__ = [ - "___remote_class_name___", - "___identifier___", - "___connection___", - "__weakref__", - ] - + __slots__ = () # def __iter__(self): # FIXME: Keep debugger QUIET!! # raise AttributeError - def __init__(self, connection, remote_class_name, identifier): + def __init__( + self, connection, remote_class_name, identifier, _is_returned_exception=False + ): self.___remote_class_name___ = remote_class_name self.___identifier___ = identifier self.___connection___ = connection + # If it is a returned exception (ie: it was raised by the server), it behaves + # a bit differently for methods like __str__ and __repr__ (we try not to get + # stuff from the server) + self.___is_returned_exception___ = _is_returned_exception def __del__(self): try: - fwd_request(self, OP_DEL) + if not self.___is_returned_exception___: + fwd_request(self, OP_DEL) except Exception: # raised in a destructor, most likely on program termination, # when the connection might have already been closed. @@ -120,9 +132,7 @@ def __del__(self): def __getattribute__(self, name): if name in LOCAL_ATTRS: - if name == "__class__": - return None - elif name == "__doc__": + if name == "__doc__": return self.__getattr__("__doc__") elif name in DELETED_ATTRS: raise AttributeError() @@ -144,10 +154,16 @@ def __delattr__(self, name): if name in LOCAL_ATTRS: object.__delattr__(self, name) else: + if self.___is_returned_exception___: + raise AttributeError() return fwd_request(self, OP_DELATTR, name) def __setattr__(self, name, value): - if name in LOCAL_ATTRS or name in self.___local_overrides___: + if ( + name in LOCAL_ATTRS + or name in self.___local_overrides___ + or self.___is_returned_exception___ + ): object.__setattr__(self, name, value) else: fwd_request(self, OP_SETATTR, name, value) @@ -159,9 +175,13 @@ def __hash__(self): return fwd_request(self, OP_HASH) def __repr__(self): + if self.___is_returned_exception___: + return self.__exception_repr__() return fwd_request(self, OP_REPR) def __str__(self): + if self.___is_returned_exception___: + return self.__exception_str__() return fwd_request(self, OP_STR) def __exit__(self, exc, typ, tb): @@ -173,6 +193,16 @@ def __reduce_ex__(self, proto): # support for pickling return pickle.loads, (fwd_request(self, OP_PICKLE, proto),) + @classmethod + def __subclasshook__(cls, parent): + if parent.__bases__[0] == Stub: + raise NotImplementedError # Follow the usual mechanism + # If this is not a stub, we go over to the other side + parent_name = "%s.%s" % (parent.__module__, parent.__name__) + return cls.___class_connection___.stub_request( + None, OP_SUBCLASSCHECK, cls.___class_remote_class_name___, parent_name, True + ) + def _make_method(method_type, connection, class_name, name, doc): if name == "__call__": @@ -248,6 +278,80 @@ def __call__(cls, *args, **kwargs): None, OP_INIT, cls.___class_remote_class_name___, *args, **kwargs ) + def __subclasscheck__(cls, subclass): + subclass_name = "%s.%s" % (subclass.__module__, subclass.__name__) + if subclass.__bases__[0] == Stub: + subclass_name = subclass.___class_remote_class_name___ + return cls.___class_connection___.stub_request( + None, + OP_SUBCLASSCHECK, + cls.___class_remote_class_name___, + subclass_name, + ) + + def __instancecheck__(cls, instance): + if type(instance) == cls: + # Fast path if it's just an object of this class + return True + # Goes to __subclasscheck__ above + return cls.__subclasscheck__(type(instance)) + + +class MetaExceptionWithConnection(StubMetaClass, ExceptionMetaClass): + def __new__(cls, class_name, base_classes, class_dict, connection): + return type.__new__(cls, class_name, base_classes, class_dict) + + def __init__(cls, class_name, base_classes, class_dict, connection): + cls.___class_remote_class_name___ = class_name + cls.___class_connection___ = connection + + # We call the one on ExceptionMetaClass which does everything needed (StubMetaClass + # does not do anything special for init) + ExceptionMetaClass.__init__(cls, class_name, base_classes, class_dict) + + # Restore __str__ and __repr__ to the original ones because we need to determine + # if we call them depending on whether or not the object is a returned exception + # or not + cls.__exception_str__ = cls.__str__ + cls.__exception_repr__ = cls.__repr__ + cls.__str__ = cls.__orig_str__ + cls.__repr__ = cls.__orig_repr__ + + def __call__(cls, *args, **kwargs): + # Very similar to the other case but we also need to be able to detect + # local instantiation of an exception so that we can set the __is_returned_exception__ + if len(args) > 0 and id(args[0]) == id(cls.___class_connection___): + return super(MetaExceptionWithConnection, cls).__call__(*args, **kwargs) + elif kwargs and kwargs.get("_is_returned_exception", False): + return super(MetaExceptionWithConnection, cls).__call__( + None, None, None, _is_returned_exception=True + ) + else: + return cls.___class_connection___.stub_request( + None, OP_INIT, cls.___class_remote_class_name___, *args, **kwargs + ) + + # The issue is that for a proxied object that is also an exception, we now have + # two classes representing it, one that includes the Stub class and one that doesn't + # Concretely: + # - test.MyException would return a class that derives from Stub + # - test.MySubException would return a class that derives from Stub and test.MyException + # but WITHOUT the Stub portion (see get_local_class). + # - we want issubclass(test.MySubException, test.MyException) to return True and + # the same with instance checks. + def __instancecheck__(cls, instance): + return cls.__subclasscheck__(type(instance)) + + def __subclasscheck__(cls, subclass): + # __mro__[0] is this class itself + # __mro__[1] is the stub so we start checking at 2 + return any( + [ + subclass.__mro__[i] in cls.__mro__[2:] + for i in range(2, len(subclass.__mro__)) + ] + ) + def create_class( connection, @@ -256,8 +360,16 @@ def create_class( getattr_overrides, setattr_overrides, class_methods, + parents, ): - class_dict = {"__slots__": ()} + class_dict = { + "__slots__": [ + "___remote_class_name___", + "___identifier___", + "___connection___", + "___is_returned_exception___", + ] + } for name, doc in class_methods.items(): method_type = NORMAL_METHOD if name.startswith("___s___"): @@ -318,5 +430,38 @@ def create_class( ) overriden_attrs.add(attr) class_dict[attr] = property(getter, setter) + if parents: + # This means this is also an exception so we add a few more things to it + # so that it + # This is copied from ExceptionMetaClass in exception_transferer.py + for n in ("_exception_str", "_exception_repr", "_exception_tb"): + class_dict[n] = property( + lambda self, n=n: getattr(self, "%s_val" % n, ""), + lambda self, v, n=n: setattr(self, "%s_val" % n, v), + ) + + def _do_str(self): + text = self._exception_str + text += "\n\n===== Remote (on server) traceback =====\n" + text += self._exception_tb + text += "========================================\n" + return text + + class_dict["__exception_str__"] = _do_str + class_dict["__exception_repr__"] = lambda self: self._exception_repr + else: + # If we are based on an exception, we already have __weakref__ so we don't add + # it but not the case if we are not. + class_dict["__slots__"].append("__weakref__") + + class_module, class_name_only = class_name.rsplit(".", 1) class_dict["___local_overrides___"] = overriden_attrs - return MetaWithConnection(class_name, (Stub,), class_dict, connection) + class_dict["__module__"] = class_module + if parents: + to_return = MetaExceptionWithConnection( + class_name, (Stub, *parents), class_dict, connection + ) + else: + to_return = MetaWithConnection(class_name, (Stub,), class_dict, connection) + to_return.__name__ = class_name_only + return to_return diff --git a/metaflow/plugins/env_escape/utils.py b/metaflow/plugins/env_escape/utils.py index f9fe1aacf0f..c0e00c3a0f6 100644 --- a/metaflow/plugins/env_escape/utils.py +++ b/metaflow/plugins/env_escape/utils.py @@ -13,12 +13,12 @@ def get_methods(class_object): for base_class in mros: all_attributes.update(base_class.__dict__) for name, attribute in all_attributes.items(): - if hasattr(attribute, "__call__"): - all_methods[name] = inspect.getdoc(attribute) - elif isinstance(attribute, staticmethod): + if isinstance(attribute, staticmethod): all_methods["___s___%s" % name] = inspect.getdoc(attribute) elif isinstance(attribute, classmethod): all_methods["___c___%s" % name] = inspect.getdoc(attribute) + elif hasattr(attribute, "__call__"): + all_methods[name] = inspect.getdoc(attribute) return all_methods diff --git a/test/env_escape/example.py b/test/env_escape/example.py index bef9d0871f7..128556de4b8 100644 --- a/test/env_escape/example.py +++ b/test/env_escape/example.py @@ -1,6 +1,8 @@ import os import sys +from html.parser import HTMLParser + from metaflow import FlowSpec, step, conda @@ -28,41 +30,47 @@ def run_test(through_escape=False): import test_lib as test + print("-- Test aliasing --") if through_escape: # This tests package aliasing - from test_lib.package import TestClass3 - else: - from test_lib import TestClass3 + from test_lib.alias import TestClass1 - o1 = test.TestClass1(123) - print("-- Test print_value --") + o1 = test.TestClass1(10) + print("-- Test normal method with overrides --") if through_escape: - # The server_mapping should add 5 here - assert o1.print_value() == 128 + expected_value = 10 + 8 else: - assert o1.print_value() == 123 - print("-- Test property --") - assert o1.value == 123 - print("-- Test value override (get) --") - assert o1.override_value == 123 - print("-- Test value override (set) --") - o1.override_value = 456 - assert o1.override_value == 456 + expected_value = 10 + assert o1.print_value() == expected_value - print("-- Test static method --") - assert test.TestClass1.somethingstatic(5) == 47 - assert o1.somethingstatic(5) == 47 - print("-- Test class method --") - assert test.TestClass1.somethingclass() == 25 - assert o1.somethingclass() == 25 + print("-- Test property (no override) --") + assert o1.value == 10 + o1.value = 15 + assert o1.value == 15 + if through_escape: + expected_value = 15 + 8 + else: + expected_value = 15 + assert o1.print_value() == expected_value - print("-- Test set and get --") - o1.value = 2 + print("-- Test property (with override) --") if through_escape: - # The server_mapping should add 5 here - assert o1.print_value() == 7 + expected_value = 123 + 8 + expected_value2 = 200 + 16 else: - assert o1.print_value() == 2 + expected_value = 123 + expected_value2 = 200 + assert o1.override_value == expected_value + o1.override_value = 200 + assert o1.override_value == expected_value2 + + print("-- Test static method --") + assert test.TestClass1.static_method(5) == 47 + assert o1.static_method(5) == 47 + + print("-- Test class method --") + assert test.TestClass1.class_method() == 25 + assert o1.class_method() == 25 print("-- Test function --") assert test.test_func() == "In test func" @@ -74,21 +82,144 @@ def run_test(through_escape=False): print("-- Test chaining of exported classes --") o2 = o1.to_class2(5) - assert o2.something("foo") == "In Test2 with foo" + assert o2.something("foo") == "Test2:Something:foo" + assert o2.__class__.__name__ == "TestClass2" + assert o2.__class__.__module__ == "test_lib" + print("-- Test Iterating --") - for i in o2: - print("Got %d" % i) + for idx, i in enumerate(o2): + assert idx == i - 15 + assert i == 19 + + print("-- Test weird indirection --") + o1.weird_indirection("foo")(10) + assert o1.foo == 10 + o1.weird_indirection("_value")(20) + assert o1.value == 20 + + print("-- Test subclasses --") + child_obj = test.ChildClass() + child_obj_returned = o1.returnChild() + for o in (child_obj, child_obj_returned): + o.feed("Hello

World!

") + assert o.get_output() == ["html", "p", "p", "html"] + + print("-- Test isinstance/issubclass --") + ex_child = test.ExceptionAndClassChild("I am a child") + assert isinstance(ex_child, test.ExceptionAndClassChild) + assert isinstance(ex_child, test.ExceptionAndClass) + assert isinstance(ex_child, Exception) + assert isinstance(ex_child, object) + assert ex_child.__class__.__name__ == "ExceptionAndClassChild" + assert ex_child.__class__.__module__ == "test_lib" + + assert issubclass(type(ex_child), test.ExceptionAndClass) + assert issubclass(test.ExceptionAndClassChild, test.ExceptionAndClass) + assert issubclass(type(ex_child), Exception) + assert issubclass(test.ExceptionAndClassChild, Exception) + assert issubclass(type(ex_child), object) + assert issubclass(test.ExceptionAndClassChild, object) + + child_obj = test.ChildClass() + child_obj_returned = o1.returnChild() + + # I can't find an easy way (yet) to test support for subclasses based on non + # proxied types. It seems more minor for now so ignoring. + for o in (child_obj, child_obj_returned): + assert isinstance(o, test.ChildClass) + assert isinstance(o, test.BaseClass) + # assert isinstance(o, HTMLParser) + assert isinstance(o, object) + assert issubclass(type(o), test.BaseClass) + # assert issubclass(type(o), HTMLParser) + assert issubclass(type(o), object) + assert issubclass(test.ChildClass, test.BaseClass) + # assert issubclass(test.ChildClass, HTMLParser) + assert issubclass(test.ChildClass, object) + + print("-- Test exceptions --") + # Non proxied exceptions can't be returned as objects + try: + vexc = o1.raiseOrReturnValueError() + assert not through_escape, "Should have raised through escape" + assert isinstance(vexc, ValueError) + except RuntimeError as e: + assert ( + through_escape + and "Cannot proxy value of type " in str(e) + ) - print("-- Test exception --") - o3 = TestClass3() try: - o3.raiseSomething() + excclass = o1.raiseOrReturnSomeException() + assert not through_escape, "Should have raised through escape" + assert isinstance(excclass, test.SomeException) + assert excclass.__class__.__name__ == "SomeException" + assert excclass.__class__.__module__ == "test_lib" + except RuntimeError as e: + assert ( + through_escape + and "Cannot proxy value of type " in str(e) + ) + + exception_and_class = o1.raiseOrReturnExceptionAndClass() + assert isinstance(exception_and_class, test.ExceptionAndClass) + assert isinstance(exception_and_class, test.MyBaseException) + assert isinstance(exception_and_class, Exception) + assert exception_and_class.method_on_exception() == "method_on_exception" + assert str(exception_and_class).startswith("ExceptionAndClass Str:") + + exception_and_class_child = o1.raiseOrReturnExceptionAndClassChild() + assert isinstance(exception_and_class_child, test.ExceptionAndClassChild) + assert isinstance(exception_and_class_child, test.ExceptionAndClass) + assert isinstance(exception_and_class_child, test.MyBaseException) + assert isinstance(exception_and_class_child, Exception) + assert exception_and_class_child.method_on_exception() == "method_on_exception" + assert ( + exception_and_class_child.method_on_child_exception() + == "method_on_child_exception" + ) + assert str(exception_and_class_child).startswith("ExceptionAndClassChild Str:") + + try: + o1.raiseOrReturnValueError(True) + assert False, "Should have raised" + except ValueError as e: + assert True + except Exception as e: + assert False, "Should have been ValueError" + + try: + o1.raiseOrReturnSomeException(True) + assert False, "Should have raised" except test.SomeException as e: - print("Caught the local exception: %s" % str(e)) + assert True + if through_escape: + assert e.user_value == 42 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been SomeException" - print("-- Test returning proxied object --") - o3.weird_indirection("foo")(10) - assert o3.foo == 10 + try: + o1.raiseOrReturnExceptionAndClass(True) + assert False, "Should have raised" + except test.ExceptionAndClass as e: + assert True + if through_escape: + assert e.user_value == 43 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been ExceptionAndClass" + + try: + o1.raiseOrReturnExceptionAndClassChild(True) + assert False, "Should have raised" + except test.ExceptionAndClassChild as e: + assert True + if through_escape: + assert e.user_value == 44 + assert "Remote (on server) traceback" in str(e) + except Exception as e: + assert False, "Should have been ExceptionAndClassChild" class EscapeTest(FlowSpec):