15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
18
+ import os
18
19
import torch
19
20
from ..utils import DTYPE_BITS_MAPPING
20
21
from functools import reduce
23
24
from peft .tuners .lora import LoraLayer , LoraModel
24
25
from peft .utils .other import transpose
25
26
from intel_extension_for_transformers .transformers .llm .quantization .autograd import (
26
- matmul_kbit ,
27
- )
27
+ matmul_kbit , )
28
28
import intel_extension_for_transformers .qbits as qbits # pylint: disable=E0611, E0401
29
29
30
30
31
31
class DropoutQBits_ (torch .autograd .Function ):
32
+
32
33
@staticmethod
33
34
def forward (ctx , input , probability ):
34
35
mask = qbits .dropout_fwd (input , probability )
35
36
if any (ctx .needs_input_grad [:1 ]):
36
- ctx .tensors = (mask ,)
37
+ ctx .tensors = (mask , )
37
38
else :
38
- ctx .tensors = (None ,)
39
+ ctx .tensors = (None , )
39
40
return input
40
41
41
42
@staticmethod
@@ -51,6 +52,7 @@ def backward(ctx, grad_output):
51
52
52
53
53
54
class DropoutQBits (torch .nn .Module ):
55
+
54
56
def __init__ (self , p = 0.0 ):
55
57
super ().__init__ ()
56
58
self .p = p
@@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
63
65
64
66
65
67
class ParamsQBits (torch .nn .Parameter ):
68
+
66
69
def __new__ (
67
70
cls ,
68
71
data = None ,
@@ -87,6 +90,7 @@ def __new__(
87
90
88
91
89
92
class QuantizedLinearQBits (torch .nn .Linear ):
93
+
90
94
def __init__ (
91
95
self ,
92
96
input_features ,
@@ -156,6 +160,9 @@ def forward(self, x: torch.Tensor):
156
160
shape [- 1 ] = self .out_features
157
161
out = out .view (shape )
158
162
163
+ if os .environ .get ("backend" , None ) == "use_vllm" :
164
+ return out , None
165
+
159
166
return out
160
167
161
168
def set_fp_weights_bias (self , weight_data , bias = None ):
@@ -264,33 +271,24 @@ def quant_weight_w_scale(self, weight, scale, zp, group_size=-1):
264
271
if zp is not None :
265
272
zp = zp .to (device )
266
273
if group_size == - 1 :
267
- return (
268
- weight .div_ (scale ).round_ ()
269
- if zp is None
270
- else weight .div_ (scale ).add_ (zp ).round_ ()
271
- )
274
+ return (weight .div_ (scale ).round_ () if zp is None else weight .div_ (scale ).add_ (zp ).round_ ())
272
275
int_weight = torch .zeros (weight .shape ).to (device )
273
276
leng = weight .shape [1 ] // group_size
274
277
tail_flag = False if weight .shape [1 ] % group_size == 0 else True
275
278
for i in range (leng ):
276
- int_weight_tmp = weight [:, i * group_size : (i + 1 ) * group_size ].div_ (
277
- scale [:, i ].unsqueeze (1 )
278
- )
279
+ int_weight_tmp = weight [:, i * group_size :(i + 1 ) * group_size ].div_ (scale [:, i ].unsqueeze (1 ))
279
280
if zp is not None :
280
281
int_weight_tmp .add_ (zp [:, i ].unsqueeze (1 ))
281
- int_weight [:, i * group_size : (i + 1 ) * group_size ].copy_ (
282
- int_weight_tmp .round_ ()
283
- )
282
+ int_weight [:, i * group_size :(i + 1 ) * group_size ].copy_ (int_weight_tmp .round_ ())
284
283
if tail_flag :
285
- int_weight_tmp = weight [:, leng * group_size :].div_ (
286
- scale [:, - 1 ].unsqueeze (1 )
287
- )
284
+ int_weight_tmp = weight [:, leng * group_size :].div_ (scale [:, - 1 ].unsqueeze (1 ))
288
285
if zp is not None :
289
286
int_weight_tmp .add_ (zp [:, - 1 ].unsqueeze (1 ))
290
- int_weight [:, leng * group_size :].copy_ (int_weight_tmp .round_ ())
287
+ int_weight [:, leng * group_size :].copy_ (int_weight_tmp .round_ ())
291
288
return int_weight
292
289
293
290
def recover_qparms (self ):
291
+
294
292
def recover_idx (ret_idx , k , blocksize ):
295
293
g_idx = torch .zeros (k , dtype = int )
296
294
value_range = (k + blocksize - 1 ) // blocksize
@@ -328,18 +326,12 @@ def recover_int_weight(g_idx, int_weight):
328
326
else :
329
327
g_idx = None
330
328
weight_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 6 )
331
- weight_dtype = "" .join (
332
- chr (ascii_code ) for ascii_code in weight_dtype_ascii .tolist ()
333
- )
329
+ weight_dtype = "" .join (chr (ascii_code ) for ascii_code in weight_dtype_ascii .tolist ())
334
330
bits = 4 if weight_dtype in ["nf4" , "int4_clip" , "fp4" , "int4_fullrange" ] else 8
335
331
compute_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 7 )
336
- compute_dtype = "" .join (
337
- chr (ascii_code ) for ascii_code in compute_dtype_ascii .tolist ()
338
- )
332
+ compute_dtype = "" .join (chr (ascii_code ) for ascii_code in compute_dtype_ascii .tolist ())
339
333
scales_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 8 )
340
- scales_dtype = "" .join (
341
- chr (ascii_code ) for ascii_code in scales_dtype_ascii .tolist ()
342
- )
334
+ scales_dtype = "" .join (chr (ascii_code ) for ascii_code in scales_dtype_ascii .tolist ())
343
335
if scales_dtype is None :
344
336
assert False , "scales dtype only support fp32."
345
337
scales = qbits .acquire_packed_weight_info (self .weight , 9 )
@@ -356,9 +348,7 @@ def recover_int_weight(g_idx, int_weight):
356
348
357
349
revert_wei = torch .zeros (in_features , out_features , dtype = torch .float )
358
350
359
- qbits .dequantize_packed_weight (
360
- self .weight , revert_wei , False , compute_dtype , weight_dtype , scales_dtype
361
- )
351
+ qbits .dequantize_packed_weight (self .weight , revert_wei , False , compute_dtype , weight_dtype , scales_dtype )
362
352
363
353
int_weight = self .quant_weight_w_scale (
364
354
revert_wei .t (),
@@ -426,9 +416,7 @@ def __init__(
426
416
except :
427
417
qbits_customop_available = False
428
418
if lora_dropout > 0 and qbits_customop_available :
429
- self .lora_dropout = torch .nn .ModuleDict (
430
- {adapter_name : DropoutQBits (p = lora_dropout )}
431
- )
419
+ self .lora_dropout = torch .nn .ModuleDict ({adapter_name : DropoutQBits (p = lora_dropout )})
432
420
433
421
def merge (self , safe_merge : bool = False ) -> None :
434
422
"""Merge the active adapter weights into the base weights.
@@ -440,10 +428,8 @@ def merge(self, safe_merge: bool = False) -> None:
440
428
NaNs. Defaults to `False`.
441
429
"""
442
430
if self .merged :
443
- print (
444
- f"Already following adapters were merged { ',' .join (self .merged_adapters )} . "
445
- f"You are now additionally merging { ',' .join (self .active_adapters )} ."
446
- )
431
+ print (f"Already following adapters were merged { ',' .join (self .merged_adapters )} . "
432
+ f"You are now additionally merging { ',' .join (self .active_adapters )} ." )
447
433
w_dequant = torch .zeros (
448
434
self .out_features ,
449
435
self .in_features ,
@@ -468,8 +454,7 @@ def merge(self, safe_merge: bool = False) -> None:
468
454
469
455
if not torch .isfinite (orig_weights ).all ():
470
456
raise ValueError (
471
- f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken"
472
- )
457
+ f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken" )
473
458
474
459
w_data = orig_weights
475
460
else :
@@ -541,13 +526,10 @@ def unmerge(self) -> None:
541
526
)
542
527
543
528
def get_delta_weight (self , adapter ) -> torch .Tensor :
544
- return (
545
- transpose (
546
- self .lora_B [adapter ].weight @ self .lora_A [adapter ].weight ,
547
- False ,
548
- )
549
- * self .scaling [adapter ]
550
- )
529
+ return (transpose (
530
+ self .lora_B [adapter ].weight @ self .lora_A [adapter ].weight ,
531
+ False ,
532
+ ) * self .scaling [adapter ])
551
533
552
534
def forward (self , x : torch .Tensor ) -> torch .Tensor :
553
535
if self .disable_adapters :
@@ -602,24 +584,18 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs):
602
584
bias = kwargs .pop ("bias" , False )
603
585
in_features , out_features = target .in_features , target .out_features
604
586
if kwargs ["fan_in_fan_out" ]:
605
- print (
606
- "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
607
- "Setting fan_in_fan_out to False."
608
- )
587
+ print ("fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
588
+ "Setting fan_in_fan_out to False." )
609
589
kwargs ["fan_in_fan_out" ] = lora_config .fan_in_fan_out = False
610
590
kwargs ["compute_dtype" ] = target .compute_dtype
611
591
kwargs ["compress_statistics" ] = target .compress_statistics
612
592
kwargs ["weight_dtype" ] = target .weight_dtype
613
593
kwargs ["scale_dtype" ] = target .scale_dtype
614
594
kwargs ["blocksize" ] = target .blocksize
615
595
kwargs ["scheme" ] = target .scheme
616
- new_module = QuantizedLoraLinearQBits (
617
- adapter_name , in_features , out_features , bias = bias , ** kwargs
618
- )
596
+ new_module = QuantizedLoraLinearQBits (adapter_name , in_features , out_features , bias = bias , ** kwargs )
619
597
else :
620
- new_module = QBitsLoraModel ._create_new_module_ (
621
- lora_config , adapter_name , target , ** kwargs
622
- )
598
+ new_module = QBitsLoraModel ._create_new_module_ (lora_config , adapter_name , target , ** kwargs )
623
599
return new_module
624
600
625
601
0 commit comments