Skip to content

Commit bc8dead

Browse files
Various fixes in the env escape code (#1734)
* Various fixes in the env escape code We now handle exceptions that are also proxied classes better We canonicalize more names to better deal with aliases in proxied libraries. Other cleanup * Better support for subclasses and fix issues with >1 exception depth * Add support for __class__ in escape hatch * Fix mypy issue * fix typos --------- Co-authored-by: Chaoying Wang <[email protected]>
1 parent 91de0b6 commit bc8dead

File tree

12 files changed

+711
-328
lines changed

12 files changed

+711
-328
lines changed

metaflow/plugins/env_escape/client.py

+148-27
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@
3434
from .communication.socket_bytestream import SocketByteStream
3535

3636
from .data_transferer import DataTransferer, ObjReference
37-
from .exception_transferer import load_exception
38-
from .override_decorators import LocalAttrOverride, LocalException, LocalOverride
37+
from .exception_transferer import ExceptionMetaClass, load_exception
38+
from .override_decorators import (
39+
LocalAttrOverride,
40+
LocalExceptionDeserializer,
41+
LocalOverride,
42+
)
3943
from .stub import create_class
4044
from .utils import get_canonical_name
4145

@@ -193,28 +197,41 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
193197
self._proxied_classes = {
194198
k: None
195199
for k in itertools.chain(
196-
response[FIELD_CONTENT]["classes"], response[FIELD_CONTENT]["proxied"]
200+
response[FIELD_CONTENT]["classes"],
201+
response[FIELD_CONTENT]["proxied"],
202+
(e[0] for e in response[FIELD_CONTENT]["exceptions"]),
197203
)
198204
}
199205

206+
self._exception_hierarchy = dict(response[FIELD_CONTENT]["exceptions"])
207+
self._proxied_classnames = set(response[FIELD_CONTENT]["classes"]).union(
208+
response[FIELD_CONTENT]["proxied"]
209+
)
210+
self._aliases = response[FIELD_CONTENT]["aliases"]
211+
200212
# Determine all overrides
201213
self._overrides = {}
202214
self._getattr_overrides = {}
203215
self._setattr_overrides = {}
204-
self._exception_overrides = {}
216+
self._exception_deserializers = {}
205217
for override in override_values:
206218
if isinstance(override, (LocalOverride, LocalAttrOverride)):
207219
for obj_name, obj_funcs in override.obj_mapping.items():
208-
if obj_name not in self._proxied_classes:
220+
canonical_name = get_canonical_name(obj_name, self._aliases)
221+
if canonical_name not in self._proxied_classes:
209222
raise ValueError(
210223
"%s does not refer to a proxied or override type" % obj_name
211224
)
212225
if isinstance(override, LocalOverride):
213-
override_dict = self._overrides.setdefault(obj_name, {})
226+
override_dict = self._overrides.setdefault(canonical_name, {})
214227
elif override.is_setattr:
215-
override_dict = self._setattr_overrides.setdefault(obj_name, {})
228+
override_dict = self._setattr_overrides.setdefault(
229+
canonical_name, {}
230+
)
216231
else:
217-
override_dict = self._getattr_overrides.setdefault(obj_name, {})
232+
override_dict = self._getattr_overrides.setdefault(
233+
canonical_name, {}
234+
)
218235
if isinstance(obj_funcs, str):
219236
obj_funcs = (obj_funcs,)
220237
for name in obj_funcs:
@@ -223,11 +240,18 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
223240
"%s was already overridden for %s" % (name, obj_name)
224241
)
225242
override_dict[name] = override.func
226-
if isinstance(override, LocalException):
227-
cur_ex = self._exception_overrides.get(override.class_path, None)
228-
if cur_ex is not None:
229-
raise ValueError("Exception %s redefined" % override.class_path)
230-
self._exception_overrides[override.class_path] = override.wrapped_class
243+
if isinstance(override, LocalExceptionDeserializer):
244+
canonical_name = get_canonical_name(override.class_path, self._aliases)
245+
if canonical_name not in self._exception_hierarchy:
246+
raise ValueError(
247+
"%s does not refer to an exception type" % override.class_path
248+
)
249+
cur_des = self._exception_deserializers.get(canonical_name, None)
250+
if cur_des is not None:
251+
raise ValueError(
252+
"Exception %s has multiple deserializers" % override.class_path
253+
)
254+
self._exception_deserializers[canonical_name] = override.deserializer
231255

