@@ -24,14 +24,27 @@ namespace {
24
24
25
25
struct ArchParams {
26
26
/* * Maximum level of parallelism avalaible. */
27
- int parallelism = 16 ;
27
+ int parallelism{} ;
28
28
29
29
/* * Size of the last-level cache (in bytes). */
30
- uint64_t last_level_cache_size = 16 * 1024 * 1024 ;
30
+ uint64_t last_level_cache_size{} ;
31
31
32
32
/* * Indicates how much more expensive is the cost of a load compared to
33
33
* the cost of an arithmetic operation at last level cache. */
34
- float balance = 40 ;
34
+ float balance{};
35
+
36
+ /* * If GPU target is detected, but machine parameters are not specified, *
37
+ * make a realistic estimate based on consumer-grade GPUs (Nvidia GTX *
38
+ * 1660/Turing), or low-cost scientific-grade GPUs (Nvidia K40/Tesla).
39
+ *
40
+ * Section 5.4 of the Mullapudi2016 article: We configure the auto-scheduler
41
+ * to target the GPU by setting the PARALLELISM_THRESHOLD to 128, ..., and
42
+ * CACHE_SIZE to 48 KB.
43
+ */
44
+ constexpr ArchParams (bool has_gpu_feature)
45
+ : parallelism(has_gpu_feature ? 128 : 16 ), last_level_cache_size(has_gpu_feature ? 48 * 1024 : 16 * 1024 * 1024 ),
46
+ balance(has_gpu_feature ? 20 : 40 ) {
47
+ }
35
48
};
36
49
37
50
// Substitute parameter estimates into the exprs describing the box bounds.
@@ -2823,6 +2836,10 @@ void Partitioner::vectorize_stage(const Group &g, Stage f_handle, int stage_num,
2823
2836
// values produced by the function.
2824
2837
const auto vec_len = [&]() -> int {
2825
2838
if (t.has_gpu_feature ()) {
2839
+ /* * Section 5.4 of the Mullapudi2016 article: We configure the
2840
+ * auto-scheduler to target the GPU by set- ting the ...,
2841
+ * VECTOR_WIDTH to 32.
2842
+ */
2826
2843
return GPUTilingDedup::min_n_threads;
2827
2844
}
2828
2845
@@ -3851,7 +3868,7 @@ struct Mullapudi2016 {
3851
3868
pipeline_outputs.push_back (f.function ());
3852
3869
}
3853
3870
3854
- ArchParams arch_params;
3871
+ ArchParams arch_params{target. has_gpu_feature ()} ;
3855
3872
{
3856
3873
ParamParser parser (params_in.extra );
3857
3874
parser.parse (" parallelism" , &arch_params.parallelism );
0 commit comments