Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions xla/backends/gpu/autotuner/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ absl::StatusOr<std::unique_ptr<HloModule>> TritonBackend::RunHloPasses(
FusionWrapper fusion_wrapper(gpu_device_info);
TF_RETURN_IF_ERROR(fusion_wrapper.Run(hlo_module.get()).status());
TF_RETURN_IF_ERROR(HoistFusedBitcasts().Run(hlo_module.get()).status());
ConvertTritonGemmConfig convert_triton_gemm_config(gpu_device_info,
mlir_context_);
RETURN_IF_ERROR(convert_triton_gemm_config.Run(hlo_module.get()).status());
NestGemmFusion nest_gemm_fusion(gpu_device_info, mlir_context_);
RETURN_IF_ERROR(nest_gemm_fusion.Run(hlo_module.get()).status());
if (debug_options().xla_gpu_unsupported_disable_nested_gemm_fusions()) {
ConvertTritonGemmConfig convert_triton_gemm_config(gpu_device_info,
mlir_context_);
RETURN_IF_ERROR(convert_triton_gemm_config.Run(hlo_module.get()).status());
} else {
NestGemmFusion nest_gemm_fusion(gpu_device_info, mlir_context_);
RETURN_IF_ERROR(nest_gemm_fusion.Run(hlo_module.get()).status());
}
return hlo_module;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ constexpr ErrorSpec kExactMatch{/*aabs=*/0, /*arel=*/0};

class TritonEmitterTest : public GpuCodegenTest, public XTileTestBase {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_unsupported_disable_nested_gemm_fusions(true);
return debug_options;
}
const stream_executor::GpuComputeCapability& GpuComputeCapability() {
return backend()
.default_stream_executor()
Expand Down Expand Up @@ -4849,9 +4854,6 @@ ENTRY e {

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kHloTextTemplate));
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_unsupported_disable_nested_gemm_fusions(true);
TF_ASSERT_OK_AND_ASSIGN(auto optimized_module,
GetOptimizedModule(std::move(module)));
constexpr absl::string_view kExpectedOptimizedHLO = R"(
Expand Down
12 changes: 7 additions & 5 deletions xla/backends/gpu/codegen/triton/support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,12 @@ CodegenDecision IsTritonSupportedFusion(
" is not supported: ", decision.Explain()));
}

bool AnyOperandIsFusion(const HloInstruction& hlo) {
return absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kFusion;
});
// Returns whether a control-flow regions should be created at the tile level.
bool TilingControlFlowIsEnabled(const HloInstruction& hlo) {
return hlo.GetModule()
->config()
.debug_options()
.xla_gpu_unsupported_disable_nested_gemm_fusions();
}

