Skip to content

Commit da40911

Browse files
More cleanup
1 parent 975c356 commit da40911

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

bitsandbytes/autograd/_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def forward(
228228
subA = None
229229

230230
# 3. Int8 Matmul + Dequant + Bias
231-
output = torch.ops.bitsandbytes.int8_scaled_mm(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
231+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
232232

233233
# 4. Mixed-precision decomposition matmul
234234
if subA is not None and state.subB is not None:
@@ -278,7 +278,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
278278
if req_gradB:
279279
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
280280

281-
grad_B = torch.ops.bitsandbytes.int8_scaled_mm(
281+
grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
282282
Cgrad.t().contiguous(),
283283
CAt.t(),
284284
SCgradt,

bitsandbytes/backends/cuda/ops.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,11 @@ def _(
170170
A: torch.Tensor,
171171
threshold=0.0,
172172
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
173-
# TODO: Optimize/write CUDA kernel for this?
174-
175173
# Use CUDA kernel for rowwise and COO tensor
176-
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
174+
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default(
175+
A,
176+
threshold=threshold,
177+
)
177178

178179
# PyTorch impl for colwise
179180
col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold)

bitsandbytes/functional.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def dequantize_blockwise(
873873
)
874874
return out
875875

876-
return torch.ops.bitsandbytes.dequantize_blockwise(
876+
return torch.ops.bitsandbytes.dequantize_blockwise.default(
877877
A,
878878
absmax,
879879
quant_state.code.to(A.device),
@@ -2238,7 +2238,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
22382238
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
22392239
"""
22402240
# To dequantize we divide by 127, or multiply by the reciprocal.
2241-
return torch.ops.bitsandbytes.int8_vectorwise_dequant(A, stats)
2241+
return torch.ops.bitsandbytes.int8_vectorwise_dequant.default(A, stats)
22422242

22432243

22442244
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):

0 commit comments

Comments
 (0)