@@ -60,6 +60,23 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
60
60
} \
61
61
}
62
62
63
+ template <typename T>
64
+ static inline bool check_shape (
65
+ const std::vector<std::optional<int64_t >>& expected,
66
+ int ndim,
67
+ const T& actual_shape) {
68
+ if (expected.size () != static_cast <size_t >(ndim)) {
69
+ return false ;
70
+ }
71
+ for (size_t i = 0 ; i < expected.size (); ++i) {
72
+ if (!expected[i] || actual_shape[i] < 1 ) continue ;
73
+ if (actual_shape[i] != expected[i].value ()) {
74
+ return false ;
75
+ }
76
+ }
77
+ return true ;
78
+ }
79
+
63
80
static inline bool PyObject_Equal (PyObject* a, PyObject* b) {
64
81
if (a == b) {
65
82
return true ;
@@ -137,15 +154,7 @@ bool ShapeMatchGuard::check(PyObject* value) {
137
154
auto tensor = GetTensorFromPyObject (value);
138
155
HANDLE_NULL_TENSOR (tensor);
139
156
auto shape = tensor->shape ();
140
- if (shape.size () != expected_.size ()) {
141
- return false ;
142
- }
143
- for (size_t i = 0 ; i < shape.size (); ++i) {
144
- if (expected_[i] && shape[i] != *expected_[i]) {
145
- return false ;
146
- }
147
- }
148
- return true ;
157
+ return check_shape<std::vector<int64_t >>(expected_, shape.size (), shape);
149
158
}
150
159
151
160
bool AttributeMatchGuard::check (PyObject* value) {
@@ -199,16 +208,8 @@ bool NumPyArrayShapeMatchGuard::check(PyObject* value) {
199
208
return false ;
200
209
}
201
210
int ndim = array.ndim ();
202
- auto shape = array.shape ();
203
- if (ndim != static_cast <int >(expected_.size ())) {
204
- return false ;
205
- }
206
- for (int i = 0 ; i < ndim; ++i) {
207
- if (expected_[i].has_value () && shape[i] != expected_[i].value ()) {
208
- return false ;
209
- }
210
- }
211
- return true ;
211
+ const Py_ssize_t* shape = array.shape ();
212
+ return check_shape<const Py_ssize_t*>(expected_, ndim, shape);
212
213
}
213
214
214
215
bool WeakRefMatchGuard::check (PyObject* value) {
0 commit comments