Skip to content

Commit ffee75c

Browse files
Additional device agnosticism for tests
1 parent 383b62d commit ffee75c

File tree

2 files changed

+113
-82
lines changed

2 files changed

+113
-82
lines changed

Diff for: tests/test_modules.py

+100-58
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def test_linear8bitlt_inference(device, threshold):
7171

7272

7373
# TODO: Remove support for training int8 weights
74-
@pytest.mark.deprecated
74+
@pytest.mark.parametrize("device", get_available_devices())
7575
def test_linear8bitlt_accumulated_gradient(device):
76+
if device != "cuda":
77+
pytest.skip("Only supported on CUDA")
78+
7679
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)])
7780
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).to(device).half() for i in range(2)])
7881
l1[0].weight.data.copy_(l2[0].weight.data)
@@ -114,56 +117,60 @@ def test_linear8bitlt_accumulated_gradient(device):
114117
assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
115118

116119

120+
@pytest.mark.parametrize("device", get_available_devices())
117121
@pytest.mark.parametrize("threshold", [0.0, 2.0])
118-
def test_linear8bitlt_no_fp16_weights(threshold):
122+
def test_linear8bitlt_no_fp16_weights(device, threshold):
123+
if device == "cpu":
124+
pytest.xfail("Not yet supported on CPU")
125+
119126
l1 = (
120127
bnb.nn.Linear8bitLt(
121128
32,
122129
64,
123130
threshold=threshold,
124131
has_fp16_weights=False,
125132
)
126-
.cuda()
133+
.to(device)
127134
.half()
128135
)
129136
assert l1.weight.dtype == torch.int8
130137

131138
l1.eval()
132139
for i in range(100):
133-
b1 = torch.randn(16, 8, 32, device="cuda").half()
140+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
134141
o1 = l1(b1)
135142
assert o1.dtype == torch.float16
136143

137-
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
144+
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device)
138145
assert mlp.fc1.weight.dtype == torch.int8
139146
assert mlp.fc2.weight.dtype == torch.int8
140147

141148
for i in range(100):
142-
b1 = torch.randn(16, 8, 32, device="cuda").half()
149+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
143150
o1 = mlp(b1)
144151
assert o1.dtype == torch.float16
145152
if threshold > 0:
146153
assert mlp.fc1.state.idx is not None
147154
if threshold > 0:
148155
assert mlp.fc2.state.idx is not None
149156

150-
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
157+
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
151158
assert mlp.fc1.weight.dtype == torch.int8
152159
assert mlp.fc2.weight.dtype == torch.int8
153160

154161
for i in range(100):
155-
b1 = torch.randn(16, 8, 32, device="cuda").half()
162+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
156163
o1 = mlp(b1)
157164
assert o1.dtype == torch.float16
158165
if threshold > 0:
159166
assert mlp.fc1.state.idx is not None
160167
if threshold > 0:
161168
assert mlp.fc2.state.idx is not None
162169

163-
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
170+
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
164171

165172
for i in range(100):
166-
b1 = torch.randn(16, 8, 32, device="cuda").half()
173+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
167174
o1 = mlp(b1)
168175
assert o1.dtype == torch.float16
169176
if threshold > 0:
@@ -181,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
181188
has_fp16_weights=False,
182189
)
183190
.half()
184-
.to("cuda")
191+
.to(device)
185192
)
186193

187194
for i in range(100):
188-
b1 = torch.randn(16, 8, 32, device="cuda").half()
195+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
189196
o1 = mlp(b1)
190197
assert o1.dtype == torch.float16
191198
if threshold > 0:
@@ -194,20 +201,20 @@ def test_linear8bitlt_no_fp16_weights(threshold):
194201
assert mlp.fc2.state.idx is not None
195202
assert mlp.fc1.weight.dtype == torch.int8
196203
assert mlp.fc2.weight.dtype == torch.int8
197-
assert mlp.fc1.weight.device.type == "cuda"
198-
assert mlp.fc2.weight.device.type == "cuda"
204+
assert mlp.fc1.weight.device.type == device
205+
assert mlp.fc2.weight.device.type == device
199206

200207
mlp = MLP8bit(
201208
32,
202209
64,
203210
threshold=threshold,
204211
has_fp16_weights=False,
205212
)
206-
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
213+
w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization,
207214
mlp = mlp.cuda().half() # and this line triggers quantization
208215

