Skip to content

Commit 7760474

Browse files
mattjjGoogle-ML-Automation
authored andcommitted
fix topk decomposer crash
reported in jax-ml/jax#36703 PiperOrigin-RevId: 900980821
1 parent bf6c157 commit 7760474

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

xla/service/topk_rewriter.cc

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -462,31 +462,61 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
462462
return absl::OkStatus();
463463
}
464464
HloComputation* comparator = call->to_apply();
465-
return DecomposeTopK(call, comparator);
465+
bool skip_iota = IndicesAreUnused(inst);
466+
if (!skip_iota) {
467+
comparator = CreateWrapperComparator(comparator);
468+
}
469+
return DecomposeTopK(call, comparator, skip_iota);
470+
}
471+
472+
HloComputation* CreateWrapperComparator(HloComputation* old_comp) {
473+
HloModule* module = old_comp->parent();
474+
HloComputation::Builder builder("wrapper_comparator");
475+
476+
auto p0 = builder.AddInstruction(HloInstruction::CreateParameter(
477+
0, old_comp->parameter_instruction(0)->shape(), "p0"));
478+
auto p1 = builder.AddInstruction(HloInstruction::CreateParameter(
479+
1, old_comp->parameter_instruction(1)->shape(), "p1"));
480+
builder.AddInstruction(HloInstruction::CreateParameter(
481+
2, ShapeUtil::MakeShape(S32, {}), "p2"));
482+
builder.AddInstruction(HloInstruction::CreateParameter(
483+
3, ShapeUtil::MakeShape(S32, {}), "p3"));
484+
485+
auto call = builder.AddInstruction(HloInstruction::CreateCall(
486+
ShapeUtil::MakeShape(PRED, {}), {p0, p1}, old_comp));
487+
488+
return module->AddEmbeddedComputation(builder.Build(call));
466489
}
467490

468491
absl::Status HandleTopK(HloInstruction* topk) override {
469492
if (should_decompose_ && !should_decompose_(topk)) {
470493
return absl::OkStatus();
471494
}
495+
bool skip_iota = IndicesAreUnused(topk);
472496
TF_ASSIGN_OR_RETURN(HloComputation * comparator,
473-
CreateVariadicComparator(topk));
474-
return DecomposeTopK(topk, comparator);
497+
CreateVariadicComparator(topk, skip_iota));
498+
return DecomposeTopK(topk, comparator, skip_iota);
475499
}
476500

477501
private:
478-
bool HasSingleUserReadingOnlyTheValueOutput(HloInstruction* inst) {
479-
return inst->user_count() == 1 && inst->users().front()->tuple_index() == 0;
502+
bool IndicesAreUnused(HloInstruction* inst) {
503+
for (auto user : inst->users()) {
504+
if (user->opcode() != HloOpcode::kGetTupleElement ||
505+
user->tuple_index() != 0) {
506+
return false;
507+
}
508+
}
509+
return true;
480510
}
481511

482-
absl::StatusOr<HloComputation*> CreateVariadicComparator(
483-
HloInstruction* inst) {
512+
absl::StatusOr<HloComputation*> CreateVariadicComparator(HloInstruction* inst,
513+
bool skip_iota) {
484514
HloTopKInstruction* topk = DynCast<HloTopKInstruction>(inst);
485515
XlaBuilder b(absl::StrCat("comparator_", topk->name()));
486516
std::vector<PrimitiveType> ptypes = {
487517
topk->operand(0)->shape().element_type()};
488518

489-
if (!HasSingleUserReadingOnlyTheValueOutput(inst)) {
519+
if (!skip_iota) {
490520
ptypes.emplace_back(PrimitiveType::S32);
491521
}
492522

@@ -500,7 +530,8 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
500530
}
501531

502532
absl::Status DecomposeTopK(HloInstruction* call,
503-
HloComputation* variadic_comparator) {
533+
HloComputation* variadic_comparator,
534+
bool skip_iota) {
504535
HloInstruction* input = call->mutable_operand(0);
505536
Shape iota_shape = input->shape();
506537
iota_shape.set_element_type(S32);
@@ -509,17 +540,22 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
509540
std::vector<int64_t> ones(iota_shape.dimensions().size(), 1);
510541
CHECK_NE(variadic_comparator, nullptr);
511542
// If only the topk values are necessary, skip the iota.
512-
if (HasSingleUserReadingOnlyTheValueOutput(call) &&
513-
variadic_comparator->num_parameters() == 2) {
543+
if (skip_iota) {
544+
CHECK_EQ(variadic_comparator->num_parameters(), 2);
514545
HloInstruction* sort = call->AddInstruction(HloInstruction::CreateSort(
515546
input->shape(), sort_dimension, {input}, variadic_comparator,
516547
/*is_stable=*/true));
517-
TF_RETURN_IF_ERROR(ReplaceInstruction(
518-
call->users().front(),
519-
call->AddInstruction(HloInstruction::CreateSlice(
520-
call->shape().tuple_shapes(0), sort, zeroes,
521-
call->shape().tuple_shapes(0).dimensions(), ones))));
548+
HloInstruction* slice = call->AddInstruction(HloInstruction::CreateSlice(
549+
call->shape().tuple_shapes(0), sort, zeroes,
550+
call->shape().tuple_shapes(0).dimensions(), ones));
551+
std::vector<HloInstruction*> users = call->users();
552+
for (auto user : users) {
553+
CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement);
554+
CHECK_EQ(user->tuple_index(), 0);
555+
TF_RETURN_IF_ERROR(ReplaceInstruction(user, slice));
556+
}
522557
} else {
558+
CHECK_EQ(variadic_comparator->num_parameters(), 4);
523559
HloInstruction* iota = call->AddInstruction(HloInstruction::CreateIota(
524560
iota_shape, iota_shape.dimensions().size() - 1));
525561
HloInstruction* sort = call->AddInstruction(HloInstruction::CreateSort(

0 commit comments

Comments
 (0)