Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,12 +746,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
// Common handle for HLOs that runs on a single device.
absl::Status HandleSingleDevice(const HloInstruction* hlo);

// CustomCall handlers per call target.
absl::Status HandleCustomCallTopK(HloInstruction* hlo);
// Convenient custom ops defined by the partitioner itself.
absl::Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
absl::Status HandleCustomCallSPMDInternal_MultiRotate(HloInstruction* hlo);

virtual std::unique_ptr<SpmdPartitioningVisitor> Clone() const;

// Returns the PartitionedHlo that corresponds to the original hlo.
Expand Down Expand Up @@ -906,6 +900,13 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status HandleDUSAllPartitionedSliceDimsHaveConstantIndices(
HloInstruction* hlo, const HloInstruction* input_tensor,
const HloInstruction* update_tensor);

// Handlers for specific custom call targets.
// go/keep-sorted start
absl::Status HandleCustomCallSPMDInternal_MultiRotate(HloInstruction* hlo);
absl::Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
absl::Status HandleCustomCallTopK(HloInstruction* hlo);
// go/keep-sorted end
};

} // namespace spmd
Expand Down
Loading