|
37 | 37 | OP_SETVAL, |
38 | 38 | OP_INIT, |
39 | 39 | OP_SUBCLASSCHECK, |
| 40 | + OP_GETCLASSATTR, |
40 | 41 | VALUE_LOCAL, |
41 | 42 | VALUE_REMOTE, |
42 | 43 | CONTROL_GETEXPORTS, |
@@ -96,11 +97,19 @@ def __init__(self, config_dir, max_pickle_version): |
96 | 97 | {v: k for k, v in self._proxied_types.items()} |
97 | 98 | ) |
98 | 99 |
|
99 | | - # We will also proxy functions from objects as needed. This is useful |
| 100 | + # We will also proxy functions and methods from objects as needed. This is useful |
100 | 101 | # for defaultdict for example since the `default_factory` function is a |
101 | | - # lambda that needs to be transferred. |
| 102 | + # lambda that needs to be transferred. Methods are also proxied to support |
| 103 | + # bound methods returned from functions (e.g., evaluator patterns). |
102 | 104 | self._class_types_to_names[type(lambda x: x)] = "function" |
103 | 105 |
|
| 106 | + # Register method type for bound methods |
| 107 | + class _TempClass: |
| 108 | + def _temp_method(self): |
| 109 | + pass |
| 110 | + |
| 111 | + self._class_types_to_names[type(_TempClass()._temp_method)] = "method" |
| 112 | + |
104 | 113 | # Update all alias information |
105 | 114 | for base_name, aliases in itertools.chain( |
106 | 115 | a1.items(), a2.items(), a3.items(), a4.items() |
@@ -257,6 +266,7 @@ def __init__(self, config_dir, max_pickle_version): |
257 | 266 | OP_SETVAL: self._handle_setval, |
258 | 267 | OP_INIT: self._handle_init, |
259 | 268 | OP_SUBCLASSCHECK: self._handle_subclasscheck, |
| 269 | + OP_GETCLASSATTR: self._handle_getclassattr, |
260 | 270 | } |
261 | 271 |
|
262 | 272 | self._local_objects = {} |
@@ -391,9 +401,9 @@ def pickle_object(self, obj): |
391 | 401 | def unpickle_object(self, obj): |
392 | 402 | if (not isinstance(obj, ObjReference)) or obj.value_type != VALUE_LOCAL: |
393 | 403 | raise ValueError("Invalid transferred object: %s" % str(obj)) |
394 | | - obj = self._local_objects.get(obj.identifier) |
395 | | - if obj: |
396 | | - return obj |
| 404 | + result = self._local_objects.get(obj.identifier) |
| 405 | + if result is not None: |
| 406 | + return result |
397 | 407 | raise ValueError("Invalid object -- id %s not known" % obj.identifier) |
398 | 408 |
|
399 | 409 | @staticmethod |
@@ -527,6 +537,15 @@ def _handle_subclasscheck(self, target, class_name, otherclass_name, reverse=Fal |
527 | 537 | return issubclass(class_type, getattr(sys.modules[sub_module], sub_name)) |
528 | 538 | return issubclass(getattr(sys.modules[sub_module], sub_name), class_type) |
529 | 539 |
|
| 540 | + def _handle_getclassattr(self, target, class_name, attr_name): |
| 541 | + # Handle class-level attribute access like EnumClass.MEMBER |
| 542 | + class_type = self._known_classes.get(class_name) |
| 543 | + if class_type is None: |
| 544 | + class_type = self._proxied_types.get(class_name) |
| 545 | + if class_type is None: |
| 546 | + raise ValueError("Unknown class %s" % class_name) |
| 547 | + return getattr(class_type, attr_name) |
| 548 | + |
530 | 549 |
|
531 | 550 | if __name__ == "__main__": |
532 | 551 | max_pickle_version = int(sys.argv[1]) |
|
0 commit comments