Skip to content

Commit 4fc576d

Browse files
Small bugfix for int8 test
1 parent bbee99a commit 4fc576d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

Diff for: bitsandbytes/autograd/_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
284284
dtype=torch.float16,
285285
)
286286

287-
if state.threshold > 0.0 and subA is not None:
287+
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
288288
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
289289

290290
if req_gradA:

Diff for: tests/test_linear8bitlt.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ def test_linear_no_igemmlt(device):
6767

6868
@pytest.mark.parametrize("device", get_available_devices())
6969
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
70+
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
7071
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
7172
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
7273
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
7374
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
7475
def test_linear_serialization(
7576
device,
7677
has_fp16_weights,
78+
threshold,
7779
serialize_before_forward,
7880
deserialize_before_cuda,
7981
save_before_forward,
@@ -92,7 +94,7 @@ def test_linear_serialization(
9294
linear.out_features,
9395
linear.bias is not None,
9496
has_fp16_weights=has_fp16_weights,
95-
threshold=6.0,
97+
threshold=threshold,
9698
)
9799

98100
linear_custom.weight = bnb.nn.Int8Params(
@@ -137,7 +139,7 @@ def test_linear_serialization(
137139
linear.out_features,
138140
linear.bias is not None,
139141
has_fp16_weights=has_fp16_weights,
140-
threshold=6.0,
142+
threshold=threshold,
141143
)
142144

143145
if deserialize_before_cuda:

0 commit comments

Comments
 (0)