-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathqwen2_liger.py
More file actions
176 lines (152 loc) · 7.57 KB
/
Copy pathqwen2_liger.py
File metadata and controls
176 lines (152 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import List, Optional, Tuple, Union
from loguru import logger
from transformers.modeling_outputs import CausalLMOutputWithPast
from ..sequence_packing_utils import BaseModelOutputWithPastAndRmpad
try:
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
except:
print("Liger Kernel is not installed, pip install liger-kernel to use this patch")
import torch
import torch.distributed as dist
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
)
def qwen2_lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
use_rmpad: bool = False,
cu_seq_lens: Optional[torch.IntTensor] = None,
indices: Optional[torch.IntTensor] = None,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
cu_seq_lens=cu_seq_lens,
indices=indices,
)
seq_lens = outputs.get("seq_lens", None)
word_idx = outputs.get("word_idx", None)
hidden_states = outputs[0]
# if we are using sequence parallel, we need to slice the hidden states and labels
labels_unpad = labels.view(-1)[word_idx.long()]
if get_ulysses_sequence_parallel_world_size() > 1:
seq_lens = calculate_seq_len_per_rank(seq_lens.tolist()) if seq_lens is not None else None
labels_unpad = slice_input_tensor(labels_unpad, dim=0, padding=True)
labels = labels_unpad
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
if use_rmpad:
# We need to shift the tokens according to seq lens
# Otherwise, the first labels of the next seq will be the last labels of the current seq
shift_hidden_states = []
shift_labels = []
for i in range(len(seq_lens) - 1):
cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
cur_shift_labels = cur_labels[1:].contiguous()
shift_hidden_states.append(cur_shift_hidden_states)
shift_labels.append(cur_shift_labels)
shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
shift_labels = torch.cat(shift_labels, dim=0)
else:
# We do the same thing as ForCausalLMLoss but using Liger FLCE
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
# If using sp, we follow the loss calculation in verl, get loss for each token, then gather and sum them up
if get_ulysses_sequence_parallel_world_size() > 1:
reduction = "none"
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
if get_ulysses_sequence_parallel_world_size() > 1:
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
if reduction == "sum":
loss /= loss_kwargs["num_items_in_batch"]
else: # if in inference mode materialize logits
logits = self.lm_head(hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**loss_kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=hidden_states,
attentions=outputs.attentions,
)