232256
# Proxied standalone functions are functions that are proxied
233257
# as part of other objects like defaultdict for which we create a
@@ -243,8 +267,6 @@ def inner_init(self, python_executable, pythonpath, max_pickle_version, config_d
243267
"aliases": response[FIELD_CONTENT]["aliases"],
244268
}
245269

246-
self._aliases = response[FIELD_CONTENT]["aliases"]
247-
248270
def __del__(self):
249271
self.cleanup()
250272

@@ -288,8 +310,9 @@ def name(self):
288310
def get_exports(self):
289311
return self._export_info
290312

291-
def get_local_exception_overrides(self):
292-
return self._exception_overrides
313+
def get_exception_deserializer(self, name):
314+
cannonical_name = get_canonical_name(name, self._aliases)
315+
return self._exception_deserializers.get(cannonical_name)
293316

294317
def stub_request(self, stub, request_type, *args, **kwargs):
295318
# Encode the operation to send over the wire and wait for the response
@@ -313,7 +336,7 @@ def stub_request(self, stub, request_type, *args, **kwargs):
313336
if response_type == MSG_REPLY:
314337
return self.decode(response[FIELD_CONTENT])
315338
elif response_type == MSG_EXCEPTION:
316-
raise load_exception(self._datatransferer, response[FIELD_CONTENT])
339+
raise load_exception(self, response[FIELD_CONTENT])
317340
elif response_type == MSG_INTERNAL_ERROR:
318341
raise RuntimeError(
319342
"Error in the server runtime:\n\n===== SERVER TRACEBACK =====\n%s"
@@ -334,10 +357,27 @@ def decode(self, json_obj):
334357
# this connection will be converted to a local stub.
335358
return self._datatransferer.load(json_obj)
336359

337-
def get_local_class(self, name, obj_id=None):
360+
def get_local_class(
361+
self, name, obj_id=None, is_returned_exception=False, is_parent=False
362+
):
338363
# Gets (and creates if needed), the class mapping to the remote
339364
# class of name 'name'.
365+
366+
# We actually deal with four types of classes:
367+
# - proxied functions
368+
# - classes that are proxied regular classes AND proxied exceptions
369+
# - classes that are proxied regular classes AND NOT proxied exceptions
370+
# - classes that are NOT proxied regular classes AND are proxied exceptions
340371
name = get_canonical_name(name, self._aliases)
372+
373+
def name_to_parent_name(name):
374+
return "parent:%s" % name
375+
376+
if is_parent:
377+
lookup_name = name_to_parent_name(name)
378+
else:
379+
lookup_name = name
380+
341381
if name == "function":
342382
# Special handling of pickled functions. We create a new class that
343383
# simply has a __call__ method that will forward things back to
@@ -346,27 +386,108 @@ def get_local_class(self, name, obj_id=None):
346386
raise RuntimeError("Local function unpickling without an object ID")
347387
if obj_id not in self._proxied_standalone_functions:
348388
self._proxied_standalone_functions[obj_id] = create_class(
349-
self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}
389+
self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, []
350390
)
351391
return self._proxied_standalone_functions[obj_id]
392+
local_class = self._proxied_classes.get(lookup_name, None)
393+
if local_class is not None:
394+
return local_class
395+
396+
is_proxied_exception = name in self._exception_hierarchy
397+
is_proxied_non_exception = name in self._proxied_classnames
398+
399+
if not is_proxied_exception and not is_proxied_non_exception:
400+
if is_returned_exception or is_parent:
401+
# In this case, it may be a local exception that we need to
402+
# recreate
403+
try:
404+
ex_module, ex_name = name.rsplit(".", 1)
405+
__import__(ex_module, None, None, "*")
406+
except Exception:
407+
pass
408+
if ex_module in sys.modules and issubclass(
409+
getattr(sys.modules[ex_module], ex_name), BaseException
410+
):
411+
# This is a local exception that we can recreate
412+
local_exception = getattr(sys.modules[ex_module], ex_name)
413+
wrapped_exception = ExceptionMetaClass(
414+
ex_name,
415+
(local_exception,),
416+
dict(getattr(local_exception, "__dict__", {})),
417+
)
418+
wrapped_exception.__module__ = ex_module
419+
self._proxied_classes[lookup_name] = wrapped_exception
420+
return wrapped_exception
352421

