@@ -71,8 +71,11 @@ def test_linear8bitlt_inference(device, threshold):
71
71
72
72
73
73
# TODO: Remove support for training int8 weights
74
- @pytest .mark .deprecated
74
+ @pytest .mark .parametrize ( "device" , get_available_devices ())
75
75
def test_linear8bitlt_accumulated_gradient (device ):
76
+ if device != "cuda" :
77
+ pytest .skip ("Only supported on CUDA" )
78
+
76
79
l1 = torch .nn .Sequential (* [bnb .nn .Linear8bitLt (32 , 32 ).to (device ).half () for i in range (2 )])
77
80
l2 = torch .nn .Sequential (* [torch .nn .Linear (32 , 32 ).to (device ).half () for i in range (2 )])
78
81
l1 [0 ].weight .data .copy_ (l2 [0 ].weight .data )
@@ -114,56 +117,60 @@ def test_linear8bitlt_accumulated_gradient(device):
114
117
assert_all_approx_close (l1 [1 ].weight .grad , l2 [1 ].weight .grad , rtol = 1.05 , atol = 0.04 , count = 1 )
115
118
116
119
120
+ @pytest .mark .parametrize ("device" , get_available_devices ())
117
121
@pytest .mark .parametrize ("threshold" , [0.0 , 2.0 ])
118
- def test_linear8bitlt_no_fp16_weights (threshold ):
122
+ def test_linear8bitlt_no_fp16_weights (device , threshold ):
123
+ if device == "cpu" :
124
+ pytest .xfail ("Not yet supported on CPU" )
125
+
119
126
l1 = (
120
127
bnb .nn .Linear8bitLt (
121
128
32 ,
122
129
64 ,
123
130
threshold = threshold ,
124
131
has_fp16_weights = False ,
125
132
)
126
- .cuda ( )
133
+ .to ( device )
127
134
.half ()
128
135
)
129
136
assert l1 .weight .dtype == torch .int8
130
137
131
138
l1 .eval ()
132
139
for i in range (100 ):
133
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
140
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
134
141
o1 = l1 (b1 )
135
142
assert o1 .dtype == torch .float16
136
143
137
- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).cuda ( )
144
+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).to ( device )
138
145
assert mlp .fc1 .weight .dtype == torch .int8
139
146
assert mlp .fc2 .weight .dtype == torch .int8
140
147
141
148
for i in range (100 ):
142
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
149
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
143
150
o1 = mlp (b1 )
144
151
assert o1 .dtype == torch .float16
145
152
if threshold > 0 :
146
153
assert mlp .fc1 .state .idx is not None
147
154
if threshold > 0 :
148
155
assert mlp .fc2 .state .idx is not None
149
156
150
- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).cuda ( ).half ()
157
+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).to ( device ).half ()
151
158
assert mlp .fc1 .weight .dtype == torch .int8
152
159
assert mlp .fc2 .weight .dtype == torch .int8
153
160
154
161
for i in range (100 ):
155
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
162
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
156
163
o1 = mlp (b1 )
157
164
assert o1 .dtype == torch .float16
158
165
if threshold > 0 :
159
166
assert mlp .fc1 .state .idx is not None
160
167
if threshold > 0 :
161
168
assert mlp .fc2 .state .idx is not None
162
169
163
- mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).half ().cuda ( )
170
+ mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).half ().to ( device )
164
171
165
172
for i in range (100 ):
166
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
173
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
167
174
o1 = mlp (b1 )
168
175
assert o1 .dtype == torch .float16
169
176
if threshold > 0 :
@@ -181,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
181
188
has_fp16_weights = False ,
182
189
)
183
190
.half ()
184
- .to ("cuda" )
191
+ .to (device )
185
192
)
186
193
187
194
for i in range (100 ):
188
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
195
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
189
196
o1 = mlp (b1 )
190
197
assert o1 .dtype == torch .float16
191
198
if threshold > 0 :
@@ -194,20 +201,20 @@ def test_linear8bitlt_no_fp16_weights(threshold):
194
201
assert mlp .fc2 .state .idx is not None
195
202
assert mlp .fc1 .weight .dtype == torch .int8
196
203
assert mlp .fc2 .weight .dtype == torch .int8
197
- assert mlp .fc1 .weight .device .type == "cuda"
198
- assert mlp .fc2 .weight .device .type == "cuda"
204
+ assert mlp .fc1 .weight .device .type == device
205
+ assert mlp .fc2 .weight .device .type == device
199
206
200
207
mlp = MLP8bit (
201
208
32 ,
202
209
64 ,
203
210
threshold = threshold ,
204
211
has_fp16_weights = False ,
205
212
)
206
- w1 , w2 = mlp .fc1 .weight .clone ().cuda ( ), mlp .fc2 .weight .clone ().cuda ( ) # grab weights before quantization,
213
+ w1 , w2 = mlp .fc1 .weight .clone ().to ( device ), mlp .fc2 .weight .clone ().to ( device ) # grab weights before quantization,
207
214
mlp = mlp .cuda ().half () # and this line triggers quantization
208
215
209
216
for i in range (100 ):
210
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
217
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
211
218
o1 = mlp (b1 )
212
219
assert o1 .dtype == torch .float16
213
220
if threshold > 0 :
@@ -217,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
217
224
218
225
assert mlp .fc1 .weight .dtype == torch .int8
219
226
assert mlp .fc2 .weight .dtype == torch .int8
220
- assert mlp .fc1 .weight .device .type == "cuda"
221
- assert mlp .fc2 .weight .device .type == "cuda"
227
+ assert mlp .fc1 .weight .device .type == device
228
+ assert mlp .fc2 .weight .device .type == device
222
229
223
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" , requires_grad = True , dtype = torch .half )
230
+ b1 = torch .randn (16 , 8 , 32 , device = device , requires_grad = True , dtype = torch .half )
224
231
o1 = mlp (b1 )
225
232
assert o1 .dtype == torch .float16
226
233
assert o1 .requires_grad
@@ -236,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold):
236
243
assert (idx == 0 ).sum ().item () <= b1 .numel () * 0.005
237
244
238
245
246
+ @pytest .mark .parametrize ("device" , get_available_devices ())
239
247
@pytest .mark .parametrize (
240
248
"module" ,
241
249
[
242
250
lambda n_in , n_out , bias = True : bnb .nn .Linear8bitLt (n_in , n_out , bias = bias , has_fp16_weights = False ),
243
- bnb .nn .LinearFP4 ,
251
+ bnb .nn .LinearNF4 ,
244
252
],
245
- ids = ["Int8Lt" , "FP4 " ],
253
+ ids = ["Int8Lt" , "NF4 " ],
246
254
)
247
- def test_linear_kbit_fp32_bias (module ):
255
+ def test_linear_kbit_fp32_bias (device , module ):
256
+ if device == "cpu" :
257
+ pytest .xfail ("Not yet implemented on CPU" )
258
+
248
259
# casts model to fp16 -> int8 automatically
249
- l1 = module (32 , 64 ).cuda ( )
260
+ l1 = module (32 , 64 ).to ( device )
250
261
assert l1 .weight .dtype in [torch .int8 , torch .uint8 ]
251
262
assert l1 .bias .dtype == torch .float32
252
263
253
264
for i in range (100 ):
254
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
265
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
255
266
# casts bias to fp32
256
267
o1 = l1 (b1 )
257
268
assert l1 .bias .dtype == torch .float16
258
269
259
270
# casts model to fp16 -> int8 automatically
260
- l1 = module (32 , 64 , bias = False ).cuda ( )
271
+ l1 = module (32 , 64 , bias = False ).to ( device )
261
272
assert l1 .weight .dtype in [torch .int8 , torch .uint8 ]
262
273
assert l1 .bias is None
263
274
264
275
for i in range (100 ):
265
- b1 = torch .randn (16 , 8 , 32 , device = "cuda" ). half ( )
276
+ b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch . float16 )
266
277
o1 = l1 (b1 )
267
278
assert l1 .bias is None
268
279
@@ -280,8 +291,12 @@ def test_linear_kbit_fp32_bias(module):
280
291
}
281
292
282
293
294
+ @pytest .mark .parametrize ("device" , get_available_devices ())
283
295
@pytest .mark .parametrize ("module" , module_dict .values (), ids = module_dict .keys ())
284
- def test_kbit_backprop (module ):
296
+ def test_kbit_backprop (device , module ):
297
+ if device == "cpu" :
298
+ pytest .xfail ("Not yet implemented on CPU" )
299
+
285
300
b = 16
286
301
dim1 = 36
287
302
dim2 = 84
@@ -297,16 +312,16 @@ def test_kbit_backprop(module):
297
312
kbit [1 ].weight .detach ().copy_ (ref [1 ].weight )
298
313
kbit [0 ].bias .detach ().copy_ (ref [0 ].bias )
299
314
kbit [1 ].bias .detach ().copy_ (ref [1 ].bias )
300
- ref = ref .half ().cuda ( )
301
- kbit = kbit .half ().cuda ( )
302
- kbit = kbit .half ().to ("cuda" )
315
+ ref = ref .half ().to ( device )
316
+ kbit = kbit .half ().to ( device )
317
+ kbit = kbit .half ().to (device )
303
318
304
319
errs1 = []
305
320
errs2 = []
306
321
relerrs1 = []
307
322
relerrs2 = []
308
323
for i in range (100 ):
309
- batch = torch .randn (b , dim1 ). half (). cuda ( )
324
+ batch = torch .randn (b , dim1 , device = device , dtype = torch . float16 )
310
325
out1 = ref (batch )
311
326
out2 = kbit (batch )
312
327
out1 .mean ().backward ()
@@ -339,6 +354,7 @@ def test_kbit_backprop(module):
339
354
assert kbit [0 ].weight .grad is None or kbit [0 ].bias .grad .sum ().item () == 0
340
355
341
356
357
+ @pytest .mark .deprecated
342
358
def test_fp8linear ():
343
359
b = 10
344
360
h = 1024
@@ -369,6 +385,7 @@ def test_fp8linear():
369
385
assert bgraderr < 0.00002
370
386
371
387
388
+ @pytest .mark .parametrize ("device" , get_available_devices ())
372
389
@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
373
390
@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
374
391
@pytest .mark .parametrize (
@@ -382,7 +399,10 @@ def test_fp8linear():
382
399
],
383
400
ids = lambda x : x .__name__ if inspect .isclass (x ) else str (x ),
384
401
)
385
- def test_embedding_lossless (embedding_class , input_shape , embedding_dim , quant_storage ):
402
+ def test_embedding_lossless (device , embedding_class , input_shape , embedding_dim , quant_storage ):
403
+ if device == "cpu" :
404
+ pytest .xfail ("Not yet supported on CPU" )
405
+
386
406
num_embeddings = 128
387
407
388
408
src_weight = (torch .randn ((num_embeddings , embedding_dim ), dtype = torch .float32 ) > 0 ).to (
@@ -402,17 +422,18 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
402
422
403
423
e .load_state_dict (emb_base .state_dict ())
404
424
405
- emb_base .cuda ( )
406
- e .cuda ( )
425
+ emb_base .to ( device )
426
+ e .to ( device )
407
427
408
- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = "cuda" )
428
+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = device )
409
429
410
430
torch .testing .assert_close (
411
431
actual = e (input_tokens ),
412
432
expected = emb_base (input_tokens ),
413
433
)
414
434
415
435
436
+ @pytest .mark .parametrize ("device" , get_available_devices ())
416
437
@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
417
438
@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
418
439
@pytest .mark .parametrize (
@@ -426,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
426
447
],
427
448
ids = lambda x : x .__name__ if inspect .isclass (x ) else str (x ),
428
449
)
429
- def test_embedding_error (embedding_class , input_shape , embedding_dim , quant_storage ):
450
+ def test_embedding_error (device , embedding_class , input_shape , embedding_dim , quant_storage ):
451
+ if device == "cpu" :
452
+ pytest .xfail ("Not yet supported on CPU" )
453
+
430
454
is_8bit = embedding_class is bnb .nn .Embedding8bit
431
455
432
456
num_embeddings = 128
@@ -446,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
446
470
447
471
e .load_state_dict (emb_base .state_dict ())
448
472
449
- emb_base .cuda ( )
450
- e .cuda ( )
473
+ emb_base .to ( device )
474
+ e .to ( device )
451
475
452
- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = "cuda" )
476
+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = device )
453
477
454
478
torch .testing .assert_close (
455
479
actual = e (input_tokens ),
@@ -459,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
459
483
)
460
484
461
485
462
- def test_4bit_linear_warnings ():
486
+ @pytest .mark .parametrize ("device" , get_available_devices ())
487
+ def test_4bit_linear_warnings (device ):
488
+ if device == "cpu" :
489
+ pytest .xfail ("Not yet implemented on CPU" )
490
+
463
491
dim1 = 64
464
492
465
493
with pytest .warns (UserWarning , match = r"inference or training" ):
466
- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
467
- net = net .cuda ()
468
- inp = torch .rand (10 , dim1 ).cuda ().half ()
494
+ net = nn .Sequential (
495
+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
496
+ )
497
+ net = net .to (device )
498
+ inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
469
499
net (inp )
470
500
with pytest .warns (UserWarning , match = r"inference." ):
471
- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
472
- net = net .cuda ()
473
- inp = torch .rand (1 , dim1 ).cuda ().half ()
501
+ net = nn .Sequential (
502
+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
503
+ )
504
+ net = net .to (device )
505
+ inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
474
506
net (inp )
475
507
476
508
with pytest .warns (UserWarning ) as record :
477
- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
478
- net = net .cuda ()
479
- inp = torch .rand (10 , dim1 ).cuda ().half ()
509
+ net = nn .Sequential (
510
+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
511
+ )
512
+ net = net .to (device )
513
+ inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
480
514
net (inp )
481
515
482
- net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , compute_dtype = torch .float32 ) for i in range (10 )])
483
- net = net .cuda ()
484
- inp = torch .rand (1 , dim1 ).cuda ().half ()
516
+ net = nn .Sequential (
517
+ * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
518
+ )
519
+ net = net .to (device )
520
+ inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
485
521
net (inp )
486
522
487
523
assert len (record ) == 2
488
524
489
525
490
- def test_4bit_embedding_warnings ():
526
+ @pytest .mark .parametrize ("device" , get_available_devices ())
527
+ def test_4bit_embedding_warnings (device ):
528
+ if device == "cpu" :
529
+ pytest .xfail ("Not yet implemented on CPU" )
530
+
491
531
num_embeddings = 128
492
532
default_block_size = 64
493
533
494
534
with pytest .warns (UserWarning , match = r"inference." ):
495
- net = bnb .nn .Embedding4bit (num_embeddings = num_embeddings , embedding_dim = default_block_size + 1 )
496
- net .cuda ()
497
- inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = "cuda" )
535
+ net = bnb .nn .Embedding4bit (
536
+ num_embeddings = num_embeddings , embedding_dim = default_block_size + 1 , quant_type = "nf4"
537
+ )
538
+ net .to (device )
539
+ inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = device )
498
540
net (inp )
499
541
500
542
501
- def test_4bit_embedding_weight_fsdp_fix ():
543
+ def test_4bit_embedding_weight_fsdp_fix (requires_cuda ):
502
544
num_embeddings = 64
503
545
embedding_dim = 32
504
546
@@ -515,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix():
515
557
assert module .weight .quant_state is not None
516
558
517
559
518
- def test_4bit_linear_weight_fsdp_fix ():
560
+ def test_4bit_linear_weight_fsdp_fix (requires_cuda ):
519
561
inp_size = 64
520
562
out_size = 32
521
563
0 commit comments