forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_conv.cpp
129 lines (109 loc) · 4.07 KB
/
test_conv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <gtest/gtest.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
namespace te = torch::jit::tensorexpr;
namespace F = torch::nn::functional;
// Generate test data with few bits of precision, to minimize error
// accumulation from floating-point reordering.
static at::Tensor genTestData(c10::IntArrayRef args) {
return at::trunc(at::randn(args) * 256.0f) / 256.0f;
}
#ifdef TORCH_ENABLE_LLVM
TEST(Conv, DepthwiseConv2D) {
te::KernelScope kernel_scope;
constexpr int N = 1, C = 72, H = 56, W = 56;
constexpr int K = 72, R = 3, S = 3;
constexpr int kPad = 1, kStride = 2, kGroups = C;
constexpr int CperG = C / kGroups;
te::Placeholder input("input", te::kFloat, {N, C, H, W});
te::Placeholder weight("weight", te::kFloat, {K, CperG, R, S});
te::Placeholder bias("bias", te::kFloat, {K});
te::Tensor* output = te::conv2d_depthwise(
input.handle(), weight.handle(), bias.handle(), kStride, kPad, kGroups);
te::LoopNest loop({output});
loop.simplify();
loop.prepareForCodegen();
te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output});
auto it = genTestData({N, C, H, W});
auto wt = genTestData({K, CperG, R, S});
auto bt = genTestData({K});
auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups);
auto ot = at::zeros_like(ref);
cg.call(
{it.data_ptr<float>(),
wt.data_ptr<float>(),
bt.data_ptr<float>(),
ot.data_ptr<float>()});
ASSERT_TRUE(at::allclose(ref, ot));
}
#endif
TEST(Conv, Conv2D) {
te::KernelScope kernel_scope;
// Input dimensions.
constexpr int N = 1;
constexpr int C = 3;
constexpr int H = 11;
constexpr int W = 11;
// Filter dimensions.
constexpr int K = 8;
constexpr int R = 3;
constexpr int S = 3;
// Output dims.
constexpr int OH = H - R + 1;
constexpr int OW = W - S + 1;
// Compute reference result.
at::Tensor input = torch::randn({N, C, H, W});
at::Tensor filter = torch::randn({K, C, R, S});
at::Tensor ref = F::conv2d(input, filter);
// Double check the output size is as expected.
ASSERT_EQ(ref.size(0), N);
ASSERT_EQ(ref.size(1), K);
ASSERT_EQ(ref.size(2), OH);
ASSERT_EQ(ref.size(3), OW);
te::Placeholder inputB(te::BufHandle("input", {N, C, H, W}, te::kFloat));
te::Placeholder filterB(te::BufHandle("filter", {K, C, R, S}, te::kFloat));
te::Tensor* conv = te::Reduce(
"conv",
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
te::Sum(),
// FIXME: We have to use a `std::vector` parameter here and then unpack
// it, because we don't have an overload allowing for an arbitrary number
// of ExprHandle/VarHandle parameters.
[&](const std::vector<te::VarHandle>& v) {
auto const& n = v[0];
auto const& k = v[1];
auto const& oh = v[2];
auto const& ow = v[3];
auto const& c = v[4];
auto const& r = v[5];
auto const& s = v[6];
// FIXME: We have to use `call` and construct a `std::vector` here
// because the `operator()` overload is only specialized for a small
// number of arguments.
return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s);
},
// FIXME: If you forget one of the reduction dims, you get a segfault.
// Could that be caught by a verifier?
{{C, "c"}, {R, "r"}, {S, "s"}});
// FIXME: It'd be nice to have a single header that pulls in things like
// LoopNest, IRSimplifier, etc.
te::LoopNest loop({conv});
loop.prepareForCodegen();
te::Stmt* s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
at::Tensor result = at::empty_like(ref);
te::SimpleIREvaluator cg(s, {inputB, filterB, conv});
cg.call(
{input.data_ptr<float>(),
filter.data_ptr<float>(),
result.data_ptr<float>()});
ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3));
}
} // namespace jit
} // namespace torch