-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathcache.py
More file actions
202 lines (161 loc) · 7.2 KB
/
cache.py
File metadata and controls
202 lines (161 loc) · 7.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from transformers import AutoConfig
import torch
import flashinfer
from flash_attn import flash_attn_with_kvcache
import math
class KV_Cache:
def __init__(self,
config :AutoConfig,
batch_size :int = 1,
max_length :int = 256,
device :str = 'cuda:0',
dtype = torch.float16) -> None:
self.config = config
self.max_length = max_length
self.device = device
self.dtype = dtype
self.k_cache = torch.zeros(
config.num_hidden_layers,
batch_size,
max_length,
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
device=self.device,
dtype=self.dtype
)
self.v_cache = torch.zeros(
config.num_hidden_layers,
batch_size,
max_length,
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
device=self.device,
dtype=self.dtype
)
self.num_layers = config.num_hidden_layers
self.kv_offset = 0
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
def gather_kv_incremental(self, indices: torch.LongTensor, offset:int):
self.k_cache[:,:,offset:offset + len(indices), :,:] = self.k_cache[:,:,indices, :,:]
self.v_cache[:,:,offset:offset + len(indices), :,:] = self.v_cache[:,:,indices, :,:]
self.k_cache[:,:,offset + len(indices):, :,:] = 0.0
self.v_cache[:,:,offset + len(indices):, :,:] = 0.0
self.kv_offset = offset + len(indices)
def update_kv_cache(self,
new_k_cache :torch.Tensor,
new_v_cache :torch.Tensor,
layer_idx :int,
storage_ids: torch.LongTensor=None
):
new_kv_len = new_k_cache.shape[1] # [bsz, seq, num_heads, head_dim]
if layer_idx == 0:
self.kv_offset += new_kv_len
self.k_cache[layer_idx][:, self.kv_offset - new_kv_len:self.kv_offset] = new_k_cache
self.v_cache[layer_idx][:, self.kv_offset - new_kv_len:self.kv_offset] = new_v_cache
return self.k_cache[layer_idx][:, :self.kv_offset], self.v_cache[layer_idx][:, :self.kv_offset]
def compute_attention(self,
query_states :torch.Tensor,
key_states :torch.Tensor,
value_states :torch.Tensor,
layer_idx,
storage_ids :torch.Tensor=None,
attention_mask :torch.Tensor=None):
key_states, value_states = self.update_kv_cache(key_states, value_states, layer_idx, storage_ids)
if attention_mask is not None:
hidden_states = flashinfer.single_prefill_with_kv_cache(
q = query_states[0],
k = key_states[0],
v = value_states[0],
kv_layout="NHD",
custom_mask=attention_mask[:,:self.kv_offset],
allow_fp16_qk_reduction=True
)
else:
# do not use attn mask
# print(query_states.shape, key_states.shape, value_states.shape)
hidden_states = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, causal=True)
return hidden_states
def clear(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.kv_offset = 0
def set_kv_len(self, kv_len :int):
self.kv_offset = kv_len
def get_kv_len(self):
return self.kv_offset
class StaticKV_Cache:
def __init__(self,
config :AutoConfig,
batch_size :int = 1,
max_length :int = 256,
device :str = 'cuda:0',
dtype = torch.float16) -> None:
self.config = config
self.max_length = max_length
self.device = device
self.dtype = dtype
self.k_cache = torch.zeros(
config.num_hidden_layers,
config.num_key_value_heads,
max_length,
config.hidden_size // config.num_attention_heads,
device=self.device,
dtype=self.dtype
)
self.v_cache = torch.zeros(
config.num_hidden_layers,
config.num_key_value_heads,
max_length,
config.hidden_size // config.num_attention_heads,
device=self.device,
dtype=self.dtype
)
self.num_layers = config.num_hidden_layers
self.kv_offset = 0
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
def gather_kv_incremental(self, indices: list[int], offset:int):
self.k_cache[..., offset:offset + len(indices), :] = self.k_cache[..., indices, :]
self.v_cache[..., offset:offset + len(indices), :] = self.v_cache[..., indices, :]
self.k_cache[..., offset + len(indices):, :] = 0.0
self.v_cache[..., offset + len(indices):, :] = 0.0
self.kv_offset = offset + len(indices)
def update_kv_cache(self,
new_k_cache :torch.Tensor,
new_v_cache :torch.Tensor,
layer_idx :int,
storage_ids: torch.LongTensor
):
self.k_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_k_cache)
self.v_cache[layer_idx].index_copy_(dim=-2, index=storage_ids, source=new_v_cache)
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def clear(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.kv_offset = 0
def set_kv_len(self, kv_len :int):
self.kv_offset = kv_len
def compute_attention(self,
query_states :torch.Tensor,
key_states :torch.Tensor,
value_states :torch.Tensor,
layer_idx,
storage_ids :torch.Tensor,
attention_mask :torch.Tensor):
bsz, _, q_len, _ = query_states.shape
key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids)
query_states = query_states[0]
query_states = query_states.reshape(self.num_key_value_heads, q_len * self.num_key_value_groups, self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
mask = attention_mask[None,:,:].repeat(1, self.num_key_value_groups, 1)
attn_weights.masked_fill_(~mask, torch.finfo(attn_weights.dtype).min)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
hidden_states = torch.matmul(attn_weights, value_states)
hidden_states = hidden_states.reshape(bsz, self.num_attention_heads, q_len, -1)
hidden_states = hidden_states.transpose(1, 2).contiguous()
return hidden_states