Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions xla/service/topk_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,31 +462,77 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
return absl::OkStatus();
}
HloComputation* comparator = call->to_apply();
return DecomposeTopK(call, comparator);
bool skip_iota = IndicesAreUnused(inst);
comparator = CreateWrapperComparator(comparator, skip_iota);
return DecomposeTopK(call, comparator, skip_iota);
}

HloComputation* CreateWrapperComparator(HloComputation* old_comp,
bool skip_iota) {
if ((skip_iota ? 2 : 4) == old_comp->num_parameters()) {
return old_comp;
}
HloModule* module = old_comp->parent();
HloComputation::Builder builder("wrapper_comparator");

auto p0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, old_comp->parameter_instruction(0)->shape(), "p0"));
auto p1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, old_comp->parameter_instruction(1)->shape(), "p1"));

if (!skip_iota) {
builder.AddInstruction(HloInstruction::CreateParameter(
2, ShapeUtil::MakeShape(S32, {}), "p2"));
builder.AddInstruction(HloInstruction::CreateParameter(
3, ShapeUtil::MakeShape(S32, {}), "p3"));
}

if (old_comp->num_parameters() == 2) {
auto call = builder.AddInstruction(HloInstruction::CreateCall(
ShapeUtil::MakeShape(PRED, {}), {p0, p1}, old_comp));
return module->AddEmbeddedComputation(builder.Build(call));
} else {
CHECK_EQ(old_comp->num_parameters(), 4);
HloInstruction* d = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
ShapeUtil::MakeShape(PRED, {}), {p0, p1, d, d}, old_comp));
return module->AddEmbeddedComputation(builder.Build(call));
}
}

absl::Status HandleTopK(HloInstruction* topk) override {
if (should_decompose_ && !should_decompose_(topk)) {
return absl::OkStatus();
}
bool skip_iota = IndicesAreUnused(topk);
TF_ASSIGN_OR_RETURN(HloComputation * comparator,
CreateVariadicComparator(topk));
return DecomposeTopK(topk, comparator);
CreateVariadicComparator(topk, skip_iota));
return DecomposeTopK(topk, comparator, skip_iota);
}

private:
bool HasSingleUserReadingOnlyTheValueOutput(HloInstruction* inst) {
return inst->user_count() == 1 && inst->users().front()->tuple_index() == 0;
bool IndicesAreUnused(HloInstruction* inst) {
if (inst == inst->parent()->root_instruction()) {
return false;
}
for (auto user : inst->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement ||
user->tuple_index() != 0) {
return false;
}
}
return true;
}

absl::StatusOr<HloComputation*> CreateVariadicComparator(
HloInstruction* inst) {
absl::StatusOr<HloComputation*> CreateVariadicComparator(HloInstruction* inst,
bool skip_iota) {
HloTopKInstruction* topk = DynCast<HloTopKInstruction>(inst);
XlaBuilder b(absl::StrCat("comparator_", topk->name()));
std::vector<PrimitiveType> ptypes = {
topk->operand(0)->shape().element_type()};

if (!HasSingleUserReadingOnlyTheValueOutput(inst)) {
if (!skip_iota) {
ptypes.emplace_back(PrimitiveType::S32);
}

Expand All @@ -500,7 +546,8 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
}

absl::Status DecomposeTopK(HloInstruction* call,
HloComputation* variadic_comparator) {
HloComputation* variadic_comparator,
bool skip_iota) {
HloInstruction* input = call->mutable_operand(0);
Shape iota_shape = input->shape();
iota_shape.set_element_type(S32);
Expand All @@ -509,17 +556,24 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
std::vector<int64_t> ones(iota_shape.dimensions().size(), 1);
CHECK_NE(variadic_comparator, nullptr);
// If only the topk values are necessary, skip the iota.
if (HasSingleUserReadingOnlyTheValueOutput(call) &&
variadic_comparator->num_parameters() == 2) {
if (skip_iota) {
CHECK_EQ(variadic_comparator->num_parameters(), 2);
HloInstruction* sort = call->AddInstruction(HloInstruction::CreateSort(
input->shape(), sort_dimension, {input}, variadic_comparator,
/*is_stable=*/true));
HloInstruction* slice = call->AddInstruction(HloInstruction::CreateSlice(
call->shape().tuple_shapes(0), sort, zeroes,
call->shape().tuple_shapes(0).dimensions(), ones));
// Create a dummy value for indices since they are unused.
HloInstruction* dummy_indices = call->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
dummy_indices = call->AddInstruction(HloInstruction::CreateBroadcast(
call->shape().tuple_shapes(1), dummy_indices, {}));
TF_RETURN_IF_ERROR(ReplaceInstruction(
call->users().front(),
call->AddInstruction(HloInstruction::CreateSlice(
call->shape().tuple_shapes(0), sort, zeroes,
call->shape().tuple_shapes(0).dimensions(), ones))));
call, call->AddInstruction(
HloInstruction::CreateTuple({slice, dummy_indices}))));
} else {
CHECK_EQ(variadic_comparator->num_parameters(), 4);
HloInstruction* iota = call->AddInstruction(HloInstruction::CreateIota(
iota_shape, iota_shape.dimensions().size() - 1));
HloInstruction* sort = call->AddInstruction(HloInstruction::CreateSort(
Expand Down
1 change: 1 addition & 0 deletions xla/service/topk_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ ENTRY cluster {
TF_ASSERT_OK_AND_ASSIGN(bool decomposer_changed,
TopkDecomposer().Run(module.get()));
EXPECT_TRUE(decomposer_changed);
TF_ASSERT_OK(TupleSimplifier().Run(module.get()).status());
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
auto root = module->entry_computation()->root_instruction();
HloInstruction* sort;
Expand Down
Loading