20
20
from datetime import datetime
21
21
22
22
import numpy as np
23
- from paddlenlp_ops import get_output
23
+ from paddlenlp_ops import get_output , speculate_get_output
24
24
from server .utils import datetime_diff , model_server_logger , monitor_logger
25
+ from paddlenlp .utils .env import MAX_DRAFT_TOKENS , SPECULATE_MAX_BSZ
25
26
26
27
27
28
class TokenProcessor (object ):
@@ -37,7 +38,12 @@ def __init__(self, cfg):
37
38
self .all_tokens = [[] for _ in range (self .cfg .max_batch_size )]
38
39
39
40
self .tokens_counter = Counter ()
40
- self .output_tokens = paddle .full (shape = [self .cfg .max_batch_size + 2 , 1 ], fill_value = 2 , dtype = "int64" )
41
+
42
+ self .is_speculate_decoding = self .cfg .get_speculate_config ().speculate_method != "None"
43
+ if self .is_speculate_decoding :
44
+ self .output_tokens = paddle .full (shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" )
45
+ else :
46
+ self .output_tokens = paddle .full (shape = [self .cfg .max_batch_size + 2 , 1 ], fill_value = 2 , dtype = "int64" )
41
47
self .worker = None
42
48
43
49
self .record_time_interval = int (os .getenv ("RECORD_TIME_INTERVAL" , "600" ))
@@ -77,10 +83,14 @@ def process_sampling_results(self):
77
83
try :
78
84
rank_id = 0
79
85
is_blocking = True
80
- get_output (self .output_tokens , rank_id , is_blocking )
86
+ if self .is_speculate_decoding :
87
+ speculate_get_output (self .output_tokens , rank_id , is_blocking )
88
+ else :
89
+ get_output (self .output_tokens , rank_id , is_blocking )
81
90
82
91
if self .output_tokens [0 , 0 ] == - 2 :
83
92
continue
93
+
84
94
self ._process_batch_output ()
85
95
except Exception as e :
86
96
model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
@@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
101
111
with open (result_file , "a" ) as f :
102
112
f .write ("{}\n " .format (result ))
103
113
104
- def _get_single_result (self , i , task_id , token_id , task ):
114
+ def _get_single_result (self , i , task_id , token_ids , task ):
105
115
"""
106
116
processing single results
107
117
108
118
Args:
109
119
i (int): batch index
110
120
task_id (str): task id
111
- token_id (int ): token id
121
+ token_ids (list ): token id
112
122
task (dict): task information
113
123
114
124
Returns:
@@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task):
121
131
result = {
122
132
"req_id" : task_id ,
123
133
"is_end" : 0 ,
124
- "token_ids" : [ token_id ] ,
134
+ "token_ids" : token_ids ,
125
135
"send_idx" : self .tokens_counter [task_id ],
126
136
"inference_time_cost" : inference_time_cost ,
127
137
"infer_seed" : task ["infer_seed" ],
@@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task):
137
147
result [key ] = str (task [key ])
138
148
139
149
# fill some extra information
140
- if token_id in task ["eos_token_ids" ]:
141
- result ["is_end" ] = 1
142
- result ["token_ids" ] = []
143
- result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
144
- result ["tokens_all_ids" ] = self .all_tokens [i ]
145
-
146
- info_dict = {}
147
- info_dict ["req_id" ] = task ["req_id" ]
148
- info_dict ["input_token_num" ] = len (task ["input_ids" ])
149
- info_dict ["output_token_num" ] = len (self .all_tokens [i ])
150
- if hasattr (task , "preprocess_start_time" ) and hasattr (task , "preprocess_end_time" ):
151
- info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
152
- task ["preprocess_end_time" ])
153
- if hasattr (task , "preprocess_end_time" ) and hasattr (task , "schedule_start_time" ):
154
- info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
155
- task ["schedule_start_time" ])
156
- info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
157
- info_dict ["version" ] = "4.6"
158
- info_dict ["timestamp" ] = time .time ()
159
- monitor_logger .info (f"{ info_dict } " )
150
+ result ["token_ids" ] = []
151
+ for token_id in token_ids :
152
+ if token_id in task ["eos_token_ids" ]:
153
+ result ["is_end" ] = 1
154
+ result ["token_ids" ] = []
155
+ result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
156
+ result ["tokens_all_ids" ] = self .all_tokens [i ]
157
+
158
+ info_dict = {}
159
+ info_dict ["req_id" ] = task ["req_id" ]
160
+ info_dict ["input_token_num" ] = len (task ["input_ids" ])
161
+ info_dict ["output_token_num" ] = len (self .all_tokens [i ])
162
+ if hasattr (task , "preprocess_start_time" ) and hasattr (task , "preprocess_end_time" ):
163
+ info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
164
+ task ["preprocess_end_time" ])
165
+ if hasattr (task , "preprocess_end_time" ) and hasattr (task , "schedule_start_time" ):
166
+ info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
167
+ task ["schedule_start_time" ])
168
+ info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
169
+ info_dict ["version" ] = "OpenSource"
170
+ info_dict ["timestamp" ] = time .time ()
171
+ monitor_logger .info (f"{ info_dict } " )
172
+ break
173
+ else :
174
+ result ["token_ids" ].append (token_id )
160
175
161
176
return result
162
177
@@ -177,33 +192,42 @@ def _process_batch_output(self):
177
192
"""
178
193
tokens = self .output_tokens .numpy ()
179
194
batch = self .output_tokens [1 , 0 ]
180
- tokens = tokens [2 :batch + 2 ]
195
+ if not self .is_speculate_decoding :
196
+ tokens = tokens [2 :batch + 2 ]
197
+ else :
198
+ accept_num = tokens [2 :batch + 2 ]
181
199
182
200
batch_result = list ()
183
201
exist_finished_task = False
184
202
for i in range (batch ):
185
203
if self .resource_manager .stop_flags [i ]:
186
204
continue
187
205
188
- token_id = int (tokens [i , 0 ])
189
- if token_id < 0 :
206
+ if not self .is_speculate_decoding :
207
+ token_ids = [int (tokens [i , 0 ])]
208
+ else :
209
+ token_ids = tokens [2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS : 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num [i , 0 ], 0 ].tolist ()
210
+
211
+ if any (token_id < 0 for token_id in token_ids ):
190
212
continue
191
213
192
214
task = self .resource_manager .tasks_list [i ]
193
215
194
216
task_id = task ["req_id" ]
195
- result = self ._get_single_result (i , task_id , token_id , task )
196
-
197
- self .tokens_counter [task_id ] += 1
198
- if token_id not in task ["eos_token_ids" ]:
199
- self .all_tokens [i ].append (token_id )
200
-
201
- self .number_of_output_tokens += 1
202
- if token_id in task ["eos_token_ids" ]:
203
- self ._recycle_resources (task_id , i , task )
204
- model_server_logger .info ("req_id: {0} finished" .format (task_id ))
205
- model_server_logger .info (f"{ self .resource_manager .info ()} " )
206
- exist_finished_task = True
217
+ result = self ._get_single_result (i , task_id , token_ids , task )
218
+
219
+ for token_id in token_ids :
220
+ self .tokens_counter [task_id ] += 1
221
+ if token_id not in task ["eos_token_ids" ]:
222
+ self .all_tokens [i ].append (token_id )
223
+
224
+ self .number_of_output_tokens += 1
225
+ if token_id in task ["eos_token_ids" ]:
226
+ self ._recycle_resources (task_id , i , task )
227
+ model_server_logger .info ("req_id: {0} finished" .format (task_id ))
228
+ model_server_logger .info (f"{ self .resource_manager .info ()} " )
229
+ exist_finished_task = True
230
+ break
207
231
batch_result .append (result )
208
232
209
233
self .postprocess (batch_result , exist_finished_task )
@@ -228,7 +252,10 @@ def process_sampling_results(self):
228
252
while self ._is_running :
229
253
try :
230
254
rank_id = 0
231
- get_output (self .output_tokens , rank_id , self ._is_blocking )
255
+ if self .is_speculate_decoding :
256
+ speculate_get_output (self .output_tokens , rank_id , self ._is_blocking )
257
+ else :
258
+ get_output (self .output_tokens , rank_id , self ._is_blocking )
232
259
233
260
if self .output_tokens [0 , 0 ] == - 2 :
234
261
continue
0 commit comments