353-
if name not in self._proxied_classes:
354422
raise ValueError("Class '%s' is not known" % name)
355-
local_class = self._proxied_classes[name]
356-
if local_class is None:
357-
# We need to build up this class. To do so, we take everything that the
358-
# remote class has and remove UNSUPPORTED things and overridden things
423+
424+
# At this stage:
425+
# - we don't have a local_class for this
426+
# - it is not an inbuilt exception so it is either a proxied exception, a
427+
# proxied class or a proxied object that is both an exception and a class.
428+
429+
parents = []
430+
if is_proxied_exception:
431+
# If exception, we need to get the parents from the exception
432+
ex_parents = self._exception_hierarchy[name]
433+
for parent in ex_parents:
434+
# We always consider it to be an exception so that we wrap even non
435+
# proxied builtins exceptions
436+
parents.append(self.get_local_class(parent, is_parent=True))
437+
# For regular classes, we get what it exposes from the server
438+
if is_proxied_non_exception:
359439
remote_methods = self.stub_request(None, OP_GETMETHODS, name)
440+
else:
441+
remote_methods = {}
442+
443+
parent_local_class = None
444+
local_class = None
445+
if is_proxied_exception:
446+
# If we are a proxied exception AND a proxied class, we create two classes:
447+
# actually:
448+
# - the class itself (which is a stub)
449+
# - the class in the capacity of a parent class (to another exception
450+
# presumably). The reason for this is that if we have an exception/proxied
451+
# class A and another B and B inherits from A, the MRO order would be all
452+
# wrong since both A and B would also inherit from `Stub`. Here what we
453+
# do is:
454+
# - A_parent inherits from the actual parents of A (let's assume a
455+
# builtin exception)
456+
# - A inherits from (Stub, A_parent)
457+
# - B_parent inherits from A_parent and the builtin Exception
458+
# - B inherits from (Stub, B_parent)
459+
ex_module, ex_name = name.rsplit(".", 1)
460+
parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {})
461+
parent_local_class.__module__ = ex_module
462+
463+
if is_proxied_non_exception:
360464
local_class = create_class(
361465
self,
362466
name,
363467
self._overrides.get(name, {}),
364468
self._getattr_overrides.get(name, {}),
365469
self._setattr_overrides.get(name, {}),
366470
remote_methods,
471+
(parent_local_class,) if parent_local_class else None,
367472
)
473+
if parent_local_class:
474+
self._proxied_classes[name_to_parent_name(name)] = parent_local_class
475+
if local_class:
368476
self._proxied_classes[name] = local_class
369-
return local_class
477+
else:
478+
# This is for the case of pure proxied exceptions -- we want the lookup of
479+
# foo.MyException to be the same class as looking of foo.MyException as a parent
480+
# of another exception so `isinstance` works properly
481+
self._proxied_classes[name] = parent_local_class
482+
483+
if is_parent:
484+
# This should never happen but making sure
485+
if not parent_local_class:
486+
raise RuntimeError(
487+
"Exception parent class %s is not a proxied exception" % name
488+
)
489+
return parent_local_class
490+
return self._proxied_classes[name]
370491

371492
def can_pickle(self, obj):
372493
return getattr(obj, "___connection___", None) == self
@@ -395,7 +516,7 @@ def unpickle_object(self, obj):
395516
obj_id = obj.identifier
396517
local_instance = self._proxied_objects.get(obj_id)
397518
if not local_instance:
398-
local_class = self.get_local_class(remote_class_name, obj_id)
519+
local_class = self.get_local_class(remote_class_name, obj_id=obj_id)
399520
local_instance = local_class(self, remote_class_name, obj_id)
400521
return local_instance
401522

metaflow/plugins/env_escape/client_modules.py

+15-47
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL
99
from .client import Client
10-
from .override_decorators import LocalException
1110
from .utils import get_canonical_name
1211

1312

@@ -16,27 +15,28 @@ def _clean_client(client):
1615

1716

