Skip to content

Commit 139a3d7

Browse files
authored
[template] update template decode_generate_ids (#9523)
1 parent 5728eda commit 139a3d7

11 files changed

Lines changed: 25 additions & 24 deletions

File tree

swift/infer_engine/grpo_vllm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque
108108
choices = []
109109
for output in result.outputs:
110110
output.token_ids = list(output.token_ids)
111-
response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
111+
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
112112
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
113113
toolcall = self._get_toolcall(response)
114114

swift/infer_engine/lmdeploy_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def _infer_stream_async(
226226
toolcall = None
227227
if is_finished:
228228
toolcall = self._get_toolcall(
229-
self.template.decode(output.token_ids, template_inputs=inputs['template_inputs']))
229+
self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs']))
230230
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
231231
output.status.name == 'FINISH')
232232
choices = [
@@ -261,7 +261,7 @@ async def _infer_full_async(
261261
async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs):
262262
pass
263263

264-
response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
264+
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
265265
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
266266

267267
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)

swift/infer_engine/sglang_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _create_chat_completion_response(self, output, inputs, return_details: bool
185185
assert output is not None
186186
meta_info = output['meta_info']
187187
usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens'])
188-
response = self.template.decode(output['output_ids'], template_inputs=inputs['template_inputs'])
188+
response = self.template.decode_generate_ids(output['output_ids'], template_inputs=inputs['template_inputs'])
189189
toolcall = self._get_toolcall(response)
190190
token_ids = output['output_ids'] if return_details else None
191191
choice = ChatCompletionResponseChoice(
@@ -289,7 +289,8 @@ def _create_chat_completion_stream_response(self, output, infer_streamer) -> Opt
289289
toolcall = None
290290
if is_finished:
291291
finish_reason = finish_reason['type']
292-
toolcall = self._get_toolcall(self.template.decode(output['output_ids'], **infer_streamer.decode_kwargs))
292+
toolcall = self._get_toolcall(
293+
self.template.decode_generate_ids(output['output_ids'], **infer_streamer.decode_kwargs))
293294
meta_info = output['meta_info']
294295
usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens'])
295296
# TODO: logprobs

swift/infer_engine/transformers_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _model_generate(**kwargs):
310310
toolcall = None
311311
if is_finished[i]:
312312
toolcall = self._get_toolcall(
313-
self.template.decode(generate_ids, template_inputs=template_inputs[i]))
313+
self.template.decode_generate_ids(generate_ids, template_inputs=template_inputs[i]))
314314
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, usage_info.completion_tokens,
315315
is_finished[i])
316316

@@ -434,7 +434,7 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo
434434

435435
logprobs = self._get_logprobs(logprobs_list, generate_ids, request_config.top_logprobs)
436436
usage_info = self._update_usage_info(usage_info, len(generate_ids))
437-
response = self.template.decode(generate_ids, template_inputs=template_inputs[i])
437+
response = self.template.decode_generate_ids(generate_ids, template_inputs=template_inputs[i])
438438
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True)
439439
toolcall = self._get_toolcall(response)
440440
token_ids = generate_ids if request_config.return_details else None

swift/infer_engine/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str:
8787
raw_tokens = raw_tokens[self.cache_idx:]
8888
if self.first_token:
8989
raw_tokens = []
90-
response = self.template.decode(
90+
response = self.template.decode_generate_ids(
9191
raw_tokens, is_finished=is_finished, first_token=self.first_token, **self.decode_kwargs)
9292
response = self._align_blank_suffix(response)
9393
return self._get_response(response, is_finished, len(raw_tokens))

swift/infer_engine/vllm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def _create_chat_completion_stream_response(self, result, request_config, reques
611611
toolcall = None
612612
if output.is_finished:
613613
toolcall = self._get_toolcall(
614-
self.template.decode(output.token_ids, **infer_streamers[i].decode_kwargs))
614+
self.template.decode_generate_ids(output.token_ids, **infer_streamers[i].decode_kwargs))
615615

616616
choice = ChatCompletionResponseStreamChoice(
617617
index=i,
@@ -664,7 +664,7 @@ def _create_chat_completion_response(
664664
choices = []
665665
for output in result.outputs:
666666
output.token_ids = list(output.token_ids)
667-
response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
667+
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
668668

669669
# Extract reasoning content if reasoning_parser is enabled
670670
reasoning_content = None

swift/template/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -721,13 +721,13 @@ def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int):
721721
logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)]
722722
return preds, logprobs
723723

724-
def decode(self,
725-
generate_ids: List[int],
726-
*,
727-
is_finished: bool = True,
728-
first_token=True,
729-
template_inputs=None,
730-
**kwargs) -> Any:
724+
def decode_generate_ids(self,
725+
generate_ids: List[int],
726+
*,
727+
is_finished: bool = True,
728+
first_token=True,
729+
template_inputs=None,
730+
**kwargs) -> Any:
731731
if kwargs.get('spaces_between_special_tokens') is None:
732732
kwargs['spaces_between_special_tokens'] = False
733733
generate_ids = self.skip_stop_tokens(generate_ids, is_finished)

swift/template/templates/baai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None
115115
res['logits_processor'] = logits_processor
116116
return res
117117

118-
def decode(self, generate_ids: List[int], **kwargs) -> Any:
118+
def decode_generate_ids(self, generate_ids: List[int], **kwargs) -> Any:
119119
mm_list = self.processor.decode(generate_ids)
120120
for im in mm_list:
121121
if not isinstance(im, Image.Image):

swift/template/templates/deepseek.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def generate(self, model, *args, **kwargs):
186186

187187
return {'sequences': generated_tokens}
188188

189-
def decode(self, generate_ids: List[int], **kwargs) -> Any:
189+
def decode_generate_ids(self, generate_ids: List[int], **kwargs) -> Any:
190190
if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
191-
return super().decode(generate_ids, **kwargs)
191+
return super().decode_generate_ids(generate_ids, **kwargs)
192192
else:
193193
img_size = get_env_args('img_size', int, 384)
194194
patch_size = 16

swift/template/templates/glm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def _swift_encode(self, inputs: StdTemplateInputs):
3333
res_context_list[i] = res_context_list[i][:-len('\n')]
3434
return res_context_list, loss_scale_list, answer_len
3535

36-
def decode(self, *args, **kwargs):
37-
response = super().decode(*args, **kwargs)
36+
def decode_generate_ids(self, *args, **kwargs):
37+
response = super().decode_generate_ids(*args, **kwargs)
3838
return response.lstrip('\n') if self.strip_newline else response
3939

4040

0 commit comments

Comments
 (0)