@@ -687,6 +687,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
687687 rotary_dim = head_dim
688688 scale = tir .const (scale , "float32" )
689689 is_longrope_scaling = rope_scaling .get ("rope_type" ) == "longrope"
690+ if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling :
691+ original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ]
692+ else :
693+ original_max_position_embeddings = 0
690694
691695 def _rope ( # pylint: disable=too-many-arguments
692696 x : T .Buffer ,
@@ -770,7 +774,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
770774 var_q : T .handle ,
771775 var_k : T .handle ,
772776 var_v : T .handle ,
773- ext_factors : T .Buffer ((rotary_dim // 2 ,), "float32" ), # type: ignore
777+ ext_factors : T .Buffer ((rotary_dim ,), "float32" ), # type: ignore
774778 ):
775779 T .func_attr (
776780 {
@@ -787,37 +791,76 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
787791 position_map = T .match_buffer (
788792 var_position_map , (seq_len ,), "int32" , elem_offset = position_map_elem_offset
789793 )
790- for iters in T .grid (seq_len , fused_heads , head_dim ):
791- with T .block ("llama_fused_rope" ):
792- s , h , d = T .axis .remap ("SSS" , iters )
793- if h < num_q_heads :
794- q [s , h , d ] = T .if_then_else (
795- d < rotary_dim ,
796- _rope (
797- qkv ,
798- s ,
799- h ,
800- d ,
801- position_map [s ],
802- ext_factors if is_longrope_scaling else None ,
803- ),
804- qkv [s , h , d ],
805- )
806- elif h < num_q_heads + num_kv_heads :
807- k [s , h - num_q_heads , d ] = T .if_then_else (
808- d < rotary_dim ,
809- _rope (
810- qkv ,
811- s ,
812- h ,
813- d ,
814- position_map [s ],
815- ext_factors if is_longrope_scaling else None ,
816- ),
817- qkv [s , h , d ],
818- )
819- else :
820- v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
794+ # long factors is the first half, short factors is the second half
795+ long_factors = T .Buffer ((rotary_dim // 2 ,), "float32" , data = ext_factors .data )
796+ short_factors = T .Buffer (
797+ (rotary_dim // 2 ,), "float32" , data = ext_factors .data , elem_offset = (rotary_dim // 2 )
798+ )
799+
800+ if seq_len > original_max_position_embeddings :
801+ for iters in T .grid (seq_len , fused_heads , head_dim ):
802+ with T .block ("llama_fused_rope" ):
803+ s , h , d = T .axis .remap ("SSS" , iters )
804+ if h < num_q_heads :
805+ q [s , h , d ] = T .if_then_else (
806+ d < rotary_dim ,
807+ _rope (
808+ qkv ,
809+ s ,
810+ h ,
811+ d ,
812+ position_map [s ],
813+ long_factors if is_longrope_scaling else None ,
814+ ),
815+ qkv [s , h , d ],
816+ )
817+ elif h < num_q_heads + num_kv_heads :
818+ k [s , h - num_q_heads , d ] = T .if_then_else (
819+ d < rotary_dim ,
820+ _rope (
821+ qkv ,
822+ s ,
823+ h ,
824+ d ,
825+ position_map [s ],
826+ long_factors if is_longrope_scaling else None ,
827+ ),
828+ qkv [s , h , d ],
829+ )
830+ else :
831+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
832+ else :
833+ for iters in T .grid (seq_len , fused_heads , head_dim ):
834+ with T .block ("llama_fused_rope" ):
835+ s , h , d = T .axis .remap ("SSS" , iters )
836+ if h < num_q_heads :
837+ q [s , h , d ] = T .if_then_else (
838+ d < rotary_dim ,
839+ _rope (
840+ qkv ,
841+ s ,
842+ h ,
843+ d ,
844+ position_map [s ],
845+ short_factors if is_longrope_scaling else None ,
846+ ),
847+ qkv [s , h , d ],
848+ )
849+ elif h < num_q_heads + num_kv_heads :
850+ k [s , h - num_q_heads , d ] = T .if_then_else (
851+ d < rotary_dim ,
852+ _rope (
853+ qkv ,
854+ s ,
855+ h ,
856+ d ,
857+ position_map [s ],
858+ short_factors if is_longrope_scaling else None ,
859+ ),
860+ qkv [s , h , d ],
861+ )
862+ else :
863+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
821864
822865 if is_longrope_scaling :
823866 return fused_rope_longrope_scaling
0 commit comments