|
21 | 21 |
|
22 | 22 | logger = init_logger(__name__) |
23 | 23 |
|
24 | | -_COSYVOICE3_SPEECH_TOKEN_SIZE = 6561 |
25 | | - |
26 | 24 |
|
27 | 25 | def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None: |
28 | 26 | """Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct.""" |
@@ -85,20 +83,6 @@ def _prompt_speech_token_ids(multi_modal_data: dict[str, Any]) -> list[int]: |
85 | 83 | return _to_token_id_list(speech_token) |
86 | 84 |
|
87 | 85 |
|
88 | | -def _has_speech_stop_token(output_ids: list[Any]) -> bool: |
89 | | - return any(token_id >= _COSYVOICE3_SPEECH_TOKEN_SIZE for token_id in _to_token_id_list(output_ids)) |
90 | | - |
91 | | - |
92 | | -def _set_non_stream_prompt_trim(additional_info: dict[str, Any], prompt_speech_len: int) -> None: |
93 | | - if prompt_speech_len <= 0: |
94 | | - return |
95 | | - meta = additional_info.get("meta") |
96 | | - if not isinstance(meta, dict): |
97 | | - meta = {} |
98 | | - additional_info["meta"] = meta |
99 | | - meta["talker_prefill_offset"] = prompt_speech_len |
100 | | - |
101 | | - |
102 | 86 | def _to_cpu_tensor(x: Any) -> torch.Tensor | None: |
103 | 87 | if isinstance(x, list): |
104 | 88 | if not x: |
@@ -154,8 +138,6 @@ def text2flow( |
154 | 138 | output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids) |
155 | 139 | output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) |
156 | 140 | additional_info = dict(multi_modal_data) |
157 | | - if _has_speech_stop_token(raw_output_ids): |
158 | | - _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) |
159 | 141 | additional_info.setdefault("ids", {})["prompt"] = prefix_ids |
160 | 142 | engine_inputs.append(OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=additional_info)) |
161 | 143 | return engine_inputs |
@@ -389,8 +371,6 @@ def text2flow_token_only( |
389 | 371 | prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data) |
390 | 372 | output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids) |
391 | 373 | additional_info: dict[str, Any] = dict(multi_modal_data) |
392 | | - if _has_speech_stop_token(raw_output_ids): |
393 | | - _set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids)) |
394 | 374 | additional_info.setdefault("ids", {})["prompt"] = prefix_ids |
395 | 375 | engine_inputs.append( |
396 | 376 | OmniTokensPrompt( |
|
0 commit comments