@@ -23,9 +23,9 @@ __global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){
2323 float x2 = x[idx * 2 + 1 ];
2424 int token_pos = idx / N;
2525 int token_idx = idx % N;
26- float exp_v = 1 .0f / powf (theta, token_idx / (N * 2 ));
27- float sin_v = sinf (token_pos / exp_v);
28- float cos_v = cosf (token_pos / exp_v);
26+ float exp_v = 1 .0f / powf (theta, 2 * token_idx / (N * 2 . 0f ));
27+ float sin_v = sinf (token_pos * exp_v);
28+ float cos_v = cosf (token_pos * exp_v);
2929 float out1 = x1 * cos_v - x2 * sin_v;
3030 float out2 = x1 * sin_v + x2 * cos_v;
3131 out[idx * 2 ] = out1;
@@ -38,9 +38,9 @@ __global__ void rope_f32_v2_kernel(float* x, float* out, int seq_len, int N){
3838 int tid = threadIdx .x ;
3939 float x1 = x[token_pos * N * 2 + tid * 2 ];
4040 float x2 = x[token_pos * N * 2 + tid * 2 + 1 ];
41- float exp_v = 1 .0f / powf (theta, ( int )(tid / 2 ) / (N * 2 ));
42- float sin_v = sinf (token_pos / exp_v);
43- float cos_v = cosf (token_pos / exp_v);
41+ float exp_v = 1 .0f / powf (theta, 2 * tid / (N * 2 . 0f ));
42+ float sin_v = sinf (token_pos * exp_v);
43+ float cos_v = cosf (token_pos * exp_v);
4444 float out1 = x1 * cos_v - x2 * sin_v;
4545 float out2 = x1 * sin_v + x2 * cos_v;
4646 out[token_pos * N * 2 + tid * 2 ] = out1;
@@ -52,12 +52,12 @@ __global__ void rope_f32x4_pack_kernel(float* x, float* out, int seq_len, int N)
5252 float4 x_v = FLOAT4 (x[idx * 4 ]);
5353 int token_pos = idx / N;
5454 int token_idx = idx % N;
55- float exp_f_v = 1 .0f / powf (theta, token_idx * 2 / (N * 4 ));
56- float exp_s_v = 1 .0f / powf (theta, (( token_idx * 2 ) + 1 ) / (N * 4 ));
57- float sin_f_v = sinf (token_pos / exp_f_v);
58- float cos_f_v = cosf (token_pos / exp_f_v);
59- float sin_s_v = sinf (token_pos / exp_s_v);
60- float cos_s_v = cosf (token_pos / exp_s_v);
55+ float exp_f_v = 1 .0f / powf (theta, 2 * token_idx * 2 / (N * 4 . 0f ));
56+ float exp_s_v = 1 .0f / powf (theta, 2 * ( token_idx * 2 + 1 ) / (N * 4 . 0f ));
57+ float sin_f_v = sinf (token_pos * exp_f_v);
58+ float cos_f_v = cosf (token_pos * exp_f_v);
59+ float sin_s_v = sinf (token_pos * exp_s_v);
60+ float cos_s_v = cosf (token_pos * exp_s_v);
6161 float4 out_v;
6262 out_v.x = x_v.x * cos_f_v - x_v.y * sin_f_v;
6363 out_v.y = x_v.x * sin_f_v + x_v.y * cos_f_v;
0 commit comments