Skip to content

Commit 38b87ed

Browse files
authored
[SOT][FasterGuard] ShapeMatchGuard support dynamic shape (#72564)
1 parent 28eae79 commit 38b87ed

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

paddle/fluid/pybind/sot/guards.cc

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
6060
} \
6161
}
6262

63+
template <typename T>
64+
static inline bool check_shape(
65+
const std::vector<std::optional<int64_t>>& expected,
66+
int ndim,
67+
const T& actual_shape) {
68+
if (expected.size() != static_cast<size_t>(ndim)) {
69+
return false;
70+
}
71+
for (size_t i = 0; i < expected.size(); ++i) {
72+
if (!expected[i] || actual_shape[i] < 1) continue;
73+
if (actual_shape[i] != expected[i].value()) {
74+
return false;
75+
}
76+
}
77+
return true;
78+
}
79+
6380
static inline bool PyObject_Equal(PyObject* a, PyObject* b) {
6481
if (a == b) {
6582
return true;
@@ -137,15 +154,7 @@ bool ShapeMatchGuard::check(PyObject* value) {
137154
auto tensor = GetTensorFromPyObject(value);
138155
HANDLE_NULL_TENSOR(tensor);
139156
auto shape = tensor->shape();
140-
if (shape.size() != expected_.size()) {
141-
return false;
142-
}
143-
for (size_t i = 0; i < shape.size(); ++i) {
144-
if (expected_[i] && shape[i] != *expected_[i]) {
145-
return false;
146-
}
147-
}
148-
return true;
157+
return check_shape<std::vector<int64_t>>(expected_, shape.size(), shape);
149158
}
150159

151160
bool AttributeMatchGuard::check(PyObject* value) {
@@ -199,16 +208,8 @@ bool NumPyArrayShapeMatchGuard::check(PyObject* value) {
199208
return false;
200209
}
201210
int ndim = array.ndim();
202-
auto shape = array.shape();
203-
if (ndim != static_cast<int>(expected_.size())) {
204-
return false;
205-
}
206-
for (int i = 0; i < ndim; ++i) {
207-
if (expected_[i].has_value() && shape[i] != expected_[i].value()) {
208-
return false;
209-
}
210-
}
211-
return true;
211+
const Py_ssize_t* shape = array.shape();
212+
return check_shape<const Py_ssize_t*>(expected_, ndim, shape);
212213
}
213214

214215
bool WeakRefMatchGuard::check(PyObject* value) {

paddle/fluid/pybind/sot/guards.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,6 @@ class DtypeMatchGuard : public GuardBase {
156156

157157
class ShapeMatchGuard : public GuardBase {
158158
public:
159-
explicit ShapeMatchGuard(const std::vector<std::optional<int64_t>>& shape)
160-
: expected_(shape) {}
161-
162159
explicit ShapeMatchGuard(const std::vector<py::object>& shape) {
163160
expected_.resize(shape.size());
164161
for (size_t i = 0; i < shape.size(); ++i) {
@@ -255,10 +252,6 @@ class NumPyArrayValueMatchGuard : public GuardBase {
255252

256253
class NumPyArrayShapeMatchGuard : public GuardBase {
257254
public:
258-
explicit NumPyArrayShapeMatchGuard(
259-
const std::vector<std::optional<int64_t>>& shape)
260-
: expected_(shape) {}
261-
262255
explicit NumPyArrayShapeMatchGuard(const std::vector<py::object>& shape) {
263256
expected_.resize(shape.size());
264257
for (size_t i = 0; i < shape.size(); ++i) {

0 commit comments

Comments
 (0)