@@ -10,6 +10,12 @@ def init_history() -> list[dict[str, Any]]:
1010 return [{"role" : "system" , "content" : [{"type" : "text" , "text" : SYSTEM_PROMPT }]}]
1111
1212
13+ def _extract_text_ids (output : Any ) -> torch .Tensor :
14+ if isinstance (output , tuple ):
15+ return output [0 ]
16+ return output
17+
18+
1319def generate_response (
1420 model : Any ,
1521 processor : Any ,
@@ -34,23 +40,24 @@ def generate_response(
3440 clean_up_tokenization_spaces = False ,
3541 ),
3642 "thinker_do_sample" : False ,
37- "thinker_max_new_tokens" : 10 ,
3843 }
3944 if speaker :
4045 gen_kwargs ["speaker" ] = speaker
4146
47+ input_len = inputs ["input_ids" ].shape [- 1 ]
48+
4249 if enable_audio :
4350 gen_kwargs ["return_audio" ] = True
4451 gen_kwargs ["talker_do_sample" ] = True
4552 text_ids , audio = model .generate (** inputs , ** gen_kwargs )
4653 else :
4754 gen_kwargs ["return_audio" ] = False
48- text_ids = model .generate (** inputs , ** gen_kwargs )
55+ output = model .generate (** inputs , ** gen_kwargs )
56+ text_ids = _extract_text_ids (output )
4957 audio = None
5058
59+ generated_ids = text_ids [:, input_len :]
5160 text = processor .batch_decode (
52- text_ids ,
53- skip_special_tokens = True ,
54- clean_up_tokenization_spaces = False ,
61+ generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False ,
5562 )[0 ]
5663 return text , audio
0 commit comments