Skip to content

Commit 1851035

Browse files
pre-commit-ci[bot]char-1ee
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 59ba43b commit 1851035

File tree

3 files changed

+108
-122
lines changed

3 files changed

+108
-122
lines changed

Diff for: colossalai/inference/modeling/models/baichuan_13b.py

+98-105
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torch.nn import CrossEntropyLoss
99
from transformers import PreTrainedModel
1010
from transformers.activations import ACT2FN
11+
from transformers.generation.utils import GenerationConfig
1112
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1213
from transformers.utils import logging
13-
from transformers.generation.utils import GenerationConfig
1414

1515
from .configuration_baichuan import BaichuanConfig
1616

@@ -19,42 +19,42 @@
1919

2020
def _get_interleave(n):
2121
def _get_interleave_power_of_2(n):
22-
start = (2 ** (-2 ** -(math.log2(n) - 3)))
22+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
2323
ratio = start
24-
return [start * ratio ** i for i in range(n)]
24+
return [start * ratio**i for i in range(n)]
2525

2626
if math.log2(n).is_integer():
2727
return _get_interleave_power_of_2(n)
2828
else:
2929
closest_power_of_2 = 2 ** math.floor(math.log2(n))
30-
return _get_interleave_power_of_2(closest_power_of_2) + \
31-
_get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
30+
return (
31+
_get_interleave_power_of_2(closest_power_of_2)
32+
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
33+
)
34+
3235

3336
def _fill_with_neg_inf(t):
3437
"""FP16-compatible function that fills a tensor with -inf."""
3538
return t.float().fill_(float("-inf")).type_as(t)
3639

40+
3741
def _gen_alibi_mask(n_head, max_pos):
3842
"""used in inference only"""
3943
slopes = torch.Tensor(_get_interleave(n_head))
40-
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
41-
n_head, -1, -1)
44+
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
4245
alibi = alibi.view(n_head, 1, max_pos)
43-
alibi_mask = torch.triu(
44-
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
45-
)
46+
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
4647
alibi_mask = alibi_mask.unsqueeze(0) + alibi
4748
return alibi_mask
4849

50+
4951
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
5052
"""used in training only"""
51-
dim = tensor.size(1)
52-
_future_mask = torch.triu(
53-
_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
54-
)
53+
tensor.size(1)
54+
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
5555
_future_mask = _future_mask.unsqueeze(0) + alibi
5656
_future_mask = _future_mask.to(tensor)
57-
return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
57+
return _future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
5858

5959

6060
class RMSNorm(torch.nn.Module):
@@ -76,10 +76,10 @@ def forward(self, hidden_states):
7676

7777
class MLP(torch.nn.Module):
7878
def __init__(
79-
self,
80-
hidden_size: int,
81-
intermediate_size: int,
82-
hidden_act: str,
79+
self,
80+
hidden_size: int,
81+
intermediate_size: int,
82+
hidden_act: str,
8383
):
8484
super().__init__()
8585
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
@@ -101,24 +101,21 @@ def __init__(self, config: BaichuanConfig):
101101
self.max_position_embeddings = config.model_max_length
102102

103103
if (self.head_dim * self.num_heads) != self.hidden_size:
104-
raise ValueError(
105-
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
106-
)
104+
raise ValueError(f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}")
107105
self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
108106
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
109107

110108
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
111109
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
112110

