@@ -84,6 +84,7 @@ ParseOpaqueAsAttributes(const HloInstruction* hlo) {
8484
8585constexpr char kSPMDOpRotateRight [] = " _SPMDInternalOp_RotateRight" ;
8686constexpr char kSPMDOpMultiRotate [] = " _SPMDInternalOp_MultiRotate" ;
87+ constexpr char kSPMDOpWrap [] = " _SPMDInternalOp_Wrap" ;
8788
8889} // namespace
8990
@@ -579,6 +580,218 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_MultiRotate(
579580 return absl::OkStatus ();
580581}
581582
583+ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_Wrap (
584+ HloInstruction* hlo) {
585+ TF_ASSIGN_OR_RETURN (auto attrs, ParseOpaqueAsAttributes (hlo));
586+ auto dim_it = attrs.find (" dimension" );
587+ TF_RET_CHECK (dim_it != attrs.end ())
588+ << " No dimension attribute in SPMD multi rotate op" ;
589+ int64_t dim = dim_it->second ;
590+
591+ auto left_amount_it = attrs.find (" left_amount" );
592+ TF_RET_CHECK (left_amount_it != attrs.end ())
593+ << " No left_amount attribute in SPMD multi rotate op" ;
594+ int64_t left_amount = left_amount_it->second ;
595+
596+ auto right_amount_it = attrs.find (" right_amount" );
597+ TF_RET_CHECK (right_amount_it != attrs.end ())
598+ << " No right_amount attribute in SPMD multi rotate op" ;
599+ int64_t right_amount = right_amount_it->second ;
600+
601+ int32_t totalResults = left_amount + right_amount + 1 ;
602+
603+ PartitionedHlo input = GetPartitionedHlo (hlo->operand (0 ));
604+ HloSharding element_sharding = hlo->sharding ();
605+
606+ TF_RET_CHECK (
607+ !(element_sharding.IsReplicated () || element_sharding.IsTileMaximal ()))
608+ << " MultiRotate op requires sharding along the rotate dimension." ;
609+
610+ input = input.Reshard (element_sharding);
611+
612+ const Shape& pre_wrap_shape = hlo->operand (0 )->shape ().tuple_shapes (0 );
613+ const int64_t full_pre_wrap_size = pre_wrap_shape.dimensions (dim);
614+ const int64_t shard_size = input.hlo ()->shape ().dimensions (dim);
615+
616+ const int64_t participating_shards =
617+ CeilOfRatio (full_pre_wrap_size, shard_size);
618+ const int64_t right_padding =
619+ participating_shards * shard_size - full_pre_wrap_size;
620+
621+ HloInstruction* local_input = input.hlo ();
622+ HloInstruction* padded_local_input = local_input;
623+
624+ if (right_padding > 0 ) {
625+ auto paddingConfig = MakeNoPaddingConfig (pre_wrap_shape.dimensions_size ());
626+ paddingConfig.mutable_dimensions (dim)->set_edge_padding_high (right_padding);
627+ auto zero = b_.AddInstruction (HloInstruction::CreateConstant (
628+ LiteralUtil::Zero (pre_wrap_shape.element_type ())));
629+ Shape padded_shape = local_input->shape ();
630+ padded_shape.set_dimensions (dim,
631+ padded_shape.dimensions (dim) + right_padding);
632+ padded_local_input = b_.AddInstruction (HloInstruction::CreatePad (
633+ padded_shape, local_input, zero, paddingConfig));
634+ }
635+
636+ HloInstruction* left_halo = nullptr ;
637+ if (left_amount > 0 ) {
638+ std::vector<std::pair<int64_t , int64_t >> pairs;
639+ element_sharding.tile_assignment ().Each (
640+ [&](absl::Span<const int64_t > indices, int64_t device) {
641+ if (indices[dim] >= participating_shards) {
642+ return ;
643+ }
644+ std::vector<int64_t > dst_idx (indices.begin (), indices.end ());
645+ dst_idx[dim] += 1 ;
646+ dst_idx[dim] %= participating_shards;
647+ pairs.emplace_back (device,
648+ element_sharding.tile_assignment ()(dst_idx));
649+ });
650+
651+ Shape slice_shape = padded_local_input->shape ();
652+ slice_shape.set_dimensions (dim, left_amount);
653+ std::vector<int64_t > slice_starts (full_pre_wrap_size.dimensions_size (), 0 );
654+ slice_starts[dim] = 0 ;
655+ std::vector<int64_t > slice_limits (
656+ padded_local_input->shape ().dimensions ().begin (),
657+ padded_local_input->shape ().dimensions ().end ());
658+ slice_limits[dim] = left_amount;
659+ HloInstruction* slice_to_send =
660+ b_.AddInstruction (HloInstruction::CreateSlice (
661+ slice_shape, padded_local_input, slice_starts, slice_limits,
662+ std::vector<int64_t >(full_pre_wrap_size.dimensions_size (), 1 )));
663+
664+ left_halo = collective_ops_creator_.create_collective_permute (
665+ &b_, slice_to_send, pairs, NewChannel ());
666+ }
667+
668+ HloInstruction* shard_offset = MakePartitionOffsets (
669+ pre_wrap_shape, element_sharding, MakePartitioningState ().partition_id ,
670+ &b_, {dim})[dim];
671+
672+ HloInstruction* zero_offset = b_.AddInstruction (
673+ HloInstruction::CreateConstant (LiteralUtil::CreateR0<int32_t >(0 )));
674+
675+ HloInstruction* right_halo = nullptr ;
676+ if (right_amount > 0 ) {
677+ std::vector<std::pair<int64_t , int64_t >> pairs;
678+ element_sharding.tile_assignment ().Each (
679+ [&](absl::Span<const int64_t > indices, int64_t device) {
680+ if (indices[dim] >= participating_shards) {
681+ return ;
682+ }
683+ std::vector<int64_t > dst_idx (indices.begin (), indices.end ());
684+ dst_idx[dim] += participating_shards - 1 ;
685+ dst_idx[dim] %= participating_shards;
686+ pairs.emplace_back (device,
687+ element_sharding.tile_assignment ()(dst_idx));
688+ });
689+
690+ HloInstruction* base_start =
691+ b_.AddInstruction (HloInstruction::CreateConstant (
692+ LiteralUtil::CreateR0<int32_t >(shard_size - right_amount)));
693+ HloInstruction* padding_val =
694+ b_.AddInstruction (HloInstruction::CreateConstant (
695+ LiteralUtil::CreateR0<int32_t >(right_padding)));
696+
697+ HloInstruction* total_mesh_size =
698+ b_.AddInstruction (HloInstruction::CreateConstant (
699+ LiteralUtil::CreateR0<int32_t >(participating_shards * shard_size)));
700+
701+ // shard_offset tells us the current start index.
702+ // The last shard is the one where shard_offset + shard_size ==
703+ // total_mesh_size.
704+ HloInstruction* shard_size_inst =
705+ b_.AddInstruction (HloInstruction::CreateConstant (
706+ LiteralUtil::CreateR0<int32_t >(shard_size)));
707+ HloInstruction* current_shard_end = b_.AddInstruction (
708+ HloInstruction::CreateBinary (shard_offset->shape (), HloOpcode::kAdd ,
709+ shard_offset, shard_size_inst));
710+
711+ HloInstruction* is_last = b_.AddInstruction (HloInstruction::CreateCompare (
712+ ShapeUtil::ChangeElementType (shard_offset->shape (), PRED),
713+ current_shard_end, total_mesh_size, Comparison::Direction::kEq ));
714+
715+ HloInstruction* offset = b_.AddInstruction (
716+ HloInstruction::CreateTernary (shard_offset->shape (), HloOpcode::kSelect ,
717+ is_last, padding_val, zero_offset));
718+
719+ HloInstruction* start_idx = b_.AddInstruction (HloInstruction::CreateBinary (
720+ base_start->shape (), HloOpcode::kSubtract , base_start, offset));
721+
722+ Shape dynamic_slice_shape = padded_local_input->shape ();
723+ dynamic_slice_shape.set_dimensions (dim, right_amount);
724+
725+ std::vector<HloInstruction*> start_indices (
726+ pre_wrap_shape.dimensions_size ());
727+ for (int i = 0 ; i < pre_wrap_shape.dimensions_size (); i++) {
728+ start_indices[i] = zero_offset;
729+ }
730+ start_indices[dim] = start_idx;
731+
732+ std::vector<int64_t > slice_sizes (
733+ padded_local_input->shape ().dimensions ().begin (),
734+ padded_local_input->shape ().dimensions ().end ());
735+ slice_sizes[dim] = right_amount;
736+
737+ HloInstruction* slice_to_send =
738+ b_.AddInstruction (HloInstruction::CreateDynamicSlice (
739+ dynamic_slice_shape, padded_local_input, start_indices,
740+ slice_sizes));
741+
742+ right_halo = collective_ops_creator_.create_collective_permute (
743+ &b_, slice_to_send, pairs, NewChannel ());
744+ }
745+
746+ std::vector<HloInstruction*> concat_ops;
747+ if (left_halo) {
748+ concat_ops.push_back (left_halo);
749+ }
750+ concat_ops.push_back (padded_local_input);
751+ if (right_halo) {
752+ concat_ops.push_back (right_halo);
753+ }
754+
755+ HloInstruction* super_shard = padded_local_input;
756+ if (concat_ops.size () > 1 ) {
757+ Shape concat_shape = local_input->shape ();
758+ concat_shape.set_dimensions (
759+ dim, shard_size + right_padding + left_amount + right_amount);
760+ super_shard = b_.AddInstruction (
761+ HloInstruction::CreateConcatenate (concat_shape, concat_ops, dim));
762+ }
763+
764+ int64_t post_wrap_shard_size =
765+ CeilOfRatio (hlo->shape ().dimensions (dim), participating_shards);
766+
767+ Shape slice_shape = super_shard->shape ();
768+ slice_shape.set_dimensions (dim, post_wrap_shard_size);
769+
770+ HloInstruction* shard_size_change =
771+ b_.AddInstruction (HloInstruction::CreateConstant (
772+ LiteralUtil::CreateR0<int32_t >(post_wrap_shard_size - shard_size)));
773+
774+ HloInstruction* start = b_.AddInstruction (
775+ HloInstruction::CreateBinary (shard_offset->shape (), HloOpcode::kMultiply ,
776+ shard_offset, shard_size_change));
777+
778+ std::vector<HloInstruction*> start_indices (pre_wrap_shape.dimensions_size ());
779+ for (int i = 0 ; i < pre_wrap_shape.dimensions_size (); i++) {
780+ start_indices[i] = zero_offset;
781+ }
782+ start_indices[dim] = start;
783+
784+ std::vector<int64_t > slice_sizes (super_shard->shape ().dimensions ().begin (),
785+ super_shard->shape ().dimensions ().end ());
786+ slice_sizes[dim] = post_wrap_shard_size;
787+
788+ HloInstruction* result = b_.AddInstruction (HloInstruction::CreateDynamicSlice (
789+ slice_shape, super_shard, start_indices, slice_sizes));
790+
791+ SetPartitionedHlo (hlo, [&] { return result; });
792+ return absl::OkStatus ();
793+ }
794+
582795std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight (
583796 HloInstruction* input, int64_t dim, int64_t amount) {
584797 std::string opaque = absl::StrCat (" dimension=" , dim, " ,amount=" , amount);
@@ -624,6 +837,9 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
624837 if (hlo->custom_call_target () == kSPMDOpMultiRotate ) {
625838 return HandleCustomCallSPMDInternal_MultiRotate (hlo);
626839 }
840+ if (hlo->custom_call_target () == kSPMDOpWrap ) {
841+ return HandleCustomCallSPMDInternal_Wrap (hlo);
842+ }
627843
628844 if (hlo->sharding ().HasUniqueDevice ()) {
629845 return HandleSingleDevice (hlo);
0 commit comments