Skip to content

Commit 5926056

Browse files
committed
Merge branch 'update-pyo3' of github.com:huggingface/candle into update-pyo3
2 parents 6c6f713 + d78cc0f commit 5926056

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

candle-core/src/quantized/metal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ impl QMetalStorage {
3838

3939
let buffer = self.device.allocate_buffer(self.buffer.length())?;
4040
let blit = self.device.blit_command_encoder()?;
41-
blie.set_label("blit_to_cpu")?;
41+
blit.set_label("blit_to_cpu");
4242
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
4343
blit.end_encoding();
4444
self.device.wait_until_completed()?;

candle-kernels/src/reduce.cu

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,63 @@
11
#include "cuda_utils.cuh"
22
#include <cmath>
33
#include <stdint.h>
4+
#include <cuda/std/limits>
45

56
#define WARP_SIZE 32
67
const int BLOCK_SIZE = 1024;
78

9+
// Helpers to initialize reduction identities for both floating-point and
10+
// integer types. For floats we keep using +/-INFINITY, while for integers
11+
// we use well-defined numeric_limits values instead of relying on casting
12+
// +/-INFINITY to an integer type (which is undefined behaviour and has been
13+
// observed to break on newer GPU architectures such as Blackwell).
14+
template <typename T>
15+
__device__ __forceinline__ T reduce_init_lowest() {
16+
// Default implementation is used for floating-point types (__half,
17+
// __nv_bfloat16, float, double). The conversion from -INFINITY (double)
18+
// to these types is well-defined and produces -inf.
19+
return -INFINITY;
20+
}
21+
22+
template <typename T>
23+
__device__ __forceinline__ T reduce_init_highest() {
24+
// Default implementation is used for floating-point types (__half,
25+
// __nv_bfloat16, float, double). The conversion from INFINITY (double)
26+
// to these types is well-defined and produces +inf.
27+
return INFINITY;
28+
}
29+
30+
// Integer specializations – use numeric_limits instead of +/-INFINITY.
31+
template <>
32+
__device__ __forceinline__ int64_t reduce_init_lowest<int64_t>() {
33+
return ::cuda::std::numeric_limits<int64_t>::lowest();
34+
}
35+
36+
template <>
37+
__device__ __forceinline__ uint32_t reduce_init_lowest<uint32_t>() {
38+
return ::cuda::std::numeric_limits<uint32_t>::lowest();
39+
}
40+
41+
template <>
42+
__device__ __forceinline__ uint8_t reduce_init_lowest<uint8_t>() {
43+
return ::cuda::std::numeric_limits<uint8_t>::lowest();
44+
}
45+
46+
template <>
47+
__device__ __forceinline__ int64_t reduce_init_highest<int64_t>() {
48+
return ::cuda::std::numeric_limits<int64_t>::max();
49+
}
50+
51+
template <>
52+
__device__ __forceinline__ uint32_t reduce_init_highest<uint32_t>() {
53+
return ::cuda::std::numeric_limits<uint32_t>::max();
54+
}
55+
56+
template <>
57+
__device__ __forceinline__ uint8_t reduce_init_highest<uint8_t>() {
58+
return ::cuda::std::numeric_limits<uint8_t>::max();
59+
}
60+
861
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
962
// but also expect a f32 output so that this can be used for normalization e.g.
1063
// in softmax.
@@ -102,29 +155,29 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
102155

103156
if (alpha == nullptr && beta == nullptr) {
104157
for (int col = tid; col < ncols; col += block_size) {
105-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
158+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
106159
dst[row*ncols + col] = static_cast<T>(lhs);
107160
}
108161
}
109162
else if (alpha == nullptr && beta != nullptr) {
110163
for (int col = tid; col < ncols; col += block_size) {
111164
float b = static_cast<float>(beta[col]);
112-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
165+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
113166
dst[row*ncols + col] = static_cast<T>(lhs + b);
114167
}
115168
}
116169
else if (alpha != nullptr && beta == nullptr) {
117170
for (int col = tid; col < ncols; col += block_size) {
118171
float a = static_cast<float>(alpha[col]);
119-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
172+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
120173
dst[row*ncols + col] = static_cast<T>(lhs * a);
121174
}
122175
}
123176
else {
124177
for (int col = tid; col < ncols; col += block_size) {
125178
float a = static_cast<float>(alpha[col]);
126179
float b = static_cast<float>(beta[col]);
127-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
180+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
128181
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
129182
}
130183
}
@@ -301,7 +354,9 @@ fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
301354
size_t tid = threadIdx.x;
302355
size_t dst_id = blockIdx.x;
303356

304-
shr[tid] = -INFINITY;
357+
// Initialize with the lowest representable value for T so that the first
358+
// comparison in the reduction always picks a real element.
359+
shr[tid] = reduce_init_lowest<T>();
305360
// Elements summed in this block range from dst_id * el_to_sum_per_block
306361
// to (dst_id + 1) * el_to_sum_per_block.
307362
size_t start_idx = dst_id * el_to_sum_per_block;
@@ -339,7 +394,9 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
339394
size_t tid = threadIdx.x;
340395
size_t dst_id = blockIdx.x;
341396

342-
shr[tid] = INFINITY;
397+
// Initialize with the highest representable value for T so that the first
398+
// comparison in the reduction always picks a real element.
399+
shr[tid] = reduce_init_highest<T>();
343400
// Elements summed in this block range from dst_id * el_to_sum_per_block
344401
// to (dst_id + 1) * el_to_sum_per_block.
345402
size_t start_idx = dst_id * el_to_sum_per_block;
@@ -378,8 +435,9 @@ fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block,
378435
size_t tid = threadIdx.x;
379436
size_t dst_id = blockIdx.x;
380437

381-
// Not sure how that works on uint32_t and uint8_t but it seems to do ok.
382-
shr[tid] = INFINITY;
438+
// For floating types this uses +inf; for integer types we use the largest
439+
// representable value instead of casting INFINITY to an integer.
440+
shr[tid] = reduce_init_highest<T>();
383441
shr_index[tid] = 0xFFFFFFFF;
384442
bool not_set = true;
385443
// Elements summed in this block range from dst_id * el_to_sum_per_block
@@ -427,7 +485,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
427485
size_t tid = threadIdx.x;
428486
size_t dst_id = blockIdx.x;
429487

430-
shr[tid] = -INFINITY;
488+
// For floating types this uses -inf; for integer types we use the lowest
489+
// representable value instead of casting -INFINITY to an integer.
490+
shr[tid] = reduce_init_lowest<T>();
431491
shr_index[tid] = 0xFFFFFFFF;
432492
bool not_set = true;
433493
// Elements summed in this block range from dst_id * el_to_sum_per_block

0 commit comments

Comments
 (0)