113111
def forward(
114-
self,
115-
hidden_states: torch.Tensor,
116-
attention_mask: Optional[torch.Tensor] = None,
117-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
118-
output_attentions: bool = False,
119-
use_cache: bool = False,
112+
self,
113+
hidden_states: torch.Tensor,
114+
attention_mask: Optional[torch.Tensor] = None,
115+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
116+
output_attentions: bool = False,
117+
use_cache: bool = False,
120118
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
121-
122119
bsz, q_len, _ = hidden_states.size()
123120

124121
proj = self.W_pack(hidden_states)
@@ -141,11 +138,11 @@ def forward(
141138
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
142139

143140
if attention_mask is not None:
144-
if q_len == 1: # inference with cache
141+
if q_len == 1: # inference with cache
145142
if len(attention_mask.size()) == 4:
146-
attention_mask = attention_mask[:, :, -1:, :]
143+
attention_mask = attention_mask[:, :, -1:, :]
147144
else:
148-
attention_mask = attention_mask[:, -1:, :]
145+
attention_mask = attention_mask[:, -1:, :]
149146
attn_weights = attn_weights + attention_mask
150147
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
151148

@@ -177,14 +174,13 @@ def __init__(self, config: BaichuanConfig):
177174
self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
178175

179176
def forward(
180-
self,
181-
hidden_states: torch.Tensor,
182-
attention_mask: Optional[torch.Tensor] = None,
183-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
184-
output_attentions: Optional[bool] = False,
185-
use_cache: Optional[bool] = False,
177+
self,
178+
hidden_states: torch.Tensor,
179+
attention_mask: Optional[torch.Tensor] = None,
180+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
181+
output_attentions: Optional[bool] = False,
182+
use_cache: Optional[bool] = False,
186183
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
187-
188184
residual = hidden_states
189185

190186
hidden_states = self.input_layernorm(hidden_states)
@@ -261,33 +257,36 @@ def set_input_embeddings(self, value):
261257
def get_alibi_mask(self, tensor, seq_length_with_past):
262258
if self.training:
263259
slopes = torch.Tensor(_get_interleave(self.n_head))
264-
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
265-
self.n_head,
266-
-1, -1)
260+
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(
261+
0
262+
).expand(self.n_head, -1, -1)
267263
alibi = alibi.view(self.n_head, 1, seq_length_with_past)
268264
mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
269265
else:
270266
if self.first_run:
271267
self.first_run = False
272-
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
268+
self.register_buffer(
269+
"future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False
270+
)
273271
if seq_length_with_past > self.max_cache_pos:
274272
self.max_cache_pos = seq_length_with_past
275-
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
276-
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
273+
self.register_buffer(
274+
"future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False
275+
)
276+
mask = self.future_mask[: self.n_head, :seq_length_with_past, :seq_length_with_past]
277277
return mask
278278

279279
def forward(
280-
self,
281-
input_ids: torch.LongTensor = None,
282-
attention_mask: Optional[torch.Tensor] = None,
283-
past_key_values: Optional[List[torch.FloatTensor]] = None,
284-
inputs_embeds: Optional[torch.FloatTensor] = None,
285-
use_cache: Optional[bool] = False,
286-
output_attentions: Optional[bool] = False,
287-
output_hidden_states: Optional[bool] = False,
288-
return_dict: Optional[bool] = True,
280+
self,
281+
input_ids: torch.LongTensor = None,
282+
attention_mask: Optional[torch.Tensor] = None,
283+
past_key_values: Optional[List[torch.FloatTensor]] = None,
284+
inputs_embeds: Optional[torch.FloatTensor] = None,
285+
use_cache: Optional[bool] = False,
286+
output_attentions: Optional[bool] = False,
287+
output_hidden_states: Optional[bool] = False,
288+
return_dict: Optional[bool] = True,
289289
) -> Union[Tuple, BaseModelOutputWithPast]:
290-
291290
if input_ids is not None and inputs_embeds is not None:
292291
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
293292
elif input_ids is not None:
@@ -318,10 +317,11 @@ def forward(
318317
if attention_mask is not None:
319318
if len(attention_mask.shape) == 2:
320319
expanded_mask = attention_mask.to(alibi_mask.dtype)
321-
expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
322-
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
320+
expanded_mask = torch.tril(
321+
torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
322+
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
323323
else:
324-
expanded_mask = attention_mask
324+
expanded_mask = attention_mask
325325
bsz = inputs_embeds.size(0)
326326
src_len, tgt_len = alibi_mask.size()[-2:]
327327
expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
@@ -428,21 +428,20 @@ def get_decoder(self):
428428
return self.model
429429

430430
def forward(
431-
self,
432-
input_ids: torch.LongTensor = None,
433-
attention_mask: Optional[torch.Tensor] = None,
434-
past_key_values: Optional[List[torch.FloatTensor]] = None,
435-
inputs_embeds: Optional[torch.FloatTensor] = None,
436-
labels: Optional[torch.LongTensor] = None,
437-
use_cache: Optional[bool] = None,
438-
output_attentions: Optional[bool] = False,
439-
output_hidden_states: Optional[bool] = False,
440-
return_dict: Optional[bool] = True,
441-
**kwargs
431+
self,
432+
input_ids: torch.LongTensor = None,
433+
attention_mask: Optional[torch.Tensor] = None,
434+
past_key_values: Optional[List[torch.FloatTensor]] = None,
435+
inputs_embeds: Optional[torch.FloatTensor] = None,
436+
labels: Optional[torch.LongTensor] = None,
437+
use_cache: Optional[bool] = None,
438+
output_attentions: Optional[bool] = False,
439+
output_hidden_states: Optional[bool] = False,
440+
return_dict: Optional[bool] = True,
441+
**kwargs,
442442
) -> Union[Tuple, CausalLMOutputWithPast]:
443-
444443
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
445-
444+
446445
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
447446
outputs = self.model(
448447
input_ids=input_ids,
@@ -484,12 +483,12 @@ def forward(
484483
)
485484

486485
def prepare_inputs_for_generation(
487-
self,
488-
input_ids: torch.LongTensor,
489-
past_key_values: Optional[torch.Tensor] = None,
490-
attention_mask: Optional[torch.Tensor] = None,
491-
inputs_embeds: Optional[torch.Tensor] = None,
492-
**kwargs
486+
self,
487+
input_ids: torch.LongTensor,
488+
past_key_values: Optional[torch.Tensor] = None,
489+
attention_mask: Optional[torch.Tensor] = None,
490+
inputs_embeds: Optional[torch.Tensor] = None,
491+
**kwargs,
493492
):
494493
if past_key_values:
495494
input_ids = input_ids[:, -1:]
@@ -501,65 +500,58 @@ def prepare_inputs_for_generation(
501500
model_inputs = {"input_ids": input_ids}
502501

503502
model_inputs.update(
504-
{
505-
"past_key_values": past_key_values,
506-
"use_cache": kwargs.get("use_cache"),
507-
"attention_mask": attention_mask
508-
}
503+
{"past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask}
509504
)
510505
return model_inputs
511506

512507
@staticmethod
513508
def _reorder_cache(past_key_values, beam_idx):
514509
return tuple(
515-
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
516-
for layer_past in past_key_values
510+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past_key_values
517511
)
518512

519513
def quantize(self, bits: int):
520514
try:
521515
from .quantizer import QLinear
522516
except ImportError:
523-
raise ImportError(
524-
f"Needs QLinear to run quantize."
525-
)
517+
raise ImportError(f"Needs QLinear to run quantize.")
526518

527519
for layer in self.model.layers:
528520
layer.self_attn.W_pack = QLinear(
529521
bits=bits,
530522
weight=layer.self_attn.W_pack.weight,
531-
bias = None,
523+
bias=None,
532524
)
533525
layer.self_attn.o_proj = QLinear(
534526
bits=bits,
535527
weight=layer.self_attn.o_proj.weight,
536-
bias = None,
528+
bias=None,
537529
)
538530
layer.mlp.gate_proj = QLinear(
539531
bits=bits,
540532
weight=layer.mlp.gate_proj.weight,
541-
bias = None,
533+
bias=None,
542534
)
543535
layer.mlp.down_proj = QLinear(
544536
bits=bits,
545537
weight=layer.mlp.down_proj.weight,
546-
bias = None,
538+
bias=None,
547539
)
548540
layer.mlp.up_proj = QLinear(
549541
bits=bits,
550542
weight=layer.mlp.up_proj.weight,
551-
bias = None,
543+
bias=None,
552544
)
553545
return self
554546

555-
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
547+
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int = 0):
556548
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
557549
max_input_tokens = self.config.model_max_length - max_new_tokens
558550
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
559551
total_input, round_input = [], []
560552
for i, message in enumerate(messages[::-1]):
561-
content_tokens = tokenizer.encode(message['content'])
562-
if message['role'] == 'user':
553+
content_tokens = tokenizer.encode(message["content"])
554+
if message["role"] == "user":
563555
round_input = [self.generation_config.user_token_id] + content_tokens + round_input
564556
if total_input and len(total_input) + len(round_input) > max_input_tokens:
565557
break
@@ -569,12 +561,13 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
569561
break
570562
else:
571563
round_input = []
572-
elif message['role'] == 'assistant':
573-
round_input = [
574-
self.generation_config.assistant_token_id
575-
] + content_tokens + [
576-
self.generation_config.eos_token_id
577-
] + round_input
564+
elif message["role"] == "assistant":
565+
round_input = (
566+
[self.generation_config.assistant_token_id]
567+
+ content_tokens
568+
+ [self.generation_config.eos_token_id]
569+
+ round_input
570+
)
578571
else:
579572
raise ValueError(f"message role not supported yet: {message['role']}")
580573
total_input = total_input[-max_input_tokens:] # truncate left
@@ -583,12 +576,12 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
583576
return total_input
584577

585578
@torch.no_grad()
586-
def chat(self, tokenizer, messages: List[dict], stream=False,
587-
generation_config: Optional[GenerationConfig]=None):
579+
def chat(self, tokenizer, messages: List[dict], stream=False, generation_config: Optional[GenerationConfig] = None):
588580
generation_config = generation_config or self.generation_config
589581
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
590582
if stream:
591583
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
584+
592585
self.__class__.generate = NewGenerationMixin.generate
593586
self.__class__.sample_stream = NewGenerationMixin.sample_stream
594587
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
@@ -603,5 +596,5 @@ def stream_generator():
603596
else:
604597
self.__class__.generate = PreTrainedModel.generate # disable stream
605598
outputs = self.generate(input_ids, generation_config=generation_config)
606-
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
607-
return response
599+
response = tokenizer.decode(outputs[0][len(input_ids[0]) :], skip_special_tokens=True)
600+
return response

0 commit comments

Comments
 (0)