diff --git a/billiard/einfo.py b/billiard/einfo.py index 23ea4d2..7eb351b 100644 --- a/billiard/einfo.py +++ b/billiard/einfo.py @@ -1,5 +1,6 @@ import sys import traceback +import types __all__ = ['ExceptionInfo', 'Traceback'] @@ -26,6 +27,13 @@ def __init__(self, code): self.co_qualname = code.co_qualname self._co_positions = list(code.co_positions()) + @property + def __class__(self): + return types.CodeType + + def __reduce__(self): + return _Code.__new__, (_Code,), self.__dict__ + if sys.version_info >= (3, 11): @property def co_positions(self): @@ -58,6 +66,13 @@ def __init__(self, frame): # don't want to hit https://bugs.python.org/issue21967 self.f_restricted = False + @property + def __class__(self): + return types.FrameType + + def __reduce__(self): + return _Frame.__new__, (_Frame,), self.__dict__ + if sys.version_info >= (3, 11): @property def co_positions(self): @@ -100,6 +115,13 @@ def __init__(self): self.tb_next = None self.tb_lasti = 0 + @property + def __class__(self): + return types.TracebackType + + def __reduce__(self): + return _Truncated.__new__, (_Truncated,), self.__dict__ + if sys.version_info >= (3, 11): @property def co_positions(self): @@ -120,6 +142,13 @@ def __init__(self, tb, max_frames=DEFAULT_MAX_FRAMES, depth=0): else: self.tb_next = _Truncated() + @property + def __class__(self): + return types.TracebackType + + def __reduce__(self): + return Traceback.__new__, (Traceback,), self.__dict__ + class RemoteTraceback(Exception): def __init__(self, tb): diff --git a/t/unit/test_einfo.py b/t/unit/test_einfo.py index c2c126b..13489b0 100644 --- a/t/unit/test_einfo.py +++ b/t/unit/test_einfo.py @@ -1,6 +1,8 @@ +import inspect import logging import pickle import sys +import types from billiard.einfo import _Code # noqa from billiard.einfo import _Frame # noqa @@ -86,7 +88,7 @@ def test_code(): assert isinstance(code.co_argcount, int) if sys.version_info >= (3, 11): assert callable(code.co_positions) - assert next(code.co_positions()) == (77, 77, 0, 0) + assert next(code.co_positions()) == (79, 79, 0, 0) def test_object_init(): @@ -116,3 +118,94 @@ def test_truncated_co_positions(): assert list(iter(truncated.co_positions())) == list( iter(truncated.tb_frame.co_positions()) ) + + +def make_python_tb(): + tb = None + depth = 0 + while True: + try: + frame = sys._getframe(depth) + except ValueError: + break + else: + depth += 1 + + tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno) + + assert tb is not None, "Failed to create a traceback object" + + return tb + + +def test_isinstance(): + tb = Traceback(tb=make_python_tb()) + frame = tb.tb_frame + code = frame.f_code + + assert isinstance(tb, types.TracebackType) + assert isinstance(tb, Traceback) + assert isinstance(frame, types.FrameType) + assert isinstance(frame, _Frame) + assert isinstance(code, types.CodeType) + assert isinstance(code, _Code) + + +def repickle(obj): + """Round-trip an object through pickle.""" + return pickle.loads(pickle.dumps(obj)) + + +def test_pickle(): + """ + While `__class__` is overridden to return the built-in types, + this would break unpickling in Python versions prior to 3.10 + """ + tb = Traceback(tb=make_python_tb()) + tb2 = repickle(tb) + + assert type(tb2) == type(tb) + assert tb2.tb_lineno == tb.tb_lineno + + frame = tb.tb_frame + frame2 = repickle(frame) + + assert type(frame2) == type(frame) + assert frame2.f_lineno == frame.f_lineno + + code = frame.f_code + code2 = repickle(code) + + assert type(code2) == type(code) + assert code2.co_name == code.co_name + + +class TestInspect: + def test_istraceback(self): + tb = Traceback(tb=make_python_tb()) + assert inspect.istraceback(tb) + assert inspect.istraceback(repickle(tb)) + + def test_isframe(self): + frame = _Frame(make_python_tb().tb_frame) + assert inspect.isframe(frame) + assert inspect.isframe(repickle(frame)) + + def test_iscode(self): + code = _Code(make_python_tb().tb_frame.f_code) + assert inspect.iscode(code) + assert inspect.iscode(repickle(code)) + + def test_getframeinfo(self): + tb = Traceback(make_python_tb()) + assert inspect.getframeinfo(tb) + assert inspect.getframeinfo(repickle(tb)) + + frame = tb.tb_frame + assert inspect.getframeinfo(frame) + assert inspect.getframeinfo(repickle(frame)) + + def test_getinnerframes(self): + tb = Traceback(tb=make_python_tb()) + assert inspect.getinnerframes(tb) + assert inspect.getinnerframes(repickle(tb))