Skip to content

Commit ebc6cfb

Browse files
committed
fix eb5 big tensor bug
1 parent 29f97bb commit ebc6cfb

4 files changed

Lines changed: 37 additions & 37 deletions

File tree

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_spaq.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
182182
const float *__restrict__ prob,
183183
phi::float8_e4m3fn *__restrict__ out,
184184
float *__restrict__ scales,
185-
const int rows,
186-
const int cols) {
185+
const int64_t rows,
186+
const int64_t cols) {
187187
// Configure shared memory
188188
__shared__ float smem_tile[256]; // Shared memory for activation values
189189
__shared__ float warp_max[2][4]; // Shared memory for warp maxima (2 quant
@@ -192,12 +192,12 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
192192
quant_block_amax[2]; // Shared memory for quant block maxima
193193

194194
const __nv_bfloat16 *X = reinterpret_cast<const __nv_bfloat16 *>(Xin);
195-
const int x_offset = threadIdx.x;
195+
const uint32_t x_offset = threadIdx.x;
196196
const int quant_block_idx =
197197
threadIdx.x / 128; // 0 or 1, two quant blocks per block
198-
const int in_y_idx = blockIdx.y;
199-
const int in_x_idx = blockIdx.x * blockDim.x + x_offset;
200-
const int src_idx = in_y_idx * cols + in_x_idx;
198+
const int64_t in_y_idx = blockIdx.y;
199+
const int64_t in_x_idx = static_cast<uint64_t>(blockIdx.x) * blockDim.x + x_offset;
200+
const int64_t src_idx = in_y_idx * cols + in_x_idx;
201201

202202
// Load data and compute swiGLU activation
203203
if (in_x_idx < cols / 2) [[likely]] {
@@ -255,7 +255,7 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
255255

256256
// Phase 3: Compute scales and quantize the outputs
257257
const float block_max_float = (float)quant_block_amax[quant_block_idx];
258-
const int scale_stride = (cols / 2 + 127) / 128;
258+
const int64_t scale_stride = (cols / 2 + 127) / 128;
259259

260260
float scale = ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>(
261261
block_max_float, 0.0f);
@@ -265,8 +265,8 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
265265
float output_scaled_fp32 = smem_tile[x_offset] * scale;
266266

267267

268-
const int g_output_y_offset = in_y_idx;
269-
const int g_output_x_offset = in_x_idx;
268+
const int64_t g_output_y_offset = in_y_idx;
269+
const int64_t g_output_x_offset = in_x_idx;
270270

271271
// Write output and scales
272272
if (g_output_y_offset < rows && g_output_x_offset < cols / 2) {
@@ -284,8 +284,8 @@ void dispatch_fused_spaq(const paddle::Tensor &X,
284284
const paddle::optional<paddle::Tensor> &prob,
285285
paddle::Tensor &out,
286286
paddle::Tensor &scale,
287-
const int rows,
288-
const int cols,
287+
const int64_t rows,
288+
const int64_t cols,
289289
const bool &using_pow2_scaling,
290290
const bool &with_prob) {
291291
constexpr int thread_per_block = 256;
@@ -297,8 +297,8 @@ void dispatch_fused_spaq(const paddle::Tensor &X,
297297
// 1x128 vector Each block handles several sub-row (numel = 4 x blockDim.x)
298298
// of input vector
299299
block.x = thread_per_block;
300-
constexpr int vec_numel = 4;
301-
const int scale_cols = scale.shape().back();
300+
constexpr int64_t vec_numel = 4;
301+
const int64_t scale_cols = scale.shape().back();
302302
DISPATCH_BOOL(
303303
using_pow2_scaling,
304304
k_using_pow2_scaling,

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_stack_transpose_quant.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ __global__ void __launch_bounds__(1024)
183183
for (int j = 0; j < 4; j++) {
184184
float input_fp32 = static_cast<float>(input[i][j]);
185185
float output_scaled = input_fp32 * scale_inv;
186-
shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] =
186+
shm[static_cast<size_t>(threadIdx.x) * 4 + j][i * 32 + threadIdx.y] =
187187
static_cast<OutT>(output_scaled);
188188
}
189189
}
@@ -193,13 +193,13 @@ __global__ void __launch_bounds__(1024)
193193
for (size_t i = 0; i < 4; i++) {
194194
size_t idx_n = blockIdx.z;
195195
size_t idx_k = block_x * 128 + threadIdx.y + i * 32;
196-
size_t idx_m = block_y * 128 + threadIdx.x * 4;
196+
size_t idx_m = block_y * 128 + static_cast<size_t>(threadIdx.x) * 4;
197197
size_t idx = (idx_n * K + idx_k) * M + idx_m;
198198

199199
using StoreT = VecType<OutT, 4>;
200200
StoreT data;
201201
for (int j = 0; j < 4; j++) {
202-
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
202+
data[j] = shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
203203
}
204204
*reinterpret_cast<StoreT*>(out + idx) = data;
205205
}

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_swiglu_probs_bwd.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ __global__ void SwigluProbsGradKernel(
4949
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
5050
float* probs_grad, // [seq_len*topk, 1]
5151
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
52-
int moe_intermediate_size) {
53-
const int row_idx = blockIdx.x;
54-
const int tid = threadIdx.x;
52+
int64_t moe_intermediate_size) {
53+
const int64_t row_idx = blockIdx.x;
54+
const int64_t tid = threadIdx.x;
5555

5656
const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2;
5757
const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size;
@@ -64,7 +64,7 @@ __global__ void SwigluProbsGradKernel(
6464

6565
float local_probs_grad = 0.0f;
6666

67-
for (int i = tid; i < moe_intermediate_size; i += blockDim.x) {
67+
for (int64_t i = tid; i < moe_intermediate_size; i += blockDim.x) {
6868
float lhs = static_cast<float>(o1_row[i]);
6969
float rhs = static_cast<float>(o1_row[i + moe_intermediate_size]);
7070

@@ -185,7 +185,7 @@ __global__ void SwigluProbsGradKernelVec4(
185185
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
186186
float* probs_grad, // [seq_len*topk, 1]
187187
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
188-
int moe_intermediate_size) {
188+
int64_t moe_intermediate_size) {
189189
constexpr int numel_per_thread = 4;
190190
constexpr int k_warp_size = 32;
191191
const int64_t row_idx = blockIdx.x;
@@ -210,7 +210,7 @@ __global__ void SwigluProbsGradKernelVec4(
210210

211211
float local_probs_grad = 0.0f;
212212

213-
const int vec_numel = (int64_t)moe_intermediate_size / numel_per_thread;
213+
const int64_t vec_numel = (int64_t)moe_intermediate_size / numel_per_thread;
214214
for (int64_t i = tid; i < vec_numel; i += blockDim.x) {
215215
float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i);
216216
float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i);
@@ -262,13 +262,13 @@ std::vector<paddle::Tensor> SwigluProbsGradCUDABackward(
262262
const paddle::Tensor& unzipped_probs,
263263
bool inplace) {
264264
auto o1_dims = o1.dims();
265-
int o1_outer_dim = 1;
265+
int64_t o1_outer_dim = 1;
266266
for (int i = 0; i < o1_dims.size() - 1; i++) {
267267
o1_outer_dim *= o1_dims[i];
268268
}
269269

270-
const int moe_intermediate_size_2 = o1_dims[o1_dims.size() - 1];
271-
const int moe_intermediate_size = moe_intermediate_size_2 / 2;
270+
const int64_t moe_intermediate_size_2 = o1_dims[o1_dims.size() - 1];
271+
const int64_t moe_intermediate_size = moe_intermediate_size_2 / 2;
272272

273273
auto do1 = inplace ? o1 : paddle::empty_like(o1);
274274
auto probs_grad =

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ __device__ void BlockLoad(const phi::bfloat16* X,
1414
__nv_bfloat16 input[4][4],
1515
size_t K) {
1616
for (size_t i = 0; i < 4; i++) {
17-
size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32;
18-
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
17+
size_t off_m = static_cast<size_t>(blockIdx.x) * 128 + threadIdx.y + i * 32;
18+
size_t off_k = static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
1919
size_t offset = off_m * K + off_k;
2020

2121
for (size_t j = 0; j < 4; j += VecSize) {
@@ -45,15 +45,15 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
4545

4646
// Reduce [(32), 32, 4] => [32, 4]
4747
for (int i = 0; i < 4; i++) {
48-
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
48+
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] = warp_max[i];
4949
}
5050
__syncthreads();
5151
for (int offset = 16; offset > 0; offset /= 2) {
5252
if (threadIdx.y < offset) {
5353
for (int i = 0; i < 4; i++) {
54-
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
55-
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
56-
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
54+
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] =
55+
__hmax(shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x],
56+
shm[(static_cast<size_t>(threadIdx.y) + offset) * 128 + i * 32 + threadIdx.x]);
5757
}
5858
}
5959
__syncthreads();
@@ -79,7 +79,7 @@ __device__ void BlockStoreScale(float* scale,
7979
}
8080
if (threadIdx.y == 0) {
8181
size_t idx_m = blockIdx.x - off_m / 128;
82-
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
82+
size_t off_k = static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
8383
size_t offset = idx_m * K + off_k;
8484

8585
for (size_t j = 0; j < 4; j += VecSize) {
@@ -103,15 +103,15 @@ __device__ void BlockStoreOut(OutT* out,
103103
const OutT shm[128][129],
104104
size_t K) {
105105
for (size_t i = 0; i < 4; i++) {
106-
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
107-
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32;
106+
size_t idx_m = static_cast<size_t>(blockIdx.x) * 128 + threadIdx.x * 4;
107+
size_t idx_k = static_cast<size_t>(blockIdx.y) * 128 + threadIdx.y + i * 32;
108108
size_t idx = idx_k * cur_tokens + (idx_m - off_m);
109109

110110
if (idx_k < K) {
111111
using StoreT = VecType<OutT, VecSize>;
112112
StoreT data;
113113
for (int j = 0; j < VecSize; j++) {
114-
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
114+
data[j] = shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
115115
}
116116
*reinterpret_cast<StoreT*>(out + idx) = data;
117117
}
@@ -123,7 +123,7 @@ __device__ std::pair<size_t, size_t> GetExpertIdx(int64_t* tokens_per_expert,
123123
__shared__ size_t expert_idx_, off_m_;
124124

125125
if (threadIdx.x == 0 && threadIdx.y == 0) {
126-
size_t idx_m = blockIdx.x * 128;
126+
size_t idx_m = static_cast<size_t>(blockIdx.x) * 128;
127127
size_t off_m = 0, next_off_m = 0;
128128
size_t expert_idx;
129129
for (expert_idx = 0; expert_idx < num_experts; expert_idx++) {
@@ -176,7 +176,7 @@ __global__ void __launch_bounds__(1024)
176176
for (int k = 0; k < VecSize; k++) {
177177
float input_fp32 = static_cast<float>(input[i][j + k]);
178178
float output_scaled = input_fp32 * scale_inv[j + k];
179-
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
179+
shm[static_cast<size_t>(threadIdx.x) * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
180180
static_cast<OutT>(output_scaled);
181181
}
182182
}

0 commit comments

Comments
 (0)