@@ -945,10 +945,10 @@ class GPUTileHelper {
945
945
return ;
946
946
}
947
947
948
- internal_assert (vars.size () <= 3 );
949
-
950
948
std::stringstream oss;
951
949
switch (vars.size ()) {
950
+ case 0 :
951
+ return ;
952
952
case 1 : {
953
953
const auto &[v, outer, inner, factor, strategy] = vars.front ();
954
954
f.split (v, outer, inner, factor, strategy);
@@ -996,7 +996,7 @@ class GPUTileHelper {
996
996
997
997
break ;
998
998
}
999
- case 3 : {
999
+ default : {
1000
1000
const auto &x = vars[0 ];
1001
1001
const auto &y = vars[1 ];
1002
1002
const auto &z = vars[2 ];
@@ -1086,6 +1086,7 @@ class GPUTilingDedup {
1086
1086
bool is_initial_order = true ;
1087
1087
std::vector<VarOrRVar> ordering;
1088
1088
1089
+ std::set<std::string> is_split;
1089
1090
std::set<std::string> outer_vars;
1090
1091
std::set<std::string> inner_vars;
1091
1092
@@ -1114,12 +1115,17 @@ class GPUTilingDedup {
1114
1115
continue ;
1115
1116
}
1116
1117
1118
+ if (is_compute_at) {
1119
+ continue ;
1120
+ }
1121
+
1117
1122
// Skip all gpu_blocks if the current Stage is "compute_at" another
1118
1123
// stage, in which the gpu_blocks are already specified.
1119
- if (!is_compute_at && is_outer (v_name)) {
1124
+ if (is_outer (v_name)) {
1120
1125
// Mark as gpu blocks;
1121
1126
f.gpu_blocks (v);
1122
1127
sched.push_schedule (f.name (), stage_num, " gpu_blocks(" + v_name + " )" , {v_name});
1128
+ continue ;
1123
1129
}
1124
1130
}
1125
1131
}
@@ -1193,6 +1199,7 @@ class GPUTilingDedup {
1193
1199
*/
1194
1200
void has_split (const VarOrRVar &v, const VarOrRVar &vo, const VarOrRVar &vi, const Expr &factor, TailStrategy strategy) {
1195
1201
debug (2 ) << f.name () << " .split(" << v.name () << " ," << factor << " )\n " ;
1202
+ is_split.emplace (v.name ());
1196
1203
outer_vars.emplace (vo.name ());
1197
1204
inner_vars.emplace (vi.name ());
1198
1205
@@ -1247,7 +1254,7 @@ class GPUTilingDedup {
1247
1254
sched.push_schedule (f.name (), stage_num, oss.str (), var_list);
1248
1255
}
1249
1256
1250
- const bool is_already_split = (inner_vars. size () + outer_vars. size () > 0 );
1257
+ const bool is_already_split = (!is_split. empty () );
1251
1258
if (is_already_split) {
1252
1259
// If the Mullapudi's auto-splitting algorithm already computes the
1253
1260
// tile size, we simply mark the inner dims as gpu_threads();
@@ -1634,7 +1641,7 @@ struct Partitioner {
1634
1641
1635
1642
// Loop over the dimensions of function stage 'f_handle' starting from innermost
1636
1643
// and vectorize the first pure dimension encountered.
1637
- void vectorize_stage (
1644
+ std::optional<pair<VarOrRVar, VarOrRVar>> vectorize_stage (
1638
1645
const Group &g, Stage f_handle, int stage_num, Definition def,
1639
1646
const Function &func, bool is_group_output, const Target &t, set<string> &rvars,
1640
1647
map<string, Expr> &estimates, AutoSchedule &sched, GPUTilingDedup &gpu_tiling);
@@ -2824,11 +2831,11 @@ pair<VarOrRVar, VarOrRVar> Partitioner::split_dim(
2824
2831
return make_pair (inner, outer);
2825
2832
}
2826
2833
2827
- void Partitioner::vectorize_stage (const Group &g, Stage f_handle, int stage_num,
2828
- Definition def, const Function &func, bool is_group_output,
2829
- const Target &t, set<string> &rvars,
2830
- map<string, Expr> &estimates, AutoSchedule &sched,
2831
- GPUTilingDedup &gpu_tiling) {
2834
+ std::optional<pair<VarOrRVar, VarOrRVar>> Partitioner::vectorize_stage (const Group &g, Stage f_handle, int stage_num,
2835
+ Definition def, const Function &func, bool is_group_output,
2836
+ const Target &t, set<string> &rvars,
2837
+ map<string, Expr> &estimates, AutoSchedule &sched,
2838
+ GPUTilingDedup &gpu_tiling) {
2832
2839
vector<Dim> &dims = def.schedule ().dims ();
2833
2840
int vec_dim_index = -1 ;
2834
2841
@@ -2902,7 +2909,11 @@ void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
2902
2909
debug (1 ) << " Outer dim vectorization of var \" " << vec_dim_name
2903
2910
<< " \" in function \" " << f_handle.name () << " \"\n " ;
2904
2911
}
2912
+
2913
+ return make_pair (inner, outer);
2905
2914
}
2915
+
2916
+ return std::nullopt;
2906
2917
}
2907
2918
2908
2919
// Return true if the vars/rvars in 'ordering' are in the same order as the
@@ -3184,8 +3195,16 @@ void Partitioner::generate_group_cpu_schedule(
3184
3195
}
3185
3196
}
3186
3197
3187
- vectorize_stage (g, f_handle, g.output .stage_num , def, g_out, true , t,
3188
- rvars, stg_estimates, sched, gpu_tiling);
3198
+ {
3199
+ auto vectorized_split = vectorize_stage (g, f_handle, g.output .stage_num , def, g_out, true , t,
3200
+ rvars, stg_estimates, sched, gpu_tiling);
3201
+
3202
+ if (t.has_gpu_feature () && vectorized_split) {
3203
+ auto [v_i, v_o] = *vectorized_split;
3204
+ inner_dims.emplace_back (std::move (v_i));
3205
+ outer_dims.emplace_back (std::move (v_o));
3206
+ }
3207
+ }
3189
3208
3190
3209
// Parallelize definition
3191
3210
Expr def_par = 1 ;
@@ -3296,14 +3315,15 @@ void Partitioner::generate_group_cpu_schedule(
3296
3315
mem_handle = Func (mem.func ).update (mem.stage_num - 1 );
3297
3316
} else {
3298
3317
if (!outer_dims.empty ()) {
3318
+ string sanitized_g_out = get_sanitized_name (g_out.name ());
3299
3319
if (tile_inner_var.is_rvar ) {
3300
3320
Func (mem.func ).compute_at (Func (g_out), tile_inner_var.rvar );
3321
+ debug (2 ) << mem_handle.name () << " .compute_at(" << sanitized_g_out << " , " << tile_inner_var.rvar << " )\n " ;
3301
3322
} else {
3302
3323
Func (mem.func ).compute_at (Func (g_out), tile_inner_var.var );
3324
+ debug (2 ) << mem_handle.name () << " .compute_at(" << sanitized_g_out << " , " << tile_inner_var.var << " )\n " ;
3303
3325
}
3304
3326
3305
- string sanitized_g_out = get_sanitized_name (g_out.name ());
3306
- debug (2 ) << mem_handle.name () << " .compute_at(" << sanitized_g_out << " )\n " ;
3307
3327
sched.push_schedule (mem_handle.name (), mem.stage_num ,
3308
3328
" compute_at(" + sanitized_g_out + " , " + tile_inner_var.name () + " )" ,
3309
3329
{sanitized_g_out, tile_inner_var.name ()});
0 commit comments