Skip to content

[SOT][Exception][3.10-][3.13] Add exception handler for Py3.10- #72559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/exception_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses

from ...utils import InnerError
from .variables import ConstantVariable, ExceptionVariable


@dataclasses.dataclass
class ExceptionStack:

# This data structure manages exceptions as in CPython, primarily handling
# the __context__ attribute of SotCapturedException.

_exception_stack: list[ExceptionVariable | None] = dataclasses.field(
default_factory=list
)
_current_exception: ExceptionVariable | None = dataclasses.field(
default=None
)

def clear_current_exception(self):
self._current_exception = None

def set_current_exception(self, val: ExceptionVariable, graph):
self._set_context_and_break_context_reference_cycle(val, graph)
self._current_exception = val

def move_current_exception_to_stack(self):
self.push(self.get_current_exception())
self.clear_current_exception()

def get_current_exception(self):
if self._current_exception is None:
raise InnerError("Current exception should not be None")
return self._current_exception

def _set_context_recursive(self, val: ExceptionVariable, prev_idx):
# Recursively sets the __context__ attribute for ExceptionVariable objects
# in self._exception_stack. Ensures that __context__ is properly linked
# to the previous exception in the stack.
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
return val
if (
len(self._exception_stack) + prev_idx > 0
): # Prevent invalid negative indexing
prev = self._exception_stack[prev_idx]
self._set_context_recursive(prev, prev_idx - 1)
val.setattr("__context__", prev)
return val

def _break_context_reference_cycle(self, val: ExceptionVariable, graph):
# Detects and breaks cycles in exception __context__ chains using Floyd's algorithm,
# following CPython's implementation.

fast = slow = val
slow_update_toggle = False
while True:
context = fast.__context__
if (
type(context) is ConstantVariable
): # End of the chain; no context set
break

if context is val:
# The chain loops back to the original exception; break the cycle.
fast.setattr(
"__context__", ConstantVariable.wrap_literal(None, graph)
)
break

fast = context
if fast is slow:
# Cycle detected; all exceptions on the path have been visited and checked.
break

if slow_update_toggle:
slow = slow.__context__
slow_update_toggle = not slow_update_toggle

def _set_context_and_break_context_reference_cycle(
self, val: ExceptionVariable, graph
):
# set Exception.__context__
self._set_context_recursive(val, len(self._exception_stack) - 1)
self._break_context_reference_cycle(val, graph)

def pop(self):
return self._exception_stack.pop()

def push(self, val):
self._exception_stack.append(val)

def __len__(self):
return len(self._exception_stack)

def __str__(self):
return f"{self._exception_stack}"

def __getitem__(self, idx):
return self._exception_stack[idx]

def cleanup(self):
self._exception_stack[:] = []
self._current_exception = None
Loading
Loading