Skip to content

Commit e799541

Browse files
authored
bugfix: corrected theta calculation in RoPE CUDA kernel (#290)
1 parent db84d12 commit e799541

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

kernels/rope/rope.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ __global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){
2323
float x2 = x[idx * 2 + 1];
2424
int token_pos = idx / N;
2525
int token_idx = idx % N;
26-
float exp_v = 1.0f / powf(theta, token_idx / (N * 2));
27-
float sin_v = sinf(token_pos / exp_v);
28-
float cos_v = cosf(token_pos / exp_v);
26+
float exp_v = 1.0f / powf(theta, 2 * token_idx / (N * 2.0f));
27+
float sin_v = sinf(token_pos * exp_v);
28+
float cos_v = cosf(token_pos * exp_v);
2929
float out1 = x1 * cos_v - x2 * sin_v;
3030
float out2 = x1 * sin_v + x2 * cos_v;
3131
out[idx * 2] = out1;
@@ -38,9 +38,9 @@ __global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){
3838
int tid = threadIdx.x;
3939
float x1 = x[token_pos * N * 2 + tid * 2];
4040
float x2 = x[token_pos * N * 2 + tid * 2 + 1];
41-
float exp_v = 1.0f / powf(theta, (int)(tid / 2) / (N * 2));
42-
float sin_v = sinf(token_pos / exp_v);
43-
float cos_v = cosf(token_pos / exp_v);
41+
float exp_v = 1.0f / powf(theta, 2 * tid / (N * 2.0f));
42+
float sin_v = sinf(token_pos * exp_v);
43+
float cos_v = cosf(token_pos * exp_v);
4444
float out1 = x1 * cos_v - x2 * sin_v;
4545
float out2 = x1 * sin_v + x2 * cos_v;
4646
out[token_pos * N * 2 + tid * 2] = out1;
@@ -52,12 +52,12 @@ __global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N)
5252
float4 x_v = FLOAT4(x[idx * 4]);
5353
int token_pos = idx / N;
5454
int token_idx = idx % N;
55-
float exp_f_v = 1.0f / powf(theta, token_idx * 2 / (N * 4));
56-
float exp_s_v = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4));
57-
float sin_f_v = sinf(token_pos / exp_f_v);
58-
float cos_f_v = cosf(token_pos / exp_f_v);
59-
float sin_s_v = sinf(token_pos / exp_s_v);
60-
float cos_s_v = cosf(token_pos / exp_s_v);
55+
float exp_f_v = 1.0f / powf(theta, 2 * token_idx * 2 / (N * 4.0f));
56+
float exp_s_v = 1.0f / powf(theta, 2 * (token_idx * 2 + 1) / (N * 4.0f));
57+
float sin_f_v = sinf(token_pos * exp_f_v);
58+
float cos_f_v = cosf(token_pos * exp_f_v);
59+
float sin_s_v = sinf(token_pos * exp_s_v);
60+
float cos_s_v = cosf(token_pos * exp_s_v);
6161
float4 out_v;
6262
out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v;
6363
out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v;

0 commit comments

Comments
 (0)