Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque
choices = []
for output in result.outputs:
output.token_ids = list(output.token_ids)
response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
toolcall = self._get_toolcall(response)

Expand Down
4 changes: 2 additions & 2 deletions swift/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def _infer_stream_async(
toolcall = None
if is_finished:
toolcall = self._get_toolcall(
self.template.decode(output.token_ids, template_inputs=inputs['template_inputs']))
self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs']))
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
output.status.name == 'FINISH')
choices = [
Expand Down Expand Up @@ -261,7 +261,7 @@ async def _infer_full_async(
async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs):
pass

response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)

usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
Expand Down
5 changes: 3 additions & 2 deletions swift/infer_engine/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _create_chat_completion_response(self, output, inputs, return_details: bool
assert output is not None
meta_info = output['meta_info']
usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens'])
response = self.template.decode(output['output_ids'], template_inputs=inputs['template_inputs'])
response = self.template.decode_generate_ids(output['output_ids'], template_inputs=inputs['template_inputs'])
toolcall = self._get_toolcall(response)
token_ids = output['output_ids'] if return_details else None
choice = ChatCompletionResponseChoice(
Expand Down Expand Up @@ -289,7 +289,8 @@ def _create_chat_completion_stream_response(self, output, infer_streamer) -> Opt
toolcall = None
if is_finished:
finish_reason = finish_reason['type']
toolcall = self._get_toolcall(self.template.decode(output['output_ids'], **infer_streamer.decode_kwargs))
toolcall = self._get_toolcall(
self.template.decode_generate_ids(output['output_ids'], **infer_streamer.decode_kwargs))
meta_info = output['meta_info']
usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens'])
# TODO: logprobs
Expand Down
4 changes: 2 additions & 2 deletions swift/infer_engine/transformers_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _model_generate(**kwargs):
toolcall = None
if is_finished[i]:
toolcall = self._get_toolcall(
self.template.decode(generate_ids, template_inputs=template_inputs[i]))
self.template.decode_generate_ids(generate_ids, template_inputs=template_inputs[i]))
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, usage_info.completion_tokens,
is_finished[i])

Expand Down Expand Up @@ -434,7 +434,7 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo

logprobs = self._get_logprobs(logprobs_list, generate_ids, request_config.top_logprobs)
usage_info = self._update_usage_info(usage_info, len(generate_ids))
response = self.template.decode(generate_ids, template_inputs=template_inputs[i])
response = self.template.decode_generate_ids(generate_ids, template_inputs=template_inputs[i])
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True)
toolcall = self._get_toolcall(response)
token_ids = generate_ids if request_config.return_details else None
Expand Down
2 changes: 1 addition & 1 deletion swift/infer_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str:
raw_tokens = raw_tokens[self.cache_idx:]
if self.first_token:
raw_tokens = []
response = self.template.decode(
response = self.template.decode_generate_ids(
raw_tokens, is_finished=is_finished, first_token=self.first_token, **self.decode_kwargs)
Comment thread
Jintao-Huang marked this conversation as resolved.
response = self._align_blank_suffix(response)
return self._get_response(response, is_finished, len(raw_tokens))
Expand Down
4 changes: 2 additions & 2 deletions swift/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def _create_chat_completion_stream_response(self, result, request_config, reques
toolcall = None
if output.is_finished:
toolcall = self._get_toolcall(
self.template.decode(output.token_ids, **infer_streamers[i].decode_kwargs))
self.template.decode_generate_ids(output.token_ids, **infer_streamers[i].decode_kwargs))

choice = ChatCompletionResponseStreamChoice(
index=i,
Expand Down Expand Up @@ -664,7 +664,7 @@ def _create_chat_completion_response(
choices = []
for output in result.outputs:
output.token_ids = list(output.token_ids)
response = self.template.decode(output.token_ids, template_inputs=inputs['template_inputs'])
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])

# Extract reasoning content if reasoning_parser is enabled
reasoning_content = None
Expand Down
Loading