CodegenDecision IsTritonSupportedConcatenate(const HloInstruction& hlo) {
Expand All @@ -653,7 +655,7 @@ CodegenDecision IsTritonSupportedConcatenate(const HloInstruction& hlo) {
return CodegenDecision::Forbid(
"Only concatenates in nested GEMM fusions are supported.");
}
if (AnyOperandIsFusion(hlo)) {
if (!TilingControlFlowIsEnabled(hlo)) {
// TODO(b/393299275): remove this operand filter once migration is
// complete and priority fusion can produce nests.
if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) {
Expand Down
7 changes: 4 additions & 3 deletions xla/backends/gpu/codegen/triton/xtile_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,10 @@ absl::StatusOr<absl::InlinedVector<int64_t, 4>> DotTilingParameters(
const HloInstruction* hlo,
const SymbolicTileAnalysis& symbolic_tile_analysis,
const BlockLevelParameters& block_level_parameters) {
if (absl::c_all_of(hlo->operands(), [](const HloInstruction* operand) {
return operand->opcode() != HloOpcode::kFusion;
})) {
if (hlo->GetModule()
->config()
.debug_options()
.xla_gpu_unsupported_disable_nested_gemm_fusions()) {
ASSIGN_OR_RETURN(Tile tile_config, hlo->backend_config<Tile>());
return FlatTiling(tile_config.sizes().begin(), tile_config.sizes().end());
}
Expand Down
21 changes: 12 additions & 9 deletions xla/codegen/tiling/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,20 +421,23 @@ using UnsafeSymbolicTiledHloInstructionOrderedSet =
UnsafeSymbolicTiledHloInstructionOperandAgnosticHash,
UnsafeSymbolicTiledHloInstructionOperandAgnosticEq>;

bool AnyOperandIsFusion(const HloInstruction& hlo) {
return absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kFusion;
});
// Returns whether a control-flow regions should be created at the tile level.
bool TilingControlFlowIsEnabled(const HloInstruction& hlo) {
return ABSL_DIE_IF_NULL(hlo.GetModule())
->config()
.debug_options()
.xla_gpu_unsupported_disable_nested_gemm_fusions();
}

// Returns whether the instruction is a conditional block for tiling.
bool IsControlFlowCondition(const HloInstruction& hlo) {
return hlo.opcode() == HloOpcode::kConcatenate && !AnyOperandIsFusion(hlo);
return hlo.opcode() == HloOpcode::kConcatenate &&
TilingControlFlowIsEnabled(hlo);
}

// Returns whether the instruction is a loop block for tiling.
bool IsControlFlowLoop(const HloInstruction& hlo) {
return IsSomeDot(hlo) && !AnyOperandIsFusion(hlo);
return IsSomeDot(hlo) && TilingControlFlowIsEnabled(hlo);
}

// Detects pathological cases on which symbolic tile derivation should bail out.
Expand All @@ -444,7 +447,6 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation(
const SymbolicTiledHloInstruction& tiled_hlo_instruction) {
const HloInstruction* hlo = tiled_hlo_instruction.hlo();
const IndexingMap& indexing_map = tiled_hlo_instruction.indexing_map();
// TODO(b/446827313): update comment after disabling nested fusions.
// Bail out on concatenates in the general path for now, but allow a
// restricted form of concatenates for the nested GEMM fusion path.
//
Expand All @@ -453,7 +455,7 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation(
// for concatenates.
if ((hlo->opcode() == HloOpcode::kConcatenate ||
hlo->opcode() == HloOpcode::kPad) &&
!(IsWithinNestedGemmFusion(*hlo) || IsControlFlowCondition(*hlo))) {
!(IsWithinNestedGemmFusion(*hlo) || TilingControlFlowIsEnabled(*hlo))) {
return FusionDecision::Forbid("Bailing out on ") << hlo->ToString();
}

Expand Down Expand Up @@ -1927,7 +1929,8 @@ absl::StatusOr<std::unique_ptr<TiledHloInstruction>> ComputeTiledHloInstruction(
std::move(tile_sizes), std::move(tile_strides),
std::move(tile_offset_indexing));
}
if (!symbolic_tiled_hlo->regions().empty()) {
if (TilingControlFlowIsEnabled(*hlo) &&
!symbolic_tiled_hlo->regions().empty()) {
// Copy instruction mapping to avoid polluting with instructions from
// sub-regions.
absl::flat_hash_map<const SymbolicTiledHloInstruction*,
Expand Down
112 changes: 111 additions & 1 deletion xla/codegen/tiling/symbolic_tile_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,14 @@ class SymbolicTileAnalysisTest : public HloHardwareIndependentTestBase {
// SymbolicTileAnalysisTest with updated expectations. Original tests should be
// updated after the flag is on by default. Test names should also be updated
// as at the moment they usually match the original name.
class SymbolicTileAnalysisRegionsTest : public SymbolicTileAnalysisTest {};
class SymbolicTileAnalysisRegionsTest : public SymbolicTileAnalysisTest {
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options =
SymbolicTileAnalysisTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_unsupported_disable_nested_gemm_fusions(true);
return debug_options;
}
};

TEST_F(SymbolicTileAnalysisTest, SimpleNormalizationDiamondIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
Expand Down Expand Up @@ -953,6 +960,70 @@ ENTRY main {
)"));
}

TEST_F(SymbolicTileAnalysisTest, ScaledDotOffsetIndexingIsCorrect) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
fusion {
lhs = f8e4m3fn[128,64] parameter(0)
rhs = f8e4m3fn[64,128] parameter(1)
lhs_scale = f8e8m0fnu[128,2] parameter(2)
rhs_scale = f8e8m0fnu[2,128] parameter(3)
ROOT dot = f32[128,128] scaled-dot(lhs, rhs, lhs_scale, rhs_scale),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}

ENTRY main {
lhs = f8e4m3fn[128,64] parameter(0)
rhs = f8e4m3fn[64,128] parameter(1)
lhs_scale = f8e8m0fnu[128,2] parameter(2)
rhs_scale = f8e8m0fnu[2,128] parameter(3)
ROOT fusion = f32[128,128] fusion(lhs, rhs, lhs_scale, rhs_scale),
kind=kLoop, calls=fusion
})"));
std::optional<SymbolicTileAnalysis> analysis = TryAnalyzeModule(module.get());
ASSERT_TRUE(analysis.has_value());
const HloInstruction* dot_hlo =
module->entry_computation()->root_instruction()->fused_expression_root();
constexpr int64_t kContractingTileSize = 32;
constexpr int64_t kLhsTileSize = 16;
constexpr int64_t kRhsTileSize = 16;
Tiling tiling(Tiling::TileMapping{
{dot_hlo, {kContractingTileSize, kLhsTileSize, kRhsTileSize}}});
TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation,
analysis->ComputeTiledComputation(
tiling, default_schedule_builder_,
/*constraints_are_known_satisfied=*/false,
/*compute_all_tile_offset_indexing_maps=*/true));

const TiledHloInstruction* dot = GetFirstRoot(tiled_hlo_computation);
EXPECT_THAT(*dot, MatchTiledHloInstruction(
/*tile_sizes=*/{16, 16}, /*tile_strides=*/{1, 1},
/*tile_offsets_indexing=*/R"(
(pid_0) -> ((pid_0 floordiv 8) * 16, (pid_0 mod 8) * 16),
domain:
pid_0 in [0, 63]
)"));

ASSERT_THAT(dot->operands(), SizeIs(4));
const TiledHloInstruction* lhs = dot->operand(0);
EXPECT_THAT(*lhs, MatchTiledHloInstruction(
/*tile_sizes=*/{16, 32}, /*tile_strides=*/{1, 1},
/*tile_offsets_indexing=*/R"(
(pid_0) -> ((pid_0 floordiv 8) * 16, 0),
domain:
pid_0 in [0, 63]
)"));

const TiledHloInstruction* rhs = dot->operand(1);
EXPECT_THAT(*rhs, MatchTiledHloInstruction(
/*tile_sizes=*/{32, 16}, /*tile_strides=*/{1, 1},
/*tile_offsets_indexing=*/R"(
(pid_0) -> (0, (pid_0 mod 8) * 16),
domain:
pid_0 in [0, 63]
)"));
}

TEST_F(SymbolicTileAnalysisRegionsTest, ScaledDotOffsetIndexingIsCorrect) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
Expand Down Expand Up @@ -1057,6 +1128,23 @@ ENTRY main {
"2 mod d0 in [0, 0] || d0 mod 2 in [0, 0]"));
}

TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedConcatenate) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
fusion {
p0 = f32[1,3]{1,0} parameter(0)
p1 = f32[1,3]{1,0} parameter(1)
ROOT concatenate = f32[2,3] concatenate(p0, p1), dimensions={0}
}

ENTRY main {
p0 = f32[1,3]{1,0} parameter(0)
p1 = f32[1,3]{1,0} parameter(1)
ROOT fusion = f32[2,3] fusion(p0, p1), kind=kLoop, calls=fusion
})"));
EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value());
}

TEST_F(SymbolicTileAnalysisRegionsTest, ConcatenateIsSupported) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
Expand Down Expand Up @@ -2416,6 +2504,28 @@ ENTRY main {
}
}

