-
Notifications
You must be signed in to change notification settings - Fork 41
[Library] Add RoPE and modulate_fused operators to the nn.py library #332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -199,3 +199,44 @@ def scaled_dot_product_attention[ | |
Z[i, h * (D // H) + j] = C_h[i, j] | ||
|
||
return Z | ||
|
||
|
||
def RoPE[ | ||
Ty, H, L, D | ||
](X: "Ty[L, D]", cos: "Ty[L, D // H // 2]", sin: "Ty[L, D // H // 2]") -> "Ty[L, D]": | ||
# Rotary Position Embedding | ||
# Reference: https://arxiv.org/abs/2104.09864 | ||
X_rotary: Ty[L, D] | ||
for h in range(H): | ||
X_1_h: Ty[L, D // H // 2] | ||
X_2_h: Ty[L, D // H // 2] | ||
for i, j in dsl.grid(L, D // H // 2, name="rope_split_1"): | ||
X_1_h[i, j] = X[i, h * (D // H) + j] | ||
for i, j in dsl.grid(L, D // H // 2, name="rope_split_2"): | ||
X_2_h[i, j] = X[i, h * (D // H) + D // H // 2 + j] | ||
X_1_rotary: Ty[L, D // H // 2] = 0 | ||
X_2_rotary: Ty[L, D // H // 2] = 0 | ||
for i, j in dsl.grid(L, D // H // 2, name="rotary_1"): | ||
X_1_rotary[i, j] = cos[i, j] * X_1_h[i, j] - sin[i, j] * X_2_h[i, j] | ||
for i, j in dsl.grid(L, D // H // 2, name="rotary_2"): | ||
X_2_rotary[i, j] = sin[i, j] * X_1_h[i, j] + cos[i, j] * X_2_h[i, j] | ||
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_1"): | ||
X_rotary[i, h * (D // H) + j] = X_1_rotary[i, j] | ||
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_2"): | ||
X_rotary[i, h * (D // H) + D // H // 2 + j] = X_2_rotary[i, j] | ||
return X_rotary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose the RoPE kernel also needs the corresponding scheduling function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally planned to add a
However, I get the following error message. A potential cause is that the main body of the RoPE operator is nested inside a standard Python for loop, so the nested
Could you please kindly tell me how to add a schedule in this case, or it's fine to leave this function unscheduled for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Can you add a name for the |
||
|
||
|
||
def modulate_fused[ | ||
Ty, L, D | ||
](X: "Ty[L,D]", scale: "Ty[D]", shift: "Ty[D]") -> "Ty[L, D]": | ||
Z: Ty[L, D] | ||
for i, j in dsl.grid(L, D, name="m_fused"): | ||
Z[i, j] = X[i, j] * (1 + scale[j]) + shift[j] | ||
return Z | ||
|
||
|
||
def schedule_modulate_fused(s): | ||
lj = s.get_loops(s.top_func_name)["m_fused"]["j"] | ||
s.pipeline(lj) | ||
return s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the reference paper here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the suggestion! I've added a comment of ref to the original paper. You can also find more details here:
Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” — arXiv:2104.09864