Skip to content

Commit 19e752c

Browse files
Merge pull request #2541 from Wanglongzhi2001/speculate_decoding
Support speculate decoding
2 parents 947f230 + fe35dc5 commit 19e752c

File tree

3 files changed

+144
-43
lines changed

3 files changed

+144
-43
lines changed

llm/server/server/engine/config.py

+29
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from paddlenlp.generation import GenerationConfig
2121
from server.utils import model_server_logger
22+
from dataclasses import dataclass
2223

2324

2425
class Config:
@@ -203,6 +204,27 @@ def get_model_config(self):
203204
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
204205
return model_config_json
205206

207+
def get_speculate_config(self):
208+
"""
209+
get speculate_decoding related config
210+
211+
Returns:
212+
SpeculateConfig: the speculate related config
213+
"""
214+
speculate_config = SpeculateConfig()
215+
model_cfg = self.get_model_config()
216+
if model_cfg.get("speculate_method", "None") != "None":
217+
speculate_config.speculate_method = str(model_cfg["speculate_method"])
218+
speculate_config.speculate_max_draft_token_num = model_cfg[
219+
"speculate_max_draft_token_num"]
220+
speculate_config.speculate_max_ngram_size = model_cfg[
221+
"speculate_max_ngram_size"]
222+
223+
if speculate_config.speculate_method not in ["None", "inference_with_reference"]:
224+
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")
225+
226+
return speculate_config
227+
206228
def read_from_config(self):
207229
"""
208230
reset model config from json file
@@ -234,3 +256,10 @@ def get_unique_name(self, name):
234256

235257
def __str__(self) -> str:
236258
return json.dumps(self.__dict__, indent=4)
259+
260+
261+
@dataclass
262+
class SpeculateConfig:
263+
speculate_method: str = "None"
264+
speculate_max_draft_token_num: int = 1
265+
speculate_max_ngram_size: int = 1

llm/server/server/engine/infer.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from paddlenlp_ops import step_paddle
3030
from server.data.processor import DataProcessor
3131
from server.engine.config import Config
32+
from paddlenlp.experimental.transformers import InferenceWithReferenceProposer
3233
from server.utils import get_logger
3334
from task_queue_manager import TaskQueueManager
3435

@@ -46,6 +47,8 @@ def __init__(self, args):
4647

4748
self.config = Config()
4849
self.model_cfg = self.config.get_model_config()
50+
self.speculate_config = self.config.get_speculate_config()
51+
self.is_speculate_decoding = self.speculate_config.speculate_method != "None"
4952
self.format_print_configuration()
5053

5154
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
@@ -67,6 +70,17 @@ def __init__(self, args):
6770
self.cache_kvs = {}
6871
self.init_inputs()
6972

73+
if self.is_speculate_decoding:
74+
logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.')
75+
if self.speculate_config.speculate_method == "inference_with_reference":
76+
self.proposer = InferenceWithReferenceProposer(
77+
self.speculate_config.speculate_max_draft_token_num,
78+
self.speculate_config.speculate_max_ngram_size,
79+
self.args.max_batch_size,
80+
self.args.max_seq_len)
81+
else:
82+
self.proposer = None
83+
7084
self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
7185

7286
model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
@@ -263,6 +277,18 @@ def init_inputs(self):
263277
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
264278
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
265279
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
280+
# speculate decoding input
281+
if self.is_speculate_decoding:
282+
self.share_inputs["accept_tokens"] = paddle.full(
283+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
284+
)
285+
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
286+
self.share_inputs["draft_tokens"] = paddle.full(
287+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
288+
)
289+
self.share_inputs["actual_draft_token_num"] = paddle.full(
290+
shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32"
291+
)
266292

267293
def dy_input_preprocess(self, tasks):
268294
"""
@@ -318,10 +344,21 @@ def dy_input_preprocess(self, tasks):
318344
task["stop_seqs_len"], dtype="int32")
319345
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
320346
task["stop_seqs"], dtype="int64")
347+
348+
if self.is_speculate_decoding:
349+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1])
350+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num])
351+
321352
def step_cuda(self, seq_lens_this_time):
322353
"""
323354
step cuda
324355
"""
356+
# whether speculate decoding
357+
if self.is_speculate_decoding:
358+
speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1
359+
else:
360+
speculate_step_token_num = 0
361+
325362
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
326363
self.share_inputs['step_seq_lens_encoder'],
327364
self.share_inputs['seq_lens_encoder'],
@@ -334,7 +371,8 @@ def step_cuda(self, seq_lens_this_time):
334371
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
335372
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
336373
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
337-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
374+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
375+
speculate_step_token_num)
338376

