Skip to content

Commit 69c2319

Browse files
wsmosesGoogle-ML-Automation
authored andcommitted
[SPMD] Add Wrap custom call
PiperOrigin-RevId: 873355682
1 parent 964a0a4 commit 69c2319

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

xla/service/spmd/custom_call_handler.cc

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ ParseOpaqueAsAttributes(const HloInstruction* hlo) {
8484

8585
constexpr char kSPMDOpRotateRight[] = "_SPMDInternalOp_RotateRight";
8686
constexpr char kSPMDOpMultiRotate[] = "_SPMDInternalOp_MultiRotate";
87+
constexpr char kSPMDOpWrap[] = "_SPMDInternalOp_Wrap";
8788

8889
} // namespace
8990

@@ -579,6 +580,218 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_MultiRotate(
579580
return absl::OkStatus();
580581
}
581582

583+
absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_Wrap(
584+
HloInstruction* hlo) {
585+
TF_ASSIGN_OR_RETURN(auto attrs, ParseOpaqueAsAttributes(hlo));
586+
auto dim_it = attrs.find("dimension");
587+
TF_RET_CHECK(dim_it != attrs.end())
588+
<< "No dimension attribute in SPMD multi rotate op";
589+
int64_t dim = dim_it->second;
590+
591+
auto left_amount_it = attrs.find("left_amount");
592+
TF_RET_CHECK(left_amount_it != attrs.end())
593+
<< "No left_amount attribute in SPMD multi rotate op";
594+
int64_t left_amount = left_amount_it->second;
595+
596+
auto right_amount_it = attrs.find("right_amount");
597+
TF_RET_CHECK(right_amount_it != attrs.end())
598+
<< "No right_amount attribute in SPMD multi rotate op";
599+
int64_t right_amount = right_amount_it->second;
600+
601+
int32_t totalResults = left_amount + right_amount + 1;
602+
603+
PartitionedHlo input = GetPartitionedHlo(hlo->operand(0));
604+
HloSharding element_sharding = hlo->sharding();
605+
606+
TF_RET_CHECK(
607+
!(element_sharding.IsReplicated() || element_sharding.IsTileMaximal()))
608+
<< "MultiRotate op requires sharding along the rotate dimension.";
609+
610+
input = input.Reshard(element_sharding);
611+
612+
const Shape& pre_wrap_shape = hlo->operand(0)->shape().tuple_shapes(0);
613+
const int64_t full_pre_wrap_size = pre_wrap_shape.dimensions(dim);
614+
const int64_t shard_size = input.hlo()->shape().dimensions(dim);
615+
616+
const int64_t participating_shards =
617+
CeilOfRatio(full_pre_wrap_size, shard_size);
618+
const int64_t right_padding =
619+
participating_shards * shard_size - full_pre_wrap_size;
620+
621+
HloInstruction* local_input = input.hlo();
622+
HloInstruction* padded_local_input = local_input;
623+
624+
if (right_padding > 0) {
625+
auto paddingConfig = MakeNoPaddingConfig(pre_wrap_shape.dimensions_size());
626+
paddingConfig.mutable_dimensions(dim)->set_edge_padding_high(right_padding);
627+
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
628+
LiteralUtil::Zero(pre_wrap_shape.element_type())));
629+
Shape padded_shape = local_input->shape();
630+
padded_shape.set_dimensions(dim,
631+
padded_shape.dimensions(dim) + right_padding);
632+
padded_local_input = b_.AddInstruction(HloInstruction::CreatePad(
633+
padded_shape, local_input, zero, paddingConfig));
634+
}
635+
636+
HloInstruction* left_halo = nullptr;
637+
if (left_amount > 0) {
638+
std::vector<std::pair<int64_t, int64_t>> pairs;
639+
element_sharding.tile_assignment().Each(
640+
[&](absl::Span<const int64_t> indices, int64_t device) {
641+
if (indices[dim] >= participating_shards) {
642+
return;
643+
}
644+
std::vector<int64_t> dst_idx(indices.begin(), indices.end());
645+
dst_idx[dim] += 1;
646+
dst_idx[dim] %= participating_shards;
647+
pairs.emplace_back(device,
648+
element_sharding.tile_assignment()(dst_idx));
649+
});
650+
651+
Shape slice_shape = padded_local_input->shape();
652+
slice_shape.set_dimensions(dim, left_amount);
653+
std::vector<int64_t> slice_starts(full_pre_wrap_size.dimensions_size(), 0);
654+
slice_starts[dim] = 0;
655+
std::vector<int64_t> slice_limits(
656+
padded_local_input->shape().dimensions().begin(),
657+
padded_local_input->shape().dimensions().end());
658+
slice_limits[dim] = left_amount;
659+
HloInstruction* slice_to_send =
660+
b_.AddInstruction(HloInstruction::CreateSlice(
661+
slice_shape, padded_local_input, slice_starts, slice_limits,
662+
std::vector<int64_t>(full_pre_wrap_size.dimensions_size(), 1)));
663+
664+
left_halo = collective_ops_creator_.create_collective_permute(
665+
&b_, slice_to_send, pairs, NewChannel());
666+
}
667+
668+
HloInstruction* shard_offset = MakePartitionOffsets(
669+
pre_wrap_shape, element_sharding, MakePartitioningState().partition_id,
670+
&b_, {dim})[dim];
671+
672+
HloInstruction* zero_offset = b_.AddInstruction(
673+
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
674+
675+
HloInstruction* right_halo = nullptr;
676+
if (right_amount > 0) {
677+
std::vector<std::pair<int64_t, int64_t>> pairs;
678+
element_sharding.tile_assignment().Each(
679+
[&](absl::Span<const int64_t> indices, int64_t device) {
680+
if (indices[dim] >= participating_shards) {
681+
return;
682+
}
683+
std::vector<int64_t> dst_idx(indices.begin(), indices.end());
684+
dst_idx[dim] += participating_shards - 1;
685+
dst_idx[dim] %= participating_shards;
686+
pairs.emplace_back(device,
687+
element_sharding.tile_assignment()(dst_idx));
688+
});
689+
690+
HloInstruction* base_start =
691+
b_.AddInstruction(HloInstruction::CreateConstant(
692+
LiteralUtil::CreateR0<int32_t>(shard_size - right_amount)));
693+
HloInstruction* padding_val =
694+
b_.AddInstruction(HloInstruction::CreateConstant(
695+
LiteralUtil::CreateR0<int32_t>(right_padding)));
696+
697+
HloInstruction* total_mesh_size =
698+
b_.AddInstruction(HloInstruction::CreateConstant(
699+
LiteralUtil::CreateR0<int32_t>(participating_shards * shard_size)));
700+
701+
// shard_offset tells us the current start index.
702+
// The last shard is the one where shard_offset + shard_size ==
703+
// total_mesh_size.
704+
HloInstruction* shard_size_inst =
705+
b_.AddInstruction(HloInstruction::CreateConstant(
706+
LiteralUtil::CreateR0<int32_t>(shard_size)));
707+
HloInstruction* current_shard_end = b_.AddInstruction(
708+
HloInstruction::CreateBinary(shard_offset->shape(), HloOpcode::kAdd,
709+
shard_offset, shard_size_inst));
710+
711+
HloInstruction* is_last = b_.AddInstruction(HloInstruction::CreateCompare(
712+
ShapeUtil::ChangeElementType(shard_offset->shape(), PRED),
713+
current_shard_end, total_mesh_size, Comparison::Direction::kEq));
714+
715+
HloInstruction* offset = b_.AddInstruction(
716+
HloInstruction::CreateTernary(shard_offset->shape(), HloOpcode::kSelect,
717+
is_last, padding_val, zero_offset));
718+
719+
HloInstruction* start_idx = b_.AddInstruction(HloInstruction::CreateBinary(
720+
base_start->shape(), HloOpcode::kSubtract, base_start, offset));
721+
722+
Shape dynamic_slice_shape = padded_local_input->shape();
723+
dynamic_slice_shape.set_dimensions(dim, right_amount);
724+
725+
std::vector<HloInstruction*> start_indices(
726+
pre_wrap_shape.dimensions_size());
727+
for (int i = 0; i < pre_wrap_shape.dimensions_size(); i++) {
728+
start_indices[i] = zero_offset;
729+
}
730+
start_indices[dim] = start_idx;
731+
732+
std::vector<int64_t> slice_sizes(
733+
padded_local_input->shape().dimensions().begin(),
734+
padded_local_input->shape().dimensions().end());
735+
slice_sizes[dim] = right_amount;
736+
737+
HloInstruction* slice_to_send =
738+
b_.AddInstruction(HloInstruction::CreateDynamicSlice(
739+
dynamic_slice_shape, padded_local_input, start_indices,
740+
slice_sizes));
741+
742+
right_halo = collective_ops_creator_.create_collective_permute(
743+
&b_, slice_to_send, pairs, NewChannel());
744+
}
745+
746+
std::vector<HloInstruction*> concat_ops;
747+
if (left_halo) {
748+
concat_ops.push_back(left_halo);
749+
}
750+
concat_ops.push_back(padded_local_input);
751+
if (right_halo) {
752+
concat_ops.push_back(right_halo);
753+
}
754+
755+
HloInstruction* super_shard = padded_local_input;
756+
if (concat_ops.size() > 1) {
757+
Shape concat_shape = local_input->shape();
758+
concat_shape.set_dimensions(
759+
dim, shard_size + right_padding + left_amount + right_amount);
760+
super_shard = b_.AddInstruction(
761+
HloInstruction::CreateConcatenate(concat_shape, concat_ops, dim));
762+
}
763+
764+
int64_t post_wrap_shard_size =
765+
CeilOfRatio(hlo->shape().dimensions(dim), participating_shards);
766+
767+
Shape slice_shape = super_shard->shape();
768+
slice_shape.set_dimensions(dim, post_wrap_shard_size);
769+
770+
HloInstruction* shard_size_change =
771+
b_.AddInstruction(HloInstruction::CreateConstant(
772+
LiteralUtil::CreateR0<int32_t>(post_wrap_shard_size - shard_size)));
773+
774+
HloInstruction* start = b_.AddInstruction(
775+
HloInstruction::CreateBinary(shard_offset->shape(), HloOpcode::kMultiply,
776+
shard_offset, shard_size_change));
777+
778+
std::vector<HloInstruction*> start_indices(pre_wrap_shape.dimensions_size());
779+
for (int i = 0; i < pre_wrap_shape.dimensions_size(); i++) {
780+
start_indices[i] = zero_offset;
781+
}
782+
start_indices[dim] = start;
783+
784+
std::vector<int64_t> slice_sizes(super_shard->shape().dimensions().begin(),
785+
super_shard->shape().dimensions().end());
786+
slice_sizes[dim] = post_wrap_shard_size;
787+
788+
HloInstruction* result = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
789+
slice_shape, super_shard, start_indices, slice_sizes));
790+
791+
SetPartitionedHlo(hlo, [&] { return result; });
792+
return absl::OkStatus();
793+
}
794+
582795
std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight(
583796
HloInstruction* input, int64_t dim, int64_t amount) {
584797
std::string opaque = absl::StrCat("dimension=", dim, ",amount=", amount);
@@ -624,6 +837,9 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
624837
if (hlo->custom_call_target() == kSPMDOpMultiRotate) {
625838
return HandleCustomCallSPMDInternal_MultiRotate(hlo);
626839
}
840+
if (hlo->custom_call_target() == kSPMDOpWrap) {
841+
return HandleCustomCallSPMDInternal_Wrap(hlo);
842+
}
627843

628844
if (hlo->sharding().HasUniqueDevice()) {
629845
return HandleSingleDevice(hlo);

xla/service/spmd/spmd_partitioner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
751751
// Convenient custom ops defined by the partitioner itself.
752752
absl::Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
753753
absl::Status HandleCustomCallSPMDInternal_MultiRotate(HloInstruction* hlo);
754+
absl::Status HandleCustomCallSPMDInternal_Wrap(HloInstruction* hlo);
754755

755756
virtual std::unique_ptr<SpmdPartitioningVisitor> Clone() const;
756757

0 commit comments

Comments
 (0)