Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions xla/service/spmd/custom_call_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ limitations under the License.
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/literal_util.h"
#include "xla/service/custom_call_sharding_helper.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/memory_annotations.h"
#include "xla/service/spmd/spmd_partitioner.h"
#include "xla/service/spmd/spmd_partitioner_util.h"
Expand Down Expand Up @@ -99,11 +99,11 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallTopK(
sharding.ReplicateOnLastTileDim()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(sharding.IsTiled());

const int64_t batch_dim = 0;
const int64_t sort_dim = 1;

CHECK(sharding.IsTiled());
const int64_t shard_count = sharding.dimension(sort_dim);
const int64_t batch_dim_partition = sharding.dimension(batch_dim);

Expand Down Expand Up @@ -397,7 +397,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_MultiRotate(
<< "No right_amount attribute in SPMD multi rotate op";
int64_t right_amount = right_amount_it->second;

int32_t totalResults = left_amount + right_amount + 1;
int64_t total_results = left_amount + right_amount + 1;

PartitionedHlo input = GetPartitionedHlo(hlo->operand(0));
HloSharding element_sharding = hlo->sharding().IsTuple()
Expand Down Expand Up @@ -434,15 +434,17 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_MultiRotate(
HloInstruction* padded_local_input = local_input;

if (right_padding > 0) {
auto paddingConfig = MakeNoPaddingConfig(full_shape.dimensions_size());
paddingConfig.mutable_dimensions(dim)->set_edge_padding_high(right_padding);
PaddingConfig padding_config =
MakeNoPaddingConfig(full_shape.dimensions_size());
padding_config.mutable_dimensions(dim)->set_edge_padding_high(
right_padding);
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(full_shape.element_type())));
Shape padded_shape = local_input->shape();
padded_shape.set_dimensions(dim,
padded_shape.dimensions(dim) + right_padding);
padded_local_input = b_.AddInstruction(HloInstruction::CreatePad(
padded_shape, local_input, zero, paddingConfig));
padded_shape, local_input, zero, padding_config));
}

HloInstruction* left_halo = nullptr;
Expand Down Expand Up @@ -564,7 +566,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_MultiRotate(

int64_t R = right_halo ? right_amount : 0;

for (int i = 0; i < totalResults; ++i) {
for (int64_t i = 0; i < total_results; ++i) {
int64_t amount = left_amount - i;
int64_t slice_start = amount + R;

Expand Down Expand Up @@ -598,8 +600,9 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
CreateR0WithType(hlo->shape().element_type(), 0, &b_));
}
auto input = input_partitioned.hlo();
CHECK(hlo->sharding().IsManual() || hlo->sharding().IsManualSubgroup());
CHECK(ShapeUtil::Compatible(
TF_RET_CHECK(hlo->sharding().IsManual() ||
hlo->sharding().IsManualSubgroup());
TF_RET_CHECK(ShapeUtil::Compatible(
input->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
Expand All @@ -609,10 +612,11 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
// This op switches from manual partitioning to auto partitioning.
auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
CHECK(input->sharding().IsManual() || input->sharding().IsManualSubgroup());
TF_RET_CHECK(input->sharding().IsManual() ||
input->sharding().IsManualSubgroup());
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
CHECK(ShapeUtil::Compatible(
TF_RET_CHECK(ShapeUtil::Compatible(
copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
SetPartitionedHlo(hlo, copy);
return absl::OkStatus();
Expand Down
13 changes: 7 additions & 6 deletions xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,12 +746,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
// Common handle for HLOs that runs on a single device.
absl::Status HandleSingleDevice(const HloInstruction* hlo);

// CustomCall handlers per call target.
absl::Status HandleCustomCallTopK(HloInstruction* hlo);
// Convenient custom ops defined by the partitioner itself.
absl::Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
absl::Status HandleCustomCallSPMDInternal_MultiRotate(HloInstruction* hlo);

virtual std::unique_ptr<SpmdPartitioningVisitor> Clone() const;

// Returns the PartitionedHlo that corresponds to the original hlo.
Expand Down Expand Up @@ -906,6 +900,13 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status HandleDUSAllPartitionedSliceDimsHaveConstantIndices(
HloInstruction* hlo, const HloInstruction* input_tensor,
const HloInstruction* update_tensor);

// Handlers for specific custom call targets.
// go/keep-sorted start
absl::Status HandleCustomCallSPMDInternal_MultiRotate(HloInstruction* hlo);
absl::Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
absl::Status HandleCustomCallTopK(HloInstruction* hlo);
// go/keep-sorted end
};

} // namespace spmd
Expand Down
Loading