Skip to content

Commit e7b07ad

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #40920: [xla] Add a separate pass to propagate metadata across kCall instruction
Imported from GitHub PR #40920 - New `PropagateCallMetadata` HLO pass that propagates metadata (op_name prefix and stack_frame_id) from `kCall` instructions into their called computations, recursing through nested control-flow (while, conditional) but not into embedded computations (reduce's to_apply, etc.) - Removes metadata propagation from `CallInliner` — the `RecursivelyUpdateMetadata` helper and per-instruction metadata update during inlining are no longer needed since the new pass handles this as a separate concern - Wired into GPU compiler pre-SPMD pipeline before `CallInliner`, so metadata is propagated while `kCall` ops still exist (including non-inlinable calls) - Test coverage for op_name propagation, stack frame concatenation, overflow protection, redundant prefix detection, nested calls, and idempotency #### Motivation The `CallInliner` only updates metadata for calls it actually inlines. Non-inlinable calls (e.g. calls with inlineable="false") were skipped entirely, leaving their callee instructions with incomplete metadata context. Extracting this into a standalone pass ensures all calls get metadata propagation regardless of inlining decisions. Copybara import of the project: -- 3315b79 by Eugene Zhulenev <[email protected]>: [xla] Add a separate pass to propagate metadata across kCall instructions Merging this change closes #40920 FUTURE_COPYBARA_INTEGRATE_REVIEW=#40920 from ezhulenev:propagate-metadata-pass 3315b79 PiperOrigin-RevId: 900707821
1 parent ca4f5fe commit e7b07ad

6 files changed

Lines changed: 42 additions & 553 deletions

File tree