TEST_F(SymbolicTileAnalysisTest,
ConcatenatesOutsideOfNestedGemmFusionsForbidSymbolicTileDerivation) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
concatenate {
p0 = bf16[6] parameter(0)
p1 = bf16[6] parameter(1)
p2 = bf16[6] parameter(2)
ROOT concatenate = bf16[18] concatenate(p0, p1, p2), dimensions={0}
}

ENTRY main {
p0 = bf16[6] parameter(0)
p1 = bf16[6] parameter(1)
p2 = bf16[6] parameter(2)
ROOT fusion = bf16[18] fusion(p0, p1, p2),
kind=kCustom, calls=concatenate
})"));
std::optional<SymbolicTileAnalysis> analysis = TryAnalyzeModule(module.get());
EXPECT_FALSE(analysis.has_value());
}

TEST_F(SymbolicTileAnalysisTest,
ConcatenateOperandsInNestedGemmFusionsAreProvidedCorrectTilingBounds) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
Expand Down
16 changes: 9 additions & 7 deletions xla/codegen/xtile/codegen/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ TensorValue Iota(mlir::ImplicitLocOpBuilder& b, int32_t limit) {
return stablehlo::IotaOp::create(b, type, /*iota_dimension=*/0);
}

bool AnyOperandIsFusion(const HloInstruction& hlo) {
return absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kFusion;
});
// Returns whether we expect to see sub-regions defined for TiledHloInstruction.
bool TilingControlFlowIsEnabled(const HloInstruction& hlo) {
return hlo.GetModule()
->config()
.debug_options()
.xla_gpu_unsupported_disable_nested_gemm_fusions();
}

