Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions xla/backends/gpu/codegen/triton/xtile_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> TileAndEmitXTileModule(
const HloComputation* computation = fusion.fused_instructions_computation();

if (use_experimental_tiling) {
using experimental::TileAnalysisOrError;
using experimental::TiledHloComputation;
using experimental::TilingSpace;

Expand All @@ -230,14 +229,9 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> TileAndEmitXTileModule(
GetTilingSpaceConcreteSizes(*tiling_space, block_level_parameters));
tiling_space->AssignTileSizes(xtile::GetPaddedTileSizes(tile_sizes));

TileAnalysisOrError tiled_computation_or =
TiledHloComputation::Tile(*fusion_adaptor, std::move(tiling_space));
if (std::holds_alternative<FusionDecision>(tiled_computation_or)) {
return Internal("Unsupported fusion in CreateTritonModule: %s",
std::get<FusionDecision>(tiled_computation_or).Explain());
}
const auto& tiled_computation =
std::get<TiledHloComputation>(tiled_computation_or);
ASSIGN_OR_RETURN(
TiledHloComputation tiled_computation,
TiledHloComputation::Tile(*fusion_adaptor, std::move(tiling_space)));
VLOG(6) << "tiled computation: " << tiled_computation.ToString();
return xtile::EmitXTileModule(
fn_name, fusion, tiled_computation, mlir_context,
Expand Down
2 changes: 2 additions & 0 deletions xla/codegen/tiling/experimental/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,15 @@ cc_library(
"//xla/service:instruction_fusion",
"//xla/service:name_uniquer",
"//xla/service/gpu:backend_configs_cc",
"//xla/tsl/platform:status_macros",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
Expand Down
20 changes: 9 additions & 11 deletions xla/codegen/tiling/experimental/scheduling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,24 @@ class SchedulingTest : public HloHardwareIndependentTestBase {
return module_->entry_computation()->root_instruction();
}

TiledHloComputation ParseAndTile(absl::string_view hlo_string,
absl::Span<const int64_t> tile_sizes = {}) {
absl::StatusOr<TiledHloComputation> ParseAndTile(
absl::string_view hlo_string, absl::Span<const int64_t> tile_sizes = {}) {
HloInstruction* root = ParseAndGetRoot(hlo_string);
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root);
auto tiling_space = TilingSpace::Create(*fusion_adaptor, &mlir_context_);
if (!tile_sizes.empty()) {
tiling_space->AssignTileSizes(tile_sizes);
}
auto tiled_computation_or =
TiledHloComputation::Tile(*fusion_adaptor, std::move(tiling_space));
CHECK(std::holds_alternative<TiledHloComputation>(tiled_computation_or));
return std::get<TiledHloComputation>(std::move(tiled_computation_or));
return TiledHloComputation::Tile(*fusion_adaptor, std::move(tiling_space));
}

mlir::MLIRContext mlir_context_;
std::unique_ptr<VerifiedHloModule> module_;
};

TEST_F(SchedulingTest, OnlyParallelDimensions) {
const TiledHloComputation tiled_computation = ParseAndTile(R"(
ASSERT_OK_AND_ASSIGN(const TiledHloComputation tiled_computation,
ParseAndTile(R"(
fusion {
p0 = f32[2,97]{1,0} parameter(0)
p1 = f32[2,97]{1,0} parameter(1)
Expand All @@ -95,16 +93,16 @@ TEST_F(SchedulingTest, OnlyParallelDimensions) {
p1 = f32[2,97]{1,0} parameter(1)
ROOT fusion = f32[2,97]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion
})",
{1, 32});
{1, 32}));
auto scheduling = GetSchedule(tiled_computation);
EXPECT_THAT(scheduling,
IsOkAndHolds(MatchSchedule(
"d0 -> pid floordiv 4, d1 -> pid mod 4, pid_bounds=[0, 7]")));
}

TEST_F(SchedulingTest, ReductionsAndContractionsAreNotSupported) {
const TiledHloComputation tiled_computation =
ParseAndTile(R"(
ASSERT_OK_AND_ASSIGN(const TiledHloComputation tiled_computation,
ParseAndTile(R"(
max {
p1 = f32[] parameter(1)
p0 = f32[] parameter(0)
Expand All @@ -121,7 +119,7 @@ TEST_F(SchedulingTest, ReductionsAndContractionsAreNotSupported) {
p0 = f32[2,97]{1,0} parameter(0)
ROOT fusion = f32[2,97]{1,0} fusion(p0), kind=kLoop, calls=fusion
})",
{1, 32, /*reduction_tile_size=*/8});
{1, 32, /*reduction_tile_size=*/8}));
auto scheduling = GetSchedule(tiled_computation);
EXPECT_THAT(scheduling,
IsOkAndHolds(MatchSchedule(
Expand Down
82 changes: 34 additions & 48 deletions xla/codegen/tiling/experimental/tiled_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/tsl/platform/status_macros.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -211,7 +212,7 @@ void PrepopulateTileNames(
if (!inserted) {
return;
}
for (const auto& region : tiled_hlo->regions()) {
for (const auto& region : tiled_hlo->hlo_regions()) {
for (const auto& region_instruction : region) {
PrepopulateTileNames(region_instruction.get(), name_uniquer, tile_names);
}
Expand All @@ -236,14 +237,12 @@ void PrintTiledHloInstruction(
<< absl::StrJoin(operand_names, ", ") << ") "
<< tiled_hlo->tile().ToString(false) << "\n";

if (!tiled_hlo->regions().empty()) {
for (auto const& [i, region] : llvm::enumerate(tiled_hlo->regions())) {
ss << indentation << "region #" << i << " {\n";
for (const auto& instruction : region) {
PrintTiledHloInstruction(instruction.get(), tile_names, ss, indent + 2);
}
ss << indentation << "}\n";
for (auto const& [i, region] : llvm::enumerate(tiled_hlo->hlo_regions())) {
ss << indentation << "region #" << i << " {\n";
for (const auto& instruction : region) {
PrintTiledHloInstruction(instruction.get(), tile_names, ss, indent + 2);
}
ss << indentation << "}\n";
}
}

Expand All @@ -266,7 +265,7 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
// * If tiled_root is a concat,
// returns {tiled_root}, and tiled_root has one region per operand.
// * Otherwise, returns a region including tiled_root and all dependencies.
/*static*/ TiledHloRegionOrError TiledHloComputation::CreateRegion(
/*static*/ absl::StatusOr<TiledHloRegion> TiledHloComputation::CreateHloRegion(
std::unique_ptr<TiledHloInstruction> tiled_root,
const HloFusionAdaptor& fusion, const TilingSpace& tiling_space,
absl::flat_hash_map<int64_t,
Expand All @@ -283,51 +282,42 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
continue;
}

auto operands_tiles =
PropagateTileToInput(tiling_space, *hlo, tiled_hlo->tile(), 0);
if (!operands_tiles.ok()) {
return FusionDecision::Forbid("Couldn't propagate tile ")
<< tiled_hlo->tile().ToString() << " to the input of "
<< hlo->ToString()
<< " with error: " << operands_tiles.status().ToString();
}
ASSIGN_OR_RETURN(
auto operands_tiles,
PropagateTileToInput(tiling_space, *hlo, tiled_hlo->tile(), 0));

HloInstructionAdaptor instruction_adaptor(*hlo, &fusion);
const bool hlo_is_condition = IsControlFlowCondition(*tiled_hlo);
for (const auto& [operand_id, tile_and_operand] : llvm::enumerate(
llvm::zip(*operands_tiles, instruction_adaptor.GetOperands()))) {
llvm::zip(operands_tiles, instruction_adaptor.GetOperands()))) {
auto& [tile, operand] = tile_and_operand;
const HloInstruction* operand_hlo = &operand.instruction();
auto tiled_operand =
std::make_unique<TiledHloInstruction>(operand_hlo, tile);
const bool operand_is_loop = IsControlFlowLoop(*tiled_operand);

if (hlo_is_condition || operand_is_loop) {
auto region_or_error =
CreateRegion(std::move(tiled_operand), fusion, tiling_space,
rt_symbol_to_tiled_hlo);
if (auto* decision = std::get_if<FusionDecision>(&region_or_error)) {
return *decision;
}
auto region =
std::get<TiledHloInstruction::Region>(std::move(region_or_error));
ASSIGN_OR_RETURN(auto region,
CreateHloRegion(std::move(tiled_operand), fusion,
tiling_space, rt_symbol_to_tiled_hlo));

if (hlo_is_condition) {
// Case 1: HLO is a condition (e.g., concat).
// Each operand introduces a new branch/sub-region in `tiled_hlo`.
CHECK(!region.empty()) << "CreateRegion: returned empty region for "
<< operand_hlo->ToString();
tiled_hlo->AppendOperand(region.back().get());
tiled_hlo->AddRegion(std::move(region));
CHECK(!region.empty())
<< "CreateHloRegion: returned empty region for "
<< operand_hlo->ToString();
tiled_hlo->AddOperand(region.back().get());
tiled_hlo->AddHloRegion(std::move(region));

} else {
// Case 2: Operand is a loop (e.g., dot/scaled_dot/reduce).
// Operand has its loop-body as a region. Operand itself is added as a
// node to the current flat list.
CHECK(region.size() == 1)
<< "CreateRegion: expected exactly 1 region for "
<< "CreateHloRegion: expected exactly 1 region for "
<< operand_hlo->ToString() << " but got " << region.size();
tiled_hlo->AppendOperand(region.back().get());
tiled_hlo->AddOperand(region.back().get());
tiled_hlo_instructions_set.Insert(std::move(region.back()));
}

Expand All @@ -338,7 +328,7 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
if (inserted) {
worklist.push_back(operand_tiled_hlo);
}
tiled_hlo->AppendOperand(operand_tiled_hlo);
tiled_hlo->AddOperand(operand_tiled_hlo);
// If the operand is a runtime variable, add it to the
// `rt_symbol_to_tiled_hlo` map.
std::optional<const TilingSpace::RTVarInfo*> rt_var_info =
Expand All @@ -351,17 +341,18 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
}
}
}
auto tiled_hlo_instructions = tiled_hlo_instructions_set.ExtractData();
TiledHloRegion tiled_hlo_instructions{
tiled_hlo_instructions_set.ExtractData()};
SortTiledHloInstructionsInPostOrder(tiled_hlo_instructions, tiled_root.get());
if (IsControlFlowLoop(*tiled_root)) {
tiled_root->AddRegion(std::move(tiled_hlo_instructions));
tiled_root->AddHloRegion(std::move(tiled_hlo_instructions));
tiled_hlo_instructions.clear();
}
tiled_hlo_instructions.push_back(std::move(tiled_root));
return tiled_hlo_instructions;
}

/*static*/ TileAnalysisOrError TiledHloComputation::Tile(
/*static*/ absl::StatusOr<TiledHloComputation> TiledHloComputation::Tile(
const HloFusionAdaptor& fusion, std::unique_ptr<TilingSpace> tiling_space) {
SmallVector<const TiledHloInstruction*> roots;
SmallVector<const TiledHloInstruction*> roots_with_no_users;
Expand All @@ -378,15 +369,9 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
roots_with_no_users.push_back(root_tiled_hlo.get());
}

TiledHloRegionOrError region_or_error =
CreateRegion(std::move(root_tiled_hlo), fusion, *tiling_space,
rt_symbol_to_tiled_hlo);
if (FusionDecision* decision =
std::get_if<FusionDecision>(&region_or_error)) {
return *decision;
}
TiledHloInstruction::Region region =
std::get<TiledHloInstruction::Region>(std::move(region_or_error));
ASSIGN_OR_RETURN(TiledHloRegion region,
CreateHloRegion(std::move(root_tiled_hlo), fusion,
*tiling_space, rt_symbol_to_tiled_hlo));
for (std::unique_ptr<TiledHloInstruction>& tiled_hlo : region) {
tiled_hlo_instructions_set.Insert(std::move(tiled_hlo));
}
Expand All @@ -399,9 +384,10 @@ absl::InlinedVector<const HloInstruction*, 2> ToInstructions(
SortTiledHloInstructionsInPostOrder(tiled_hlo_instructions,
roots_with_no_users);

return TiledHloComputation(
std::move(tiling_space), std::move(tiled_hlo_instructions),
std::move(roots), std::move(rt_symbol_to_tiled_hlo));
return TiledHloComputation(std::move(tiling_space),
TiledHloRegion{std::move(tiled_hlo_instructions)},
std::move(roots),
std::move(rt_symbol_to_tiled_hlo));
}

std::string TiledHloComputation::ToString() const {
Expand Down
Loading
Loading