Skip to content

Commit cba43e4

Browse files
[Unittest] Add NCNN tensorslice unittest and fix tensorslice.cpp bugs. (open-mmlab#115)
* add tensorslice unittest * reply code review * fix lint * fix typo
1 parent f56a300 commit cba43e4

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

backend_ops/ncnn/ops/tensorslice/tensorslice.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,17 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
5353
size_t elemsize = bottom_blob.elemsize;
5454
const int* start_ptr = starts;
5555
const int* end_ptr = ends;
56-
const float* axes_ptr = axes;
56+
const int* axes_ptr = axes;
5757
const int* step_ptr = steps;
5858
if (starts.w > dims || ends.w > dims) {
5959
fprintf(stderr, "start/end attributes shape error!\n");
6060
return -100;
6161
}
62+
if (axes.w != 1) {
63+
fprintf(stderr,
64+
"axes.w must be 1 because any of multiaxes slice is regarded as "
65+
"multi-staged onnx slice in pytorch2onnx.");
66+
}
6267
if (dims == 1) {
6368
for (int i = 0; i < axes.w; i++) {
6469
int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i];
@@ -106,6 +111,8 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
106111
int start = start_ptr[i];
107112
int end = end_ptr[i];
108113
int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims);
114+
int dim_shape_test =
115+
get_shape_by_axes(bottom_blob, positive_axis, dims - 1);
109116
if (dim_shape < 0) {
110117
return -1;
111118
}
@@ -127,6 +134,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
127134
return -100;
128135
}
129136
active_indice[positive_axis - 1] = temp_indice;
137+
active_indice[positive_axis - 1].resize(temp_indice.size());
130138
}
131139
top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(),
132140
elemsize, opt.blob_allocator);
@@ -138,6 +146,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
138146
}
139147
return 0;
140148
}
149+
141150
if (dims == 3) {
142151
std::vector<std::vector<int> > active_indice;
143152
std::vector<int> indices;
@@ -177,7 +186,8 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
177186
fprintf(stderr, "step should not be 0!\n");
178187
return -100;
179188
}
180-
active_indice[positive_axis] = temp_indice;
189+
active_indice[positive_axis - 1] = temp_indice;
190+
active_indice[positive_axis - 1].resize(temp_indice.size());
181191
}
182192
top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(),
183193
(int)active_indice[0].size(), elemsize, opt.blob_allocator);
@@ -192,6 +202,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
192202
}
193203
return 0;
194204
}
205+
195206
return 0;
196207
}
197208

tests/test_ops/test_ops.py

+35
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,38 @@ def test_constantofshape(backend,
545545
ncnn_outputs = ncnn_model(dict(zip(input_names, [input.float()])))
546546
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
547547
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)
548+
549+
550+
@pytest.mark.parametrize('backend', [TEST_NCNN])
551+
@pytest.mark.parametrize('dim', [1, 2, 3])
552+
def test_tensorslice(backend, dim, input_list=None, save_dir=None):
553+
backend.check_env()
554+
555+
if input_list is None:
556+
input = torch.rand((8, 12, 17)[-dim:]).unsqueeze(0)
557+
else:
558+
input = input_list[0]
559+
assert input.dim() == dim + 1, f'input.dim() must equal to \
560+
dim + 1, expected: {dim + 1}, got: {input.dim()}'
561+
562+
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
563+
but got {input.shape[0]}')
564+
cfg = dict()
565+
register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11)
566+
567+
def tensorslice_function(inputs):
568+
if dim == 1:
569+
return inputs[:, 2:17:7]
570+
if dim == 2:
571+
return inputs[:, 3:12:4, 2:15:3]
572+
if dim == 3:
573+
return inputs[:, 0:8:2, 2:12:4, 2:17:7]
574+
575+
wrapped_model = WrapFunction(tensorslice_function)
576+
577+
backend.run_and_validate(
578+
wrapped_model, [input.float()],
579+
'tensorslice',
580+
input_names=['inputs'],
581+
output_names=['outputs'],
582+
save_dir=save_dir)

0 commit comments

Comments
 (0)