|
1 | 1 | #include "cuda_utils.cuh" |
2 | 2 | #include <cmath> |
3 | 3 | #include <stdint.h> |
| 4 | +#include <cuda/std/limits> |
4 | 5 |
|
5 | 6 | #define WARP_SIZE 32 |
6 | 7 | const int BLOCK_SIZE = 1024; |
7 | 8 |
|
| 9 | +// Helpers to initialize reduction identities for both floating-point and |
| 10 | +// integer types. For floats we keep using +/-INFINITY, while for integers |
| 11 | +// we use well-defined numeric_limits values instead of relying on casting |
| 12 | +// +/-INFINITY to an integer type (which is undefined behaviour and has been |
| 13 | +// observed to break on newer GPU architectures such as Blackwell). |
| 14 | +template <typename T> |
| 15 | +__device__ __forceinline__ T reduce_init_lowest() { |
| 16 | + // Default implementation is used for floating-point types (__half, |
| 17 | + // __nv_bfloat16, float, double). The conversion from -INFINITY (double) |
| 18 | + // to these types is well-defined and produces -inf. |
| 19 | + return -INFINITY; |
| 20 | +} |
| 21 | + |
| 22 | +template <typename T> |
| 23 | +__device__ __forceinline__ T reduce_init_highest() { |
| 24 | + // Default implementation is used for floating-point types (__half, |
| 25 | + // __nv_bfloat16, float, double). The conversion from INFINITY (double) |
| 26 | + // to these types is well-defined and produces +inf. |
| 27 | + return INFINITY; |
| 28 | +} |
| 29 | + |
| 30 | +// Integer specializations – use numeric_limits instead of +/-INFINITY. |
| 31 | +template <> |
| 32 | +__device__ __forceinline__ int64_t reduce_init_lowest<int64_t>() { |
| 33 | + return ::cuda::std::numeric_limits<int64_t>::lowest(); |
| 34 | +} |
| 35 | + |
| 36 | +template <> |
| 37 | +__device__ __forceinline__ uint32_t reduce_init_lowest<uint32_t>() { |
| 38 | + return ::cuda::std::numeric_limits<uint32_t>::lowest(); |
| 39 | +} |
| 40 | + |
| 41 | +template <> |
| 42 | +__device__ __forceinline__ uint8_t reduce_init_lowest<uint8_t>() { |
| 43 | + return ::cuda::std::numeric_limits<uint8_t>::lowest(); |
| 44 | +} |
| 45 | + |
| 46 | +template <> |
| 47 | +__device__ __forceinline__ int64_t reduce_init_highest<int64_t>() { |
| 48 | + return ::cuda::std::numeric_limits<int64_t>::max(); |
| 49 | +} |
| 50 | + |
| 51 | +template <> |
| 52 | +__device__ __forceinline__ uint32_t reduce_init_highest<uint32_t>() { |
| 53 | + return ::cuda::std::numeric_limits<uint32_t>::max(); |
| 54 | +} |
| 55 | + |
| 56 | +template <> |
| 57 | +__device__ __forceinline__ uint8_t reduce_init_highest<uint8_t>() { |
| 58 | + return ::cuda::std::numeric_limits<uint8_t>::max(); |
| 59 | +} |
| 60 | + |
8 | 61 | // TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 |
9 | 62 | // but also expect a f32 output so that this can be used for normalization e.g. |
10 | 63 | // in softmax. |
@@ -102,29 +155,29 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, |
102 | 155 |
|
103 | 156 | if (alpha == nullptr && beta == nullptr) { |
104 | 157 | for (int col = tid; col < ncols; col += block_size) { |
105 | | - float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
| 158 | + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
106 | 159 | dst[row*ncols + col] = static_cast<T>(lhs); |
107 | 160 | } |
108 | 161 | } |
109 | 162 | else if (alpha == nullptr && beta != nullptr) { |
110 | 163 | for (int col = tid; col < ncols; col += block_size) { |
111 | 164 | float b = static_cast<float>(beta[col]); |
112 | | - float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
| 165 | + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
113 | 166 | dst[row*ncols + col] = static_cast<T>(lhs + b); |
114 | 167 | } |
115 | 168 | } |
116 | 169 | else if (alpha != nullptr && beta == nullptr) { |
117 | 170 | for (int col = tid; col < ncols; col += block_size) { |
118 | 171 | float a = static_cast<float>(alpha[col]); |
119 | | - float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
| 172 | + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
120 | 173 | dst[row*ncols + col] = static_cast<T>(lhs * a); |
121 | 174 | } |
122 | 175 | } |
123 | 176 | else { |
124 | 177 | for (int col = tid; col < ncols; col += block_size) { |
125 | 178 | float a = static_cast<float>(alpha[col]); |
126 | 179 | float b = static_cast<float>(beta[col]); |
127 | | - float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
| 180 | + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; |
128 | 181 | dst[row*ncols + col] = static_cast<T>(lhs * a + b); |
129 | 182 | } |
130 | 183 | } |
@@ -301,7 +354,9 @@ fast_max(const size_t src_numel, const size_t el_to_sum_per_block, |
301 | 354 | size_t tid = threadIdx.x; |
302 | 355 | size_t dst_id = blockIdx.x; |
303 | 356 |
|
304 | | - shr[tid] = -INFINITY; |
| 357 | + // Initialize with the lowest representable value for T so that the first |
| 358 | + // comparison in the reduction always picks a real element. |
| 359 | + shr[tid] = reduce_init_lowest<T>(); |
305 | 360 | // Elements summed in this block range from dst_id * el_to_sum_per_block |
306 | 361 | // to (dst_id + 1) * el_to_sum_per_block. |
307 | 362 | size_t start_idx = dst_id * el_to_sum_per_block; |
@@ -339,7 +394,9 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block, |
339 | 394 | size_t tid = threadIdx.x; |
340 | 395 | size_t dst_id = blockIdx.x; |
341 | 396 |
|
342 | | - shr[tid] = INFINITY; |
| 397 | + // Initialize with the highest representable value for T so that the first |
| 398 | + // comparison in the reduction always picks a real element. |
| 399 | + shr[tid] = reduce_init_highest<T>(); |
343 | 400 | // Elements summed in this block range from dst_id * el_to_sum_per_block |
344 | 401 | // to (dst_id + 1) * el_to_sum_per_block. |
345 | 402 | size_t start_idx = dst_id * el_to_sum_per_block; |
@@ -378,8 +435,9 @@ fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block, |
378 | 435 | size_t tid = threadIdx.x; |
379 | 436 | size_t dst_id = blockIdx.x; |
380 | 437 |
|
381 | | - // Not sure how that works on uint32_t and uint8_t but it seems to do ok. |
382 | | - shr[tid] = INFINITY; |
| 438 | + // For floating types this uses +inf; for integer types we use the largest |
| 439 | + // representable value instead of casting INFINITY to an integer. |
| 440 | + shr[tid] = reduce_init_highest<T>(); |
383 | 441 | shr_index[tid] = 0xFFFFFFFF; |
384 | 442 | bool not_set = true; |
385 | 443 | // Elements summed in this block range from dst_id * el_to_sum_per_block |
@@ -427,7 +485,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, |
427 | 485 | size_t tid = threadIdx.x; |
428 | 486 | size_t dst_id = blockIdx.x; |
429 | 487 |
|
430 | | - shr[tid] = -INFINITY; |
| 488 | + // For floating types this uses -inf; for integer types we use the lowest |
| 489 | + // representable value instead of casting -INFINITY to an integer. |
| 490 | + shr[tid] = reduce_init_lowest<T>(); |
431 | 491 | shr_index[tid] = 0xFFFFFFFF; |
432 | 492 | bool not_set = true; |
433 | 493 | // Elements summed in this block range from dst_id * el_to_sum_per_block |
|
0 commit comments