Skip to content

Commit a894aef

Browse files
committed
Fixed error
1 parent 1d9c8e4 commit a894aef

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tools/qwen3/qwen3_chat/generate.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ def init_history() -> list[dict[str, Any]]:
1010
return [{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}]
1111

1212

13+
def _extract_text_ids(output: Any) -> torch.Tensor:
14+
if isinstance(output, tuple):
15+
return output[0]
16+
return output
17+
18+
1319
def generate_response(
1420
model: Any,
1521
processor: Any,
@@ -34,23 +40,24 @@ def generate_response(
3440
clean_up_tokenization_spaces=False,
3541
),
3642
"thinker_do_sample": False,
37-
"thinker_max_new_tokens": 10,
3843
}
3944
if speaker:
4045
gen_kwargs["speaker"] = speaker
4146

47+
input_len = inputs["input_ids"].shape[-1]
48+
4249
if enable_audio:
4350
gen_kwargs["return_audio"] = True
4451
gen_kwargs["talker_do_sample"] = True
4552
text_ids, audio = model.generate(**inputs, **gen_kwargs)
4653
else:
4754
gen_kwargs["return_audio"] = False
48-
text_ids = model.generate(**inputs, **gen_kwargs)
55+
output = model.generate(**inputs, **gen_kwargs)
56+
text_ids = _extract_text_ids(output)
4957
audio = None
5058

59+
generated_ids = text_ids[:, input_len:]
5160
text = processor.batch_decode(
52-
text_ids,
53-
skip_special_tokens=True,
54-
clean_up_tokenization_spaces=False,
61+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False,
5562
)[0]
5663
return text, audio

0 commit comments

Comments
 (0)