@@ -452,6 +452,7 @@ def __init__(
452452        self .rope_fusion  =  neox_args .rope_fusion 
453453        self .attention_type  =  neox_args .attention_config [layer_number ]
454454        self .use_flash_attention  =  self .attention_type  ==  "flash" 
455+         self .use_ring_attention  =  self .attention_type  ==  "ring" 
455456        self .use_triton  =  (
456457            self .use_flash_attention 
457458            and  self .pos_emb  ==  "alibi" 
@@ -460,7 +461,7 @@ def __init__(
460461                >=  packaging .version .Version ("2.4.0.post1" )
461462            )
462463        )
463-         self .sparse  =  self .attention_type  not  in   ("global" , "flash" )
464+         self .sparse  =  self .attention_type  not  in   ("global" , "flash" ,  "ring" )
464465
465466        if  self .gqa :
466467            assert  not  self .sparse 
@@ -489,6 +490,12 @@ def __init__(
489490                self .flash_triton_fn  =  flash_attn_unpadded_unpacked_func_triton 
490491                self .flash_qkv_fn  =  flash_attn_func 
491492                self .flash_varlen_qkv_fn  =  flash_attn_varlen_func 
493+             elif  self .use_ring_attention :
494+                 from  ring_flash_attn .zigzag_ring_flash_attn  import  (
495+                     zigzag_ring_flash_attn_func ,
496+                 )
497+ 
498+                 self .ring_attn_fn  =  zigzag_ring_flash_attn_func 
492499            else :
493500                self .scale_mask_softmax  =  FusedScaleMaskSoftmax (
494501                    input_in_fp16 = self .fp16 ,
@@ -736,6 +743,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):
736743
737744        return  matmul_result 
738745
746+     def  ring_attention (self , query_layer , key_layer , value_layer ):
747+         # [b, np, sq, sk] 
748+         output_size  =  (
749+             query_layer .size (1 ),
750+             query_layer .size (2 ),
751+             query_layer .size (0 ),
752+             key_layer .size (0 ),
753+         )
754+ 
755+         # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] 
756+         key_layer  =  key_layer .transpose (0 , 1 ).reshape (
757+             output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1 
758+         )
759+         value_layer  =  value_layer .transpose (0 , 1 ).reshape (
760+             output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1 
761+         )
762+ 
763+         # [sq, b, np, hn] -> [b, sq, np, hn] 
764+         query_layer  =  query_layer .transpose (0 , 1 ).reshape (
765+             output_size [0 ], output_size [2 ], output_size [1 ], - 1 
766+         )
767+ 
768+         # only pass in window_size or alibi_slopes kwarg 
769+         # if we use Sliding Window Attention / AliBi. 
770+         # Flash attn defaults to (-1,-1), or 
771+         # does not have this kwarg prior to v2.3.0 
772+         extra_kwargs  =  (
773+             {"window_size" : (self .sliding_window_width , - 1 )}
774+             if  self .sliding_window_width  is  not   None 
775+             else  {}
776+         )
777+         if  self .pos_emb  ==  "alibi" :
778+             extra_kwargs ["alibi_slopes" ] =  self .alibi_embed .slopes .to (
779+                 query_layer .device 
780+             ).to (torch .float32 )
781+ 
782+         if  not  self .training :
783+             batch_size  =  output_size [0 ]
784+             max_seqlen_q  =  output_size [2 ]
785+             max_seqlen_k  =  output_size [3 ]
786+ 
787+             cu_seqlens_q  =  torch .arange (
788+                 0 ,
789+                 (batch_size  +  1 ) *  max_seqlen_q ,
790+                 step = max_seqlen_q ,
791+                 dtype = torch .int32 ,
792+                 device = query_layer .device ,
793+             )
794+ 
795+             cu_seqlens_k  =  torch .arange (
796+                 0 ,
797+                 (batch_size  +  1 ) *  max_seqlen_k ,
798+                 step = max_seqlen_k ,
799+                 dtype = torch .int32 ,
800+                 device = key_layer .device ,
801+             )
802+ 
803+             q_shape  =  query_layer .shape 
804+             k_shape  =  key_layer .shape 
805+             v_shape  =  value_layer .shape 
806+             is_causal  =  max_seqlen_q  ==  max_seqlen_k 
807+             output  =  self .ring_attn_fn (
808+                 query_layer ,
809+                 key_layer ,
810+                 value_layer ,
811+                 0.0 ,
812+                 softmax_scale = None ,
813+                 causal = is_causal ,
814+                 group = mpu .get_context_parallel_group (),
815+                 ** extra_kwargs ,
816+             )
817+             output  =  output .reshape (q_shape )
818+         else :
819+             output  =  self .ring_attn_fn (
820+                 query_layer ,
821+                 key_layer ,
822+                 value_layer ,
823+                 self .dropout_p  if  self .training  else  0.0 ,
824+                 softmax_scale = None ,
825+                 causal = True ,
826+                 group = mpu .get_context_parallel_group (),
827+                 ** extra_kwargs ,
828+             )
829+ 
830+         matmul_result  =  output 
831+         # [b, sq, np, hn] -> [b, np, sq, hn] 
832+         matmul_result  =  matmul_result .transpose (1 , 2 )
833+ 
834+         return  matmul_result 
835+ 
739836    def  sparse_attention (self , query_layer , key_layer , value_layer , attention_mask ):
740837        # TODO: sparse attn dropout? 
741838        # TODO: pad to block size 
@@ -831,7 +928,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
831928        value_layer  =  value_layer .view (* new_kv_shape )
832929
833930        # if not using Flash attention, we repeat K/V heads to match Q head counts 
834-         if  not  self .use_flash_attention :
931+         if  not  ( self .use_flash_attention   or   self . use_ring_attention ) :
835932            key_layer  =  torch .repeat_interleave (
836933                key_layer ,
837934                repeats = int (
@@ -945,6 +1042,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
9451042
9461043        if  self .use_flash_attention :
9471044            context_layer  =  self .flash_attention (query_layer , key_layer , value_layer )
1045+         elif  self .use_ring_attention :
1046+             context_layer  =  self .ring_attention (query_layer , key_layer , value_layer )
9481047        elif  not  self .sparse :
9491048            context_layer  =  self .attention (
9501049                query_layer , key_layer , value_layer , layer_past , attention_mask 
0 commit comments