xla/backends/gpu/autotuner/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ cc_library(
402402
"//xla/service/gpu:gpu_float_support",
403403
"//xla/service/gpu:ir_emission_utils",
404404
"//xla/service/gpu:matmul_utils",
405-
"//xla/service/gpu:split_k_gemm_rewriter",
406405
"//xla/service/gpu/model:triton_emitter_constraints",
407406
"//xla/stream_executor:device_description",
408407
"//xla/stream_executor:stream_executor_h",

xla/backends/gpu/autotuner/triton.cc

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ limitations under the License.
5050
#include "xla/service/gpu/ir_emission_utils.h"
5151
#include "xla/service/gpu/matmul_utils.h"
5252
#include "xla/service/gpu/model/triton_emitter_constraints.h"
53-
#include "xla/service/gpu/split_k_gemm_rewriter.h"
5453
#include "xla/service/hlo_cost_analysis.h"
5554
#include "xla/service/instruction_fusion.h"
5655
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
@@ -132,13 +131,6 @@ TritonBackend::GetSupportedConfigsForDot(const HloInstruction* instr) {
132131
const HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
133132
TritonDotFusionSearchSpace search_space(target_config().device_description,
134133
dot);
135-
bool supports_contracting_split =
136-
HloBfsFindAll({dot}, [&](const HloInstruction* node) {
137-
return node->opcode() == HloOpcode::kSlice;
138-
}).empty();
139-
bool autotune_contracting_split =
140-
supports_contracting_split &&
141-
debug_options().xla_gpu_enable_split_k_autotuning();
142134
bool autotune_warp_specialization =
143135
debug_options()
144136
.xla_gpu_experimental_enable_triton_warp_specialization() &&
@@ -151,9 +143,6 @@ TritonBackend::GetSupportedConfigsForDot(const HloInstruction* instr) {
151143
// We don't need to consider small_dot here. The new search space will
152144
// already generate a unique config for small problems.
153145
std::vector<TritonGemmConfig> gemm_configs = search_space.GenerateConfigs(
154-
/*force_contracting_split=*/autotune_contracting_split
155-
? std::nullopt
156-
: std::make_optional(1),
157146
/*autotune_warp_specialization=*/autotune_warp_specialization);
158147

159148
if (!debug_options().xla_gpu_exhaustive_tiling_search()) {
@@ -231,15 +220,7 @@ absl::StatusOr<std::unique_ptr<BackendConfig>> TritonBackend::GetDefaultConfig(
231220
const HloInstruction& instr) {
232221
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<BackendConfig>> configs,
233222
GetSupportedConfigs(instr));
234-
// Filter split_k>1 configs. Split_k>1 is not guaranteed to be supported.
235-
configs.erase(
236-
std::remove_if(configs.begin(), configs.end(),
237-
[](const std::unique_ptr<BackendConfig>& config) {
238-
AutotuneResult::TritonGemmKey triton_config_proto;
239-
config->UnpackTo(&triton_config_proto);
240-
return triton_config_proto.split_k() > 1;
241-
}),
242-
configs.end());
223+
243224
if (configs.empty()) {
244225
return absl::InvalidArgumentError(
245226
"TritonBackend does not support this instruction.");
@@ -268,11 +249,13 @@ absl::Status TritonBackend::ApplyConfig(HloInstruction& instr,
268249
*backend_config.mutable_triton_gemm_config() = triton_config_proto;
269250
TF_RETURN_IF_ERROR(instr.set_backend_config(gpu_config));
270251

271-
TF_ASSIGN_OR_RETURN(TritonGemmConfig triton_config,
272-
TritonGemmConfig::FromProto(triton_config_proto));
273-
if (triton_config.split_k > 1) {
274-
TF_RETURN_IF_ERROR(MakeDotSplitKBatch(&instr, triton_config));
252+
// FromProto has validation checks, that's why we call it here.
253+
TF_RETURN_IF_ERROR(TritonGemmConfig::FromProto(triton_config_proto).status());
254+
if (triton_config_proto.split_k() != 1) {
255+
return absl::InvalidArgumentError(
256+
"TritonBackend no longer supports split-k (split_k > 1).");
275257
}
258+
276259
return absl::OkStatus();
277260
}
278261

xla/backends/gpu/autotuner/triton/dot_search_space.cc

Lines changed: 28 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace(
110110
should_optimize_for_occupancy_(ShouldOptimizeForOccupancy()),
111111
min_out_tile_(GetMinOutputTile()),
112112
min_warps_per_cta_(GetMinWarpsPerCta()),
113-
min_contracting_tile_size_(GetMinContractingTileSize()),
114-
max_contracting_split_(GetMaxContractingSplit(max_out_tile_)) {
113+
min_contracting_tile_size_(GetMinContractingTileSize()) {
115114
// Make sure that the range of output tile sizes is not empty
116115
// (min_output_tile_ is a hard limit, while max_output_tile_ is a soft one).
117116
max_out_tile_.lhs_dim =
@@ -125,29 +124,9 @@ TritonDotFusionSearchSpace::TritonDotFusionSearchSpace(
125124
}
126125

127126
std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::GenerateConfigs(
128-
std::optional<int64_t> force_contracting_split,
129127
bool autotune_warp_specialization) const {
130-
std::vector<ConfigWithNotes> configs;
131-
if (force_contracting_split.has_value()) {
132-
ConfigWithNotes config;
133-
const int split = force_contracting_split.value();
134-
config.config.split_k = split;
135-
// It is possible that the user manually forced a huge contracting split
136-
// that is outside of the search space. In that case, we would end up
137-
// discarding all configs, and use the smallest possible tile size further
138-
// down, which is likely not what the user had in mind.
139-
config.keep_large_split = GetMaxContractingSplit(max_out_tile_) < split;
140-
VLOG(5) << "Forcing split_k, config = " << config.ToString();
141-
if (config.keep_large_split) {
142-
LOG(WARNING)
143-
<< "split_k is larger than what we would have found automatically. "
144-
"Skipping split and output tile compatibility checks. Should we "
145-
"expand the split_k search space?";
146-
}
147-
configs.push_back(config);
148-
} else {
149-
configs = GenerateContractingSplitFactors();
150-
}
128+
std::vector<ConfigWithNotes> configs(1);
129+
configs[0].config.split_k = 1;
151130

152131
ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddOutputTilings);
153132
EliminateLowOccupancyConfigs(configs);
@@ -183,31 +162,6 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
183162
return configs;
184163
}
185164

186-
absl::flat_hash_map<std::pair<int, int>, std::pair<int, int>>
187-
m_n_to_split_limits;
188-
// Init with first config vals, otherwise they would be 0 and min comparison
189-
// won't update properly.
190-
std::pair<int, int> global_split_limits{configs.front().split_k,
191-
configs.front().split_k};
192-
auto update_split_limits = [](auto& limits, int value) {
193-
limits = std::minmax({limits.first, limits.second, value});
194-
};
195-
for (const TritonGemmConfig& config : configs) {
196-
auto m_n_key = std::make_pair(config.block_m, config.block_n);
197-
auto& split_limits =
198-
m_n_to_split_limits.try_emplace(m_n_key, config.split_k, config.split_k)
199-
.first->second;
200-
update_split_limits(split_limits, config.split_k);
201-
update_split_limits(global_split_limits, config.split_k);
202-
}
203-
204-
auto get_split_limits = [&](int block_m, int block_n) {
205-
auto m_n_key = std::make_pair(block_m, block_n);
206-
auto split_limits_it = m_n_to_split_limits.find(m_n_key);
207-
return split_limits_it == m_n_to_split_limits.end()
208-
? global_split_limits
209-
: split_limits_it->second;
210-
};
211165
absl::flat_hash_set<TritonGemmConfig> filter;
212166
for (TritonGemmConfig config : hints) {
213167
// Our default config set does not take problem size into account, so we
@@ -220,11 +174,8 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
220174
max_out_tile_.rhs_dim);
221175
config.block_k =
222176
std::clamp(config.block_k, min_contracting_tile_size_,
223-
GetMaxContractingTileSize({config.block_m, config.block_n},
224-
/*contracting_split=*/1));
225-
const auto& split_limits = get_split_limits(config.block_m, config.block_n);
226-
config.split_k =
227-
std::clamp(config.split_k, split_limits.first, split_limits.second);
177+
GetMaxContractingTileSize({config.block_m, config.block_n}));
178+
228179
VLOG(10) << "Adding config to hint filter: " << config.ToString();
229180
filter.insert(config);
230181
}
@@ -252,13 +203,13 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::OptimizeConfigSet(
252203
std::string TritonDotFusionSearchSpace::ToString() const {
253204
return absl::StrFormat(
254205
"problem_size_BxMxNxKxE: %dx%dx%dx%dx(%d->%d) "
255-
"tile_range_SxMxNxK: [1-%d]x[%d-%d]x[%d-%d]x[%d-?] "
206+
"tile_range_MxNxK: [%d-%d]x[%d-%d]x[%d-?] "
256207
"desired_total_warps: %d occupancy_optimization: %d "
257208
"warps_per_cta: [%d-?]",
258209
batch_size_, lhs_parallel_size_, rhs_parallel_size_, contracting_size_,
259-
operand_bitwidth_, compute_bitwidth_, max_contracting_split_,
260-
min_out_tile_.lhs_dim, max_out_tile_.lhs_dim, min_out_tile_.rhs_dim,
261-
max_out_tile_.rhs_dim, min_contracting_tile_size_, desired_total_warps_,
210+
operand_bitwidth_, compute_bitwidth_, min_out_tile_.lhs_dim,
211+
max_out_tile_.lhs_dim, min_out_tile_.rhs_dim, max_out_tile_.rhs_dim,
212+
min_contracting_tile_size_, desired_total_warps_,
262213
should_optimize_for_occupancy_, min_warps_per_cta_);
263214
}
264215

@@ -414,43 +365,6 @@ int TritonDotFusionSearchSpace::GetMinContractingTileSize() const {
414365
return min_contracting_tile_size;
415366
}
416367

417-
int TritonDotFusionSearchSpace::GetMaxContractingSplit(
418-
OutputTile output_tile) const {
419-
const int64_t desired_num_ctas = desired_total_warps_ / min_warps_per_cta_;
420-
VLOG(5) << "Computing split_k: Considering output tile "
421-
<< output_tile.lhs_dim << "x" << output_tile.rhs_dim;
422-
VLOG(5) << "Computing split_k: Want up to " << desired_num_ctas
423-
<< " CTAs to occupy all cores.";
424-
425-
const int64_t min_result_tiles = GetNumResultTiles(output_tile);
426-
VLOG(5) << "Computing split_k: Without split_k have " << min_result_tiles
427-
<< " tiles.";
428-
429-
const int64_t split_for_occupancy =
430-
NextPowerOfTwo(CeilOfRatio(desired_num_ctas, min_result_tiles));
431-
VLOG(5) << "Computing split_k: Want split_k of up to " << split_for_occupancy
432-
<< " for sufficient occupancy.";
433-
434-
// Calculate the maximum split_k that is valid with the smallest block_k.
435-
// The validation in MakeSplitKOperand requires:
436-
// split_k <= ceil(contracting_size / block_k)
437-
// Using min_contracting_tile_size_ (smallest block_k) gives largest valid
438-
// split.
439-
const int64_t max_valid_split = CeilOfRatio(
440-
contracting_size_, static_cast<int64_t>(min_contracting_tile_size_));
441-
const int64_t split_for_contracting_size =
442-
PreviousPowerOfTwo(max_valid_split);
443-
VLOG(5) << "Computing split_k: Can't have split_k more than "
444-
<< split_for_contracting_size
445-
<< " to have sufficiently large contracting dimension (max_valid="
446-
<< max_valid_split << ").";
447-
448-
const int64_t split =
449-
std::min(split_for_occupancy, split_for_contracting_size);
450-
VLOG(5) << "Computing split_k: max_split_k = " << split;
451-
return split;
452-
}
453-
454368
int TritonDotFusionSearchSpace::GetContractingSizeLimitToFitSharedMemory(
455369
OutputTile output_tile) const {
456370
const int64_t shared_memory_budget =
@@ -462,54 +376,38 @@ int TritonDotFusionSearchSpace::GetContractingSizeLimitToFitSharedMemory(
462376
}
463377

464378
int TritonDotFusionSearchSpace::GetMaxContractingTileSize(
465-
OutputTile output_tile, int contracting_split) const {
466-
const int64_t available_size = contracting_size_ / contracting_split;
379+
OutputTile output_tile) const {
380+
const int64_t available_size = contracting_size_;
467381
const int size_limit = GetContractingSizeLimitToFitSharedMemory(output_tile);
468382
const int max_size =
469383
std::min(NextPowerOfTwo(available_size), PreviousPowerOfTwo(size_limit));
470-
VLOG(5) << "Computing max_contracting_tile_size for tiling BxMxN = "
471-
<< contracting_split << "x" << output_tile.lhs_dim << "x"
472-
<< output_tile.rhs_dim << ": limit based on problem is "
473-
<< available_size << ", limit based on available shared memory is "
474-
<< size_limit << ", max_contracting_tile_size = " << max_size;
384+
VLOG(5) << "Computing max_contracting_tile_size for tiling BxMxN = " << 1
385+
<< "x" << output_tile.lhs_dim << "x" << output_tile.rhs_dim
386+
<< ": limit based on problem is " << available_size
387+
<< ", limit based on available shared memory is " << size_limit
388+
<< ", max_contracting_tile_size = " << max_size;
475389
return std::max(min_contracting_tile_size_, max_size);
476390
}
477391

478-
int TritonDotFusionSearchSpace::GetMaxNumStages(OutputTile output_tile,
479-
int contracting_tile_size,
480-
int contracting_split) const {
481-
const int64_t available_stages = CeilOfRatio<int64_t>(
482-
contracting_size_, contracting_split * contracting_tile_size);
392+
int TritonDotFusionSearchSpace::GetMaxNumStages(
393+
OutputTile output_tile, int contracting_tile_size) const {
394+
const int64_t available_stages =
395+
CeilOfRatio<int64_t>(contracting_size_, contracting_tile_size);
483396
const int64_t stage_limit = std::max(
484397
1, CeilOfRatio(GetContractingSizeLimitToFitSharedMemory(output_tile),
485398
contracting_tile_size));
486399
// Number of stages is basically a replacement for oversubscription, so
487400
// the maximum number we want is also limited by kMaxWarpsPerScheduler.
488401
const int stages = std::min({available_stages, stage_limit,
489402
static_cast<int64_t>(kMaxWarpsPerScheduler)});
490-
VLOG(5) << "Computing max_num_stages for tiling BxMxNxK = "
491-
<< contracting_split << "x" << output_tile.lhs_dim << "x"
492-
<< output_tile.rhs_dim << "x" << contracting_tile_size
493-
<< ": limit based on problem is " << available_stages
494-
<< ", limit based on available shared memory is " << stage_limit
495-
<< ", max_num_stages = " << stages;
403+
VLOG(5) << "Computing max_num_stages for tiling BxMxNxK = " << 1 << "x"
404+
<< output_tile.lhs_dim << "x" << output_tile.rhs_dim << "x"
405+
<< contracting_tile_size << ": limit based on problem is "
406+
<< available_stages << ", limit based on available shared memory is "
407+
<< stage_limit << ", max_num_stages = " << stages;
496408
return stages;
497409
}
498410

499-
std::vector<TritonDotFusionSearchSpace::ConfigWithNotes>
500-
TritonDotFusionSearchSpace::GenerateContractingSplitFactors() const {
501-
CHECK_GE(max_contracting_split_, 1);
502-
std::vector<ConfigWithNotes> configs;
503-
ConfigWithNotes config;
504-
for (int split = 1; split <= max_contracting_split_; split *= 2) {
505-
config.config.split_k = split;
506-
VLOG(10) << "Generating contracting split factors: config = "
507-
<< config.ToString();
508-
configs.push_back(config);
509-
}
510-
return configs;
511-
}
512-
513411
void TritonDotFusionSearchSpace::ExtendConfigs(
514412
std::vector<ConfigWithNotes>& configs,
515413
ExtendConfigCallback extend_config) const {
@@ -525,9 +423,6 @@ void TritonDotFusionSearchSpace::ExtendConfigs(
525423
void TritonDotFusionSearchSpace::AddOutputTilings(
526424
const ConfigWithNotes& config,
527425
std::vector<ConfigWithNotes>& updated_configs) const {
528-
CHECK_GT(config.config.split_k, 0)
529-
<< "Need config with contracting split already set.";
530-
const int split = config.config.split_k;
531426
ConfigWithNotes new_config = config;
532427
for (int m = min_out_tile_.lhs_dim; m <= max_out_tile_.lhs_dim; m *= 2) {
533428
int min_n = min_out_tile_.rhs_dim;
@@ -561,17 +456,8 @@ void TritonDotFusionSearchSpace::AddOutputTilings(
561456
}
562457
for (int n = min_n; n <= max_n; n *= 2) {
563458
OutputTile tile = {m, n};
564-
// We could make the tile size limits depend on split_k, but then we
565-
// need to implement the "inverse" of `GetMaxContractingSplit`.
566-
// Simpler is to just verify that the given combination of tiling and
567-
// split_k is compatible.
568-
if (!config.keep_large_split && GetMaxContractingSplit(tile) < split) {
569-
VLOG(10) << "Skipping due to too large split_k, config = "
570-
<< new_config.ToString();
571-
continue;
572-
}
573459
new_config.not_enough_tiles =
574-
GetNumResultTiles(tile) * split < device_description_.core_count();
460+
GetNumResultTiles(tile) < device_description_.core_count();
575461
new_config.config.block_m = m;
576462
new_config.config.block_n = n;
577463
VLOG(10) << "Adding output tiling: config = " << new_config.ToString();
@@ -604,29 +490,13 @@ void TritonDotFusionSearchSpace::AddContractingTiling(
604490
std::vector<ConfigWithNotes>& updated_configs) const {
605491
const int tile_rows = config.config.block_m;
606492
const int tile_cols = config.config.block_n;
607-
const int split = config.config.split_k;
608493
CHECK_GT(tile_rows * tile_cols, 0)
609494
<< "Need configs with output tilings determined.";
610-
CHECK_GT(split, 0) << "Need config with contracting split determined.";
611495
int max_tile_size =
612-
std::max(GetMaxContractingTileSize({tile_rows, tile_cols}, split),
496+
std::max(GetMaxContractingTileSize({tile_rows, tile_cols}),
613497
min_contracting_tile_size_);
614498
ConfigWithNotes new_config = config;
615499
for (int k = min_contracting_tile_size_; k <= max_tile_size; k *= 2) {
616-
// Safety check: skip block_k values that are incompatible with split_k.
617-
// The validation in MakeSplitKOperand checks:
618-
// split_k > ceil(contracting_size / block_k) → error
619-
// So we need: split_k <= ceil(contracting_size / block_k)
620-
// Skip this check if keep_large_split is true (user forced a large split).
621-
if (!config.keep_large_split) {
622-
const int64_t max_split_for_this_k =
623-
CeilOfRatio(contracting_size_, static_cast<int64_t>(k));
624-
if (split > max_split_for_this_k) {
625-
VLOG(10) << "Skipping block_k=" << k << " for split_k=" << split
626-
<< " (max_split=" << max_split_for_this_k << ")";
627-
continue;
628-
}
629-
}
630500
new_config.config.block_k = k;
631501
VLOG(10) << "Adding contracting tiling: config = " << new_config.ToString();
632502
updated_configs.push_back(new_config);
@@ -639,14 +509,11 @@ void TritonDotFusionSearchSpace::AddPipeliningParameter(
639509
const int tile_rows = config.config.block_m;
640510
const int tile_cols = config.config.block_n;
641511
const int tile_contracting = config.config.block_k;
642-
const int split = config.config.split_k;
643512
CHECK_GT(tile_rows * tile_cols, 0)
644513
<< "Need config with output tilings determined.";
645514
CHECK_GT(tile_contracting, 0)
646515
<< "Need config with contracting tiling determined.";
647-
CHECK_GT(split, 0) << "Need config with contracting split determined.";
648-
int max_stages =
649-
GetMaxNumStages({tile_rows, tile_cols}, tile_contracting, split);
516+
int max_stages = GetMaxNumStages({tile_rows, tile_cols}, tile_contracting);
650517
ConfigWithNotes new_config = config;
651518
for (int num_stages = 1; num_stages <= max_stages; ++num_stages) {
652519
new_config.config.num_stages = num_stages;
@@ -712,14 +579,6 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs(
712579

713580
ConfigWithNotes last_config = configs.back(); // Largest split.
714581
auto has_too_few_tiles = [](const ConfigWithNotes& config) {
715-
// Small dots frequently lead to large split_k values that are not
716-
// compatible with codegen. We skip occupancy optimization for them to be
717-
// able to consider smaller splits in non-exhaustive mode.
718-
// The value of 4 was found by running exhaustive autotuning and noting that
719-
// the majority of optimal configs with block_n == 8 had split_k <= 4.
720-
if (config.config.block_n == 8 && config.config.split_k <= 4) {
721-
return false;
722-
}
723582
if (config.not_enough_tiles) {
724583
VLOG(10) << "Skipping due to fewer tiles than cores, config = "
725584
<< config.ToString();

0 commit comments

Comments
 (0)