@@ -67,13 +67,15 @@ def test_linear_no_igemmlt(device):
67
67
68
68
@pytest .mark .parametrize ("device" , get_available_devices ())
69
69
@pytest .mark .parametrize ("has_fp16_weights" , TRUE_FALSE , ids = id_formatter ("has_fp16_weights" ))
70
+ @pytest .mark .parametrize ("threshold" , [0.0 , 6.0 ], ids = id_formatter ("threshold" ))
70
71
@pytest .mark .parametrize ("serialize_before_forward" , TRUE_FALSE , ids = id_formatter ("serialize_before_forward" ))
71
72
@pytest .mark .parametrize ("deserialize_before_cuda" , TRUE_FALSE , ids = id_formatter ("deserialize_before_cuda" ))
72
73
@pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE , ids = id_formatter ("save_before_forward" ))
73
74
@pytest .mark .parametrize ("load_before_cuda" , TRUE_FALSE , ids = id_formatter ("load_before_cuda" ))
74
75
def test_linear_serialization (
75
76
device ,
76
77
has_fp16_weights ,
78
+ threshold ,
77
79
serialize_before_forward ,
78
80
deserialize_before_cuda ,
79
81
save_before_forward ,
@@ -92,7 +94,7 @@ def test_linear_serialization(
92
94
linear .out_features ,
93
95
linear .bias is not None ,
94
96
has_fp16_weights = has_fp16_weights ,
95
- threshold = 6.0 ,
97
+ threshold = threshold ,
96
98
)
97
99
98
100
linear_custom .weight = bnb .nn .Int8Params (
@@ -137,7 +139,7 @@ def test_linear_serialization(
137
139
linear .out_features ,
138
140
linear .bias is not None ,
139
141
has_fp16_weights = has_fp16_weights ,
140
- threshold = 6.0 ,
142
+ threshold = threshold ,
141
143
)
142
144
143
145
if deserialize_before_cuda :
0 commit comments