Skip to content

Commit efcd65c

Browse files
gh-122313: Clean up deep recursion guarding code in the compiler (GH-122640)
Add ENTER_RECURSIVE and LEAVE_RECURSIVE macros in ast.c, ast_opt.c and symtable.c. Remove VISIT_QUIT macro in symtable.c. The current recursion depth counter only needs to be updated during normal execution -- all functions should just return an error code if an error occurs.
1 parent fe0a28d commit efcd65c

File tree

3 files changed

+164
-162
lines changed

3 files changed

+164
-162
lines changed

Python/ast.c

+22-24
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@ struct validator {
1414
int recursion_limit; /* recursion limit */
1515
};
1616

17+
#define ENTER_RECURSIVE(ST) \
18+
do { \
19+
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
20+
PyErr_SetString(PyExc_RecursionError, \
21+
"maximum recursion depth exceeded during compilation"); \
22+
return 0; \
23+
} \
24+
} while(0)
25+
26+
#define LEAVE_RECURSIVE(ST) \
27+
do { \
28+
--(ST)->recursion_depth; \
29+
} while(0)
30+
1731
static int validate_stmts(struct validator *, asdl_stmt_seq *);
1832
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
1933
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
@@ -166,11 +180,7 @@ validate_constant(struct validator *state, PyObject *value)
166180
return 1;
167181

168182
if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
169-
if (++state->recursion_depth > state->recursion_limit) {
170-
PyErr_SetString(PyExc_RecursionError,
171-
"maximum recursion depth exceeded during compilation");
172-
return 0;
173-
}
183+
ENTER_RECURSIVE(state);
174184

175185
PyObject *it = PyObject_GetIter(value);
176186
if (it == NULL)
@@ -195,7 +205,7 @@ validate_constant(struct validator *state, PyObject *value)
195205
}
196206

197207
Py_DECREF(it);
198-
--state->recursion_depth;
208+
LEAVE_RECURSIVE(state);
199209
return 1;
200210
}
201211

@@ -213,11 +223,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
213223
assert(!PyErr_Occurred());
214224
VALIDATE_POSITIONS(exp);
215225
int ret = -1;
216-
if (++state->recursion_depth > state->recursion_limit) {
217-
PyErr_SetString(PyExc_RecursionError,
218-
"maximum recursion depth exceeded during compilation");
219-
return 0;
220-
}
226+
ENTER_RECURSIVE(state);
221227
int check_ctx = 1;
222228
expr_context_ty actual_ctx;
223229

@@ -398,7 +404,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
398404
PyErr_SetString(PyExc_SystemError, "unexpected expression");
399405
ret = 0;
400406
}
401-
state->recursion_depth--;
407+
LEAVE_RECURSIVE(state);
402408
return ret;
403409
}
404410

@@ -544,11 +550,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
544550
assert(!PyErr_Occurred());
545551
VALIDATE_POSITIONS(p);
546552
int ret = -1;
547-
if (++state->recursion_depth > state->recursion_limit) {
548-
PyErr_SetString(PyExc_RecursionError,
549-
"maximum recursion depth exceeded during compilation");
550-
return 0;
551-
}
553+
ENTER_RECURSIVE(state);
552554
switch (p->kind) {
553555
case MatchValue_kind:
554556
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
@@ -690,7 +692,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
690692
PyErr_SetString(PyExc_SystemError, "unexpected pattern");
691693
ret = 0;
692694
}
693-
state->recursion_depth--;
695+
LEAVE_RECURSIVE(state);
694696
return ret;
695697
}
696698

@@ -725,11 +727,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
725727
assert(!PyErr_Occurred());
726728
VALIDATE_POSITIONS(stmt);
727729
int ret = -1;
728-
if (++state->recursion_depth > state->recursion_limit) {
729-
PyErr_SetString(PyExc_RecursionError,
730-
"maximum recursion depth exceeded during compilation");
731-
return 0;
732-
}
730+
ENTER_RECURSIVE(state);
733731
switch (stmt->kind) {
734732
case FunctionDef_kind:
735733
ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
@@ -946,7 +944,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
946944
PyErr_SetString(PyExc_SystemError, "unexpected statement");
947945
ret = 0;
948946
}
949-
state->recursion_depth--;
947+
LEAVE_RECURSIVE(state);
950948
return ret;
951949
}
952950

Python/ast_opt.c

+20-19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ typedef struct {
1515
int recursion_limit; /* recursion limit */
1616
} _PyASTOptimizeState;
1717

18+
#define ENTER_RECURSIVE(ST) \
19+
do { \
20+
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
21+
PyErr_SetString(PyExc_RecursionError, \
22+
"maximum recursion depth exceeded during compilation"); \
23+
return 0; \
24+
} \
25+
} while(0)
26+
27+
#define LEAVE_RECURSIVE(ST) \
28+
do { \
29+
--(ST)->recursion_depth; \
30+
} while(0)
1831

1932
static int
2033
make_const(expr_ty node, PyObject *val, PyArena *arena)
@@ -708,11 +721,7 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
708721
static int
709722
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
710723
{
711-
if (++state->recursion_depth > state->recursion_limit) {
712-
PyErr_SetString(PyExc_RecursionError,
713-
"maximum recursion depth exceeded during compilation");
714-
return 0;
715-
}
724+
ENTER_RECURSIVE(state);
716725
switch (node_->kind) {
717726
case BoolOp_kind:
718727
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
@@ -811,7 +820,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
811820
case Name_kind:
812821
if (node_->v.Name.ctx == Load &&
813822
_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
814-
state->recursion_depth--;
823+
LEAVE_RECURSIVE(state);
815824
return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
816825
}
817826
break;
@@ -824,7 +833,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
824833
// No default case, so the compiler will emit a warning if new expression
825834
// kinds are added without being handled here
826835
}
827-
state->recursion_depth--;
836+
LEAVE_RECURSIVE(state);;
828837
return 1;
829838
}
830839

@@ -871,11 +880,7 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
871880
static int
872881
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
873882
{
874-
if (++state->recursion_depth > state->recursion_limit) {
875-
PyErr_SetString(PyExc_RecursionError,
876-
"maximum recursion depth exceeded during compilation");
877-
return 0;
878-
}
883+
ENTER_RECURSIVE(state);
879884
switch (node_->kind) {
880885
case FunctionDef_kind:
881886
CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
@@ -999,7 +1004,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
9991004
// No default case, so the compiler will emit a warning if new statement
10001005
// kinds are added without being handled here
10011006
}
1002-
state->recursion_depth--;
1007+
LEAVE_RECURSIVE(state);
10031008
return 1;
10041009
}
10051010

@@ -1031,11 +1036,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
10311036
// Currently, this is really only used to form complex/negative numeric
10321037
// constants in MatchValue and MatchMapping nodes
10331038
// We still recurse into all subexpressions and subpatterns anyway
1034-
if (++state->recursion_depth > state->recursion_limit) {
1035-
PyErr_SetString(PyExc_RecursionError,
1036-
"maximum recursion depth exceeded during compilation");
1037-
return 0;
1038-
}
1039+
ENTER_RECURSIVE(state);
10391040
switch (node_->kind) {
10401041
case MatchValue_kind:
10411042
CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
@@ -1067,7 +1068,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
10671068
// No default case, so the compiler will emit a warning if new pattern
10681069
// kinds are added without being handled here
10691070
}
1070-
state->recursion_depth--;
1071+
LEAVE_RECURSIVE(state);
10711072
return 1;
10721073
}
10731074

0 commit comments

Comments
 (0)