339377
def initialize_engine_ready_check_flag(self):
340378
"""
@@ -459,6 +497,13 @@ def run(self):
459497
time.sleep(0.001)
460498
continue
461499

500+
if self.proposer is not None:
501+
self.proposer.run(
502+
self.share_inputs,
503+
real_batch_size=seq_lens_this_time.shape[0],
504+
seq_lens_this_time=seq_lens_this_time,
505+
)
506+
462507
self.infer_engine.predictor.run()
463508
self.share_inputs['infer_seed'].add_(infer_seed_increment)
464509
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED

llm/server/server/engine/token_processor.py

+69-42
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from datetime import datetime
2121

2222
import numpy as np
23-
from paddlenlp_ops import get_output
23+
from paddlenlp_ops import get_output, speculate_get_output
2424
from server.utils import datetime_diff, model_server_logger, monitor_logger
25+
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
2526

2627

2728
class TokenProcessor(object):
@@ -37,7 +38,12 @@ def __init__(self, cfg):
3738
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]
3839

3940
self.tokens_counter = Counter()
40-
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
41+
42+
self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None"
43+
if self.is_speculate_decoding:
44+
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
45+
else:
46+
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
4147
self.worker = None
4248

4349
self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
@@ -77,10 +83,14 @@ def process_sampling_results(self):
7783
try:
7884
rank_id = 0
7985
is_blocking = True
80-
get_output(self.output_tokens, rank_id, is_blocking)
86+
if self.is_speculate_decoding:
87+
speculate_get_output(self.output_tokens, rank_id, is_blocking)
88+
else:
89+
get_output(self.output_tokens, rank_id, is_blocking)
8190

8291
if self.output_tokens[0, 0] == -2:
8392
continue
93+
8494
self._process_batch_output()
8595
except Exception as e:
8696
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
@@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
101111
with open(result_file, "a") as f:
102112
f.write("{}\n".format(result))
103113

104-
def _get_single_result(self, i, task_id, token_id, task):
114+
def _get_single_result(self, i, task_id, token_ids, task):
105115
"""
106116
processing single results
107117
108118
Args:
109119
i (int): batch index
110120
task_id (str): task id
111-
token_id (int): token id
121+
token_ids (list): token id
112122
task (dict): task information
113123
114124
Returns:
@@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task):
121131
result = {
122132
"req_id": task_id,
123133
"is_end": 0,
124-
"token_ids": [token_id],
134+
"token_ids": token_ids,
125135
"send_idx": self.tokens_counter[task_id],
126136
"inference_time_cost": inference_time_cost,
127137
"infer_seed": task["infer_seed"],
@@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task):
137147
result[key] = str(task[key])
138148

