-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathtransitive_in.cpp
More file actions
181 lines (157 loc) · 5.1 KB
/
transitive_in.cpp
File metadata and controls
181 lines (157 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include "Halide.h"
#include "check_call_graphs.h"
#include <cstdio>
using namespace Halide;
namespace {
// Build a small pipeline with anonymous intermediate Funcs, similar in shape
// to local_laplacian's pyramid: we want to call clone_in / in on a non-direct
// caller and have the wrapper be inserted along all paths from that caller
// down to the wrapped Func.
int transitive_clone_in_test() {
Var x("x"), y("y");
Func base("base");
base(x, y) = x + y;
// Two anonymous helpers that each directly call base.
Func helper_a, helper_b;
helper_a(x, y) = base(x, y) + 1;
helper_b(x, y) = base(x, y) * 2;
// top transitively calls base via the helpers, but does not directly.
Func top("top");
top(x, y) = helper_a(x, y) + helper_b(x, y);
// sibling also uses base directly but is *not* on the path from top.
Func sibling("sibling");
sibling(x, y) = base(x, y) - 1;
// Cloning base into top should expand to {helper_a, helper_b}, but must
// leave sibling's call to base untouched.
Func cloned = base.clone_in(top);
Func out("out");
out(x, y) = top(x, y) + sibling(x, y);
base.compute_root();
helper_a.compute_root();
helper_b.compute_root();
cloned.compute_root();
sibling.compute_root();
top.compute_root();
// First check: numerical correctness.
Pipeline p(out);
Buffer<int> result = p.realize({16, 16});
auto check = [](int xv, int yv) {
int b = xv + yv;
int ha = b + 1;
int hb = b * 2;
int t = ha + hb;
int s = b - 1;
return t + s;
};
if (check_image2(result, check) != 0) {
return 1;
}
// Second check: helper_a and helper_b should load from the clone, not
// base; sibling should still load from base.
CheckCalls *checker = new CheckCalls;
Pipeline p2(out);
p2.add_custom_lowering_pass(checker);
p2.compile_to_module(p2.infer_arguments(), "");
const auto &calls = checker->calls;
auto loads_from = [&](const std::string &producer, const std::string &callee) {
auto it = calls.find(producer);
if (it == calls.end()) {
printf("Producer %s not found\n", producer.c_str());
return false;
}
for (const std::string &c : it->second) {
if (c == callee) return true;
}
return false;
};
if (loads_from(helper_a.name(), base.name())) {
printf("helper_a should not directly call base after clone_in\n");
return 1;
}
if (loads_from(helper_b.name(), base.name())) {
printf("helper_b should not directly call base after clone_in\n");
return 1;
}
if (!loads_from(helper_a.name(), cloned.name())) {
printf("helper_a should call the clone\n");
return 1;
}
if (!loads_from(helper_b.name(), cloned.name())) {
printf("helper_b should call the clone\n");
return 1;
}
if (!loads_from(sibling.name(), base.name())) {
printf("sibling should still call base\n");
return 1;
}
return 0;
}
// Direct callers passed to clone_in should still work (no expansion needed).
int direct_clone_in_still_works_test() {
Var x("x"), y("y");
Func f("f"), g("g");
f(x, y) = x + y;
g(x, y) = f(x, y) + 7;
Func cloned = f.clone_in(g);
f.compute_root();
cloned.compute_root();
Buffer<int> r = g.realize({8, 8});
return check_image2(r, [](int xv, int yv) { return xv + yv + 7; });
}
// in() is also transitive.
int transitive_in_test() {
Var x("x"), y("y");
Func base("base");
base(x, y) = x + y;
Func mid;
mid(x, y) = base(x, y) + 3;
Func top("top");
top(x, y) = mid(x, y) * 2;
// base.in(top) should resolve to base.in(mid).
Func wrapper = base.in(top);
base.compute_root();
mid.compute_root();
wrapper.compute_root();
top.compute_root();
Buffer<int> r = top.realize({8, 8});
if (check_image2(r, [](int xv, int yv) { return (xv + yv + 3) * 2; }) != 0) {
return 1;
}
CheckCalls *checker = new CheckCalls;
Pipeline p(top);
p.add_custom_lowering_pass(checker);
p.compile_to_module(p.infer_arguments(), "");
const auto &calls = checker->calls;
auto it = calls.find(mid.name());
if (it == calls.end()) {
printf("mid not found in call graph\n");
return 1;
}
for (const auto &c : it->second) {
if (c == base.name()) {
printf("mid should not directly call base after in()\n");
return 1;
}
}
return 0;
}
} // namespace
int main(int argc, char **argv) {
printf("Running transitive_clone_in_test\n");
if (transitive_clone_in_test() != 0) {
printf("transitive_clone_in_test failed\n");
return 1;
}
printf("Running direct_clone_in_still_works_test\n");
if (direct_clone_in_still_works_test() != 0) {
printf("direct_clone_in_still_works_test failed\n");
return 1;
}
printf("Running transitive_in_test\n");
if (transitive_in_test() != 0) {
printf("transitive_in_test failed\n");
return 1;
}
printf("Success!\n");
return 0;
}