@@ -110,8 +110,7 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace(
110110 should_optimize_for_occupancy_(ShouldOptimizeForOccupancy()),
111111 min_out_tile_(GetMinOutputTile()),
112112 min_warps_per_cta_(GetMinWarpsPerCta()),
113- min_contracting_tile_size_(GetMinContractingTileSize()),
114- max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) {
113+ min_contracting_tile_size_(GetMinContractingTileSize()) {
115114 // Make sure that the range of output tile sizes is not empty
116115 // (min_output_tile_ is a hard limit, while max_output_tile_ is a soft one).
117116 max_out_tile_.lhs_dim =
@@ -125,29 +124,9 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace(
125124}
126125
127126std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::GenerateConfigs (
128- std::optional<int64_t > force_contracting_split,
129127 bool autotune_warp_specialization) const {
130- std::vector<ConfigWithNotes> configs;
131- if (force_contracting_split.has_value ()) {
132- ConfigWithNotes config;
133- const int split = force_contracting_split.value ();
134- config.config .split_k = split;
135- // It is possible that the user manually forced a huge contracting split
136- // that is outside of the search space. In that case, we would end up
137- // discarding all configs, and use the smallest possible tile size further
138- // down, which is likely not what the user had in mind.
139- config.keep_large_split = GetMaxContractingSplit (max_out_tile_) < split;
140- VLOG (5 ) << " Forcing split_k, config = " << config.ToString ();
141- if (config.keep_large_split ) {
142- LOG (WARNING)
143- << " split_k is larger than what we would have found automatically. "
144- " Skipping split and output tile compatibility checks. Should we "
145- " expand the split_k search space?" ;
146- }
147- configs.push_back (config);
148- } else {
149- configs = GenerateContractingSplitFactors ();
150- }
128+ std::vector<ConfigWithNotes> configs (1 );
129+ configs[0 ].config .split_k = 1 ;
151130
152131 ExtendConfigs (configs, &TritonDotFusionSearchSpace::AddOutputTilings);
153132 EliminateLowOccupancyConfigs (configs);
@@ -183,31 +162,6 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
183162 return configs;
184163 }
185164
186- absl::flat_hash_map<std::pair<int , int >, std::pair<int , int >>
187- m_n_to_split_limits;
188- // Init with first config vals, otherwise they would be 0 and min comparison
189- // won't update properly.
190- std::pair<int , int > global_split_limits{configs.front ().split_k ,
191- configs.front ().split_k };
192- auto update_split_limits = [](auto & limits, int value) {
193- limits = std::minmax ({limits.first , limits.second , value});
194- };
195- for (const TritonGemmConfig& config : configs) {
196- auto m_n_key = std::make_pair (config.block_m , config.block_n );
197- auto & split_limits =
198- m_n_to_split_limits.try_emplace (m_n_key, config.split_k , config.split_k )
199- .first ->second ;
200- update_split_limits (split_limits, config.split_k );
201- update_split_limits (global_split_limits, config.split_k );
202- }
203-
204- auto get_split_limits = [&](int block_m, int block_n) {
205- auto m_n_key = std::make_pair (block_m, block_n);
206- auto split_limits_it = m_n_to_split_limits.find (m_n_key);
207- return split_limits_it == m_n_to_split_limits.end ()
208- ? global_split_limits
209- : split_limits_it->second ;
210- };
211165 absl::flat_hash_set<TritonGemmConfig> filter;
212166 for (TritonGemmConfig config : hints) {
213167 // Our default config set does not take problem size into account, so we
@@ -220,11 +174,8 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
220174 max_out_tile_.rhs_dim );
221175 config.block_k =
222176 std::clamp (config.block_k , min_contracting_tile_size_,
223- GetMaxContractingTileSize ({config.block_m , config.block_n },
224- /* contracting_split=*/ 1 ));
225- const auto & split_limits = get_split_limits (config.block_m , config.block_n );
226- config.split_k =
227- std::clamp (config.split_k , split_limits.first , split_limits.second );
177+ GetMaxContractingTileSize ({config.block_m , config.block_n }));
178+
228179 VLOG (10 ) << " Adding config to hint filter: " << config.ToString ();
229180 filter.insert (config);
230181 }
@@ -252,13 +203,13 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
252203std::string TritonDotFusionSearchSpace::ToString () const {
253204 return absl::StrFormat (
254205 " problem_size_BxMxNxKxE: %dx%dx%dx%dx(%d->%d) "
255- " tile_range_SxMxNxK: [1-%d]x [%d-%d]x[%d-%d]x[%d-?] "
206+ " tile_range_MxNxK: [%d-%d]x[%d-%d]x[%d-?] "
256207 " desired_total_warps: %d occupancy_optimization: %d "
257208 " warps_per_cta: [%d-?]" ,
258209 batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_,
259- operand_bitwidth_, compute_bitwidth_, max_contracting_split_ ,
260- min_out_tile_ .lhs_dim , max_out_tile_. lhs_dim , min_out_tile_ .rhs_dim ,
261- max_out_tile_. rhs_dim , min_contracting_tile_size_, desired_total_warps_,
210+ operand_bitwidth_, compute_bitwidth_, min_out_tile_. lhs_dim ,
211+ max_out_tile_ .lhs_dim , min_out_tile_. rhs_dim , max_out_tile_ .rhs_dim ,
212+ min_contracting_tile_size_, desired_total_warps_,
262213 should_optimize_for_occupancy_, min_warps_per_cta_);
263214}
264215
@@ -414,43 +365,6 @@ int TritonDotFusionSearchSpace::GetMinContractingTileSize() const {
414365 return min_contracting_tile_size;
415366}
416367
417- int TritonDotFusionSearchSpace::GetMaxContractingSplit (
418- OutputTile output_tile) const {
419- const int64_t desired_num_ctas = desired_total_warps_ / min_warps_per_cta_;
420- VLOG (5 ) << " Computing split_k: Considering output tile "
421- << output_tile.lhs_dim << " x" << output_tile.rhs_dim ;
422- VLOG (5 ) << " Computing split_k: Want up to " << desired_num_ctas
423- << " CTAs to occupy all cores." ;
424-
425- const int64_t min_result_tiles = GetNumResultTiles (output_tile);
426- VLOG (5 ) << " Computing split_k: Without split_k have " << min_result_tiles
427- << " tiles." ;
428-
429- const int64_t split_for_occupancy =
430- NextPowerOfTwo (CeilOfRatio (desired_num_ctas, min_result_tiles));
431- VLOG (5 ) << " Computing split_k: Want split_k of up to " << split_for_occupancy
432- << " for sufficient occupancy." ;
433-
434- // Calculate the maximum split_k that is valid with the smallest block_k.
435- // The validation in MakeSplitKOperand requires:
436- // split_k <= ceil(contracting_size / block_k)
437- // Using min_contracting_tile_size_ (smallest block_k) gives largest valid
438- // split.
439- const int64_t max_valid_split = CeilOfRatio (
440- contracting_size_, static_cast <int64_t >(min_contracting_tile_size_));
441- const int64_t split_for_contracting_size =
442- PreviousPowerOfTwo (max_valid_split);
443- VLOG (5 ) << " Computing split_k: Can't have split_k more than "
444- << split_for_contracting_size
445- << " to have sufficiently large contracting dimension (max_valid="
446- << max_valid_split << " )." ;
447-
448- const int64_t split =
449- std::min (split_for_occupancy, split_for_contracting_size);
450- VLOG (5 ) << " Computing split_k: max_split_k = " << split;
451- return split;
452- }
453-
454368int TritonDotFusionSearchSpace::GetContractingSizeLimitToFitSharedMemory (
455369 OutputTile output_tile) const {
456370 const int64_t shared_memory_budget =
@@ -462,54 +376,38 @@ int TritonDotFusionSearchSpace::GetContractingSizeLimitToFitSharedMemory(
462376}
463377
464378int TritonDotFusionSearchSpace::GetMaxContractingTileSize (
465- OutputTile output_tile, int contracting_split ) const {
466- const int64_t available_size = contracting_size_ / contracting_split ;
379+ OutputTile output_tile) const {
380+ const int64_t available_size = contracting_size_;
467381 const int size_limit = GetContractingSizeLimitToFitSharedMemory (output_tile);
468382 const int max_size =
469383 std::min (NextPowerOfTwo (available_size), PreviousPowerOfTwo (size_limit));
470- VLOG (5 ) << " Computing max_contracting_tile_size for tiling BxMxN = "
471- << contracting_split << " x" << output_tile.lhs_dim << " x"
472- << output_tile. rhs_dim << " : limit based on problem is "
473- << available_size << " , limit based on available shared memory is "
474- << size_limit << " , max_contracting_tile_size = " << max_size;
384+ VLOG (5 ) << " Computing max_contracting_tile_size for tiling BxMxN = " << 1
385+ << " x" << output_tile.lhs_dim << " x" << output_tile. rhs_dim
386+ << " : limit based on problem is " << available_size
387+ << " , limit based on available shared memory is " << size_limit
388+ << " , max_contracting_tile_size = " << max_size;
475389 return std::max (min_contracting_tile_size_, max_size);
476390}
477391
478- int TritonDotFusionSearchSpace::GetMaxNumStages (OutputTile output_tile,
479- int contracting_tile_size,
480- int contracting_split) const {
481- const int64_t available_stages = CeilOfRatio<int64_t >(
482- contracting_size_, contracting_split * contracting_tile_size);
392+ int TritonDotFusionSearchSpace::GetMaxNumStages (
393+ OutputTile output_tile, int contracting_tile_size) const {
394+ const int64_t available_stages =
395+ CeilOfRatio<int64_t >(contracting_size_, contracting_tile_size);
483396 const int64_t stage_limit = std::max (
484397 1 , CeilOfRatio (GetContractingSizeLimitToFitSharedMemory (output_tile),
485398 contracting_tile_size));
486399 // Number of stages is basically a replacement for oversubscription, so
487400 // the maximum number we want is also limited by kMaxWarpsPerScheduler.
488401 const int stages = std::min ({available_stages, stage_limit,
489402 static_cast <int64_t >(kMaxWarpsPerScheduler )});
490- VLOG (5 ) << " Computing max_num_stages for tiling BxMxNxK = "
491- << contracting_split << " x" << output_tile.lhs_dim << " x"
492- << output_tile.rhs_dim << " x" << contracting_tile_size
493- << " : limit based on problem is " << available_stages
494- << " , limit based on available shared memory is " << stage_limit
495- << " , max_num_stages = " << stages;
403+ VLOG (5 ) << " Computing max_num_stages for tiling BxMxNxK = " << 1 << " x"
404+ << output_tile.lhs_dim << " x" << output_tile.rhs_dim << " x"
405+ << contracting_tile_size << " : limit based on problem is "
406+ << available_stages << " , limit based on available shared memory is "
407+ << stage_limit << " , max_num_stages = " << stages;
496408 return stages;
497409}
498410
499- std::vector<TritonDotFusionSearchSpace::ConfigWithNotes>
500- TritonDotFusionSearchSpace::GenerateContractingSplitFactors () const {
501- CHECK_GE (max_contracting_split_, 1 );
502- std::vector<ConfigWithNotes> configs;
503- ConfigWithNotes config;
504- for (int split = 1 ; split <= max_contracting_split_; split *= 2 ) {
505- config.config .split_k = split;
506- VLOG (10 ) << " Generating contracting split factors: config = "
507- << config.ToString ();
508- configs.push_back (config);
509- }
510- return configs;
511- }
512-
513411void TritonDotFusionSearchSpace::ExtendConfigs (
514412 std::vector<ConfigWithNotes>& configs,
515413 ExtendConfigCallback extend_config) const {
@@ -525,9 +423,6 @@ void TritonDotFusionSearchSpace::ExtendConfigs(
525423void TritonDotFusionSearchSpace::AddOutputTilings (
526424 const ConfigWithNotes& config,
527425 std::vector<ConfigWithNotes>& updated_configs) const {
528- CHECK_GT (config.config .split_k , 0 )
529- << " Need config with contracting split already set." ;
530- const int split = config.config .split_k ;
531426 ConfigWithNotes new_config = config;
532427 for (int m = min_out_tile_.lhs_dim ; m <= max_out_tile_.lhs_dim ; m *= 2 ) {
533428 int min_n = min_out_tile_.rhs_dim ;
@@ -561,17 +456,8 @@ void TritonDotFusionSearchSpace::AddOutputTilings(
561456 }
562457 for (int n = min_n; n <= max_n; n *= 2 ) {
563458 OutputTile tile = {m, n};
564- // We could make the tile size limits depend on split_k, but then we
565- // need to implement the "inverse" of `GetMaxContractingSplit`.
566- // Simpler is to just verify that the given combination of tiling and
567- // split_k is compatible.
568- if (!config.keep_large_split && GetMaxContractingSplit (tile) < split) {
569- VLOG (10 ) << " Skipping due to too large split_k, config = "
570- << new_config.ToString ();
571- continue ;
572- }
573459 new_config.not_enough_tiles =
574- GetNumResultTiles (tile) * split < device_description_.core_count ();
460+ GetNumResultTiles (tile) < device_description_.core_count ();
575461 new_config.config .block_m = m;
576462 new_config.config .block_n = n;
577463 VLOG (10 ) << " Adding output tiling: config = " << new_config.ToString ();
@@ -604,29 +490,13 @@ void TritonDotFusionSearchSpace::AddContractingTiling(
604490 std::vector<ConfigWithNotes>& updated_configs) const {
605491 const int tile_rows = config.config .block_m ;
606492 const int tile_cols = config.config .block_n ;
607- const int split = config.config .split_k ;
608493 CHECK_GT (tile_rows * tile_cols, 0 )
609494 << " Need configs with output tilings determined." ;
610- CHECK_GT (split, 0 ) << " Need config with contracting split determined." ;
611495 int max_tile_size =
612- std::max (GetMaxContractingTileSize ({tile_rows, tile_cols}, split ),
496+ std::max (GetMaxContractingTileSize ({tile_rows, tile_cols}),
613497 min_contracting_tile_size_);
614498 ConfigWithNotes new_config = config;
615499 for (int k = min_contracting_tile_size_; k <= max_tile_size; k *= 2 ) {
616- // Safety check: skip block_k values that are incompatible with split_k.
617- // The validation in MakeSplitKOperand checks:
618- // split_k > ceil(contracting_size / block_k) → error
619- // So we need: split_k <= ceil(contracting_size / block_k)
620- // Skip this check if keep_large_split is true (user forced a large split).
621- if (!config.keep_large_split ) {
622- const int64_t max_split_for_this_k =
623- CeilOfRatio (contracting_size_, static_cast <int64_t >(k));
624- if (split > max_split_for_this_k) {
625- VLOG (10 ) << " Skipping block_k=" << k << " for split_k=" << split
626- << " (max_split=" << max_split_for_this_k << " )" ;
627- continue ;
628- }
629- }
630500 new_config.config .block_k = k;
631501 VLOG (10 ) << " Adding contracting tiling: config = " << new_config.ToString ();
632502 updated_configs.push_back (new_config);
@@ -639,14 +509,11 @@ void TritonDotFusionSearchSpace::AddPipeliningParameter(
639509 const int tile_rows = config.config .block_m ;
640510 const int tile_cols = config.config .block_n ;
641511 const int tile_contracting = config.config .block_k ;
642- const int split = config.config .split_k ;
643512 CHECK_GT (tile_rows * tile_cols, 0 )
644513 << " Need config with output tilings determined." ;
645514 CHECK_GT (tile_contracting, 0 )
646515 << " Need config with contracting tiling determined." ;
647- CHECK_GT (split, 0 ) << " Need config with contracting split determined." ;
648- int max_stages =
649- GetMaxNumStages ({tile_rows, tile_cols}, tile_contracting, split);
516+ int max_stages = GetMaxNumStages ({tile_rows, tile_cols}, tile_contracting);
650517 ConfigWithNotes new_config = config;
651518 for (int num_stages = 1 ; num_stages <= max_stages; ++num_stages) {
652519 new_config.config .num_stages = num_stages;
@@ -712,14 +579,6 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs(
712579
713580 ConfigWithNotes last_config = configs.back (); // Largest split.
714581 auto has_too_few_tiles = [](const ConfigWithNotes& config) {
715- // Small dots frequently lead to large split_k values that are not
716- // compatible with codegen. We skip occupancy optimization for them to be
717- // able to consider smaller splits in non-exhaustive mode.
718- // The value of 4 was found by running exhaustive autotuning and noting that
719- // the majority of optimal configs with block_n == 8 had split_k <= 4.
720- if (config.config .block_n == 8 && config.config .split_k <= 4 ) {
721- return false ;
722- }
723582 if (config.not_enough_tiles ) {
724583 VLOG (10 ) << " Skipping due to fewer tiles than cores, config = "
725584 << config.ToString ();
0 commit comments