Skip to content

Commit 75689cf

Browse files
committed
Indicate vectorize() outer dimensions
When the outer dimension "x_vo" is assigned as gpu_block(), inform the main Grouping algorithm of such an outer dimension. Avoid compute_at "y_o" the outer gpu_block() level. Use the inner gpu_block() level. When 4 or more GPU block levels are requested, schedule only the innermost levels.
1 parent e5d29e9 commit 75689cf

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

src/autoschedulers/mullapudi2016/AutoSchedule.cpp

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -945,10 +945,10 @@ class GPUTileHelper {
945945
return;
946946
}
947947

948-
internal_assert(vars.size() <= 3);
949-
950948
std::stringstream oss;
951949
switch (vars.size()) {
950+
case 0:
951+
return;
952952
case 1: {
953953
const auto &[v, outer, inner, factor, strategy] = vars.front();
954954
f.split(v, outer, inner, factor, strategy);
@@ -996,7 +996,7 @@ class GPUTileHelper {
996996

997997
break;
998998
}
999-
case 3: {
999+
default: {
10001000
const auto &x = vars[0];
10011001
const auto &y = vars[1];
10021002
const auto &z = vars[2];
@@ -1086,6 +1086,7 @@ class GPUTilingDedup {
10861086
bool is_initial_order = true;
10871087
std::vector<VarOrRVar> ordering;
10881088

1089+
std::set<std::string> is_split;
10891090
std::set<std::string> outer_vars;
10901091
std::set<std::string> inner_vars;
10911092

@@ -1114,12 +1115,17 @@ class GPUTilingDedup {
11141115
continue;
11151116
}
11161117

1118+
if (is_compute_at) {
1119+
continue;
1120+
}
1121+
11171122
// Skip all gpu_blocks if the current Stage is "compute_at" another
11181123
// stage, in which the gpu_blocks are already specified.
1119-
if (!is_compute_at && is_outer(v_name)) {
1124+
if (is_outer(v_name)) {
11201125
// Mark as gpu blocks;
11211126
f.gpu_blocks(v);
11221127
sched.push_schedule(f.name(), stage_num, "gpu_blocks(" + v_name + ")", {v_name});
1128+
continue;
11231129
}
11241130
}
11251131
}
@@ -1193,6 +1199,7 @@ class GPUTilingDedup {
11931199
*/
11941200
void has_split(const VarOrRVar &v, const VarOrRVar &vo, const VarOrRVar &vi, const Expr &factor, TailStrategy strategy) {
11951201
debug(2) << f.name() << ".split(" << v.name() << "," << factor << ")\n";
1202+
is_split.emplace(v.name());
11961203
outer_vars.emplace(vo.name());
11971204
inner_vars.emplace(vi.name());
11981205

@@ -1247,7 +1254,7 @@ class GPUTilingDedup {
12471254
sched.push_schedule(f.name(), stage_num, oss.str(), var_list);
12481255
}
12491256

1250-
const bool is_already_split = (inner_vars.size() + outer_vars.size() > 0);
1257+
const bool is_already_split = (!is_split.empty());
12511258
if (is_already_split) {
12521259
// If the Mullapudi's auto-splitting algorithm already computes the
12531260
// tile size, we simply mark the inner dims as gpu_threads();
@@ -1634,7 +1641,7 @@ struct Partitioner {
16341641

16351642
// Loop over the dimensions of function stage 'f_handle' starting from innermost
16361643
// and vectorize the first pure dimension encountered.
1637-
void vectorize_stage(
1644+
std::optional<pair<VarOrRVar, VarOrRVar>> vectorize_stage(
16381645
const Group &g, Stage f_handle, int stage_num, Definition def,
16391646
const Function &func, bool is_group_output, const Target &t, set<string> &rvars,
16401647
map<string, Expr> &estimates, AutoSchedule &sched, GPUTilingDedup &gpu_tiling);
@@ -2824,11 +2831,11 @@ pair<VarOrRVar, VarOrRVar> Partitioner::split_dim(
28242831
return make_pair(inner, outer);
28252832
}
28262833

2827-
void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
2828-
Definition def, const Function &func, bool is_group_output,
2829-
const Target &t, set<string> &rvars,
2830-
map<string, Expr> &estimates, AutoSchedule &sched,
2831-
GPUTilingDedup &gpu_tiling) {
2834+
std::optional<pair<VarOrRVar, VarOrRVar>> Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
2835+
Definition def, const Function &func, bool is_group_output,
2836+
const Target &t, set<string> &rvars,
2837+
map<string, Expr> &estimates, AutoSchedule &sched,
2838+
GPUTilingDedup &gpu_tiling) {
28322839
vector<Dim> &dims = def.schedule().dims();
28332840
int vec_dim_index = -1;
28342841

@@ -2902,7 +2909,11 @@ void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
29022909
debug(1) << "Outer dim vectorization of var \"" << vec_dim_name
29032910
<< "\" in function \"" << f_handle.name() << "\"\n";
29042911
}
2912+
2913+
return make_pair(inner, outer);
29052914
}
2915+
2916+
return std::nullopt;
29062917
}
29072918

29082919
// Return true if the vars/rvars in 'ordering' are in the same order as the
@@ -3184,8 +3195,16 @@ void Partitioner::generate_group_cpu_schedule(
31843195
}
31853196
}
31863197

3187-
vectorize_stage(g, f_handle, g.output.stage_num, def, g_out, true, t,
3188-
rvars, stg_estimates, sched, gpu_tiling);
3198+
{
3199+
auto vectorized_split = vectorize_stage(g, f_handle, g.output.stage_num, def, g_out, true, t,
3200+
rvars, stg_estimates, sched, gpu_tiling);
3201+
3202+
if (t.has_gpu_feature() && vectorized_split) {
3203+
auto [v_i, v_o] = *vectorized_split;
3204+
inner_dims.emplace_back(std::move(v_i));
3205+
outer_dims.emplace_back(std::move(v_o));
3206+
}
3207+
}
31893208

31903209
// Parallelize definition
31913210
Expr def_par = 1;
@@ -3296,14 +3315,15 @@ void Partitioner::generate_group_cpu_schedule(
32963315
mem_handle = Func(mem.func).update(mem.stage_num - 1);
32973316
} else {
32983317
if (!outer_dims.empty()) {
3318+
string sanitized_g_out = get_sanitized_name(g_out.name());
32993319
if (tile_inner_var.is_rvar) {
33003320
Func(mem.func).compute_at(Func(g_out), tile_inner_var.rvar);
3321+
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ", " << tile_inner_var.rvar << ")\n";
33013322
} else {
33023323
Func(mem.func).compute_at(Func(g_out), tile_inner_var.var);
3324+
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ", " << tile_inner_var.var << ")\n";
33033325
}
33043326

3305-
string sanitized_g_out = get_sanitized_name(g_out.name());
3306-
debug(2) << mem_handle.name() << ".compute_at(" << sanitized_g_out << ")\n";
33073327
sched.push_schedule(mem_handle.name(), mem.stage_num,
33083328
"compute_at(" + sanitized_g_out + ", " + tile_inner_var.name() + ")",
33093329
{sanitized_g_out, tile_inner_var.name()});

0 commit comments

Comments
 (0)