absl::Status EmitReduceComputation(mlir::ImplicitLocOpBuilder& b,
Expand Down Expand Up @@ -1614,7 +1616,7 @@ absl::StatusOr<TensorValue> EmitTiledHloInstruction(
}

if (hlo->opcode() == HloOpcode::kConcatenate) {
if (!AnyOperandIsFusion(*hlo)) {
if (TilingControlFlowIsEnabled(*hlo)) {
return EmitUnnestedConcatenate(b, fusion, tiled_hlo, fn, pid, values);
}
return EmitConcatenate(b, fusion, tiled_hlo, fn, pid, values);
Expand All @@ -1625,14 +1627,14 @@ absl::StatusOr<TensorValue> EmitTiledHloInstruction(
}

if (hlo->opcode() == HloOpcode::kDot) {
if (!AnyOperandIsFusion(*hlo)) {
if (TilingControlFlowIsEnabled(*hlo)) {
return EmitUnnestedDot(b, fusion, tiled_hlo, fn, pid, values);
}
return EmitDot(b, fusion, tiled_hlo, fn, pid, values);
}

if (hlo->opcode() == HloOpcode::kScaledDot) {
if (!AnyOperandIsFusion(*hlo)) {
if (TilingControlFlowIsEnabled(*hlo)) {
return EmitUnnestedScaledDot(b, fusion, tiled_hlo, fn, pid, values);
}
return EmitScaledDot(b, fusion, tiled_hlo, fn, pid, values);
Expand Down
13 changes: 8 additions & 5 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,14 @@ absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(

HoistFusedBitcasts hoist_fused_bitcasts;
TF_RETURN_IF_ERROR(hoist_fused_bitcasts.Run(new_module.get()).status());
ConvertTritonGemmConfig convert_triton_gemm_config(gpu_device_info,
mlir_context);
RETURN_IF_ERROR(convert_triton_gemm_config.Run(new_module.get()).status());
NestGemmFusion nest_gemm_fusion(gpu_device_info, mlir_context);
RETURN_IF_ERROR(nest_gemm_fusion.Run(new_module.get()).status());
if (debug_opts.xla_gpu_unsupported_disable_nested_gemm_fusions()) {
ConvertTritonGemmConfig convert_triton_gemm_config(gpu_device_info,
mlir_context);
RETURN_IF_ERROR(convert_triton_gemm_config.Run(new_module.get()).status());
} else {
NestGemmFusion nest_gemm_fusion(gpu_device_info, mlir_context);
RETURN_IF_ERROR(nest_gemm_fusion.Run(new_module.get()).status());
}
return new_module;
}

Expand Down
11 changes: 7 additions & 4 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1948,10 +1948,13 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// GemmFusionAutotuner runs hoist-fused-bitcasts and nest-gemm-fusion,
// matching its behavior here.
pipeline.AddPass<HoistFusedBitcasts>();
pipeline.AddPass<ConvertTritonGemmConfig>(
gpu_target_config.device_description, &mlir_context_);
pipeline.AddPass<NestGemmFusion>(gpu_target_config.device_description,
&mlir_context_);
if (debug_options.xla_gpu_unsupported_disable_nested_gemm_fusions()) {
pipeline.AddPass<ConvertTritonGemmConfig>(
gpu_target_config.device_description, &mlir_context_);
} else {
pipeline.AddPass<NestGemmFusion>(gpu_target_config.device_description,
&mlir_context_);
}

// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
Expand Down
13 changes: 7 additions & 6 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,13 @@ ENTRY main {
auto result = indexing_cost_model_.EstimateRunTimeForTiledFusion(
*fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{{1, 128}});

EXPECT_THAT(
result,
absl_testing::StatusIs(
absl::StatusCode::kFailedPrecondition,
HasSubstr(
"Concatenate is not supported by the indexing cost model")));
// Currently SymbolicTileAnalysis fails for concatenate. Once the analysis
// gets support of concatenate, this test should fail with an error from
// `EstimateRunTimeForTiledHloComputation` that propagation of the number of
// blocks is not supported (b/351342921).
EXPECT_THAT(result,
absl_testing::StatusIs(absl::StatusCode::kFailedPrecondition,
HasSubstr("SymbolicTileAnalysis failed")));
}

TEST_F(GpuIndexingPerformanceModelTest,
Expand Down
Loading