Skip to content

Commit 5b2348b

Browse files
Bugfix sparse decomp
1 parent 5d536c6 commit 5b2348b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

bitsandbytes/functional.py

+5
Original file line numberDiff line numberDiff line change
@@ -2793,6 +2793,11 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
27932793
_get_tensor_stream(A),
27942794
)
27952795

2796+
# Zero out values from outlier columns across all rows.
2797+
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
2798+
if rows > 1 and outlier_cols is not None:
2799+
out_row[:, outlier_cols] = 0
2800+
27962801
return out_row, row_stats, outlier_cols
27972802

27982803

csrc/kernels.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21452145

21462146
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
21472147
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
2148-
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__
2148+
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE
21492149
using TReduction = T;
21502150
#else
21512151
using TReduction = float;

0 commit comments

Comments
 (0)