Skip to content

[SOT][FasterGuard] add ExprNodeBase cleanup func #72552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
} \
}

// ExprNodeBase delayed cleaning
#define DELAYED_CLEAN(clean_py_obj_, value) \
{ \
std::set<PyObject*>::iterator iter; \
if ((iter = clean_py_obj_.find(value)) != clean_py_obj_.end()) { \
Py_DECREF(value); \
} \
}

static inline bool PyObject_Equal(PyObject* a, PyObject* b) {
if (a == b) {
return true;
Expand Down Expand Up @@ -252,25 +261,37 @@ std::string GlobalVarExprNode::stringify(int indent) {

PyObject* AttributeExprNode::eval(FrameProxy* frame) {
PyObject* var = var_expr_->eval(frame);
return PyObject_GetAttrString(var, attr_name_.c_str());
auto res = PyObject_GetAttrString(var, attr_name_.c_str());
if (res != NULL) clean_py_obj_.insert(res);
return res;
}
std::string AttributeExprNode::stringify(int indent) {
std::stringstream ss;
ss << var_expr_->stringify() << "." << attr_name_;
return ss.str();
}

void AttributeExprNode::cleanup(PyObject* value) {
DELAYED_CLEAN(clean_py_obj_, value);
}

PyObject* ItemExprNode::eval(FrameProxy* frame) {
PyObject* var = var_expr_->eval(frame);
PyObject* key = key_expr_->eval(frame);
return PyObject_GetItem(var, key);
auto res = PyObject_GetItem(var, key);
if (res != NULL) clean_py_obj_.insert(res);
return res;
}
std::string ItemExprNode::stringify(int indent) {
std::stringstream ss;
ss << var_expr_->stringify() << "[" << key_expr_->stringify() << "]";
return ss.str();
}

void ItemExprNode::cleanup(PyObject* value) {
DELAYED_CLEAN(clean_py_obj_, value);
}

PyObject* BinaryExprNode::eval(FrameProxy* frame) {
PyObject* lhs = lhs_->eval(frame);
PyObject* rhs = rhs_->eval(frame);
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class ExprNodeBase : public GuardTreeNodeBase,
public:
virtual PyObject* eval(FrameProxy* frame) = 0;
virtual ~ExprNodeBase() = default;
virtual void cleanup(PyObject* value) {}
};

class ConstantExprNode : public ExprNodeBase {
Expand Down Expand Up @@ -345,10 +346,12 @@ class AttributeExprNode : public ExprNodeBase {

PyObject* eval(FrameProxy* frame) override;
std::string stringify(int indent = 0) override;
void cleanup(PyObject* value) override;

private:
std::shared_ptr<ExprNodeBase> var_expr_;
std::string attr_name_;
std::set<PyObject*> clean_py_obj_;
};

class ItemExprNode : public ExprNodeBase {
Expand All @@ -359,10 +362,12 @@ class ItemExprNode : public ExprNodeBase {

PyObject* eval(FrameProxy* frame) override;
std::string stringify(int indent = 0) override;
void cleanup(PyObject* value) override;

private:
std::shared_ptr<ExprNodeBase> var_expr_;
std::shared_ptr<ExprNodeBase> key_expr_;
std::set<PyObject*> clean_py_obj_;
};

class BinaryExprNode : public ExprNodeBase {
Expand Down Expand Up @@ -544,6 +549,7 @@ class CheckGuardNode : public GuardNodeBase {
ret = lookup_next(frame);
}
for (size_t i = 0; i < N; ++i) {
exprs[i]->cleanup(values[i]);
if (values[i]) {
Py_DECREF(values[i]);
}
Expand Down
Loading