Skip to content

Commit 3d595f1

Browse files
test improvement
1 parent d25ebb4 commit 3d595f1

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

bitsandbytes/functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2656,7 +2656,7 @@ def double_quant(
26562656
threshold=threshold,
26572657
)
26582658

2659-
if threshold > 0.0:
2659+
if threshold > 0.0 and outlier_cols is not None:
26602660
# Build a COO tensor including all of the outlier columns.
26612661
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
26622662
outliers = A[:, outlier_cols]

tests/test_functional.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -703,17 +703,16 @@ def test_coo_double_quant(dim1, dim2):
703703
A = torch.randn(dim1, dim2, device="cuda").half()
704704

705705
idx = torch.abs(A) >= threshold
706-
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
706+
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
707707

708-
if coo_tensor is not None:
708+
if outlier_cols is not None:
709709
A1 = A * idx
710-
A2 = torch.zeros_like(A)
711-
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
710+
A2 = torch.zeros_like(A) + A1
712711
torch.testing.assert_close(A1, A2)
713712

714-
A1 = A * (idx == 0)
713+
A[:, outlier_cols] = 0
715714
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
716-
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
715+
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
717716

718717

719718
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@@ -728,6 +727,7 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):
728727

729728
if outlier_cols is not None:
730729
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
730+
A[:, outlier_cols] = 0
731731
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
732732

733733

tests/test_modules.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ def test_linear8bitlt_accumulated_gradient():
349349
l1[0].bias.data.copy_(l2[0].bias.data)
350350
l1[1].bias.data.copy_(l2[1].bias.data)
351351
else:
352-
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04)
353-
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02)
352+
assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1)
353+
assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
354354

355355

356356
@pytest.mark.parametrize("threshold", [0.0, 2.0])

0 commit comments

Comments
 (0)