Skip to content

[Library] Add RoPE and modulate_fused operators to the nn.py library #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions allo/library/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,44 @@ def scaled_dot_product_attention[
Z[i, h * (D // H) + j] = C_h[i, j]

return Z


def RoPE[
Ty, H, L, D
](X: "Ty[L, D]", cos: "Ty[L, D // H // 2]", sin: "Ty[L, D // H // 2]") -> "Ty[L, D]":
# Rotary Position Embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the reference paper here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the suggestion! I've added a comment of ref to the original paper. You can also find more details here:
Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” — arXiv:2104.09864

# Reference: https://arxiv.org/abs/2104.09864
X_rotary: Ty[L, D]
for h in range(H):
X_1_h: Ty[L, D // H // 2]
X_2_h: Ty[L, D // H // 2]
for i, j in dsl.grid(L, D // H // 2, name="rope_split_1"):
X_1_h[i, j] = X[i, h * (D // H) + j]
for i, j in dsl.grid(L, D // H // 2, name="rope_split_2"):
X_2_h[i, j] = X[i, h * (D // H) + D // H // 2 + j]
X_1_rotary: Ty[L, D // H // 2] = 0
X_2_rotary: Ty[L, D // H // 2] = 0
for i, j in dsl.grid(L, D // H // 2, name="rotary_1"):
X_1_rotary[i, j] = cos[i, j] * X_1_h[i, j] - sin[i, j] * X_2_h[i, j]
for i, j in dsl.grid(L, D // H // 2, name="rotary_2"):
X_2_rotary[i, j] = sin[i, j] * X_1_h[i, j] + cos[i, j] * X_2_h[i, j]
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_1"):
X_rotary[i, h * (D // H) + j] = X_1_rotary[i, j]
for i, j in dsl.grid(L, D // H // 2, name="rotary_merge_2"):
X_rotary[i, h * (D // H) + D // H // 2 + j] = X_2_rotary[i, j]
return X_rotary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the RoPE kernel also needs the corresponding scheduling function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally planned to add a schedule_RoPE function, but it seems Allo raises error when trying to access the loop labels. This is the schedule function I tried to pipeline the inner loop j:

def schedule_RoPE(s):
    lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
    s.pipeline(lj1)
    lj2 = s.get_loops(s.top_func_name)["rope_split_2"]["j"]
    s.pipeline(lj2)
    lj3 = s.get_loops(s.top_func_name)["rotary_1"]["j"]
    s.pipeline(lj3)
    lj4 = s.get_loops(s.top_func_name)["rotary_2"]["j"]
    s.pipeline(lj4)
    lj5 = s.get_loops(s.top_func_name)["rotary_merge_1"]["j"]
    s.pipeline(lj5)
    lj6 = s.get_loops(s.top_func_name)["rotary_merge_2"]["j"]
    s.pipeline(lj6)
    return s

However, I get the following error message. A potential cause is that the main body of the RoPE operator is nested inside a standard Python for loop, so the nested dsl.grid() loops are not visible to s.get_loops() at the top level. This seems similar to the structure of scaled_dot_product_attention, which I suppose also doesn't have a schedule_* function for the same reason.

(allo) root@2a2f74889c68:~/allo/allo/library# python3 tests.py 
Traceback (most recent call last):
  File "/root/allo/allo/library/tests.py", line 66, in <module>
    test_RoPE()
  File "/root/allo/allo/library/tests.py", line 26, in test_RoPE
    schedule_RoPE(s)
  File "/root/allo/allo/library/nn.py", line 231, in schedule_RoPE
    lj1 = s.get_loops(s.top_func_name)["rope_split_1"]["j"]
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File "/root/allo/allo/ir/transform.py", line 66, in __getitem__
    raise AttributeError(f"No such loop {name}")
AttributeError: No such loop rope_split_1

Could you please kindly tell me how to add a schedule in this case, or it's fine to leave this function unscheduled for now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can you add a name for the h loop so that you can access h first? Also, you can print out the loops by print(s.get_loops(s.top_func_name)), so you can know what kinds of loops are accesible



def modulate_fused[
Ty, L, D
](X: "Ty[L,D]", scale: "Ty[D]", shift: "Ty[D]") -> "Ty[L, D]":
Z: Ty[L, D]
for i, j in dsl.grid(L, D, name="m_fused"):
Z[i, j] = X[i, j] * (1 + scale[j]) + shift[j]
return Z


def schedule_modulate_fused(s):
lj = s.get_loops(s.top_func_name)["m_fused"]["j"]
s.pipeline(lj)
return s
59 changes: 59 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,64 @@ def bert_layer(X, Wq, Wk, Wv, Wp, W1, W2, gamma1, beta1, gamma2, beta2):
print(s.build(target="vhls"))


def np_rope(X, cos, sin, num_heads=8):
X1 = X[:, :, :32]
X2 = X[:, :, 32:]

X_rotated = np.zeros_like(X) # [1024, 8, 64]

for i in range(num_heads):
X_1_i = X1[:, i, :]
X_2_i = X2[:, i, :]
X_rotated_i = np.concatenate(
(X_1_i * cos - X_2_i * sin, X_1_i * sin + X_2_i * cos), axis=-1
)

X_rotated[:, i, :] = X_rotated_i # [1024, 8, 64]
return X_rotated


def test_RoPE():
from allo.library.nn import RoPE

L, D = 1024, 512
H = 8
s = allo.customize(RoPE, instantiate=[float32, H, L, D])
mod = s.build()
Q = np.random.randn(L, D).astype(np.float32)
cos = np.random.randn(L, 32).astype(np.float32)
sin = np.random.randn(L, 32).astype(np.float32)
allo_out = mod(Q, cos, sin)
Q_np = Q.reshape(1024, 8, 64)
np_out = np_rope(Q_np, cos, sin)
np_out = np_out.reshape(1024, 512)
np.testing.assert_allclose(allo_out, np_out, atol=1e-3)
print("Passed!")


def np_modulate_fused(x, shift, scale):
output = x * (1 + scale) + shift
return output


def test_modulate_fused():
from allo.library.nn import modulate_fused
from allo.library.nn import schedule_modulate_fused

L, D = 1024, 512
X = np.random.randn(L, D).astype(np.float32)
X_norm = X
s = allo.customize(modulate_fused, instantiate=[float32, L, D])
schedule_modulate_fused(s)
print(s.module)
mod = s.build(target="llvm")
scale = np.random.randn(D).astype(np.float32)
shift = np.random.randn(D).astype(np.float32)
allo_out = mod(X, scale, shift)
np_out = np_modulate_fused(X_norm, shift=shift, scale=scale)
np.testing.assert_allclose(allo_out, np_out, atol=1e-3)
print("Passed!")


if __name__ == "__main__":
pytest.main([__file__])
Loading