Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion apps/local_laplacian/local_laplacian_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ class LocalLaplacian : public Halide::Generator<LocalLaplacian> {
.compute_root()
.reorder_storage(x, k, y)
.reorder(k, y)
.parallel(y, 8)
.split(y, yo, y, 8)
.parallel(yo)
.vectorize(x, 8);
outGPyramid[j]
.store_at(output, yo)
Expand All @@ -180,6 +181,12 @@ class LocalLaplacian : public Halide::Generator<LocalLaplacian> {
gPyramid[j].never_partition_all();
}
}
gPyramid[0]
.clone_in(gPyramid[1])
.store_at(gPyramid[1], yo)
.compute_at(gPyramid[1], y)
.vectorize(x, 8);

outGPyramid[0]
.compute_at(output, y)
.hoist_storage(output, yo)
Expand Down
56 changes: 55 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "CodeGen_LLVM.h"
#include "Debug.h"
#include "ExprUsesVar.h"
#include "FindCalls.h"
#include "Func.h"
#include "Function.h"
#include "IR.h"
Expand Down Expand Up @@ -2176,7 +2177,60 @@ Func create_clone_wrapper(Function wrapped_fn, const string &wrapper_name) {
return wrapper;
}

Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector<Func> &fs, bool clone) {
// Walk down the call graph from 'start'. Whenever we find a Func that directly
// calls 'target', record it and stop descending that branch — we don't want to
// pick up unrelated direct callers that happen to live deeper in the subtree.
void collect_direct_callers_of(const Function &target,
const Function &start,
std::set<std::string> &visited,
std::map<std::string, Function> &result) {
if (start.name() == target.name()) {
return;
}
if (!visited.insert(start.name()).second) {
return;
}
std::map<std::string, Function> direct = find_direct_calls(start);
if (direct.count(target.name())) {
result.emplace(start.name(), start);
return;
}
for (const auto &kv : direct) {
collect_direct_callers_of(target, kv.second, visited, result);
}
}

// Expand a user-supplied list of caller Funcs to the set of *direct* callers of
// 'target' that lie on a path from any of those callers down to 'target'.
// Funcs that already directly call 'target' pass through unchanged. If a Func
// has no static path to 'target' at all, leave it alone: the IR may not yet
// reflect a wrapper rewrite from a previous in()/clone_in(), and the existing
// in()/clone_in() semantics permit registering a wrapper for such Funcs.
vector<Func> resolve_transitive_callers(const Function &target, const vector<Func> &fs) {
vector<Func> out;
std::set<std::string> emitted;
auto emit = [&](const Function &g) {
if (emitted.insert(g.name()).second) {
out.emplace_back(g);
}
};
for (const Func &f : fs) {
std::map<std::string, Function> direct_callers;
std::set<std::string> visited;
collect_direct_callers_of(target, f.function(), visited, direct_callers);
if (direct_callers.empty()) {
emit(f.function());
} else {
for (const auto &kv : direct_callers) {
emit(kv.second);
}
}
}
return out;
}

Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector<Func> &fs_in, bool clone) {
vector<Func> fs = fs_in.empty() ? fs_in : resolve_transitive_callers(wrapped_fn, fs_in);
// Either all Funcs in 'fs' have the same wrapper or they don't already
// have any wrappers. Otherwise, throw an error. If 'fs' is empty, then
// it is a global wrapper.
Expand Down
9 changes: 9 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,12 @@ class Func {
for x:
g(x, y) = f(x, y)
\endcode
* If a Func passed to in() does not directly call this Func, in() acts
* transitively: the Func graph is searched downward from each argument,
* and every direct caller of this Func found along the way is wrapped.
* This is useful when intermediate Funcs are anonymous and not held by
* the user (e.g. a pyramid built via helper functions).
*
* using Func::in(), we can write:
\code
f(x, y) = x + y;
Expand Down Expand Up @@ -1398,6 +1404,9 @@ class Func {
h(x, y) = f(x, y) - 3;
\endcode
*
* As with Func::in(), clone_in() acts transitively: any Func in 'f'/'fs'
* that does not directly call this Func is replaced by the set of direct
* callers reachable from it along paths that lead to this Func.
*/
//@{
Func clone_in(const Func &f);
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ tests(GROUPS correctness multithreaded
ring_buffer.cpp
stream_compaction.cpp
thread_safety.cpp
transitive_in.cpp
truncated_pyramid.cpp
tuple_vector_reduce.cpp
vector_cast.cpp
Expand Down
181 changes: 181 additions & 0 deletions test/correctness/transitive_in.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#include "Halide.h"
#include "check_call_graphs.h"

#include <cstdio>

using namespace Halide;

namespace {

// Build a small pipeline with anonymous intermediate Funcs, similar in shape
// to local_laplacian's pyramid: we want to call clone_in / in on a non-direct
// caller and have the wrapper be inserted along all paths from that caller
// down to the wrapped Func.
int transitive_clone_in_test() {
Var x("x"), y("y");

Func base("base");
base(x, y) = x + y;

// Two anonymous helpers that each directly call base.
Func helper_a, helper_b;
helper_a(x, y) = base(x, y) + 1;
helper_b(x, y) = base(x, y) * 2;

// top transitively calls base via the helpers, but does not directly.
Func top("top");
top(x, y) = helper_a(x, y) + helper_b(x, y);

// sibling also uses base directly but is *not* on the path from top.
Func sibling("sibling");
sibling(x, y) = base(x, y) - 1;

// Cloning base into top should expand to {helper_a, helper_b}, but must
// leave sibling's call to base untouched.
Func cloned = base.clone_in(top);

Func out("out");
out(x, y) = top(x, y) + sibling(x, y);

base.compute_root();
helper_a.compute_root();
helper_b.compute_root();
cloned.compute_root();
sibling.compute_root();
top.compute_root();

// First check: numerical correctness.
Pipeline p(out);
Buffer<int> result = p.realize({16, 16});
auto check = [](int xv, int yv) {
int b = xv + yv;
int ha = b + 1;
int hb = b * 2;
int t = ha + hb;
int s = b - 1;
return t + s;
};
if (check_image2(result, check) != 0) {
return 1;
}

// Second check: helper_a and helper_b should load from the clone, not
// base; sibling should still load from base.
CheckCalls *checker = new CheckCalls;
Pipeline p2(out);
p2.add_custom_lowering_pass(checker);
p2.compile_to_module(p2.infer_arguments(), "");
const auto &calls = checker->calls;

auto loads_from = [&](const std::string &producer, const std::string &callee) {
auto it = calls.find(producer);
if (it == calls.end()) {
printf("Producer %s not found\n", producer.c_str());
return false;
}
for (const std::string &c : it->second) {
if (c == callee) return true;
}
return false;
};

if (loads_from(helper_a.name(), base.name())) {
printf("helper_a should not directly call base after clone_in\n");
return 1;
}
if (loads_from(helper_b.name(), base.name())) {
printf("helper_b should not directly call base after clone_in\n");
return 1;
}
if (!loads_from(helper_a.name(), cloned.name())) {
printf("helper_a should call the clone\n");
return 1;
}
if (!loads_from(helper_b.name(), cloned.name())) {
printf("helper_b should call the clone\n");
return 1;
}
if (!loads_from(sibling.name(), base.name())) {
printf("sibling should still call base\n");
return 1;
}

return 0;
}

// Direct callers passed to clone_in should still work (no expansion needed).
int direct_clone_in_still_works_test() {
Var x("x"), y("y");
Func f("f"), g("g");
f(x, y) = x + y;
g(x, y) = f(x, y) + 7;
Func cloned = f.clone_in(g);
f.compute_root();
cloned.compute_root();
Buffer<int> r = g.realize({8, 8});
return check_image2(r, [](int xv, int yv) { return xv + yv + 7; });
}

// in() is also transitive.
int transitive_in_test() {
Var x("x"), y("y");
Func base("base");
base(x, y) = x + y;
Func mid;
mid(x, y) = base(x, y) + 3;
Func top("top");
top(x, y) = mid(x, y) * 2;

// base.in(top) should resolve to base.in(mid).
Func wrapper = base.in(top);

base.compute_root();
mid.compute_root();
wrapper.compute_root();
top.compute_root();

Buffer<int> r = top.realize({8, 8});
if (check_image2(r, [](int xv, int yv) { return (xv + yv + 3) * 2; }) != 0) {
return 1;
}

CheckCalls *checker = new CheckCalls;
Pipeline p(top);
p.add_custom_lowering_pass(checker);
p.compile_to_module(p.infer_arguments(), "");
const auto &calls = checker->calls;
auto it = calls.find(mid.name());
if (it == calls.end()) {
printf("mid not found in call graph\n");
return 1;
}
for (const auto &c : it->second) {
if (c == base.name()) {
printf("mid should not directly call base after in()\n");
return 1;
}
}
return 0;
}

} // namespace

int main(int argc, char **argv) {
printf("Running transitive_clone_in_test\n");
if (transitive_clone_in_test() != 0) {
printf("transitive_clone_in_test failed\n");
return 1;
}
printf("Running direct_clone_in_still_works_test\n");
if (direct_clone_in_still_works_test() != 0) {
printf("direct_clone_in_still_works_test failed\n");
return 1;
}
printf("Running transitive_in_test\n");
if (transitive_in_test() != 0) {
printf("transitive_in_test failed\n");
return 1;
}
printf("Success!\n");
return 0;
}
Loading