209216
for i in range(100):
210-
b1 = torch.randn(16, 8, 32, device="cuda").half()
217+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
211218
o1 = mlp(b1)
212219
assert o1.dtype == torch.float16
213220
if threshold > 0:
@@ -217,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
217224

218225
assert mlp.fc1.weight.dtype == torch.int8
219226
assert mlp.fc2.weight.dtype == torch.int8
220-
assert mlp.fc1.weight.device.type == "cuda"
221-
assert mlp.fc2.weight.device.type == "cuda"
227+
assert mlp.fc1.weight.device.type == device
228+
assert mlp.fc2.weight.device.type == device
222229

223-
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
230+
b1 = torch.randn(16, 8, 32, device=device, requires_grad=True, dtype=torch.half)
224231
o1 = mlp(b1)
225232
assert o1.dtype == torch.float16
226233
assert o1.requires_grad
@@ -236,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold):
236243
assert (idx == 0).sum().item() <= b1.numel() * 0.005
237244

238245

246+
@pytest.mark.parametrize("device", get_available_devices())
239247
@pytest.mark.parametrize(
240248
"module",
241249
[
242250
lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
243-
bnb.nn.LinearFP4,
251+
bnb.nn.LinearNF4,
244252
],
245-
ids=["Int8Lt", "FP4"],
253+
ids=["Int8Lt", "NF4"],
246254
)
247-
def test_linear_kbit_fp32_bias(module):
255+
def test_linear_kbit_fp32_bias(device, module):
256+
if device == "cpu":
257+
pytest.xfail("Not yet implemented on CPU")
258+
248259
# casts model to fp16 -> int8 automatically
249-
l1 = module(32, 64).cuda()
260+
l1 = module(32, 64).to(device)
250261
assert l1.weight.dtype in [torch.int8, torch.uint8]
251262
assert l1.bias.dtype == torch.float32
252263

253264
for i in range(100):
254-
b1 = torch.randn(16, 8, 32, device="cuda").half()
265+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
255266
# casts bias to fp32
256267
o1 = l1(b1)
257268
assert l1.bias.dtype == torch.float16
258269

259270
# casts model to fp16 -> int8 automatically
260-
l1 = module(32, 64, bias=False).cuda()
271+
l1 = module(32, 64, bias=False).to(device)
261272
assert l1.weight.dtype in [torch.int8, torch.uint8]
262273
assert l1.bias is None
263274

264275
for i in range(100):
265-
b1 = torch.randn(16, 8, 32, device="cuda").half()
276+
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
266277
o1 = l1(b1)
267278
assert l1.bias is None
268279

@@ -280,8 +291,12 @@ def test_linear_kbit_fp32_bias(module):
280291
}
281292

282293

294+
@pytest.mark.parametrize("device", get_available_devices())
283295
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
284-
def test_kbit_backprop(module):
296+
def test_kbit_backprop(device, module):
297+
if device == "cpu":
298+
pytest.xfail("Not yet implemented on CPU")
299+
285300
b = 16
286301
dim1 = 36
287302
dim2 = 84
@@ -297,16 +312,16 @@ def test_kbit_backprop(module):
297312
kbit[1].weight.detach().copy_(ref[1].weight)
298313
kbit[0].bias.detach().copy_(ref[0].bias)
299314
kbit[1].bias.detach().copy_(ref[1].bias)
300-
ref = ref.half().cuda()
301-
kbit = kbit.half().cuda()
302-
kbit = kbit.half().to("cuda")
315+
ref = ref.half().to(device)
316+
kbit = kbit.half().to(device)
317+
kbit = kbit.half().to(device)
303318

304319
errs1 = []
305320
errs2 = []
306321
relerrs1 = []
307322
relerrs2 = []
308323
for i in range(100):
309-
batch = torch.randn(b, dim1).half().cuda()
324+
batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
310325
out1 = ref(batch)
311326
out2 = kbit(batch)
312327
out1.mean().backward()
@@ -339,6 +354,7 @@ def test_kbit_backprop(module):
339354
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
340355

341356

357+
@pytest.mark.deprecated
342358
def test_fp8linear():
343359
b = 10
344360
h = 1024
@@ -369,6 +385,7 @@ def test_fp8linear():
369385
assert bgraderr < 0.00002
370386

