8
8
from torch .nn import CrossEntropyLoss
9
9
from transformers import PreTrainedModel
10
10
from transformers .activations import ACT2FN
11
+ from transformers .generation .utils import GenerationConfig
11
12
from transformers .modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
12
13
from transformers .utils import logging
13
- from transformers .generation .utils import GenerationConfig
14
14
15
15
from .configuration_baichuan import BaichuanConfig
16
16
19
19
20
20
def _get_interleave (n ):
21
21
def _get_interleave_power_of_2 (n ):
22
- start = ( 2 ** (- 2 ** - (math .log2 (n ) - 3 )))
22
+ start = 2 ** (- ( 2 ** - (math .log2 (n ) - 3 )))
23
23
ratio = start
24
- return [start * ratio ** i for i in range (n )]
24
+ return [start * ratio ** i for i in range (n )]
25
25
26
26
if math .log2 (n ).is_integer ():
27
27
return _get_interleave_power_of_2 (n )
28
28
else :
29
29
closest_power_of_2 = 2 ** math .floor (math .log2 (n ))
30
- return _get_interleave_power_of_2 (closest_power_of_2 ) + \
31
- _get_interleave (2 * closest_power_of_2 )[0 ::2 ][:n - closest_power_of_2 ]
30
+ return (
31
+ _get_interleave_power_of_2 (closest_power_of_2 )
32
+ + _get_interleave (2 * closest_power_of_2 )[0 ::2 ][: n - closest_power_of_2 ]
33
+ )
34
+
32
35
33
36
def _fill_with_neg_inf (t ):
34
37
"""FP16-compatible function that fills a tensor with -inf."""
35
38
return t .float ().fill_ (float ("-inf" )).type_as (t )
36
39
40
+
37
41
def _gen_alibi_mask (n_head , max_pos ):
38
42
"""used in inference only"""
39
43
slopes = torch .Tensor (_get_interleave (n_head ))
40
- alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * torch .arange (max_pos ).unsqueeze (0 ).unsqueeze (0 ).expand (
41
- n_head , - 1 , - 1 )
44
+ alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * torch .arange (max_pos ).unsqueeze (0 ).unsqueeze (0 ).expand (n_head , - 1 , - 1 )
42
45
alibi = alibi .view (n_head , 1 , max_pos )
43
- alibi_mask = torch .triu (
44
- _fill_with_neg_inf (torch .zeros ([max_pos , max_pos ])), 1
45
- )
46
+ alibi_mask = torch .triu (_fill_with_neg_inf (torch .zeros ([max_pos , max_pos ])), 1 )
46
47
alibi_mask = alibi_mask .unsqueeze (0 ) + alibi
47
48
return alibi_mask
48
49
50
+
49
51
def _buffered_future_mask (tensor , maxpos , alibi , attn_heads ):
50
52
"""used in training only"""
51
- dim = tensor .size (1 )
52
- _future_mask = torch .triu (
53
- _fill_with_neg_inf (torch .zeros ([maxpos , maxpos ])), 1
54
- )
53
+ tensor .size (1 )
54
+ _future_mask = torch .triu (_fill_with_neg_inf (torch .zeros ([maxpos , maxpos ])), 1 )
55
55
_future_mask = _future_mask .unsqueeze (0 ) + alibi
56
56
_future_mask = _future_mask .to (tensor )
57
- return _future_mask [:tensor .shape [0 ] * attn_heads , :maxpos , :maxpos ]
57
+ return _future_mask [: tensor .shape [0 ] * attn_heads , :maxpos , :maxpos ]
58
58
59
59
60
60
class RMSNorm (torch .nn .Module ):
@@ -76,10 +76,10 @@ def forward(self, hidden_states):
76
76
77
77
class MLP (torch .nn .Module ):
78
78
def __init__ (
79
- self ,
80
- hidden_size : int ,
81
- intermediate_size : int ,
82
- hidden_act : str ,
79
+ self ,
80
+ hidden_size : int ,
81
+ intermediate_size : int ,
82
+ hidden_act : str ,
83
83
):
84
84
super ().__init__ ()
85
85
self .gate_proj = torch .nn .Linear (hidden_size , intermediate_size , bias = False )
@@ -101,24 +101,21 @@ def __init__(self, config: BaichuanConfig):
101
101
self .max_position_embeddings = config .model_max_length
102
102
103
103
if (self .head_dim * self .num_heads ) != self .hidden_size :
104
- raise ValueError (
105
- f"hidden_size { self .hidden_size } is not divisible by num_heads { self .num_heads } "
106
- )
104
+ raise ValueError (f"hidden_size { self .hidden_size } is not divisible by num_heads { self .num_heads } " )
107
105
self .W_pack = torch .nn .Linear (self .hidden_size , 3 * self .hidden_size , bias = False )
108
106
self .o_proj = torch .nn .Linear (self .num_heads * self .head_dim , self .hidden_size , bias = False )
109
107
110
108
def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
111
109
return tensor .view (bsz , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 ).contiguous ()
112
110
113
111
def forward (
114
- self ,
115
- hidden_states : torch .Tensor ,
116
- attention_mask : Optional [torch .Tensor ] = None ,
117
- past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
118
- output_attentions : bool = False ,
119
- use_cache : bool = False ,
112
+ self ,
113
+ hidden_states : torch .Tensor ,
114
+ attention_mask : Optional [torch .Tensor ] = None ,
115
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
116
+ output_attentions : bool = False ,
117
+ use_cache : bool = False ,
120
118
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
121
-
122
119
bsz , q_len , _ = hidden_states .size ()
123
120
124
121
proj = self .W_pack (hidden_states )
@@ -141,11 +138,11 @@ def forward(
141
138
attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
142
139
143
140
if attention_mask is not None :
144
- if q_len == 1 : # inference with cache
141
+ if q_len == 1 : # inference with cache
145
142
if len (attention_mask .size ()) == 4 :
146
- attention_mask = attention_mask [:, :, - 1 :, :]
143
+ attention_mask = attention_mask [:, :, - 1 :, :]
147
144
else :
148
- attention_mask = attention_mask [:, - 1 :, :]
145
+ attention_mask = attention_mask [:, - 1 :, :]
149
146
attn_weights = attn_weights + attention_mask
150
147
attn_weights = torch .max (attn_weights , torch .tensor (torch .finfo (attn_weights .dtype ).min ))
151
148
@@ -177,14 +174,13 @@ def __init__(self, config: BaichuanConfig):
177
174
self .post_attention_layernorm = RMSNorm (config .hidden_size , epsilon = config .rms_norm_eps )
178
175
179
176
def forward (
180
- self ,
181
- hidden_states : torch .Tensor ,
182
- attention_mask : Optional [torch .Tensor ] = None ,
183
- past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
184
- output_attentions : Optional [bool ] = False ,
185
- use_cache : Optional [bool ] = False ,
177
+ self ,
178
+ hidden_states : torch .Tensor ,
179
+ attention_mask : Optional [torch .Tensor ] = None ,
180
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
181
+ output_attentions : Optional [bool ] = False ,
182
+ use_cache : Optional [bool ] = False ,
186
183
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
187
-
188
184
residual = hidden_states
189
185
190
186
hidden_states = self .input_layernorm (hidden_states )
@@ -261,33 +257,36 @@ def set_input_embeddings(self, value):
261
257
def get_alibi_mask (self , tensor , seq_length_with_past ):
262
258
if self .training :
263
259
slopes = torch .Tensor (_get_interleave (self .n_head ))
264
- alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * torch .arange (seq_length_with_past ).unsqueeze (0 ).unsqueeze (0 ). expand (
265
- self . n_head ,
266
- - 1 , - 1 )
260
+ alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * torch .arange (seq_length_with_past ).unsqueeze (0 ).unsqueeze (
261
+ 0
262
+ ). expand ( self . n_head , - 1 , - 1 )
267
263
alibi = alibi .view (self .n_head , 1 , seq_length_with_past )
268
264
mask = _buffered_future_mask (tensor , seq_length_with_past , alibi , self .n_head )
269
265
else :
270
266
if self .first_run :
271
267
self .first_run = False
272
- self .register_buffer ("future_mask" , _gen_alibi_mask (self .n_head , self .max_cache_pos ).to (tensor ), persistent = False )
268
+ self .register_buffer (
269
+ "future_mask" , _gen_alibi_mask (self .n_head , self .max_cache_pos ).to (tensor ), persistent = False
270
+ )
273
271
if seq_length_with_past > self .max_cache_pos :
274
272
self .max_cache_pos = seq_length_with_past
275
- self .register_buffer ("future_mask" , _gen_alibi_mask (self .n_head , self .max_cache_pos ).to (tensor ), persistent = False )
276
- mask = self .future_mask [:self .n_head , :seq_length_with_past , :seq_length_with_past ]
273
+ self .register_buffer (
274
+ "future_mask" , _gen_alibi_mask (self .n_head , self .max_cache_pos ).to (tensor ), persistent = False
275
+ )
276
+ mask = self .future_mask [: self .n_head , :seq_length_with_past , :seq_length_with_past ]
277
277
return mask
278
278
279
279
def forward (
280
- self ,
281
- input_ids : torch .LongTensor = None ,
282
- attention_mask : Optional [torch .Tensor ] = None ,
283
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
284
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
285
- use_cache : Optional [bool ] = False ,
286
- output_attentions : Optional [bool ] = False ,
287
- output_hidden_states : Optional [bool ] = False ,
288
- return_dict : Optional [bool ] = True ,
280
+ self ,
281
+ input_ids : torch .LongTensor = None ,
282
+ attention_mask : Optional [torch .Tensor ] = None ,
283
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
284
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
285
+ use_cache : Optional [bool ] = False ,
286
+ output_attentions : Optional [bool ] = False ,
287
+ output_hidden_states : Optional [bool ] = False ,
288
+ return_dict : Optional [bool ] = True ,
289
289
) -> Union [Tuple , BaseModelOutputWithPast ]:
290
-
291
290
if input_ids is not None and inputs_embeds is not None :
292
291
raise ValueError ("You cannot provide both input_ids and inputs_embeds simultaneously" )
293
292
elif input_ids is not None :
@@ -318,10 +317,11 @@ def forward(
318
317
if attention_mask is not None :
319
318
if len (attention_mask .shape ) == 2 :
320
319
expanded_mask = attention_mask .to (alibi_mask .dtype )
321
- expanded_mask = torch .tril (torch .gt (expanded_mask [:, :, None ] * expanded_mask [:, None , :], 0 )
322
- ) * torch .eq (expanded_mask [:, :, None ] - expanded_mask [:, None , :], 0 )
320
+ expanded_mask = torch .tril (
321
+ torch .gt (expanded_mask [:, :, None ] * expanded_mask [:, None , :], 0 )
322
+ ) * torch .eq (expanded_mask [:, :, None ] - expanded_mask [:, None , :], 0 )
323
323
else :
324
- expanded_mask = attention_mask
324
+ expanded_mask = attention_mask
325
325
bsz = inputs_embeds .size (0 )
326
326
src_len , tgt_len = alibi_mask .size ()[- 2 :]
327
327
expanded_mask = expanded_mask .unsqueeze (1 ).expand (bsz , 1 , src_len , tgt_len ).to (alibi_mask .dtype )
@@ -428,21 +428,20 @@ def get_decoder(self):
428
428
return self .model
429
429
430
430
def forward (
431
- self ,
432
- input_ids : torch .LongTensor = None ,
433
- attention_mask : Optional [torch .Tensor ] = None ,
434
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
435
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
436
- labels : Optional [torch .LongTensor ] = None ,
437
- use_cache : Optional [bool ] = None ,
438
- output_attentions : Optional [bool ] = False ,
439
- output_hidden_states : Optional [bool ] = False ,
440
- return_dict : Optional [bool ] = True ,
441
- ** kwargs
431
+ self ,
432
+ input_ids : torch .LongTensor = None ,
433
+ attention_mask : Optional [torch .Tensor ] = None ,
434
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
435
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
436
+ labels : Optional [torch .LongTensor ] = None ,
437
+ use_cache : Optional [bool ] = None ,
438
+ output_attentions : Optional [bool ] = False ,
439
+ output_hidden_states : Optional [bool ] = False ,
440
+ return_dict : Optional [bool ] = True ,
441
+ ** kwargs ,
442
442
) -> Union [Tuple , CausalLMOutputWithPast ]:
443
-
444
443
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
445
-
444
+
446
445
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
447
446
outputs = self .model (
448
447
input_ids = input_ids ,
@@ -484,12 +483,12 @@ def forward(
484
483
)
485
484
486
485
def prepare_inputs_for_generation (
487
- self ,
488
- input_ids : torch .LongTensor ,
489
- past_key_values : Optional [torch .Tensor ] = None ,
490
- attention_mask : Optional [torch .Tensor ] = None ,
491
- inputs_embeds : Optional [torch .Tensor ] = None ,
492
- ** kwargs
486
+ self ,
487
+ input_ids : torch .LongTensor ,
488
+ past_key_values : Optional [torch .Tensor ] = None ,
489
+ attention_mask : Optional [torch .Tensor ] = None ,
490
+ inputs_embeds : Optional [torch .Tensor ] = None ,
491
+ ** kwargs ,
493
492
):
494
493
if past_key_values :
495
494
input_ids = input_ids [:, - 1 :]
@@ -501,65 +500,58 @@ def prepare_inputs_for_generation(
501
500
model_inputs = {"input_ids" : input_ids }
502
501
503
502
model_inputs .update (
504
- {
505
- "past_key_values" : past_key_values ,
506
- "use_cache" : kwargs .get ("use_cache" ),
507
- "attention_mask" : attention_mask
508
- }
503
+ {"past_key_values" : past_key_values , "use_cache" : kwargs .get ("use_cache" ), "attention_mask" : attention_mask }
509
504
)
510
505
return model_inputs
511
506
512
507
@staticmethod
513
508
def _reorder_cache (past_key_values , beam_idx ):
514
509
return tuple (
515
- tuple (past_state .index_select (0 , beam_idx ) for past_state in layer_past )
516
- for layer_past in past_key_values
510
+ tuple (past_state .index_select (0 , beam_idx ) for past_state in layer_past ) for layer_past in past_key_values
517
511
)
518
512
519
513
def quantize (self , bits : int ):
520
514
try :
521
515
from .quantizer import QLinear
522
516
except ImportError :
523
- raise ImportError (
524
- f"Needs QLinear to run quantize."
525
- )
517
+ raise ImportError (f"Needs QLinear to run quantize." )
526
518
527
519
for layer in self .model .layers :
528
520
layer .self_attn .W_pack = QLinear (
529
521
bits = bits ,
530
522
weight = layer .self_attn .W_pack .weight ,
531
- bias = None ,
523
+ bias = None ,
532
524
)
533
525
layer .self_attn .o_proj = QLinear (
534
526
bits = bits ,
535
527
weight = layer .self_attn .o_proj .weight ,
536
- bias = None ,
528
+ bias = None ,
537
529
)
538
530
layer .mlp .gate_proj = QLinear (
539
531
bits = bits ,
540
532
weight = layer .mlp .gate_proj .weight ,
541
- bias = None ,
533
+ bias = None ,
542
534
)
543
535
layer .mlp .down_proj = QLinear (
544
536
bits = bits ,
545
537
weight = layer .mlp .down_proj .weight ,
546
- bias = None ,
538
+ bias = None ,
547
539
)
548
540
layer .mlp .up_proj = QLinear (
549
541
bits = bits ,
550
542
weight = layer .mlp .up_proj .weight ,
551
- bias = None ,
543
+ bias = None ,
552
544
)
553
545
return self
554
546
555
- def _build_chat_input (self , tokenizer , messages : List [dict ], max_new_tokens : int = 0 ):
547
+ def _build_chat_input (self , tokenizer , messages : List [dict ], max_new_tokens : int = 0 ):
556
548
max_new_tokens = max_new_tokens or self .generation_config .max_new_tokens
557
549
max_input_tokens = self .config .model_max_length - max_new_tokens
558
550
max_input_tokens = max (self .config .model_max_length // 2 , max_input_tokens )
559
551
total_input , round_input = [], []
560
552
for i , message in enumerate (messages [::- 1 ]):
561
- content_tokens = tokenizer .encode (message [' content' ])
562
- if message [' role' ] == ' user' :
553
+ content_tokens = tokenizer .encode (message [" content" ])
554
+ if message [" role" ] == " user" :
563
555
round_input = [self .generation_config .user_token_id ] + content_tokens + round_input
564
556
if total_input and len (total_input ) + len (round_input ) > max_input_tokens :
565
557
break
@@ -569,12 +561,13 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
569
561
break
570
562
else :
571
563
round_input = []
572
- elif message ['role' ] == 'assistant' :
573
- round_input = [
574
- self .generation_config .assistant_token_id
575
- ] + content_tokens + [
576
- self .generation_config .eos_token_id
577
- ] + round_input
564
+ elif message ["role" ] == "assistant" :
565
+ round_input = (
566
+ [self .generation_config .assistant_token_id ]
567
+ + content_tokens
568
+ + [self .generation_config .eos_token_id ]
569
+ + round_input
570
+ )
578
571
else :
579
572
raise ValueError (f"message role not supported yet: { message ['role' ]} " )
580
573
total_input = total_input [- max_input_tokens :] # truncate left
@@ -583,12 +576,12 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
583
576
return total_input
584
577
585
578
@torch .no_grad ()
586
- def chat (self , tokenizer , messages : List [dict ], stream = False ,
587
- generation_config : Optional [GenerationConfig ]= None ):
579
+ def chat (self , tokenizer , messages : List [dict ], stream = False , generation_config : Optional [GenerationConfig ] = None ):
588
580
generation_config = generation_config or self .generation_config
589
581
input_ids = self ._build_chat_input (tokenizer , messages , generation_config .max_new_tokens )
590
582
if stream :
591
583
from transformers_stream_generator .main import NewGenerationMixin , StreamGenerationConfig
584
+
592
585
self .__class__ .generate = NewGenerationMixin .generate
593
586
self .__class__ .sample_stream = NewGenerationMixin .sample_stream
594
587
stream_config = StreamGenerationConfig (** generation_config .to_dict (), do_stream = True )
@@ -603,5 +596,5 @@ def stream_generator():
603
596
else :
604
597
self .__class__ .generate = PreTrainedModel .generate # disable stream
605
598
outputs = self .generate (input_ids , generation_config = generation_config )
606
- response = tokenizer .decode (outputs [0 ][len (input_ids [0 ]):], skip_special_tokens = True )
607
- return response
599
+ response = tokenizer .decode (outputs [0 ][len (input_ids [0 ]) :], skip_special_tokens = True )
600
+ return response
0 commit comments