139149
# fill some extra information
140-
if token_id in task["eos_token_ids"]:
141-
result["is_end"] = 1
142-
result["token_ids"] = []
143-
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
144-
result["tokens_all_ids"] = self.all_tokens[i]
145-
146-
info_dict = {}
147-
info_dict["req_id"] = task["req_id"]
148-
info_dict["input_token_num"] = len(task["input_ids"])
149-
info_dict["output_token_num"] = len(self.all_tokens[i])
150-
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
151-
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
152-
task["preprocess_end_time"])
153-
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
154-
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
155-
task["schedule_start_time"])
156-
info_dict["inference_time_cost"] = task["inference_time_cost"]
157-
info_dict["version"] = "4.6"
158-
info_dict["timestamp"] = time.time()
159-
monitor_logger.info(f"{info_dict}")
150+
result["token_ids"] = []
151+
for token_id in token_ids:
152+
if token_id in task["eos_token_ids"]:
153+
result["is_end"] = 1
154+
result["token_ids"] = []
155+
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
156+
result["tokens_all_ids"] = self.all_tokens[i]
157+
158+
info_dict = {}
159+
info_dict["req_id"] = task["req_id"]
160+
info_dict["input_token_num"] = len(task["input_ids"])
161+
info_dict["output_token_num"] = len(self.all_tokens[i])
162+
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
163+
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
164+
task["preprocess_end_time"])
165+
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
166+
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
167+
task["schedule_start_time"])
168+
info_dict["inference_time_cost"] = task["inference_time_cost"]
169+
info_dict["version"] = "OpenSource"
170+
info_dict["timestamp"] = time.time()
171+
monitor_logger.info(f"{info_dict}")
172+
break
173+
else:
174+
result["token_ids"].append(token_id)
160175

161176
return result
162177

@@ -177,33 +192,42 @@ def _process_batch_output(self):
177192
"""
178193
tokens = self.output_tokens.numpy()
179194
batch = self.output_tokens[1, 0]
180-
tokens = tokens[2:batch + 2]
195+
if not self.is_speculate_decoding:
196+
tokens = tokens[2:batch + 2]
197+
else:
198+
accept_num = tokens[2:batch + 2]
181199

182200
batch_result = list()
183201
exist_finished_task = False
184202
for i in range(batch):
185203
if self.resource_manager.stop_flags[i]:
186204
continue
187205

188-
token_id = int(tokens[i, 0])
189-
if token_id < 0:
206+
if not self.is_speculate_decoding:
207+
token_ids = [int(tokens[i, 0])]
208+
else:
209+
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist()
210+
211+
if any(token_id < 0 for token_id in token_ids):
190212
continue
191213

192214
task = self.resource_manager.tasks_list[i]
193215

194216
task_id = task["req_id"]
195-
result = self._get_single_result(i, task_id, token_id, task)
196-
197-
self.tokens_counter[task_id] += 1
198-
if token_id not in task["eos_token_ids"]:
199-
self.all_tokens[i].append(token_id)
200-
201-
self.number_of_output_tokens += 1
202-
if token_id in task["eos_token_ids"]:
203-
self._recycle_resources(task_id, i, task)
204-
model_server_logger.info("req_id: {0} finished".format(task_id))
205-
model_server_logger.info(f"{self.resource_manager.info()}")
206-
exist_finished_task = True
217+
result = self._get_single_result(i, task_id, token_ids, task)
218+
219+
for token_id in token_ids:
220+
self.tokens_counter[task_id] += 1
221+
if token_id not in task["eos_token_ids"]:
222+
self.all_tokens[i].append(token_id)
223+
224+
self.number_of_output_tokens += 1
225+
if token_id in task["eos_token_ids"]:
226+
self._recycle_resources(task_id, i, task)
227+
model_server_logger.info("req_id: {0} finished".format(task_id))
228+
model_server_logger.info(f"{self.resource_manager.info()}")
229+
exist_finished_task = True
230+
break
207231
batch_result.append(result)
208232

209233
self.postprocess(batch_result, exist_finished_task)
@@ -228,7 +252,10 @@ def process_sampling_results(self):
228252
while self._is_running:
229253
try:
230254
rank_id = 0
231-
get_output(self.output_tokens, rank_id, self._is_blocking)
255+
if self.is_speculate_decoding:
256+
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
257+
else:
258+
get_output(self.output_tokens, rank_id, self._is_blocking)
232259

233260
if self.output_tokens[0, 0] == -2:
234261
continue

0 commit comments

Comments
 (0)