371387

388+
@pytest.mark.parametrize("device", get_available_devices())
372389
@pytest.mark.parametrize("embedding_dim", [64, 65])
373390
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
374391
@pytest.mark.parametrize(
@@ -382,7 +399,10 @@ def test_fp8linear():
382399
],
383400
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
384401
)
385-
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
402+
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
403+
if device == "cpu":
404+
pytest.xfail("Not yet supported on CPU")
405+
386406
num_embeddings = 128
387407

388408
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
@@ -402,17 +422,18 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
402422

403423
e.load_state_dict(emb_base.state_dict())
404424

405-
emb_base.cuda()
406-
e.cuda()
425+
emb_base.to(device)
426+
e.to(device)
407427

408-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
428+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
409429

410430
torch.testing.assert_close(
411431
actual=e(input_tokens),
412432
expected=emb_base(input_tokens),
413433
)
414434

415435

436+
@pytest.mark.parametrize("device", get_available_devices())
416437
@pytest.mark.parametrize("embedding_dim", [64, 65])
417438
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
418439
@pytest.mark.parametrize(
@@ -426,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
426447
],
427448
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
428449
)
429-
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage):
450+
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
451+
if device == "cpu":
452+
pytest.xfail("Not yet supported on CPU")
453+
430454
is_8bit = embedding_class is bnb.nn.Embedding8bit
431455

432456
num_embeddings = 128
@@ -446,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
446470

447471
e.load_state_dict(emb_base.state_dict())
448472

449-
emb_base.cuda()
450-
e.cuda()
473+
emb_base.to(device)
474+
e.to(device)
451475

452-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
476+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
453477

454478
torch.testing.assert_close(
455479
actual=e(input_tokens),
@@ -459,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
459483
)
460484

461485

462-
def test_4bit_linear_warnings():
486+
@pytest.mark.parametrize("device", get_available_devices())
487+
def test_4bit_linear_warnings(device):
488+
if device == "cpu":
489+
pytest.xfail("Not yet implemented on CPU")
490+
463491
dim1 = 64
464492

465493
with pytest.warns(UserWarning, match=r"inference or training"):
466-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
467-
net = net.cuda()
468-
inp = torch.rand(10, dim1).cuda().half()
494+
net = nn.Sequential(
495+
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
496+
)
497+
net = net.to(device)
498+
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
469499
net(inp)
470500
with pytest.warns(UserWarning, match=r"inference."):
471-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
472-
net = net.cuda()
473-
inp = torch.rand(1, dim1).cuda().half()
501+
net = nn.Sequential(
502+
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
503+
)
504+
net = net.to(device)
505+
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
474506
net(inp)
475507

476508
with pytest.warns(UserWarning) as record:
477-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
478-
net = net.cuda()
479-
inp = torch.rand(10, dim1).cuda().half()
509+
net = nn.Sequential(
510+
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
511+
)
512+
net = net.to(device)
513+
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
480514
net(inp)
481515

482-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
483-
net = net.cuda()
484-
inp = torch.rand(1, dim1).cuda().half()
516+
net = nn.Sequential(
517+
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
518+
)
519+
net = net.to(device)
520+
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
485521
net(inp)
486522

487523
assert len(record) == 2
488524

489525

490-
def test_4bit_embedding_warnings():
526+
@pytest.mark.parametrize("device", get_available_devices())
527+
def test_4bit_embedding_warnings(device):
528+
if device == "cpu":
529+
pytest.xfail("Not yet implemented on CPU")
530+
491531
num_embeddings = 128
492532
default_block_size = 64
493533

494534
with pytest.warns(UserWarning, match=r"inference."):
495-
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
496-
net.cuda()
497-
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
535+
net = bnb.nn.Embedding4bit(
536+
num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
537+
)
538+
net.to(device)
539+
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
498540
net(inp)
499541

500542

501-
def test_4bit_embedding_weight_fsdp_fix():
543+
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
502544
num_embeddings = 64
503545
embedding_dim = 32
504546

@@ -515,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix():
515557
assert module.weight.quant_state is not None
516558

517559

518-
def test_4bit_linear_weight_fsdp_fix():
560+
def test_4bit_linear_weight_fsdp_fix(requires_cuda):
519561
inp_size = 64
520562
out_size = 32
521563

0 commit comments

Comments
 (0)