@@ -462,31 +462,61 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
462462 return absl::OkStatus ();
463463 }
464464 HloComputation* comparator = call->to_apply ();
465- return DecomposeTopK (call, comparator);
465+ bool skip_iota = IndicesAreUnused (inst);
466+ if (!skip_iota) {
467+ comparator = CreateWrapperComparator (comparator);
468+ }
469+ return DecomposeTopK (call, comparator, skip_iota);
470+ }
471+
472+ HloComputation* CreateWrapperComparator (HloComputation* old_comp) {
473+ HloModule* module = old_comp->parent ();
474+ HloComputation::Builder builder (" wrapper_comparator" );
475+
476+ auto p0 = builder.AddInstruction (HloInstruction::CreateParameter (
477+ 0 , old_comp->parameter_instruction (0 )->shape (), " p0" ));
478+ auto p1 = builder.AddInstruction (HloInstruction::CreateParameter (
479+ 1 , old_comp->parameter_instruction (1 )->shape (), " p1" ));
480+ builder.AddInstruction (HloInstruction::CreateParameter (
481+ 2 , ShapeUtil::MakeShape (S32, {}), " p2" ));
482+ builder.AddInstruction (HloInstruction::CreateParameter (
483+ 3 , ShapeUtil::MakeShape (S32, {}), " p3" ));
484+
485+ auto call = builder.AddInstruction (HloInstruction::CreateCall (
486+ ShapeUtil::MakeShape (PRED, {}), {p0, p1}, old_comp));
487+
488+ return module ->AddEmbeddedComputation (builder.Build (call));
466489 }
467490
468491 absl::Status HandleTopK (HloInstruction* topk) override {
469492 if (should_decompose_ && !should_decompose_ (topk)) {
470493 return absl::OkStatus ();
471494 }
495+ bool skip_iota = IndicesAreUnused (topk);
472496 TF_ASSIGN_OR_RETURN (HloComputation * comparator,
473- CreateVariadicComparator (topk));
474- return DecomposeTopK (topk, comparator);
497+ CreateVariadicComparator (topk, skip_iota ));
498+ return DecomposeTopK (topk, comparator, skip_iota );
475499 }
476500
477501 private:
478- bool HasSingleUserReadingOnlyTheValueOutput (HloInstruction* inst) {
479- return inst->user_count () == 1 && inst->users ().front ()->tuple_index () == 0 ;
502+ bool IndicesAreUnused (HloInstruction* inst) {
503+ for (auto user : inst->users ()) {
504+ if (user->opcode () != HloOpcode::kGetTupleElement ||
505+ user->tuple_index () != 0 ) {
506+ return false ;
507+ }
508+ }
509+ return true ;
480510 }
481511
482- absl::StatusOr<HloComputation*> CreateVariadicComparator (
483- HloInstruction* inst ) {
512+ absl::StatusOr<HloComputation*> CreateVariadicComparator (HloInstruction* inst,
513+ bool skip_iota ) {
484514 HloTopKInstruction* topk = DynCast<HloTopKInstruction>(inst);
485515 XlaBuilder b (absl::StrCat (" comparator_" , topk->name ()));
486516 std::vector<PrimitiveType> ptypes = {
487517 topk->operand (0 )->shape ().element_type ()};
488518
489- if (!HasSingleUserReadingOnlyTheValueOutput (inst) ) {
519+ if (!skip_iota ) {
490520 ptypes.emplace_back (PrimitiveType::S32);
491521 }
492522
@@ -500,7 +530,8 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
500530 }
501531
502532 absl::Status DecomposeTopK (HloInstruction* call,
503- HloComputation* variadic_comparator) {
533+ HloComputation* variadic_comparator,
534+ bool skip_iota) {
504535 HloInstruction* input = call->mutable_operand (0 );
505536 Shape iota_shape = input->shape ();
506537 iota_shape.set_element_type (S32);
@@ -509,17 +540,22 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor {
509540 std::vector<int64_t > ones (iota_shape.dimensions ().size (), 1 );
510541 CHECK_NE (variadic_comparator, nullptr );
511542 // If only the topk values are necessary, skip the iota.
512- if (HasSingleUserReadingOnlyTheValueOutput (call) &&
513- variadic_comparator->num_parameters () == 2 ) {
543+ if (skip_iota) {
544+ CHECK_EQ ( variadic_comparator->num_parameters (), 2 );
514545 HloInstruction* sort = call->AddInstruction (HloInstruction::CreateSort (
515546 input->shape (), sort_dimension, {input}, variadic_comparator,
516547 /* is_stable=*/ true ));
517- TF_RETURN_IF_ERROR (ReplaceInstruction (
518- call->users ().front (),
519- call->AddInstruction (HloInstruction::CreateSlice (
520- call->shape ().tuple_shapes (0 ), sort, zeroes,
521- call->shape ().tuple_shapes (0 ).dimensions (), ones))));
548+ HloInstruction* slice = call->AddInstruction (HloInstruction::CreateSlice (
549+ call->shape ().tuple_shapes (0 ), sort, zeroes,
550+ call->shape ().tuple_shapes (0 ).dimensions (), ones));
551+ std::vector<HloInstruction*> users = call->users ();
552+ for (auto user : users) {
553+ CHECK_EQ (user->opcode (), HloOpcode::kGetTupleElement );
554+ CHECK_EQ (user->tuple_index (), 0 );
555+ TF_RETURN_IF_ERROR (ReplaceInstruction (user, slice));
556+ }
522557 } else {
558+ CHECK_EQ (variadic_comparator->num_parameters (), 4 );
523559 HloInstruction* iota = call->AddInstruction (HloInstruction::CreateIota (
524560 iota_shape, iota_shape.dimensions ().size () - 1 ));
525561 HloInstruction* sort = call->AddInstruction (HloInstruction::CreateSort (
0 commit comments