4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- try :
8
- import bitsandbytes as bnb
9
-
10
- bnb_installed = True
11
- except ImportError :
12
- bnb_installed = False
13
7
import pytest
14
8
import torch
15
9
from torchao .dtypes .nf4tensor import NF4Tensor
@@ -22,19 +16,6 @@ def random():
22
16
set_seed (31 )
23
17
24
18
25
- def _build_bnb_linear (input_weight ):
26
- """
27
- Builds a bnb.nn.LinearNF4 from a given input weight
28
- """
29
- param = bnb .nn .Params4bit (input_weight , requires_grad = False , quant_type = "nf4" )
30
- bnb_linear = bnb .nn .LinearNF4 (
31
- input_weight .size (0 ), input_weight .size (1 ), bias = False
32
- )
33
- bnb_linear .weight = param
34
- bnb_linear .cuda ()
35
- return bnb_linear
36
-
37
-
38
19
class TestNF4Linear :
39
20
"""
40
21
Class for testing our NF4Linear implementation.
@@ -88,18 +69,29 @@ def test_backward_dtype(self, dtype):
88
69
assert inp .grad is not None and inp .grad .dtype == dtype
89
70
assert nf4_linear .weight .grad is None
90
71
91
- @pytest .mark .skipif (not bnb_installed , reason = "bitsandbytes is not installed" )
92
72
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
93
73
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
94
74
def test_nf4_reconstruction_vs_bnb (self , dtype ):
95
75
"""
96
76
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when
97
77
reconstructing the respective original weights.
98
78
"""
79
+ try :
80
+ import bitsandbytes as bnb
81
+ except ImportError :
82
+ pytest .skip ("bitsandbytes is not installed" )
83
+ return
84
+
99
85
dim = 512
100
86
nf4_linear = FrozenNF4Linear (dim , dim , device = "cuda" , dtype = dtype )
101
87
orig_weight = nf4_linear .weight .get_original_weight ().clone ().detach ()
102
- bnb_nf4_linear = _build_bnb_linear (input_weight = orig_weight )
88
+
89
+ param = bnb .nn .Params4bit (orig_weight , requires_grad = False , quant_type = "nf4" )
90
+ bnb_nf4_linear = bnb .nn .LinearNF4 (
91
+ orig_weight .size (0 ), orig_weight .size (1 ), bias = False
92
+ )
93
+ bnb_nf4_linear .weight = param
94
+ bnb_nf4_linear .cuda ()
103
95
104
96
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
105
97
bnb_reconstruction = bnb_nf4_linear (
@@ -110,18 +102,30 @@ def test_nf4_reconstruction_vs_bnb(self, dtype):
110
102
bnb_reconstruction .T , nf4_linear .weight .get_original_weight (), 1e-2
111
103
)
112
104
113
- @pytest .mark .skipif (not bnb_installed , reason = "bitsandbytes is not installed" )
114
105
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
115
106
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
116
107
def test_nf4_bnb_linear (self , dtype ):
117
108
"""
118
109
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
119
110
error compared to a bf16 linear is not more than BNB's implementation.
120
111
"""
112
+ try :
113
+ import bitsandbytes as bnb
114
+ except ImportError :
115
+ pytest .skip ("bitsandbytes is not installed" )
116
+ return
117
+
121
118
dim = 512
122
119
nf4_linear = FrozenNF4Linear (dim , dim , device = "cuda" , dtype = dtype )
123
120
orig_weight = nf4_linear .weight .get_original_weight ().clone ().detach ()
124
- bnb_nf4_linear = _build_bnb_linear (input_weight = orig_weight )
121
+
122
+ param = bnb .nn .Params4bit (orig_weight , requires_grad = False , quant_type = "nf4" )
123
+ bnb_nf4_linear = bnb .nn .LinearNF4 (
124
+ orig_weight .size (0 ), orig_weight .size (1 ), bias = False
125
+ )
126
+ bnb_nf4_linear .weight = param
127
+ bnb_nf4_linear .cuda ()
128
+
125
129
bf16_linear = torch .nn .Linear (dim , dim , device = "cuda" , dtype = dtype )
126
130
127
131
inp = torch .randn (2 , 512 , dtype = dtype , device = "cuda" )
0 commit comments