1817
class _WrappedModule(object):
19-
def __init__(self, loader, prefix, exports, exception_classes, client):
18+
def __init__(self, loader, prefix, exports, client):
2019
self._loader = loader
2120
self._prefix = prefix
2221
self._client = client
2322
is_match = re.compile(
2423
r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605
2524
)
2625
self._exports = {}
27-
self._aliases = exports["aliases"]
26+
self._aliases = exports.get("aliases", [])
2827
for k in ("classes", "functions", "values"):
2928
result = []
30-
for item in exports[k]:
29+
for item in exports.get(k, []):
3130
m = is_match.match(item)
3231
if m:
3332
result.append(m.group(1))
3433
self._exports[k] = result
35-
self._exception_classes = {}
36-
for k, v in exception_classes.items():
37-
m = is_match.match(k)
34+
result = []
35+
for item, _ in exports.get("exceptions", []):
36+
m = is_match.match(item)
3837
if m:
39-
self._exception_classes[m.group(1)] = v
38+
result.append(m.group(1))
39+
self._exports["exceptions"] = result
4040

4141
def __getattr__(self, name):
4242
if name == "__loader__":
@@ -50,8 +50,8 @@ def __getattr__(self, name):
5050
name = get_canonical_name(self._prefix + "." + name, self._aliases)[
5151
len(self._prefix) + 1 :
5252
]
53-
if name in self._exports["classes"]:
54-
# We load classes lazily
53+
if name in self._exports["classes"] or name in self._exports["exceptions"]:
54+
# We load classes and exceptions lazily
5555
return self._client.get_local_class("%s.%s" % (self._prefix, name))
5656
elif name in self._exports["functions"]:
5757
# TODO: Grab doc back from the remote side like in _make_method
@@ -67,8 +67,6 @@ def func(*args, **kwargs):
6767
return self._client.stub_request(
6868
None, OP_GETVAL, "%s.%s" % (self._prefix, name)
6969
)
70-
elif name in self._exception_classes:
71-
return self._exception_classes[name]
7270
else:
7371
# Try to see if this is a submodule that we can load
7472
m = None
@@ -173,7 +171,6 @@ def load_module(self, fullname):
173171

174172
# Get information about overrides and what the server knows about
175173
exports = self._client.get_exports()
176-
ex_overrides = self._client.get_local_exception_overrides()
177174

178175
prefixes = set()
179176
export_classes = exports.get("classes", [])
@@ -182,42 +179,13 @@ def load_module(self, fullname):
182179
export_exceptions = exports.get("exceptions", [])
183180
self._aliases = exports.get("aliases", {})
184181
for name in itertools.chain(
185-
export_classes, export_functions, export_values
182+
export_classes,
183+
export_functions,
184+
export_values,
185+
(e[0] for e in export_exceptions),
186186
):
187187
splits = name.rsplit(".", 1)
188188
prefixes.add(splits[0])
189-
190-
# Now look at the exceptions coming from the server
191-
formed_exception_classes = {}
192-
for ex_name, ex_parents in export_exceptions:
193-
# Exception is a tuple (name, (parents,))
194-
# Exceptions are also given in order of instantiation (ie: the
195-
# server already topologically sorted them)
196-
ex_class_dict = ex_overrides.get(ex_name, None)
197-
if ex_class_dict is None:
198-
ex_class_dict = {}
199-
else:
200-
ex_class_dict = dict(ex_class_dict.__dict__)
201-
parents = []
202-
for fake_base in ex_parents:
203-
if fake_base.startswith("builtins."):
204-
# This is something we know of here
205-
parents.append(eval(fake_base[9:]))
206-
else:
207-
# It's in formed_classes
208-
parents.append(formed_exception_classes[fake_base])
209-
splits = ex_name.rsplit(".", 1)
210-
ex_class_dict["__user_defined__"] = set(ex_class_dict.keys())
211-
new_class = type(splits[1], tuple(parents), ex_class_dict)
212-
new_class.__module__ = splits[0]
213-
new_class.__name__ = splits[1]
214-
formed_exception_classes[ex_name] = new_class
215-
216-
# Now update prefixes as needed
217-
for name in formed_exception_classes:
218-
splits = name.rsplit(".", 1)
219-
prefixes.add(splits[0])
220-
221189
# We will make sure that we create modules even for "empty" prefixes
222190
# because packages are always loaded hierarchically so if we have
223191
# something in `a.b.c` but nothing directly in `a`, we still need to
@@ -235,7 +203,7 @@ def load_module(self, fullname):
235203
self._handled_modules = {}
236204
for prefix in prefixes:
237205
self._handled_modules[prefix] = _WrappedModule(
238-
self, prefix, exports, formed_exception_classes, self._client
206+
self, prefix, exports, self._client
239207
)
240208
canonical_fullname = get_canonical_name(fullname, self._aliases)
241209
# Modules are created canonically but we need to return something for any

0 commit comments

Comments
 (0)