diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..749ccda --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class diff --git a/bin/inference.py b/bin/inference.py index f0d9046..9a2673d 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -1,14 +1,7 @@ from __future__ import print_function import argparse -import os -import json -import queue import torch -import yaml -import threading -import struct -import time import torchaudio import datetime import builtins @@ -16,15 +9,11 @@ import soundfile as sf import numpy as np -import torch.nn.functional as F import torchaudio.compliance.kaldi as k -from torch.utils.data import DataLoader - +from models.utils import print_outputs from models.pipeline import inferencePipeline from models.decoder.llm2tts import llm2TTS -from web.parms import GlobalParams -from web.pool import TTSObjectPool def get_args(): parser = argparse.ArgumentParser(description='Freeze-Omni') @@ -40,6 +29,11 @@ def get_args(): print(args) return args +def custom_print(*args, **kwargs): + current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + original_print(f'[{current_time}]', *args, **kwargs) + + class audioEncoderProcessor: def __init__(self, chunk_size = 16): self.chunk_size = 16 @@ -70,8 +64,13 @@ def chunk_data_shift(self, xs): def process(self, audio: torch.Tensor): + """ + # 1. Converts the input audio tensor to the appropriate format. + # 2. Computes the filter bank features (fbank) for the audio. + # 3. Updates the input chunk and history based on the new audio segment. + """ with torch.no_grad(): - sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768 + sample_data = audio.clone().reshape(1, -1, 1)[:, :, :1] * 32768 self.fbank_shift(sample_data) # use kaldi api to compute fbank xs = k.fbank(waveform = self.input_sample.squeeze(-1), dither=0, @@ -80,6 +79,9 @@ def process(self, return self.input_chunk.clone() def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav): + """ + Decodes the current hidden state and text to generate audio segments using speech decoder. + """ hidden_state_output = torch.cat(cur_hidden_state).squeeze(1) cur_text_procced = pipeline.post_process(cur_text) print("Synthesis: ", [cur_text_procced]) @@ -91,7 +93,7 @@ def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_p codec_chunk_size, codec_padding_size): wav.append(seg) -def inference(pipeline, audio_processor, tts, configs): +def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, tts:llm2TTS, configs): """ Perform inference for a speech dialogue system. @@ -104,11 +106,18 @@ def inference(pipeline, audio_processor, tts, configs): Returns: - None """ - wav, fs = sf.read(configs.input_wav) - wav = torch.tensor(wav) + wav, fs = torchaudio.load(configs.input_wav) if fs != 16000: - wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav) fs = 16000 + wav = wav.reshape(-1) + + #wav, fs = sf.read(configs.input_wav) + #wav = torch.tensor(wav) + #if fs != 16000: + # wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + # fs = 16000 + print("--->",wav.shape) codec_chunk_size = 40 codec_padding_size = 10 @@ -118,6 +127,7 @@ def inference(pipeline, audio_processor, tts, configs): # set system role, stat will be set to 'sl' stat = 'pre' outputs = pipeline.speech_dialogue(None, stat=stat, role="You are a helpful assistant.") + print(f"pre-> outputs:[{print_outputs(outputs)}]") chunk_size = audio_processor.get_chunk_size() # Satge1: start listen @@ -125,18 +135,29 @@ def inference(pipeline, audio_processor, tts, configs): wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) wav_input[:wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): + print("--->",wav_input.shape, wav.shape,wav_input[i:i+chunk_size].shape) + print(f"cl in-> outputs:{print_outputs(outputs)}") + if outputs['stat'] =="sl": + print(f"stat_chunk data:{wav_input[i:i+chunk_size]}") fbank = audio_processor.process(wav_input[i:i+chunk_size]) + if outputs['stat'] =="sl": + print(f"fbank:{fbank}") outputs = pipeline.speech_dialogue(fbank, **outputs) + print(f"cl out-> outputs:{print_outputs(outputs)}") outputs['stat'] = 'cl' audio_processor.reset() + print("listen",outputs.keys()) + print(f"listen-> outputs:[{print_outputs(outputs)}]") outputs['adapter_cache'] = None outputs['encoder_cache'] = None outputs['pe_index'] = 0 outputs['stat'] = 'ss' + print(f"speak get-> outputs:[{print_outputs(outputs)}]") # Stage3: start speak outputs = pipeline.speech_dialogue(None, **outputs) + print(f"ss-> outputs:[{print_outputs(outputs)}]") cur_hidden_state = [] cur_hidden_state.append(outputs['hidden_state']) @@ -155,6 +176,7 @@ def inference(pipeline, audio_processor, tts, configs): del outputs['text'] del outputs['hidden_state'] outputs = pipeline.speech_dialogue(None, **outputs) + print(f"sc-> outputs:[{print_outputs(outputs)}]") if outputs['stat'] == 'cs': cur_hidden_state.append(outputs['hidden_state']) whole_text += outputs['text'][len(last_text):] @@ -168,12 +190,14 @@ def inference(pipeline, audio_processor, tts, configs): decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav) cur_hidden_state = [] + print(f"cur_text:{cur_text}") cur_text = "" if outputs['stat'] == 'sl': break - # print(outputs['text']) + #print(outputs['text']) last_text = outputs['text'] if len(cur_hidden_state) != 0: + print(f"cur_text:{cur_text}") decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav) @@ -183,8 +207,15 @@ def inference(pipeline, audio_processor, tts, configs): print(whole_text) if __name__ == '__main__': + # change print function to add time stamp + original_print = builtins.print + builtins.print = custom_print + configs = get_args() + # encoder and audio llm pipeline = inferencePipeline(configs) + # decoder tts = llm2TTS(configs.model_path) + # stream chunk to encoder audio_processor = audioEncoderProcessor() inference(pipeline, audio_processor, tts, configs) diff --git a/bin/inference_stream.py b/bin/inference_stream.py new file mode 100644 index 0000000..1c76c52 --- /dev/null +++ b/bin/inference_stream.py @@ -0,0 +1,563 @@ +from __future__ import print_function + +import builtins +import datetime +import time +import math +import argparse +import threading +from copy import deepcopy +from dataclasses import dataclass +from typing import Generator + +import soundfile as sf +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as k + +from bin.inference import audioEncoderProcessor +from models.utils import print_outputs +from web.queue import PCMQueue, ThreadSafeQueue +from models.pipeline import inferencePipeline +from models.decoder.llm2tts import llm2TTS + + +def custom_print(*args, **kwargs): + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + original_print(f"[{current_time}]", *args, **kwargs) + + +def get_args(): + parser = argparse.ArgumentParser(description="Freeze-Omni-Inference-stream") + parser.add_argument("--model_path", required=True, help="model_path to load") + parser.add_argument("--llm_path", required=True, help="llm_path to load") + parser.add_argument("--top_k", type=int, default=5) + parser.add_argument("--top_p", type=float, default=0.8) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--input_wav", required=True, help="input wav") + parser.add_argument("--output_wav", required=True, help="output wav") + + args = parser.parse_args() + print(args) + return args + + +class GlobalVars: + """ + multi turn conversation interal speech dialogue outputs for listen and speak + """ + + speech_dialogue_outputs = {} + + @staticmethod + def deepcopy_outputs(): # cp for write + return deepcopy(GlobalVars.speech_dialogue_outputs) + + +@dataclass +class PCMStatChunk: + status: str # sl(start listen), cl(continue listen), el(end listen) + data: torch.Tensor # chunk tensor + + def __post_init__(self): + if self.status not in ["sl", "cl", "el"]: + raise ValueError("status must be one of: 'sl','cl','el'") + + def __str__(self): + return ( + f"status:{self.status} data shape:{self.data.shape if self.data is not None else None}" + ) + + +@dataclass +class Listen2SpeakFrame: + """ + outputs have stat, so need pass it from listen to speak + """ + + outputs: dict + + def __str__(self): + return f"outputs:{print_outputs(self.outputs)}" + + +@dataclass +class GenTTSFrame: + text: str + data: np.ndarray # a numpy array of dtype np.float32 + + def __str__(self): + return f"text:{self.text} numpy data:{self.data.shape if self.data is not None else None}" + + +class PCMListener: + def __init__( + self, + outputs_queue: ThreadSafeQueue, # outputs -> out queue + pipeline: inferencePipeline, + cache_history_size: int = 10, + system_prompt: str = "You are a helpful assistant.", + ) -> None: + self.pcm_stat_chunk_queue = ThreadSafeQueue() # pcm data -> in queue + + self.outputs_queue = outputs_queue + + # stream chunk to encoder + self.audio_processor = audioEncoderProcessor() + # encoder and audio llm + self.pipeline = pipeline + + # pre status system prompt, outputs stat: pre -> sl + self.outputs = self.pipeline.speech_dialogue(None, stat="pre", role=system_prompt) + # print(f"pre-> outputs:[{print_outputs(self.outputs)}]") + GlobalVars.speech_dialogue_outputs = self.outputs + + # chunck feat history cache also use ring buffer to do :) + self.history = torch.zeros( + [ + cache_history_size, + self.audio_processor.chunk_size + self.audio_processor.chunk_overlap, + self.audio_processor.feat_dim, + ] + ) + + # start listen thread + self.listen_thread = threading.Thread(target=self.listen, args=()) + self.stop_listen = False + self.listen_thread.start() + + def stop(self): + self.stop_listen = True + self.listen_thread.join(timeout=3) + + def send(self, pcm_items: torch.Tensor | None, status: str): + if pcm_items is None: + item = PCMStatChunk(status=status, data=None) + self.pcm_stat_chunk_queue.put(item) + else: + item = PCMStatChunk(status=status, data=pcm_items) + self.pcm_stat_chunk_queue.put(item) + + def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: + # cache fbank feature (input_chunk) + # << 1 + self.history[:-1] = self.history[1:].clone() + # last history = input chunk + self.history[-1:] = input_chunk + # start listen to # copy last 6 chunks + feature_last_chunk = self.history[-6:].unsqueeze(1) + return feature_last_chunk + + def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_first_pack=False): + """ + Prefills the LLM of speech dialogue system using speech. + + Parameters: + - status: the current state of the audio input. + - feature: audio feature tensor. + - outputs (dict): A dictionary containing the current state of the dialogue system. + - is_first_pack (bool, optional): Indicates if the current input packet is the first one in a new conversation + """ + + if status == "sl": + # Satge1: start listen + # stat will be auto set to 'cl' after Stage1 + outputs = self.pipeline.speech_dialogue(feature, **outputs) + return outputs + if status == "el": + print("status end listen. start to speak") + return outputs + + if status == "cl": + if outputs["stat"] == "cl" or outputs["stat"] == "sl": + # Stage2: continue listen + # stat will be auto set to 'ss' when endpoint is detected + # print("output stat continue listen") + outputs = self.pipeline.speech_dialogue(feature, **outputs) + if is_first_pack: + outputs["stat"] = "cl" + if outputs["stat"] == "el": + print("output stat end listen. Detect invalid break") + if outputs["stat"] == "ss": + # print(f"start speak. start to speak") + pass + return outputs + + def listen(self): + """ + chunk status from VAD + """ + print("Start listening") + while True: + if self.stop_listen: + print("Stop listening") + break + stat_chunk: PCMStatChunk = self.pcm_stat_chunk_queue.get() + if stat_chunk is None: + time.sleep(0.01) + continue + # print(f"Received PCM stat chunk: {stat_chunk}") + + # if self.outputs['stat'] =="sl": + # print(f"stat_chunk data:{stat_chunk.data}") + + if stat_chunk.status == "sl": + fbank_feature = self.audio_processor.process(stat_chunk.data) + self.outputs = ( + GlobalVars.deepcopy_outputs() + ) # for next turn conversation outputs from speak where in gloab var + # print(f"sl-> outputs:{print_outputs(self.outputs)}") + self.outputs["adapter_cache"] = None + self.outputs["encoder_cache"] = None + self.outputs["pe_index"] = 0 + self.outputs["stat"] = "sl" + self.outputs["last_id"] = None + if "text" in self.outputs: + del self.outputs["text"] + if "hidden_state" in self.outputs: + del self.outputs["hidden_state"] + + self.outputs = self.llm_prefill("sl", fbank_feature, self.outputs) + + elif stat_chunk.status == "cl": + # print(f"cl in-> outputs:{print_outputs(self.outputs)}") + fbank_feature = self.audio_processor.process(stat_chunk.data) + # if self.outputs['stat'] =="sl": + # print(f"fbank:{fbank_feature}") + self.outputs = self.llm_prefill(stat_chunk.status, fbank_feature, self.outputs) + # print(f"cl out-> outputs:{print_outputs(self.outputs)}") + if self.outputs["stat"] == "ss": + pass + else: + self.outputs["stat"] = "cl" + elif stat_chunk.status == "el": + self.audio_processor.reset() + + # print(f"el-> outputs:[{print_outputs(self.outputs)}]") + self.outputs["adapter_cache"] = None + self.outputs["encoder_cache"] = None + self.outputs["pe_index"] = 0 + self.outputs["last_id"] = None + self.outputs["stat"] = "ss" + + if self.outputs["stat"] == "ss": + frame = Listen2SpeakFrame(self.outputs) + # print(f"start to speak, send frame:[{frame}]") + self.outputs_queue.put(frame) + + +@dataclass +class TTSSpeakerArgs: + # https://huggingface.co/VITA-MLLM/Freeze-Omni/blob/main/checkpoints/server.json + # decoder(LLM2TTSCodecAR) + # NAR llama transformer decoder pre_nn_forward -> NAR llama transformer decoder kv_cache_prefix_forward -> AR llama transformer decoder transformer_infer + + # llama transformer decoder + decoder_top_k: int = 2 + decoder_penalty_window_size: int = -1 # <0 no penalty window + decoder_penalty: float = 1.1 + + # codec decoder + decoder_first_chunk_size: int = 20 + decoder_chunk_size: int = 40 + decoder_chunk_overlap_size: int = 10 + + # find_min_sum_index + decoder_N: int = 2401 + decoder_seg_threshold_first_pack: float = 0.1 + decoder_seg_threshold: float = 0.015 + + +class TTSSpeaker: + def __init__( + self, + args: TTSSpeakerArgs, + outputs_in_queue: ThreadSafeQueue, # out queue -> speak + pipeline: inferencePipeline, + tts: llm2TTS, + ) -> None: + self.args = args + self.outputs_in_queue = outputs_in_queue + self.pipeline = pipeline + self.tts = tts + + self.reset() + + # speak thread + self.speak_thread = threading.Thread(target=self.speak, args=()) + self.speak_thread.start() + + def reset(self): + self.stop_speak = False + self.is_generate = False + self.whole_text = "" + + self.tts_over = False + self.tts_over_time = 0 + self.tts_data = ThreadSafeQueue() + + self.stop_tts = False + + def print(self): + print("stop_speak:", self.stop_speak) + print("is_generate:", self.is_generate) + print("whole_text:", self.whole_text) + print("tts_over:", self.tts_over) + print("tts_over_time:", self.tts_over_time) + + @property + def gen_text(self): + """Get the whole text.""" + return self.whole_text + + def stop(self): + """Stop the speak thread.""" + self.stop_speak = True + self.speak_thread.join() + + def decoder( + self, cur_hidden_state: list[torch.Tensor], cur_text: str, generate_num: int + ) -> int: + """ + Decodes the current hidden state and text to generate audio segments using speech decoder. + + Parameters: + - cur_hidden_state (list of torch.Tensor): The current hidden state of the language model. + - cur_text (str): The current text to be synthesized. + - generate_num (int): The number of audio segments generated + + Returns: + - int: The updated number of audio segments generated. + """ + hidden_state_output = torch.cat(cur_hidden_state).squeeze(1) + cur_text_procced = self.pipeline.post_process(cur_text) + print("Synthesis: ", [cur_text_procced]) + embeddings = self.pipeline.model.llm_decoder.model.embed_tokens( + torch.tensor(self.pipeline.model.tokenizer.encode(cur_text_procced)).cuda() + ) + codec_chunk_size = self.args.decoder_first_chunk_size + codec_padding_size = self.args.decoder_chunk_overlap_size + seg_threshold = self.args.decoder_seg_threshold_first_pack + if generate_num != 0: + codec_chunk_size = self.args.decoder_chunk_size + seg_threshold = self.args.decoder_seg_threshold + for seg in self.tts.run( + embeddings.reshape(-1, 896).unsqueeze(0), + self.args.decoder_top_k, + hidden_state_output.reshape(-1, 896).unsqueeze(0), + codec_chunk_size=codec_chunk_size, + codec_padding_size=codec_padding_size, + penalty_window_size=self.args.decoder_penalty_window_size, + penalty=self.args.decoder_penalty, + N=self.args.decoder_N, + seg_threshold=seg_threshold, + ): + if generate_num == 0: + try: + split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] + seg = seg[:, :, split_idx:] + except Exception: + print("Do not need to split") + pass + generate_num += 1 + if self.tts_over: + self.tts_data.clear() + self.whole_text = "" + self.tts_data.put(GenTTSFrame(text="", data=None)) + break + frame = GenTTSFrame(text=cur_text, data=seg.squeeze().float().cpu().numpy()) + self.tts_data.put(frame) + return generate_num + + def get_tts_data(self) -> Generator[GenTTSFrame, None, None]: + """ + get tts bytes data + """ + while True: + if self.stop_speak: + print("Stop speak so break get tts data") + break + output_data = self.tts_data.get() + if output_data is not None: + # print("Get TTS data") + # yield output_data.astype(np.int16).tobytes() + yield output_data + else: + yield None + + def speak(self): + """ + Generates speech dialogue output based on the current state + """ + while True: + if self.stop_speak: + print("Stop speak") + break + frame: Listen2SpeakFrame = self.outputs_in_queue.get() + if frame is None: + time.sleep(0.01) + continue + + outputs = frame.outputs + # print(f"speak get-> outputs:[{print_outputs(outputs)}]") + # Stage3: start speak + self.is_generate = True + outputs = self.pipeline.speech_dialogue(None, **outputs) + # print(f"ss-> outputs:[{print_outputs(outputs)}]") + + cur_hidden_state = [] + cur_hidden_state.append(outputs["hidden_state"]) + + # Stage4: contiune speak until stat is set to 'sl' + # use 'stop' to interrupt generation, stat need to be manually set as 'sl' + stop = False + cur_text = "" + last_text = "" + generate_num = 0 + while True: + if self.stop_speak: + break + if len(outputs["past_tokens"]) > 500: + stop = True + if stop: + break + del outputs["text"] + del outputs["hidden_state"] + outputs = self.pipeline.speech_dialogue(None, **outputs) + # print(f"sc-> outputs:[{print_outputs(outputs)}]") + if outputs["stat"] == "cs": + cur_hidden_state.append(outputs["hidden_state"]) + if "�" in outputs["text"][len(last_text) :]: + continue + self.whole_text += outputs["text"][len(last_text) :] + cur_text += outputs["text"][len(last_text) :] + # print(self.whole_text]) + if generate_num == 0 or (len(cur_hidden_state) >= 20): + suffix_list = [",", ",", "。", ":", "?", "!", ".", ":", "?", "!", "\n"] + else: + suffix_list = ["。", ":", "?", "!", ".", "?", "!", "\n"] + if outputs["text"][len(last_text) :].endswith(tuple(suffix_list)) and ( + len(cur_hidden_state) >= 4 + ): + if ( + outputs["text"][len(last_text) :].endswith(".") + and last_text[-1].isdigit() + ): + pass + else: + if not self.tts_over: + if len(cur_hidden_state) > 0: + generate_num = self.decoder( + cur_hidden_state, cur_text, generate_num + ) + cur_text = "" + cur_hidden_state = [] + last_text = outputs["text"] + else: + break + if not self.tts_over: + if len(cur_hidden_state) != 0: + generate_num = self.decoder(cur_hidden_state, cur_text, generate_num) + cur_text = "" + self.is_generate = False + outputs["stat"] = "sl" + outputs["last_id"] = None + GlobalVars.speech_dialogue_outputs = deepcopy(outputs) + print(self.whole_text) + self.tts_data.put(GenTTSFrame(text="", data=None)) + + def interrupt(self): + self.stop_speak = True + self.tts_over = True + while True: + time.sleep(0.01) + if self.is_generate is False: + self.stop_speak = False + while True: + time.sleep(0.01) + if self.tts_data.is_empty(): + self.whole_text = "" + self.tts_over = False + self.tts_over_time += 1 + break + break + + +def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): + """ + Perform inference for a speech dialogue system. + - 流式语音输入通过语音编码器形成chunk-wise特征,然后通过适配器连接到LLM。 + - LLM生成隐藏状态和文本标记,在块分割后分别以块的形式发送到NAR前缀语音解码器和NAR语音解码器。 + - 最后,AR语音解码器将生成的令牌发送到语音令牌FIFO中,流式编解码器根据固定的语音令牌块大小从FIFO生成流式语音输出。 + + Parameters: + - listener: listen pcm data(chunk) to asr with pipeline(encoder and adpter), + - detail status: sl, cl, el, ss + - speaker: tts speaker with pipeline(audio llm) and decoder (NAR decoder AR decoder and codec decoder) + - detail status: ss, cs + - configs: Input args. (argparse) + + Returns: + - None + """ + wav, fs = sf.read(configs.input_wav) + wav = torch.tensor(wav) + if fs != 16000: + wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + fs = 16000 + + # like io_uring + chunk_size = listener.audio_processor.get_chunk_size() + wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) + wav_input[: wav.shape[0]] = wav + for i in range(0, wav_input.shape[0], chunk_size): + print("--->", wav_input.shape, wav.shape, wav_input[i : i + chunk_size].shape) + status = "cl" + if i == 0: + status = "sl" + # send pcm data with status + listener.send(wav_input[i : i + chunk_size], status) + listener.send(None, "el") + + wavs = [] + # get tts speak data + for item in speaker.get_tts_data(): + if item: + print(item) + if item.data is None: + break + else: + wavs.append(item.data) + else: + time.sleep(0.01) # yield thread + + sf.write(configs.output_wav, np.concatenate(wavs, -1), 24000) + print(f"write to {configs.output_wav}") + + listener.stop() + speaker.stop() + + +if __name__ == "__main__": + # change print function to add time stamp + original_print = builtins.print + builtins.print = custom_print + + configs = get_args() + + # encoder and audio llm + pipeline = inferencePipeline(configs) + # decoder + tts = llm2TTS(configs.model_path) + # listen -> gen_queue -> speak + gen_queue = ThreadSafeQueue() + # listen + listener = PCMListener(gen_queue, pipeline) + # speak + speaker = TTSSpeaker(TTSSpeakerArgs(), gen_queue, pipeline, tts) + + inference_stream(listener, speaker, configs) + + +# format: ruff format bin/inference_stream.py diff --git a/bin/server.py b/bin/server.py index 533b5fd..8d61382 100644 --- a/bin/server.py +++ b/bin/server.py @@ -36,6 +36,8 @@ def get_args(): parser.add_argument('--max_users', type=int, default=5) parser.add_argument('--llm_exec_nums', type=int, default=1) parser.add_argument('--timeout', type=int, default=600) + parser.add_argument("--ngrok", action='store_true', help="use ngrok proxy") + parser.add_argument("--ssl", action='store_true', help="use ssl") args = parser.parse_args() print(args) return args @@ -44,6 +46,13 @@ def custom_print(*args, **kwargs): current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] original_print(f'[{current_time}]', *args, **kwargs) +""" +“模型作为服务器”策略来实现语音到语音对话系统。 +首先,我们同时启动多个模型,并将它们视为服务器。 +然后,当用户的VAD被触发时,语音将以块的形式发送到服务器,服务器将负责调度哪个空闲模型应该响应当前的块。 +由于我们在推理过程中将语音编码器和LLM的所有kv-cache和CNN缓存分开,因此服务器只需要保存每个用户的推理缓存。 +这样,服务器中的任何模型都可以响应任何用户的任何块,并且不需要指定哪个模型用作监视器或生成器。 +""" # init parms configs = get_args() # read server related config @@ -216,8 +225,9 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): # Satge1: start listen # stat will be auto set to 'cl' after Stage1 outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( - torch.tensor(data['feature']), + data['feature'], **outputs) + print(f"sl -> speech_dialogue outputs stat: {outputs['stat']}") if data['status'] == 'el': connected_users[sid][1].wakeup_and_vad.in_dialog = False @@ -228,8 +238,9 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( - torch.tensor(data['feature']), + data['feature'], **outputs) + print(f"cl -> speech_dialogue outputs stat: {outputs['stat']}") if is_first_pack: outputs['stat'] = 'cl' if outputs['stat'] == 'el': @@ -251,7 +262,7 @@ def send_pcm(sid): - sid (str): The session ID of the user. """ - chunk_szie = connected_users[sid][1].wakeup_and_vad.get_chunk_size() + chunk_size = connected_users[sid][1].wakeup_and_vad.get_chunk_size() print("Sid: ", sid, " Start listening") while True: @@ -261,11 +272,12 @@ def send_pcm(sid): connected_users[sid][1].stop_tts = True break time.sleep(0.01) - e = connected_users[sid][1].pcm_fifo_queue.get(chunk_szie) + e = connected_users[sid][1].pcm_fifo_queue.get(chunk_size) if e is None: continue print("Sid: ", sid, " Received PCM data: ", len(e)) res = connected_users[sid][1].wakeup_and_vad.predict(np.float32(e)) + print(f"wakeup_and_vad.predict -> res status: {res['status']}") if res['status'] == 'sl': print("Sid: ", sid, " Vad start") @@ -410,10 +422,30 @@ def handle_audio(data): else: disconnect() +def ngrok_proxy(port): + """ + run `ngrok config add-authtoken $NGROK_TOKEN` + """ + from pyngrok import ngrok + import nest_asyncio + + ngrok_tunnel = ngrok.connect(port) + print('Public URL:', ngrok_tunnel.public_url) + nest_asyncio.apply() + + if __name__ == "__main__": print("Start Freeze-Omni sever") - cert_file = "web/resources/cert.pem" - key_file = "web/resources/key.pem" - if not os.path.exists(cert_file) or not os.path.exists(key_file): - generate_self_signed_cert(cert_file, key_file) - socketio.run(app, host=configs.ip, port=configs.port, ssl_context=(cert_file, key_file)) + if configs.ssl: + cert_file = "web/resources/cert.pem" + key_file = "web/resources/key.pem" + if not os.path.exists(cert_file) or not os.path.exists(key_file): + generate_self_signed_cert(cert_file, key_file) + + if configs.ngrok and not configs.ssl: + ngrok_proxy(configs.port) + + if configs.ssl: + socketio.run(app, host=configs.ip, port=configs.port, ssl_context=(cert_file, key_file)) + else: + socketio.run(app, host=configs.ip, port=configs.port) diff --git a/models/audioLLM.py b/models/audioLLM.py index 40a4f25..a226315 100644 --- a/models/audioLLM.py +++ b/models/audioLLM.py @@ -18,6 +18,13 @@ IGNORE_ID = -1 class AudioLLM(torch.nn.Module): + """ + Modeling of speech input + 为了使 Freeze-Omni 能够支持语音输入并实现对输入语音的快速、低延迟响应,它利用块式流式语音编码器将输入语音特征转换为高维表示。 + 然后,适配器模块将高维表示映射到主干LLM的嵌入空间中。 + 这里的语音编码器模块由几个下采样卷积层和几个 Transformer 块组成,而适配器仅包含几个下采样卷积层。 + 使用下采样的原因是为了降低语音特征的帧率,提高预填充阶段LLM的速度,降低延迟。 + """ def __init__( self, encoder: torch.nn.Module, @@ -208,6 +215,8 @@ def __init__( "hyps": 7, "/hyps": 8, } + num_params = sum(p.numel() for p in self.parameters()) + print('the number of audio llm params: {}M'.format(num_params/1024/1024)) def set_system_role( self, @@ -250,6 +259,9 @@ def recognize( speech_lengths: torch.Tensor, extra_inputs: Optional[dict] = None, ): + """ + speech encoder(down sample CNN+Transformer) -> adapter(down sample CNN) -> text llm(decoder_only Transformer) + """ assert extra_inputs.get('past_key_values', None) is not None, "must set system role first!!!" buffer = extra_inputs.get('encoder_cache', None) @@ -391,6 +403,8 @@ def _generate_one_step( top_p: float = 1.0, top_k: int = 0, temperature: float = 1.0, + el_prob: float = 0.5, + ss_prob: float = 0.5, ): """ Generates the model's next output based on the current input and state. @@ -401,6 +415,8 @@ def _generate_one_step( - top_p: The threshold for controlling top-p sampling. - top_k: The threshold for controlling top-k sampling. - temperature: Controls the randomness of sampling. + - el_prob: end listen stat logit prob + - ss_prob: start speak stat logit prob Returns: - last_id: The index of the last generated token. @@ -417,9 +433,9 @@ def _generate_one_step( state_1 = state_prob[1] state_2 = state_prob[2] print("State 1 prob: {:.4f}, State 2 prob: {:.4f}".format(state_1.item(), state_2.item())) - if state_2 > 0.5: + if state_2 > el_prob: return None, outputs['past_key_values'], 'el', None - if state_1 > 0.5: + if state_1 > ss_prob: return None, outputs['past_key_values'], 'ss', None return None, outputs['past_key_values'], 'cl', None diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index 2f008eb..c2a19e4 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -15,10 +15,19 @@ from models.decoder.ticodec.vqvae_tester import VqvaeTester class llm2TTS(): + """ + Modeling of speech output + 受 VALL-E [5] 的启发,Freeze-Omni 使用基于令牌的语音解码器,其中包含 NAR 预填充和 AR 生成阶段来实现语音输出功能。 + 语音解码器主要由NAR解码器、AR解码器和编解码器模型的解码器组成。 NAR 解码器和 AR 解码器都是基于transformer块构建的。 + NAR解码器用于根据LLM的输出对语义特征进行建模,然后AR解码器基于NAR解码器的输出生成语音token。 + 最后,编解码器模型的解码器将语音token转换为语音流。 + """ def __init__(self, model_path): self.model = self.get_model(model_path).cuda().to( torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 ) + num_params = sum(p.numel() for p in self.model.parameters()) + print('the number of LLM2TTSCodecAR(NAR decoder and AR decoder(llama transformer blocks)) params: {}M'.format(num_params/1024/1024)) self.infer = self.model.infer self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", @@ -27,6 +36,8 @@ def __init__(self, model_path): self.codec_model = self.codec_model.cuda() self.codec_model.vqvae.generator.remove_weight_norm() self.codec_model.vqvae.encoder.remove_weight_norm() + num_params = sum(p.numel() for p in self.codec_model.parameters()) + print('after remove_weight_norm, the number of llm2TTS(vq-vae codec decoder model) params: {}M'.format(num_params/1024/1024)) self.codec_model.eval() def get_model_conf(self, model_path): @@ -112,7 +123,7 @@ def find_min_sum_index(self, buffer, syn, N, threshold): return buffer, syn def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, - penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01): + penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01, max_tokens=1000): """ Run the speech decoder process. @@ -122,7 +133,7 @@ def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, - prefix (str, optional): The hidden state from the language model. - codec_chunk_size (int, default=40): The size of each chunk to process in the codec model. - codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk. - - penalty_window_size (int, default=20): The window size for applying penalties during decoding. + - penalty_window_size (int, default=-1): The window size for applying penalties during decoding. - penalty (float, default=1.1): The penalty factor. Yields: @@ -138,7 +149,8 @@ def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32): print("Starting TTS...") token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device) - for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty): + for next_token_id in self.infer( + hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=max_tokens): token = torch.cat([token, next_token_id], dim=-1) if token.size(1) == left_padding + codec_chunk_size + right_padding: syn = self.codec_model.vqvae(token.unsqueeze(-1), diff --git a/models/decoder/ticodec/vqvae.py b/models/decoder/ticodec/vqvae.py index 5d352c6..46f88f2 100644 --- a/models/decoder/ticodec/vqvae.py +++ b/models/decoder/ticodec/vqvae.py @@ -33,6 +33,9 @@ def __init__(self, if with_encoder: self.encoder = Encoder(self.h) self.encoder.load_state_dict(ckpt['encoder']) + + num_params = sum(p.numel() for p in self.parameters()) + print('the number of vq-vae(llm2tts) params: {}M'.format(num_params/1024/1024)) def forward(self, x, global_style_token): # x is the codebook diff --git a/models/encoder/cmvn.py b/models/encoder/cmvn.py index 2dcd026..b929910 100644 --- a/models/encoder/cmvn.py +++ b/models/encoder/cmvn.py @@ -4,6 +4,7 @@ import numpy as np +# https://en.wikipedia.org/wiki/Cepstral_mean_and_variance_normalization class GlobalCMVN(torch.nn.Module): def __init__(self, mean: torch.Tensor, diff --git a/models/pipeline.py b/models/pipeline.py index 4f0fad3..f17d17b 100644 --- a/models/pipeline.py +++ b/models/pipeline.py @@ -23,7 +23,7 @@ def __init__(self, args): self.model.eval() def speech_dialogue(self, - audio: tuple, + audio: torch.Tensor, role: str=None, stat: str='sl', past_key_values=None, diff --git a/models/utils.py b/models/utils.py index 2b86361..6e11596 100644 --- a/models/utils.py +++ b/models/utils.py @@ -1,9 +1,11 @@ -import torch import re import os -from models.audioLLM import AudioLLM +import yaml +import torch +import numpy as np +from models.audioLLM import AudioLLM from models.encoder.cmvn import GlobalCMVN, load_cmvn from models.encoder.encoder import speechEncoder @@ -27,6 +29,9 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: return configs def init_encoder_llm(configs): + """ + init Modeling of speech input (encoder and audio llm) + """ if configs['cmvn_file'] is not None: # read cmvn mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) @@ -40,9 +45,24 @@ def init_encoder_llm(configs): input_dim = configs['input_dim'] vocab_size = configs['output_dim'] - # init speech encoder + # init speech encoder (几个下采样卷积层和几个 Transformer) + # 块式流式语音编码器将输入语音特征转换为高维表示 encoder = speechEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) - # init audioLLM + # init audioLLM + # 默认: 适配器:CNN 仅包含几个下采样卷积层, + # 适配器模块将高维表示映射到主干LLM的嵌入空间中 + # 使用下采样的原因是为了降低语音特征的帧率,提高预填充阶段LLM的速度,降低延迟 model = AudioLLM(encoder=encoder, **configs['model_conf']) return model + +def print_outputs(outputs: dict): + print_str = "" + for key, item in outputs.items(): + if isinstance(item, (torch.Tensor, np.ndarray)): + print_str += f"{key} shape:{item.shape} " + if isinstance(item, (str, int)): + print_str += f"{key}:{item} " + if isinstance(item, (list, tuple)): + print_str += f"{key} len:{len(item)} " + return print_str \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f89e796..eb4de14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ soundfile==0.12.1 torch==2.2.0 torchaudio==2.2.0 transformers==4.45.2 -PyYAML==6.0.2 \ No newline at end of file +PyYAML==6.0.2 +pyngrok==7.2.1 \ No newline at end of file diff --git a/web/queue.py b/web/queue.py index 5b85f80..c84c3b6 100644 --- a/web/queue.py +++ b/web/queue.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import json import torch import threading diff --git a/web/vad.py b/web/vad.py index 36bc38d..a869ac5 100644 --- a/web/vad.py +++ b/web/vad.py @@ -12,7 +12,17 @@ from silero_vad.utils_vad import VADIterator class VAD: - def __init__(self, cache_history=10): + """ + 首先使用声学 VAD 模块来检测流式演讲的起点。 + 当VAD被触发时,语音流将被逐块发送到Freeze-Omni,并在LLM最后一层之后添加一个额外的分类层来预测不同的状态。 + 这里定义了三种状态: + - 状态0表示当前LLM可以继续接收语音, + - 状态1或2表示当前块是语音结束。 + - 状态1表示用户将中断对话,LLM将执行新的生成阶段, + - 状态2表示无需中断对话。 + 这两种状态都将停止向 Freeze-Omni 发送语音流并重置 VAD 模块。 + """ + def __init__(self, cache_history=10, last_chunk_size=6): self.chunk_size = 16 self.chunk_overlap = 3 self.feat_dim = 80 @@ -20,7 +30,9 @@ def __init__(self, cache_history=10): self.frame_shift = 160 self.frame_overlap = self.frame_size - self.frame_shift self.CHUNK = self.frame_shift * self.chunk_size + assert cache_history >= last_chunk_size, "cache_history must >= last_chunk_size" self.cache_history = cache_history + self.last_chunk_size = last_chunk_size self.in_dialog = False with torch.no_grad(): @@ -45,6 +57,7 @@ def reset_vad(self): # reset all parms self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim]) self.input_sample = torch.zeros([1, self.CHUNK + self.frame_overlap , 1]) + # chunck feat history cache also use ring buffer to do :) self.history = torch.zeros([self.cache_history, self.chunk_size + self.chunk_overlap, self.feat_dim]) self.vad_iterator.reset_states() self.in_dialog = False @@ -59,7 +72,7 @@ def run_vad_iterator(self, audio): return speech_dict_out def predict(self, - audio: torch.Tensor): + audio: np.ndarray): """ Predict the Voice Activity Detection (VAD) status and return related features. @@ -113,18 +126,20 @@ def predict(self, # self.vad_iterator.reset_states() else: # cache fbank feature + # << 1 self.history[:-1] = self.history[1:].clone() + # last history = input chunk self.history[-1:] = self.input_chunk # return dict if return_dict['status'] == 'sl': - # copy last 6 chunks - return_dict['feature_last_chunk'] = self.history[-6:].unsqueeze(1).numpy().tolist() - return_dict['feature'] = self.input_chunk.numpy().tolist() - return_dict['history_feature'] = self.history.numpy().tolist() + # copy last chunk size chunks + return_dict['feature_last_chunk'] = self.history[-self.last_chunk_size:].unsqueeze(1) + return_dict['feature'] = self.input_chunk + return_dict['history_feature'] = self.history elif return_dict['status'] == 'cl' or return_dict['status'] == 'el': return_dict['feature_last_chunk'] = None - return_dict['feature'] = self.input_chunk.numpy().tolist() - return_dict['history_feature'] = self.history.numpy().tolist() + return_dict['feature'] = self.input_chunk + return_dict['history_feature'] = self.history return return_dict