Skip to content

[Library] Add RoPE and modulate_fused operators to the nn.py library #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

silvenachen
Copy link
Contributor

@silvenachen silvenachen commented Mar 29, 2025

Description

Add two new operator functions to the nn.py library for better support of diffusion transformers and LLM-style workloads.

Problems

Currently, the Allo nn.py library lacks support for:

  • RoPE (Rotary Positional Embedding), which is widely used in attention layers in LLMs
  • modulate_fused, a fused operator for scaling and shifting used in diffusion models (e.g., timestep conditioning)

Adding these operators could improve coverage for transformer-based and diffusion model architectures.

Proposed Solutions

New operators in /allo/library/nn.py

  • Implemented RoPE operator.
  • Implemented modulate_fused operator, which combines scale and shift into a single operator.
  • Added accompanying schedule_modulate_fused function.

New test cases in tests/test_nn.py

  • test_RoPE and test_modulate_fused for the above mentioned two operators.

Examples

(Please provide an example of the input program and the expected behavior, e.g., the generated IR.)

Checklist

Please make sure to review and check all of these items:

  • PR's title starts with a category (e.g. [Bugfix], [IR], [Builder], etc)
  • All changes have test coverage (It would be good to provide ~2 different test cases to test the robustness of your code)
  • Pass the formatting check locally
  • Code is well-documented

Copy link
Member

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing to Allo! As these two are not commonly used operations, could you add the corresponding paper reference to the operator definition? If they have PyTorch implementation, please also provide a link

def RoPE[
Ty, H, L, D
](X: "Ty[L, D]", cos: "Ty[L, D // H // 2]", sin: "Ty[L, D // H // 2]") -> "Ty[L, D]":
# Rotary Position Embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the reference paper here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the suggestion! I've added a comment of ref to the original paper. You can also find more details here:
Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” — arXiv:2104.09864

X_rotary[i, h * (D // H) + j] = X_1_rotary[i, j]
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_2"):
X_rotary[i, h * (D // H) + D // H // 2 + j] = X_2_rotary[i, j]
return X_rotary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the RoPE kernel also needs the corresponding scheduling function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally planned to add a schedule_RoPE function, but it seems Allo raises error when trying to access the loop labels. This is the schedule function I tried to pipeline the inner loop j:

def schedule_RoPE(s):
    lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
    s.pipeline(lj1)
    lj2 = s.get_loops(s.top_func_name)["rope_split_2"]["j"]
    s.pipeline(lj2)
    lj3 = s.get_loops(s.top_func_name)["rotary_1"]["j"]
    s.pipeline(lj3)
    lj4 = s.get_loops(s.top_func_name)["rotary_2"]["j"]
    s.pipeline(lj4)
    lj5 = s.get_loops(s.top_func_name)["rotary_merge_1"]["j"]
    s.pipeline(lj5)
    lj6 = s.get_loops(s.top_func_name)["rotary_merge_2"]["j"]
    s.pipeline(lj6)
    return s

However, I get the following error message. A potential cause is that the main body of the RoPE operator is nested inside a standard Python for loop, so the nested dsl.grid() loops are not visible to s.get_loops() at the top level. This seems similar to the structure of scaled_dot_product_attention, which I suppose also doesn't have a schedule_* function for the same reason.

(allo) root@2a2f74889c68:~/allo/allo/library# python3 tests.py 
Traceback (most recent call last):
  File "/root/allo/allo/library/tests.py", line 66, in <module>
    test_RoPE()
  File "/root/allo/allo/library/tests.py", line 26, in test_RoPE
    schedule_RoPE(s)
  File "/root/allo/allo/library/nn.py", line 231, in schedule_RoPE
    lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File "/root/allo/allo/ir/transform.py", line 66, in __getitem__
    raise AttributeError(f"No such loop {name}")
AttributeError: No such loop rope_split_1

Could you please kindly tell me how to add a schedule in this case, or it's fine to leave this function unscheduled for now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can you add a name for the h loop so that you can access h first? Also, you can print out the loops by print(s.get_loops(s.top_func_name)), so you can know what kinds of loops are accesible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants