Skip to content

Commit a7ab4d1

Browse files
committed
Fix llama4_rope_with_position_map to support partial rotary factor
1 parent d5d3d81 commit a7ab4d1

File tree

1 file changed

+75
-32
lines changed

1 file changed

+75
-32
lines changed

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
687687
rotary_dim = head_dim
688688
scale = tir.const(scale, "float32")
689689
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
690+
if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling:
691+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
692+
else:
693+
original_max_position_embeddings = 0
690694

691695
def _rope( # pylint: disable=too-many-arguments
692696
x: T.Buffer,
@@ -770,7 +774,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
770774
var_q: T.handle,
771775
var_k: T.handle,
772776
var_v: T.handle,
773-
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
777+
ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
774778
):
775779
T.func_attr(
776780
{
@@ -787,37 +791,76 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
787791
position_map = T.match_buffer(
788792
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
789793
)
790-
for iters in T.grid(seq_len, fused_heads, head_dim):
791-
with T.block("llama_fused_rope"):
792-
s, h, d = T.axis.remap("SSS", iters)
793-
if h < num_q_heads:
794-
q[s, h, d] = T.if_then_else(
795-
d < rotary_dim,
796-
_rope(
797-
qkv,
798-
s,
799-
h,
800-
d,
801-
position_map[s],
802-
ext_factors if is_longrope_scaling else None,
803-
),
804-
qkv[s, h, d],
805-
)
806-
elif h < num_q_heads + num_kv_heads:
807-
k[s, h - num_q_heads, d] = T.if_then_else(
808-
d < rotary_dim,
809-
_rope(
810-
qkv,
811-
s,
812-
h,
813-
d,
814-
position_map[s],
815-
ext_factors if is_longrope_scaling else None,
816-
),
817-
qkv[s, h, d],
818-
)
819-
else:
820-
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
794+
# long factors is the first half, short factors is the second half
795+
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
796+
short_factors = T.Buffer(
797+
(rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)
798+
)
799+
800+
if seq_len > original_max_position_embeddings:
801+
for iters in T.grid(seq_len, fused_heads, head_dim):
802+
with T.block("llama_fused_rope"):
803+
s, h, d = T.axis.remap("SSS", iters)
804+
if h < num_q_heads:
805+
q[s, h, d] = T.if_then_else(
806+
d < rotary_dim,
807+
_rope(
808+
qkv,
809+
s,
810+
h,
811+
d,
812+
position_map[s],
813+
long_factors if is_longrope_scaling else None,
814+
),
815+
qkv[s, h, d],
816+
)
817+
elif h < num_q_heads + num_kv_heads:
818+
k[s, h - num_q_heads, d] = T.if_then_else(
819+
d < rotary_dim,
820+
_rope(
821+
qkv,
822+
s,
823+
h,
824+
d,
825+
position_map[s],
826+
long_factors if is_longrope_scaling else None,
827+
),
828+
qkv[s, h, d],
829+
)
830+
else:
831+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
832+
else:
833+
for iters in T.grid(seq_len, fused_heads, head_dim):
834+
with T.block("llama_fused_rope"):
835+
s, h, d = T.axis.remap("SSS", iters)
836+
if h < num_q_heads:
837+
q[s, h, d] = T.if_then_else(
838+
d < rotary_dim,
839+
_rope(
840+
qkv,
841+
s,
842+
h,
843+
d,
844+
position_map[s],
845+
short_factors if is_longrope_scaling else None,
846+
),
847+
qkv[s, h, d],
848+
)
849+
elif h < num_q_heads + num_kv_heads:
850+
k[s, h - num_q_heads, d] = T.if_then_else(
851+
d < rotary_dim,
852+
_rope(
853+
qkv,
854+
s,
855+
h,
856+
d,
857+
position_map[s],
858+
short_factors if is_longrope_scaling else None,
859+
),
860+
qkv[s, h, d],
861+
)
862+
else:
863+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
821864

822865
if is_longrope_scaling:
823866
return fused_rope_longrope_scaling

0 commit comments

Comments
 (0)