-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathinterface.py
More file actions
202 lines (165 loc) · 7.09 KB
/
interface.py
File metadata and controls
202 lines (165 loc) · 7.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import copy
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import List, Optional, Type
import torch
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MTP_EAGLE = auto()
EAGLE3 = auto()
EAGLE3_ONE_MODEL = auto()
NGRAM = auto()
DRAFT_TARGET = auto()
USER_PROVIDED = auto()
NONE = auto()
AUTO = auto()
def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
def is_mtp_vanilla(self):
return self == SpeculativeDecodingMode.MTP
def is_mtp_eagle(self):
return self == SpeculativeDecodingMode.MTP_EAGLE
def is_eagle3(self):
return self == SpeculativeDecodingMode.EAGLE3
def use_one_engine(self):
return self.is_mtp() or self.is_eagle3_one_model()
def is_eagle3_one_model(self):
return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL
def is_ngram(self):
return self == SpeculativeDecodingMode.NGRAM
def is_user_provided(self):
return self == SpeculativeDecodingMode.USER_PROVIDED
def is_none(self):
return self == SpeculativeDecodingMode.NONE
def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET
def without_logits(self):
return self.is_mtp() or self.is_eagle3_one_model()
def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram()
def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model()
def has_draft_model(self):
return self.is_eagle3() or self.is_draft_target()
def needs_kv_cache_recompute(self):
"""
Whether the draft model needs to recompute the kv cache.
If true, the 1st draft model forward will recompute the kv cache for
the accepted draft tokens.
"""
return self.is_eagle3()
def need_load_draft_weights(self):
"""
Whether the draft model and target model are in the same model engine,
and the draft model needs to load weights from the separate checkpoint.
"""
return self.is_eagle3_one_model()
def has_spec_decoder(self):
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
) or self.is_user_provided()
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""
If true, treat generation requests with draft tokens as
chunked context requests at the kernel level. Required for
any spec dec mode that uses the SpecExecutor.
"""
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
return ((self.is_eagle3() or self.is_draft_target())
and not (issubclass(attention_backend, TrtllmAttention)
and get_sm_version() == 100)
) or self.is_ngram() or self.is_user_provided()
def attention_need_spec_dec_mode(self):
"""
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
"""
return self.is_eagle3_one_model()
@staticmethod
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
if name is None:
return SpeculativeDecodingMode.NONE
return SpeculativeDecodingMode[name.upper()]
@dataclass
class SpecMetadata:
"""
Metadata for speculative decoding.
"""
# The max number of requests in a single batch.
max_num_requests: int
# The max number of draft tokens.
max_draft_len: int
# The number of gen-phase sequences in the batch.
num_generations: int = 0
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)
# The mode of speculative decoding.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
# Draft tokens.
draft_tokens: Optional[torch.Tensor] = None,
# The length of the draft tokens.
draft_lens: Optional[torch.Tensor] = None,
# The request ID of each sequence in the batch.
# The shape is (batch_size).
request_ids: Optional[List[int]] = None
# Sequence length for each request.
seq_lens: Optional[List[int]] = None
# The gather ids for logits.
gather_ids: Optional[torch.Tensor] = None
# The number of tokens for speculative model/layer
num_tokens: int = 0
# The number of tokens for speculative model/layer of different rank
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
default=None,
repr=False)
all_rank_num_tokens: Optional[List[int]]
# The max number of tokens among all ranks.
all_rank_max_num_tokens: Optional[int] = None
# The number of sequences for speculative model/layer of different rank
all_rank_num_seqs: Optional[List[int]] = None
# The number of extra kv tokens
# Some speculative decoding methods need to use different kv lengths for the
# draft/target layers. But KVCacheManager can only support kv caches with the
# same kv lengths for different layers. Add extra kv token in kv cache manager
# to handle this issue.
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
# The number of layers
num_layers: int = 0
# if spec-dec tree is a tree or a chain (linear tree)
is_spec_dec_tree: bool = False
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
is_spec_dec_dynamic_tree: bool = False
def __post_init__(self):
pass
def prepare(self):
"""
Hook to be called before the forward step of the model.
"""
def create_cuda_graph_metadata(self, max_batch_size: int):
"""
Creates metadata for CUDA graph execution.
"""
if self.is_cuda_graph:
return self
cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
cuda_graph_metadata.max_num_requests = max_batch_size
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
"""
Some spec decode algorithms require hidden states from the target
model. Use this method to record them. By default, does nothing.
"""
@property
def all_rank_num_tokens(self) -> Optional[List[int]]:
return self._all_rank_num_tokens
@all_rank_num_tokens.setter
def all_rank_num_tokens(self, value: Optional[List[int]]):
value = value if value is not SpecMetadata.all_rank_num_tokens else None
self._all_rank_num_tokens = value
self.all_rank_max_num_tokens = max(value) if value is not None else None