-
Notifications
You must be signed in to change notification settings - Fork 41
[Library] Add RoPE and modulate_fused operators to the nn.py library #332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Library] Add RoPE and modulate_fused operators to the nn.py library #332
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing to Allo! As these two are not commonly used operations, could you add the corresponding paper reference to the operator definition? If they have PyTorch implementation, please also provide a link
def RoPE[ | ||
Ty, H, L, D | ||
](X: "Ty[L, D]", cos: "Ty[L, D // H // 2]", sin: "Ty[L, D // H // 2]") -> "Ty[L, D]": | ||
# Rotary Position Embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the reference paper here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the suggestion! I've added a comment of ref to the original paper. You can also find more details here:
Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” — arXiv:2104.09864
X_rotary[i, h * (D // H) + j] = X_1_rotary[i, j] | ||
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_2"): | ||
X_rotary[i, h * (D // H) + D // H // 2 + j] = X_2_rotary[i, j] | ||
return X_rotary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose the RoPE kernel also needs the corresponding scheduling function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally planned to add a schedule_RoPE
function, but it seems Allo raises error when trying to access the loop labels. This is the schedule function I tried to pipeline the inner loop j
:
def schedule_RoPE(s):
lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
s.pipeline(lj1)
lj2 = s.get_loops(s.top_func_name)["rope_split_2"]["j"]
s.pipeline(lj2)
lj3 = s.get_loops(s.top_func_name)["rotary_1"]["j"]
s.pipeline(lj3)
lj4 = s.get_loops(s.top_func_name)["rotary_2"]["j"]
s.pipeline(lj4)
lj5 = s.get_loops(s.top_func_name)["rotary_merge_1"]["j"]
s.pipeline(lj5)
lj6 = s.get_loops(s.top_func_name)["rotary_merge_2"]["j"]
s.pipeline(lj6)
return s
However, I get the following error message. A potential cause is that the main body of the RoPE operator is nested inside a standard Python for loop, so the nested dsl.grid()
loops are not visible to s.get_loops()
at the top level. This seems similar to the structure of scaled_dot_product_attention
, which I suppose also doesn't have a schedule_*
function for the same reason.
(allo) root@2a2f74889c68:~/allo/allo/library# python3 tests.py
Traceback (most recent call last):
File "/root/allo/allo/library/tests.py", line 66, in <module>
test_RoPE()
File "/root/allo/allo/library/tests.py", line 26, in test_RoPE
schedule_RoPE(s)
File "/root/allo/allo/library/nn.py", line 231, in schedule_RoPE
lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
File "/root/allo/allo/ir/transform.py", line 66, in __getitem__
raise AttributeError(f"No such loop {name}")
AttributeError: No such loop rope_split_1
Could you please kindly tell me how to add a schedule in this case, or it's fine to leave this function unscheduled for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Can you add a name for the h
loop so that you can access h
first? Also, you can print out the loops by print(s.get_loops(s.top_func_name))
, so you can know what kinds of loops are accesible
Description
Add two new operator functions to the
nn.py
library for better support of diffusion transformers and LLM-style workloads.Problems
Currently, the Allo
nn.py
library lacks support for:RoPE
(Rotary Positional Embedding), which is widely used in attention layers in LLMsmodulate_fused
, a fused operator for scaling and shifting used in diffusion models (e.g., timestep conditioning)Adding these operators could improve coverage for transformer-based and diffusion model architectures.
Proposed Solutions
New operators in
/allo/library/nn.py
RoPE
operator.modulate_fused
operator, which combines scale and shift into a single operator.schedule_modulate_fused
function.New test cases in
tests/test_nn.py
test_RoPE
andtest_modulate_fused
for the above mentioned two operators.Examples
(Please provide an example of the input program and the expected behavior, e.g., the generated IR.)
Checklist
Please make sure to review and check all of these items: