From 497f168190517da71b900a1c58d2657710b6126a Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 15 Dec 2024 18:19:10 +0800 Subject: [PATCH 01/15] feat: add ngrok proxy Signed-off-by: weedge --- bin/inference.py | 2 +- bin/server.py | 32 +++++++++++++++++++++++++++----- models/utils.py | 2 ++ requirements.txt | 3 ++- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index f0d9046..f6f695c 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -91,7 +91,7 @@ def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_p codec_chunk_size, codec_padding_size): wav.append(seg) -def inference(pipeline, audio_processor, tts, configs): +def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, tts:llm2TTS, configs): """ Perform inference for a speech dialogue system. diff --git a/bin/server.py b/bin/server.py index 533b5fd..6e68a75 100644 --- a/bin/server.py +++ b/bin/server.py @@ -36,6 +36,8 @@ def get_args(): parser.add_argument('--max_users', type=int, default=5) parser.add_argument('--llm_exec_nums', type=int, default=1) parser.add_argument('--timeout', type=int, default=600) + parser.add_argument("--ngrok", action='store_true', help="use ngrok proxy") + parser.add_argument("--ssl", action='store_true', help="use ssl") args = parser.parse_args() print(args) return args @@ -410,10 +412,30 @@ def handle_audio(data): else: disconnect() +def ngrok_proxy(port): + """ + run `ngrok config add-authtoken $NGROK_TOKEN` + """ + from pyngrok import ngrok + import nest_asyncio + + ngrok_tunnel = ngrok.connect(port) + print('Public URL:', ngrok_tunnel.public_url) + nest_asyncio.apply() + + if __name__ == "__main__": print("Start Freeze-Omni sever") - cert_file = "web/resources/cert.pem" - key_file = "web/resources/key.pem" - if not os.path.exists(cert_file) or not os.path.exists(key_file): - generate_self_signed_cert(cert_file, key_file) - socketio.run(app, host=configs.ip, port=configs.port, ssl_context=(cert_file, key_file)) + if configs.ssl: + cert_file = "web/resources/cert.pem" + key_file = "web/resources/key.pem" + if not os.path.exists(cert_file) or not os.path.exists(key_file): + generate_self_signed_cert(cert_file, key_file) + + if configs.ngrok and not configs.ssl: + ngrok_proxy(configs.port) + + if configs.ssl: + socketio.run(app, host=configs.ip, port=configs.port, ssl_context=(cert_file, key_file)) + else: + socketio.run(app, host=configs.ip, port=configs.port) diff --git a/models/utils.py b/models/utils.py index 2b86361..0eefa62 100644 --- a/models/utils.py +++ b/models/utils.py @@ -2,6 +2,8 @@ import re import os +import yaml + from models.audioLLM import AudioLLM from models.encoder.cmvn import GlobalCMVN, load_cmvn diff --git a/requirements.txt b/requirements.txt index f89e796..eb4de14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ soundfile==0.12.1 torch==2.2.0 torchaudio==2.2.0 transformers==4.45.2 -PyYAML==6.0.2 \ No newline at end of file +PyYAML==6.0.2 +pyngrok==7.2.1 \ No newline at end of file From 583e0a50f81acd2f026e7b475a893d9cb50605cb Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 15 Dec 2024 21:20:28 +0800 Subject: [PATCH 02/15] feat: add annotate Signed-off-by: weedge --- bin/inference.py | 7 +++++++ bin/server.py | 7 +++++++ models/audioLLM.py | 9 +++++++++ models/decoder/llm2tts.py | 11 +++++++++++ models/decoder/ticodec/vqvae.py | 3 +++ models/utils.py | 11 +++++++++-- web/vad.py | 10 ++++++++++ 7 files changed, 56 insertions(+), 2 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index f6f695c..58fecb7 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -95,6 +95,10 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, """ Perform inference for a speech dialogue system. + 流式语音输入通过语音编码器形成chunk-wise特征,然后通过适配器连接到LLM。 + LLM生成隐藏状态和文本标记,在块分割后分别以块的形式发送到NAR前缀语音解码器和NAR语音解码器。 + 最后,AR语音解码器将生成的令牌发送到语音令牌FIFO中,流式编解码器根据固定的语音令牌块大小从FIFO生成流式语音输出。 + Parameters: - pipeline: Speech dialogue pipeline. - audio_processor: Processes raw audio data into a format suitable for the pipeline. @@ -184,7 +188,10 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, if __name__ == '__main__': configs = get_args() + # encoder and audio llm pipeline = inferencePipeline(configs) + # decoder tts = llm2TTS(configs.model_path) + # stream chunk to encoder audio_processor = audioEncoderProcessor() inference(pipeline, audio_processor, tts, configs) diff --git a/bin/server.py b/bin/server.py index 6e68a75..058b8a1 100644 --- a/bin/server.py +++ b/bin/server.py @@ -46,6 +46,13 @@ def custom_print(*args, **kwargs): current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] original_print(f'[{current_time}]', *args, **kwargs) +""" +“模型作为服务器”策略来实现语音到语音对话系统。 +首先,我们同时启动多个模型,并将它们视为服务器。 +然后,当用户的VAD被触发时,语音将以块的形式发送到服务器,服务器将负责调度哪个空闲模型应该响应当前的块。 +由于我们在推理过程中将语音编码器和LLM的所有kv-cache和CNN缓存分开,因此服务器只需要保存每个用户的推理缓存。 +这样,服务器中的任何模型都可以响应任何用户的任何块,并且不需要指定哪个模型用作监视器或生成器。 +""" # init parms configs = get_args() # read server related config diff --git a/models/audioLLM.py b/models/audioLLM.py index 40a4f25..8fc859c 100644 --- a/models/audioLLM.py +++ b/models/audioLLM.py @@ -18,6 +18,13 @@ IGNORE_ID = -1 class AudioLLM(torch.nn.Module): + """ + Modeling of speech input + 为了使 Freeze-Omni 能够支持语音输入并实现对输入语音的快速、低延迟响应,它利用块式流式语音编码器将输入语音特征转换为高维表示。 + 然后,适配器模块将高维表示映射到主干LLM的嵌入空间中。 + 这里的语音编码器模块由几个下采样卷积层和几个 Transformer 块组成,而适配器仅包含几个下采样卷积层。 + 使用下采样的原因是为了降低语音特征的帧率,提高预填充阶段LLM的速度,降低延迟。 + """ def __init__( self, encoder: torch.nn.Module, @@ -208,6 +215,8 @@ def __init__( "hyps": 7, "/hyps": 8, } + num_params = sum(p.numel() for p in self.parameters()) + print('the number of audio llm params: {}M'.format(num_params/1024/1024)) def set_system_role( self, diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index 2f008eb..7c96997 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -15,10 +15,19 @@ from models.decoder.ticodec.vqvae_tester import VqvaeTester class llm2TTS(): + """ + Modeling of speech output + 受 VALL-E [5] 的启发,Freeze-Omni 使用基于令牌的语音解码器,其中包含 NAR 预填充和 AR 生成阶段来实现语音输出功能。 + 语音解码器主要由NAR解码器、AR解码器和编解码器模型的解码器组成。 NAR 解码器和 AR 解码器都是基于变压器块构建的。 + NAR解码器用于根据LLM的输出对语义特征进行建模,然后AR解码器基于NAR解码器的输出生成语音标记。 + 最后,编解码器模型的解码器将语音标记转换为语音流。 + """ def __init__(self, model_path): self.model = self.get_model(model_path).cuda().to( torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 ) + num_params = sum(p.numel() for p in self.model.parameters()) + print('the number of LLM2TTSCodecAR(NAR decoder and AR decoder(transformer blocks)) params: {}M'.format(num_params/1024/1024)) self.infer = self.model.infer self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", @@ -27,6 +36,8 @@ def __init__(self, model_path): self.codec_model = self.codec_model.cuda() self.codec_model.vqvae.generator.remove_weight_norm() self.codec_model.vqvae.encoder.remove_weight_norm() + num_params = sum(p.numel() for p in self.codec_model.parameters()) + print('after remove_weight_norm, the number of llm2TTS(vq-vae codec decoder model) params: {}M'.format(num_params/1024/1024)) self.codec_model.eval() def get_model_conf(self, model_path): diff --git a/models/decoder/ticodec/vqvae.py b/models/decoder/ticodec/vqvae.py index 5d352c6..f9d2107 100644 --- a/models/decoder/ticodec/vqvae.py +++ b/models/decoder/ticodec/vqvae.py @@ -33,6 +33,9 @@ def __init__(self, if with_encoder: self.encoder = Encoder(self.h) self.encoder.load_state_dict(ckpt['encoder']) + + num_params = sum(p.numel() for p in self.parameters()) + print('the number of vq-vqe(llm2tts) params: {}M'.format(num_params/1024/1024)) def forward(self, x, global_style_token): # x is the codebook diff --git a/models/utils.py b/models/utils.py index 0eefa62..28bd6aa 100644 --- a/models/utils.py +++ b/models/utils.py @@ -29,6 +29,9 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: return configs def init_encoder_llm(configs): + """ + init Modeling of speech input (encoder and audio llm) + """ if configs['cmvn_file'] is not None: # read cmvn mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) @@ -42,9 +45,13 @@ def init_encoder_llm(configs): input_dim = configs['input_dim'] vocab_size = configs['output_dim'] - # init speech encoder + # init speech encoder (几个下采样卷积层和几个 Transformer) + # 块式流式语音编码器将输入语音特征转换为高维表示 encoder = speechEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) - # init audioLLM + # init audioLLM + # 默认: 适配器:CNN 仅包含几个下采样卷积层, + # 适配器模块将高维表示映射到主干LLM的嵌入空间中 + # 使用下采样的原因是为了降低语音特征的帧率,提高预填充阶段LLM的速度,降低延迟 model = AudioLLM(encoder=encoder, **configs['model_conf']) return model diff --git a/web/vad.py b/web/vad.py index 36bc38d..867b463 100644 --- a/web/vad.py +++ b/web/vad.py @@ -12,6 +12,16 @@ from silero_vad.utils_vad import VADIterator class VAD: + """ + 首先使用声学 VAD 模块来检测流式演讲的起点。 + 当VAD被触发时,语音流将被逐块发送到Freeze-Omni,并在LLM最后一层之后添加一个额外的分类层来预测不同的状态。 + 这里定义了三种状态: + - 状态0表示当前LLM可以继续接收语音, + - 状态1或2表示当前块是语音结束。 + - 状态1表示用户将中断对话,LLM将执行新的生成阶段, + - 状态2表示无需中断对话。 + 这两种状态都将停止向 Freeze-Omni 发送语音流并重置 VAD 模块。 + """ def __init__(self, cache_history=10): self.chunk_size = 16 self.chunk_overlap = 3 From 08acb1e0ccd58703fa66f1436b060a81d589cc2a Mon Sep 17 00:00:00 2001 From: weedge Date: Sun, 15 Dec 2024 21:22:51 +0800 Subject: [PATCH 03/15] feat: add annotate Signed-off-by: weedge --- models/decoder/llm2tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index 7c96997..24a9c93 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -27,7 +27,7 @@ def __init__(self, model_path): torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 ) num_params = sum(p.numel() for p in self.model.parameters()) - print('the number of LLM2TTSCodecAR(NAR decoder and AR decoder(transformer blocks)) params: {}M'.format(num_params/1024/1024)) + print('the number of LLM2TTSCodecAR(NAR decoder and AR decoder(llama transformer blocks)) params: {}M'.format(num_params/1024/1024)) self.infer = self.model.infer self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", From 1c35ad9cfed26a1dbef61a9208dc273106f1489a Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 16 Dec 2024 12:34:45 +0800 Subject: [PATCH 04/15] fix Signed-off-by: weedge --- models/decoder/ticodec/vqvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/decoder/ticodec/vqvae.py b/models/decoder/ticodec/vqvae.py index f9d2107..46f88f2 100644 --- a/models/decoder/ticodec/vqvae.py +++ b/models/decoder/ticodec/vqvae.py @@ -35,7 +35,7 @@ def __init__(self, self.encoder.load_state_dict(ckpt['encoder']) num_params = sum(p.numel() for p in self.parameters()) - print('the number of vq-vqe(llm2tts) params: {}M'.format(num_params/1024/1024)) + print('the number of vq-vae(llm2tts) params: {}M'.format(num_params/1024/1024)) def forward(self, x, global_style_token): # x is the codebook From 17eba5204eaae7dabe9a85126f9f7a1506a48ba8 Mon Sep 17 00:00:00 2001 From: weedge Date: Mon, 16 Dec 2024 21:56:22 +0800 Subject: [PATCH 05/15] fix Signed-off-by: weedge --- models/decoder/llm2tts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index 24a9c93..2d4d085 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -18,9 +18,9 @@ class llm2TTS(): """ Modeling of speech output 受 VALL-E [5] 的启发,Freeze-Omni 使用基于令牌的语音解码器,其中包含 NAR 预填充和 AR 生成阶段来实现语音输出功能。 - 语音解码器主要由NAR解码器、AR解码器和编解码器模型的解码器组成。 NAR 解码器和 AR 解码器都是基于变压器块构建的。 - NAR解码器用于根据LLM的输出对语义特征进行建模,然后AR解码器基于NAR解码器的输出生成语音标记。 - 最后,编解码器模型的解码器将语音标记转换为语音流。 + 语音解码器主要由NAR解码器、AR解码器和编解码器模型的解码器组成。 NAR 解码器和 AR 解码器都是基于transformer块构建的。 + NAR解码器用于根据LLM的输出对语义特征进行建模,然后AR解码器基于NAR解码器的输出生成语音token。 + 最后,编解码器模型的解码器将语音token转换为语音流。 """ def __init__(self, model_path): self.model = self.get_model(model_path).cuda().to( From cb62b9ba1c470777e5d8c916a573e5b96f9176d5 Mon Sep 17 00:00:00 2001 From: weedge Date: Tue, 17 Dec 2024 12:51:48 +0800 Subject: [PATCH 06/15] add max_tokens for llm2TTS run method Signed-off-by: weedge --- bin/inference.py | 14 +++++++++----- models/decoder/llm2tts.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index 58fecb7..1a3850f 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -69,7 +69,12 @@ def chunk_data_shift(self, xs): self.input_chunk[:, self.chunk_overlap:, :] = xs.squeeze(0) def process(self, - audio: torch.Tensor): + audio: np.ndarray): + """ + # 1. Converts the input audio tensor to the appropriate format. + # 2. Computes the filter bank features (fbank) for the audio. + # 3. Updates the input chunk and history based on the new audio segment. + """ with torch.no_grad(): sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768 self.fbank_shift(sample_data) @@ -80,6 +85,9 @@ def process(self, return self.input_chunk.clone() def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav): + """ + Decodes the current hidden state and text to generate audio segments using speech decoder. + """ hidden_state_output = torch.cat(cur_hidden_state).squeeze(1) cur_text_procced = pipeline.post_process(cur_text) print("Synthesis: ", [cur_text_procced]) @@ -95,10 +103,6 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, """ Perform inference for a speech dialogue system. - 流式语音输入通过语音编码器形成chunk-wise特征,然后通过适配器连接到LLM。 - LLM生成隐藏状态和文本标记,在块分割后分别以块的形式发送到NAR前缀语音解码器和NAR语音解码器。 - 最后,AR语音解码器将生成的令牌发送到语音令牌FIFO中,流式编解码器根据固定的语音令牌块大小从FIFO生成流式语音输出。 - Parameters: - pipeline: Speech dialogue pipeline. - audio_processor: Processes raw audio data into a format suitable for the pipeline. diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index 2d4d085..c0a74ee 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -123,7 +123,7 @@ def find_min_sum_index(self, buffer, syn, N, threshold): return buffer, syn def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, - penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01): + penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01, max_tokens=1000): """ Run the speech decoder process. @@ -149,7 +149,8 @@ def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32): print("Starting TTS...") token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device) - for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty): + for next_token_id in self.infer( + hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=max_tokens): token = torch.cat([token, next_token_id], dim=-1) if token.size(1) == left_padding + codec_chunk_size + right_padding: syn = self.codec_model.vqvae(token.unsqueeze(-1), From 798e27aa5897a592146147adb5b5b65a3492487a Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 18 Dec 2024 09:05:09 +0800 Subject: [PATCH 07/15] feat: add stream inference Signed-off-by: weedge --- bin/inference_stream.py | 509 ++++++++++++++++++++++++++++++++++++++ bin/server.py | 4 +- models/audioLLM.py | 3 + models/decoder/llm2tts.py | 2 +- models/encoder/cmvn.py | 1 + web/vad.py | 5 +- 6 files changed, 520 insertions(+), 4 deletions(-) create mode 100644 bin/inference_stream.py diff --git a/bin/inference_stream.py b/bin/inference_stream.py new file mode 100644 index 0000000..9af30fc --- /dev/null +++ b/bin/inference_stream.py @@ -0,0 +1,509 @@ +from __future__ import print_function + +import time +import math +import argparse +import threading +from copy import deepcopy +from dataclasses import dataclass +from typing import Generator + +import soundfile as sf +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as k + +from web.queue import PCMQueue, ThreadSafeQueue +from models.pipeline import inferencePipeline +from models.decoder.llm2tts import llm2TTS + + +def get_args(): + parser = argparse.ArgumentParser(description="Freeze-Omni-Inference-stream") + parser.add_argument("--model_path", required=True, help="model_path to load") + parser.add_argument("--llm_path", required=True, help="llm_path to load") + parser.add_argument("--top_k", type=int, default=5) + parser.add_argument("--top_p", type=float, default=0.8) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--input_wav", required=True, help="input wav") + parser.add_argument("--output_wav", required=True, help="output wav") + + args = parser.parse_args() + print(args) + return args + + +class audioEncoderProcessor: + def __init__(self, chunk_size=16): + self.chunk_size = 16 + self.chunk_overlap = 3 + self.feat_dim = 80 + self.frame_size = 400 + self.frame_shift = 160 + self.frame_overlap = self.frame_size - self.frame_shift + self.CHUNK = self.frame_shift * self.chunk_size + self.reset() + + def get_chunk_size(self): + return self.CHUNK + + def reset(self): + self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim]) + self.input_sample = torch.zeros([1, self.CHUNK + self.frame_overlap, 1]) + + def fbank_shift(self, sample_data): + # fbank feature shift + self.input_sample[:, : self.frame_overlap, :] = self.input_sample[ + :, -self.frame_overlap :, : + ].clone() + self.input_sample[:, self.frame_overlap :, :] = sample_data + + def chunk_data_shift(self, xs): + # chunk feature shift + self.input_chunk[:, : self.chunk_overlap, :] = self.input_chunk[ + :, -self.chunk_overlap :, : + ].clone() + self.input_chunk[:, self.chunk_overlap :, :] = xs.squeeze(0) + + def process(self, audio: np.ndarray) -> torch.Tensor: + """ + # 1. Converts the input audio tensor to the appropriate format. + # 2. Computes the filter bank features (fbank) for the audio. + # 3. Updates the input chunk and history based on the new audio segment. + """ + with torch.no_grad(): + sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768 + self.fbank_shift(sample_data) + # use kaldi api to compute fbank + xs = k.fbank( + waveform=self.input_sample.squeeze(-1), + dither=0, + frame_length=25, + frame_shift=10, + num_mel_bins=self.feat_dim, + ) + self.chunk_data_shift(xs) + return self.input_chunk.clone() + + +class PCMListener: + def __init__( + self, + outputs_queue: ThreadSafeQueue, # outputs -> out queue + pipeline: inferencePipeline, + cache_history_size: int = 10, + system_prompt: str = "You are a helpful assistant.", + ) -> None: + self.pcm_fifo_queue = PCMQueue() # pcm data -> in queue + self.listen_thread = threading.Thread(target=self.listen, args=()) + self.stop_listen = False + self.listen_thread.start() + + self.outputs_queue = outputs_queue + + # stream chunk to encoder + self.audio_processor = audioEncoderProcessor() + # encoder and audio llm + self.pipeline = pipeline + + # pre status system prompt + self.status = "pre" + self.init_outputs = pipeline.speech_dialogue(None, stat="pre", role=system_prompt) + + # chunck feat history cache also use ring buffer to do :) + self.history = torch.zeros( + [cache_history_size, self.chunk_size + self.chunk_overlap, self.feat_dim] + ) + + def stop(self): + self.stop_listen = True + self.listen_thread.join(timeout=3) + + def set_status(self, status: str): + if status not in ["sl", "cl", "el"]: + raise ValueError("status must be one of: 'sl', 'cl', 'el'") + self.status = status + + def send(self, pcm_items: np.ndarray): + """ + send float32(<1) numpy ndarray to fifo queue buffer + """ + self.pcm_fifo_queue.put(pcm_items.astype(np.float32) / 32768.0) + + def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: + # cache fbank feature (input_chunk) + # << 1 + self.history[:-1] = self.history[1:].clone() + # last history = input chunk + self.history[-1:] = input_chunk + # start listen to # copy last 6 chunks + feature_last_chunk = self.history[-6:].unsqueeze(1) + return feature_last_chunk + + def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_first_pack=False): + """ + Prefills the LLM of speech dialogue system using speech. + + Parameters: + - status: the current state of the audio input. + - feature: audio feature tensor. + - outputs (dict): A dictionary containing the current state of the dialogue system. + - is_first_pack (bool, optional): Indicates if the current input packet is the first one in a new conversation + """ + + if status == "sl": + # Satge1: start listen + # stat will be auto set to 'cl' after Stage1 + outputs = self.pipeline.speech_dialogue(feature, **outputs) + if status == "el": + print("status end listen. Detect vad time out") + + if status == "cl": + if outputs["stat"] == "cl": + # Stage2: continue listen + # stat will be auto set to 'ss' when endpoint is detected + outputs = self.pipeline.speech_dialogue(feature, **outputs) + if is_first_pack: + outputs["stat"] = "cl" + if outputs["stat"] == "el": + print("output stat end listen. Detect invalid break") + if outputs["stat"] == "ss": + print("output stat start speak. send outputs to queue") + self.outputs_queue.put(outputs) + return outputs + + def listen(self): + """ + - status: sl(start listen) -> cl(continue listen) -> ss(start generate speech to speak) + - status from VAD + """ + print("Start listening") + while True: + if self.stop_listen: + print("Stop listening") + break + e = self.pcm_fifo_queue.get(self.audio_processor.chunk_size) + if e is None: + time.sleep(0.01) + continue + print("Received PCM data: ", len(e)) + + fbank_feature = self.audio_processor.process(np.float32(e)) + if self.status == "sl": + feature_last_chunk = self.history_buffering_strategy(fbank_feature) + outputs = deepcopy(self.init_outputs) + outputs["adapter_cache"] = None + outputs["encoder_cache"] = None + outputs["pe_index"] = 0 + outputs["stat"] = "sl" + outputs["last_id"] = None + if "text" in outputs: + del outputs["text"] + if "hidden_state" in outputs: + del outputs["hidden_state"] + + for i in range(feature_last_chunk): + if i == 0: + outputs = self.llm_prefill( + "sl", feature_last_chunk[i], outputs, is_first_pack=True + ) + else: + outputs = self.llm_prefill( + "cl", feature_last_chunk[i], outputs, is_first_pack=True + ) + outputs = self.llm_prefill("cl", fbank_feature, outputs) + + elif self.status == "cl" or self.status == "el": + outputs = self.llm_prefill(self.status, fbank_feature, outputs) + + +@dataclass +class TTSSpeakerArgs: + # https://huggingface.co/VITA-MLLM/Freeze-Omni/blob/main/checkpoints/server.json + # decoder(LLM2TTSCodecAR) + # NAR llama transformer decoder pre_nn_forward -> NAR llama transformer decoder kv_cache_prefix_forward -> AR llama transformer decoder transformer_infer + + # llama transformer decoder + decoder_top_k: int = 2 + decoder_penalty_window_size: int = -1 # <0 no penalty window + decoder_penalty: float = 1.1 + + # codec decoder + decoder_first_chunk_size: int = 20 + decoder_chunk_size: int = 40 + decoder_chunk_overlap_size: int = 10 + + # find_min_sum_index + decoder_N: int = 2401 + decoder_seg_threshold_first_pack: float = 0.1 + decoder_seg_threshold: float = 0.015 + + +class TTSSpeaker: + def __init__( + self, + args: TTSSpeakerArgs, + outputs_in_queue: ThreadSafeQueue, # out queue -> speak + pipeline: inferencePipeline, + tts: llm2TTS, + ) -> None: + self.args = args + self.outputs_in_queue = outputs_in_queue + self.pipeline = pipeline + self.tts = tts + + self.reset() + + # speak thread + self.speak_thread = threading.Thread(target=self.speak, args=()) + self.speak_thread.start() + + def reset(self): + self.stop_speak = False + self.is_generate = False + self.generate_outputs = {} + self.whole_text = "" + + self.tts_over = False + self.tts_over_time = 0 + self.tts_data = ThreadSafeQueue() + + self.stop_tts = False + + def print(self): + print("stop_speak:", self.stop_speak) + print("is_generate:", self.is_generate) + print("whole_text:", self.whole_text) + print("tts_over:", self.tts_over) + print("tts_over_time:", self.tts_over_time) + + @property + def generate_outputs(self): + """Get the generate outputs.""" + return self.generate_outputs + + @property + def whole_text(self): + """Get the whole text.""" + return self.whole_text + + def stop(self): + """Stop the speak thread.""" + self.stop_speak = True + self.speak_thread.join() + + def decoder( + self, cur_hidden_state: list[torch.Tensor], cur_text: str, generate_num: int + ) -> int: + """ + Decodes the current hidden state and text to generate audio segments using speech decoder. + + Parameters: + - cur_hidden_state (list of torch.Tensor): The current hidden state of the language model. + - cur_text (str): The current text to be synthesized. + - generate_num (int): The number of audio segments generated + + Returns: + - int: The updated number of audio segments generated. + """ + hidden_state_output = torch.cat(cur_hidden_state).squeeze(1) + cur_text_procced = self.pipeline.post_process(cur_text) + print("Synthesis: ", [cur_text_procced]) + embeddings = self.pipeline.model.llm_decoder.model.embed_tokens( + torch.tensor(self.pipeline.model.tokenizer.encode(cur_text_procced)).cuda() + ) + codec_chunk_size = self.args.decoder_first_chunk_size + codec_padding_size = self.args.decoder_chunk_overlap_size + seg_threshold = self.args.decoder_seg_threshold_first_pack + if generate_num != 0: + codec_chunk_size = self.args.decoder_chunk_size + seg_threshold = self.args.decoder_seg_threshold + for seg in self.tts.run( + embeddings.reshape(-1, 896).unsqueeze(0), + self.args.decoder_top_k, + hidden_state_output.reshape(-1, 896).unsqueeze(0), + codec_chunk_size=codec_chunk_size, + codec_padding_size=codec_padding_size, + penalty_window_size=self.args.decoder_penalty_window_size, + penalty=self.args.decoder_penalty, + N=self.args.decoder_N, + seg_threshold=seg_threshold, + ): + if generate_num == 0: + try: + split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] + seg = seg[:, :, split_idx:] + except Exception: + print("Do not need to split") + pass + generate_num += 1 + if self.tts_over: + self.tts_data.clear() + self.whole_text = "" + break + self.tts_data.put(seg.squeeze().float().cpu().numpy() * 32768) + return generate_num + + def get_tts_data(self) -> Generator[np.ndarray, None, None]: + """ + get tts bytes data + """ + while True: + if self.stop_speak: + print("Stop speak so break get tts data") + break + output_data = self.tts_data.get() + if output_data is not None: + # print("Get TTS data") + #yield output_data.astype(np.int16).tobytes() + yield output_data + + yield None + + def speak(self): + """ + Generates speech dialogue output based on the current state + """ + while True: + if self.stop_speak: + print("Stop speak") + break + outputs = self.outputs_in_queue.get() + if outputs is None: + time.sleep(0.01) + continue + # Stage3: start speak + self.is_generate = True + outputs = self.pipeline.speech_dialogue(None, **outputs) + # outputs dict need change, so deepcopy + self.generate_outputs = deepcopy(outputs) + + cur_hidden_state = [] + cur_hidden_state.append(outputs["hidden_state"]) + + # Stage4: contiune speak until stat is set to 'sl' + # use 'stop' to interrupt generation, stat need to be manually set as 'sl' + stop = False + cur_text = "" + last_text = "" + generate_num = 0 + while True: + if self.stop_speak: + break + if len(outputs["past_tokens"]) > 500: + stop = True + if stop: + break + del outputs["text"] + del outputs["hidden_state"] + outputs = self.pipeline.speech_dialogue(None, **outputs) + self.generate_outputs = deepcopy(outputs) + if outputs["stat"] == "cs": + cur_hidden_state.append(outputs["hidden_state"]) + if "�" in outputs["text"][len(last_text) :]: + continue + self.whole_text += outputs["text"][len(last_text) :] + cur_text += outputs["text"][len(last_text) :] + # print(self.whole_text]) + if generate_num == 0 or (len(cur_hidden_state) >= 20): + suffix_list = [",", ",", "。", ":", "?", "!", ".", ":", "?", "!", "\n"] + else: + suffix_list = ["。", ":", "?", "!", ".", "?", "!", "\n"] + if outputs["text"][len(last_text) :].endswith(tuple(suffix_list)) and ( + len(cur_hidden_state) >= 4 + ): + if ( + outputs["text"][len(last_text) :].endswith(".") + and last_text[-1].isdigit() + ): + pass + else: + if not self.tts_over: + if len(cur_hidden_state) > 0: + generate_num = self.decoder( + cur_hidden_state, cur_text, generate_num + ) + cur_text = "" + cur_hidden_state = [] + last_text = outputs["text"] + else: + break + if not self.tts_over: + if len(cur_hidden_state) != 0: + generate_num = self.decoder(cur_hidden_state, cur_text, generate_num) + cur_text = "" + self.is_generate = False + outputs["stat"] = "sl" + outputs["last_id"] = None + print(self.whole_text) + + def interrupt(self): + self.stop_generate = True + self.tts_over = True + while True: + time.sleep(0.01) + if self.is_generate is False: + self.stop_generate = False + while True: + time.sleep(0.01) + if self.tts_data.is_empty(): + self.whole_text = "" + self.tts_over = False + self.tts_over_time += 1 + break + break + + +def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): + """ + Perform inference for a speech dialogue system. + - 流式语音输入通过语音编码器形成chunk-wise特征,然后通过适配器连接到LLM。 + - LLM生成隐藏状态和文本标记,在块分割后分别以块的形式发送到NAR前缀语音解码器和NAR语音解码器。 + - 最后,AR语音解码器将生成的令牌发送到语音令牌FIFO中,流式编解码器根据固定的语音令牌块大小从FIFO生成流式语音输出。 + + Parameters: + - listener: listen pcm data(chunk) to asr with pipeline(encoder and adpter), + - detail status: sl, cl, el, ss + - speaker: tts speaker with pipeline(audio llm) and decoder (NAR decoder AR decoder and codec decoder) + - detail status: ss, cs + - configs: Input args. (argparse) + + Returns: + - None + """ + wav, fs = sf.read(configs.input_wav) + wav = torch.tensor(wav) + if fs != 16000: + wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + fs = 16000 + # send numpy ndarray pcm data + listener.send(wav.numpy()) + + wavs = [] + # get tts speak data + for data in speaker.get_tts_data(): + wavs.append(torch.tensor(data)) + + sf.write(configs.output_wav, torch.cat(wav, -1).squeeze().float().cpu().numpy(), 24000) + + +if __name__ == "__main__": + configs = get_args() + + # encoder and audio llm + pipeline = inferencePipeline(configs) + # decoder + tts = llm2TTS(configs.model_path) + + # listen -> gen_queue -> speak + gen_queue = ThreadSafeQueue() + # listen + listener = PCMListener(gen_queue, pipeline) + # speak + speaker = TTSSpeaker(TTSSpeakerArgs(), gen_queue, pipeline, tts) + + inference_stream(listener, speaker, configs) + + +# format: ruff format bin/inference_stream.py diff --git a/bin/server.py b/bin/server.py index 058b8a1..6d3c643 100644 --- a/bin/server.py +++ b/bin/server.py @@ -260,7 +260,7 @@ def send_pcm(sid): - sid (str): The session ID of the user. """ - chunk_szie = connected_users[sid][1].wakeup_and_vad.get_chunk_size() + chunk_size = connected_users[sid][1].wakeup_and_vad.get_chunk_size() print("Sid: ", sid, " Start listening") while True: @@ -270,7 +270,7 @@ def send_pcm(sid): connected_users[sid][1].stop_tts = True break time.sleep(0.01) - e = connected_users[sid][1].pcm_fifo_queue.get(chunk_szie) + e = connected_users[sid][1].pcm_fifo_queue.get(chunk_size) if e is None: continue print("Sid: ", sid, " Received PCM data: ", len(e)) diff --git a/models/audioLLM.py b/models/audioLLM.py index 8fc859c..28d5421 100644 --- a/models/audioLLM.py +++ b/models/audioLLM.py @@ -259,6 +259,9 @@ def recognize( speech_lengths: torch.Tensor, extra_inputs: Optional[dict] = None, ): + """ + speech encoder(down sample CNN+Transformer) -> adapter(down sample CNN) -> text llm(decoder_only Transformer) + """ assert extra_inputs.get('past_key_values', None) is not None, "must set system role first!!!" buffer = extra_inputs.get('encoder_cache', None) diff --git a/models/decoder/llm2tts.py b/models/decoder/llm2tts.py index c0a74ee..c2a19e4 100644 --- a/models/decoder/llm2tts.py +++ b/models/decoder/llm2tts.py @@ -133,7 +133,7 @@ def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, - prefix (str, optional): The hidden state from the language model. - codec_chunk_size (int, default=40): The size of each chunk to process in the codec model. - codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk. - - penalty_window_size (int, default=20): The window size for applying penalties during decoding. + - penalty_window_size (int, default=-1): The window size for applying penalties during decoding. - penalty (float, default=1.1): The penalty factor. Yields: diff --git a/models/encoder/cmvn.py b/models/encoder/cmvn.py index 2dcd026..b929910 100644 --- a/models/encoder/cmvn.py +++ b/models/encoder/cmvn.py @@ -4,6 +4,7 @@ import numpy as np +# https://en.wikipedia.org/wiki/Cepstral_mean_and_variance_normalization class GlobalCMVN(torch.nn.Module): def __init__(self, mean: torch.Tensor, diff --git a/web/vad.py b/web/vad.py index 867b463..d42db10 100644 --- a/web/vad.py +++ b/web/vad.py @@ -55,6 +55,7 @@ def reset_vad(self): # reset all parms self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim]) self.input_sample = torch.zeros([1, self.CHUNK + self.frame_overlap , 1]) + # chunck feat history cache also use ring buffer to do :) self.history = torch.zeros([self.cache_history, self.chunk_size + self.chunk_overlap, self.feat_dim]) self.vad_iterator.reset_states() self.in_dialog = False @@ -69,7 +70,7 @@ def run_vad_iterator(self, audio): return speech_dict_out def predict(self, - audio: torch.Tensor): + audio: np.ndarray): """ Predict the Voice Activity Detection (VAD) status and return related features. @@ -123,7 +124,9 @@ def predict(self, # self.vad_iterator.reset_states() else: # cache fbank feature + # << 1 self.history[:-1] = self.history[1:].clone() + # last history = input chunk self.history[-1:] = self.input_chunk # return dict From b5484931b9e5e9fea43e99915e7e103819e04ab3 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 18 Dec 2024 13:28:51 +0800 Subject: [PATCH 08/15] feat: add PCMStatChunk for PCMListener Signed-off-by: weedge --- bin/inference.py | 10 +++ bin/inference_stream.py | 143 ++++++++++++++++++++++++++++------------ bin/server.py | 3 + models/audioLLM.py | 8 ++- web/queue.py | 1 + web/vad.py | 8 ++- 6 files changed, 127 insertions(+), 46 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index 1a3850f..6ebf163 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -40,6 +40,11 @@ def get_args(): print(args) return args +def custom_print(*args, **kwargs): + current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + original_print(f'[{current_time}]', *args, **kwargs) + + class audioEncoderProcessor: def __init__(self, chunk_size = 16): self.chunk_size = 16 @@ -135,6 +140,7 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, for i in range(0, wav_input.shape[0], chunk_size): fbank = audio_processor.process(wav_input[i:i+chunk_size]) outputs = pipeline.speech_dialogue(fbank, **outputs) + print(f"speech_dialogue outputs stat: {outputs['stat']}") outputs['stat'] = 'cl' audio_processor.reset() @@ -191,6 +197,10 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, print(whole_text) if __name__ == '__main__': + # change print function to add time stamp + original_print = builtins.print + builtins.print = custom_print + configs = get_args() # encoder and audio llm pipeline = inferencePipeline(configs) diff --git a/bin/inference_stream.py b/bin/inference_stream.py index 9af30fc..c893674 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -1,5 +1,7 @@ from __future__ import print_function +import builtins +import datetime import time import math import argparse @@ -19,6 +21,11 @@ from models.decoder.llm2tts import llm2TTS +def custom_print(*args, **kwargs): + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + original_print(f"[{current_time}]", *args, **kwargs) + + def get_args(): parser = argparse.ArgumentParser(description="Freeze-Omni-Inference-stream") parser.add_argument("--model_path", required=True, help="model_path to load") @@ -87,6 +94,29 @@ def process(self, audio: np.ndarray) -> torch.Tensor: return self.input_chunk.clone() +class GlobalVars: + speech_dialogue_outputs = {} + + @staticmethod + def deepcopy_outputs(): + return deepcopy(GlobalVars.deepcopy_outputs) + + +@dataclass +class PCMStatChunk: + size: int + status: str # sl(start listen), cl(continue listen), el(end listen) + data: np.ndarray # a numpy array of dtype np.float32 + + def __post_init__(self): + super().__post_init__() + if self.status not in ["sl", "cl", "el"]: + raise ValueError("status must be one of: 'sl','cl','el'") + + def __str__(self): + return f"size:{self.size} status:{self.status} numpy data:{self.data.shape} {super().__str__()}" + + class PCMListener: def __init__( self, @@ -95,10 +125,7 @@ def __init__( cache_history_size: int = 10, system_prompt: str = "You are a helpful assistant.", ) -> None: - self.pcm_fifo_queue = PCMQueue() # pcm data -> in queue - self.listen_thread = threading.Thread(target=self.listen, args=()) - self.stop_listen = False - self.listen_thread.start() + self.pcm_stat_chunk_queue = ThreadSafeQueue() # pcm data -> in queue self.outputs_queue = outputs_queue @@ -107,29 +134,52 @@ def __init__( # encoder and audio llm self.pipeline = pipeline - # pre status system prompt - self.status = "pre" - self.init_outputs = pipeline.speech_dialogue(None, stat="pre", role=system_prompt) + # pre status system prompt, outputs stat: pre -> sl + GlobalVars.speech_dialogue_outputs = pipeline.speech_dialogue( + None, stat="pre", role=system_prompt + ) # chunck feat history cache also use ring buffer to do :) self.history = torch.zeros( - [cache_history_size, self.chunk_size + self.chunk_overlap, self.feat_dim] + [ + cache_history_size, + self.audio_processor.chunk_size + self.audio_processor.chunk_overlap, + self.audio_processor.feat_dim, + ] ) + # start listen thread + self.listen_thread = threading.Thread(target=self.listen, args=()) + self.stop_listen = False + self.listen_thread.start() + def stop(self): self.stop_listen = True self.listen_thread.join(timeout=3) - def set_status(self, status: str): - if status not in ["sl", "cl", "el"]: - raise ValueError("status must be one of: 'sl', 'cl', 'el'") - self.status = status - - def send(self, pcm_items: np.ndarray): + def send(self, pcm_items: np.ndarray, status: str): """ - send float32(<1) numpy ndarray to fifo queue buffer + 将float32(<1) numpy数组按chunk_size大小切分并发送到FIFO队列缓冲区 + + Args: + pcm_items: 输入的PCM数据数组 """ - self.pcm_fifo_queue.put(pcm_items.astype(np.float32) / 32768.0) + # 获取音频处理器的块大小 + chunk_size = self.audio_processor.get_chunk_size() + + # 按chunk_size大小切分数据 + for i in range(0, len(pcm_items), chunk_size): + chunk = pcm_items[i : i + chunk_size] + # 如果最后一块数据大小不足,则用0填充 + if len(chunk) < chunk_size: + padded_chunk = np.zeros(chunk_size, dtype=np.float32) + padded_chunk[: len(chunk)] = chunk + chunk = padded_chunk + # 将数据标准化到[-1,1]范围并发送到队列 + item = PCMStatChunk( + size=chunk_size, status=status, data=chunk.astype(np.float32) / 32768.0 + ) + self.pcm_stat_chunk_queue.put(item) def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: # cache fbank feature (input_chunk) @@ -163,7 +213,9 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs if outputs["stat"] == "cl": # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected + print("output stat continue listen") outputs = self.pipeline.speech_dialogue(feature, **outputs) + print(f"speech_dialogue --> output stat {outputs['stat']}") if is_first_pack: outputs["stat"] = "cl" if outputs["stat"] == "el": @@ -175,24 +227,23 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs def listen(self): """ - - status: sl(start listen) -> cl(continue listen) -> ss(start generate speech to speak) - - status from VAD + chunk status from VAD """ print("Start listening") while True: if self.stop_listen: print("Stop listening") break - e = self.pcm_fifo_queue.get(self.audio_processor.chunk_size) - if e is None: + stat_chunk: PCMStatChunk = self.pcm_stat_chunk_queue.get() + if stat_chunk is None: time.sleep(0.01) continue - print("Received PCM data: ", len(e)) + print(f"Received PCM stat chunk: {stat_chunk}") - fbank_feature = self.audio_processor.process(np.float32(e)) - if self.status == "sl": + fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) + if stat_chunk.status == "sl": + outputs = GlobalVars.deepcopy_outputs() feature_last_chunk = self.history_buffering_strategy(fbank_feature) - outputs = deepcopy(self.init_outputs) outputs["adapter_cache"] = None outputs["encoder_cache"] = None outputs["pe_index"] = 0 @@ -214,8 +265,8 @@ def listen(self): ) outputs = self.llm_prefill("cl", fbank_feature, outputs) - elif self.status == "cl" or self.status == "el": - outputs = self.llm_prefill(self.status, fbank_feature, outputs) + elif stat_chunk.status == "cl" or stat_chunk.status == "el": + outputs = self.llm_prefill(stat_chunk.status, fbank_feature, outputs) @dataclass @@ -262,7 +313,6 @@ def __init__( def reset(self): self.stop_speak = False self.is_generate = False - self.generate_outputs = {} self.whole_text = "" self.tts_over = False @@ -279,12 +329,7 @@ def print(self): print("tts_over_time:", self.tts_over_time) @property - def generate_outputs(self): - """Get the generate outputs.""" - return self.generate_outputs - - @property - def whole_text(self): + def gen_text(self): """Get the whole text.""" return self.whole_text @@ -356,7 +401,7 @@ def get_tts_data(self) -> Generator[np.ndarray, None, None]: output_data = self.tts_data.get() if output_data is not None: # print("Get TTS data") - #yield output_data.astype(np.int16).tobytes() + # yield output_data.astype(np.int16).tobytes() yield output_data yield None @@ -377,7 +422,7 @@ def speak(self): self.is_generate = True outputs = self.pipeline.speech_dialogue(None, **outputs) # outputs dict need change, so deepcopy - self.generate_outputs = deepcopy(outputs) + GlobalVars.speech_dialogue_outputs = deepcopy(outputs) cur_hidden_state = [] cur_hidden_state.append(outputs["hidden_state"]) @@ -398,7 +443,7 @@ def speak(self): del outputs["text"] del outputs["hidden_state"] outputs = self.pipeline.speech_dialogue(None, **outputs) - self.generate_outputs = deepcopy(outputs) + GlobalVars.speech_dialogue_outputs = deepcopy(outputs) if outputs["stat"] == "cs": cur_hidden_state.append(outputs["hidden_state"]) if "�" in outputs["text"][len(last_text) :]: @@ -439,12 +484,12 @@ def speak(self): print(self.whole_text) def interrupt(self): - self.stop_generate = True + self.stop_speak = True self.tts_over = True while True: time.sleep(0.01) if self.is_generate is False: - self.stop_generate = False + self.stop_speak = False while True: time.sleep(0.01) if self.tts_data.is_empty(): @@ -477,18 +522,34 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): if fs != 16000: wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) fs = 16000 - # send numpy ndarray pcm data - listener.send(wav.numpy()) + + chunk_size = list.audio_processor.get_chunk_size() + wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) + wav_input[: wav.shape[0]] = wav + for i in range(0, wav_input.shape[0], chunk_size): + # send numpy ndarray pcm data + listener.send(wav.numpy()) + listener.set_status("cl") wavs = [] # get tts speak data for data in speaker.get_tts_data(): - wavs.append(torch.tensor(data)) + if data: + print(data.shape) + wavs.append(torch.tensor(data)) sf.write(configs.output_wav, torch.cat(wav, -1).squeeze().float().cpu().numpy(), 24000) + print(f"write to {configs.output_wav}") + + listener.stop() + speaker.stop() if __name__ == "__main__": + # change print function to add time stamp + original_print = builtins.print + builtins.print = custom_print + configs = get_args() # encoder and audio llm diff --git a/bin/server.py b/bin/server.py index 6d3c643..b038372 100644 --- a/bin/server.py +++ b/bin/server.py @@ -227,6 +227,7 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( torch.tensor(data['feature']), **outputs) + print(f"sl -> speech_dialogue outputs stat: {outputs['stat']}") if data['status'] == 'el': connected_users[sid][1].wakeup_and_vad.in_dialog = False @@ -239,6 +240,7 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( torch.tensor(data['feature']), **outputs) + print(f"cl -> speech_dialogue outputs stat: {outputs['stat']}") if is_first_pack: outputs['stat'] = 'cl' if outputs['stat'] == 'el': @@ -275,6 +277,7 @@ def send_pcm(sid): continue print("Sid: ", sid, " Received PCM data: ", len(e)) res = connected_users[sid][1].wakeup_and_vad.predict(np.float32(e)) + print(f"wakeup_and_vad.predict -> res status: {res['status']}") if res['status'] == 'sl': print("Sid: ", sid, " Vad start") diff --git a/models/audioLLM.py b/models/audioLLM.py index 28d5421..a226315 100644 --- a/models/audioLLM.py +++ b/models/audioLLM.py @@ -403,6 +403,8 @@ def _generate_one_step( top_p: float = 1.0, top_k: int = 0, temperature: float = 1.0, + el_prob: float = 0.5, + ss_prob: float = 0.5, ): """ Generates the model's next output based on the current input and state. @@ -413,6 +415,8 @@ def _generate_one_step( - top_p: The threshold for controlling top-p sampling. - top_k: The threshold for controlling top-k sampling. - temperature: Controls the randomness of sampling. + - el_prob: end listen stat logit prob + - ss_prob: start speak stat logit prob Returns: - last_id: The index of the last generated token. @@ -429,9 +433,9 @@ def _generate_one_step( state_1 = state_prob[1] state_2 = state_prob[2] print("State 1 prob: {:.4f}, State 2 prob: {:.4f}".format(state_1.item(), state_2.item())) - if state_2 > 0.5: + if state_2 > el_prob: return None, outputs['past_key_values'], 'el', None - if state_1 > 0.5: + if state_1 > ss_prob: return None, outputs['past_key_values'], 'ss', None return None, outputs['past_key_values'], 'cl', None diff --git a/web/queue.py b/web/queue.py index 5b85f80..c84c3b6 100644 --- a/web/queue.py +++ b/web/queue.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import json import torch import threading diff --git a/web/vad.py b/web/vad.py index d42db10..8ad8a3e 100644 --- a/web/vad.py +++ b/web/vad.py @@ -22,7 +22,7 @@ class VAD: - 状态2表示无需中断对话。 这两种状态都将停止向 Freeze-Omni 发送语音流并重置 VAD 模块。 """ - def __init__(self, cache_history=10): + def __init__(self, cache_history=10, last_chunk_size=6): self.chunk_size = 16 self.chunk_overlap = 3 self.feat_dim = 80 @@ -30,7 +30,9 @@ def __init__(self, cache_history=10): self.frame_shift = 160 self.frame_overlap = self.frame_size - self.frame_shift self.CHUNK = self.frame_shift * self.chunk_size + assert cache_history >= last_chunk_size, "cache_history must >= last_chunk_size" self.cache_history = cache_history + self.last_chunk_size = last_chunk_size self.in_dialog = False with torch.no_grad(): @@ -131,8 +133,8 @@ def predict(self, # return dict if return_dict['status'] == 'sl': - # copy last 6 chunks - return_dict['feature_last_chunk'] = self.history[-6:].unsqueeze(1).numpy().tolist() + # copy last chunk size chunks + return_dict['feature_last_chunk'] = self.history[-self.last_chunk_size:].unsqueeze(1).numpy().tolist() return_dict['feature'] = self.input_chunk.numpy().tolist() return_dict['history_feature'] = self.history.numpy().tolist() elif return_dict['status'] == 'cl' or return_dict['status'] == 'el': From 8880431b04048d97ae398e99cf64906c9510a55e Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 18 Dec 2024 14:08:01 +0800 Subject: [PATCH 09/15] feat: add PCMStatChunk for PCMListener Signed-off-by: weedge --- bin/inference_stream.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bin/inference_stream.py b/bin/inference_stream.py index c893674..6143ea7 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -527,9 +527,11 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) wav_input[: wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): - # send numpy ndarray pcm data - listener.send(wav.numpy()) - listener.set_status("cl") + # send numpy ndarray pcm data with status + status = "sl" + if i > 0: + status = "cl" + listener.send(wav_input[i : i + chunk_size].numpy(), status=status) wavs = [] # get tts speak data From 406caa28b2dfc5b6c6d8d3bdc8356edb3b96a6a7 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 18 Dec 2024 17:32:18 +0800 Subject: [PATCH 10/15] change feature_last_chunk feature history_feature use torch.Tensor Signed-off-by: weedge --- bin/inference.py | 4 +++- bin/inference_stream.py | 22 ++++++++++------------ bin/server.py | 4 ++-- web/vad.py | 10 +++++----- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index 6ebf163..34ac6bb 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -182,12 +182,14 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav) cur_hidden_state = [] + print(f"cur_text:{cur_text}") cur_text = "" if outputs['stat'] == 'sl': break - # print(outputs['text']) + #print(outputs['text']) last_text = outputs['text'] if len(cur_hidden_state) != 0: + print(f"cur_text:{cur_text}") decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav) diff --git a/bin/inference_stream.py b/bin/inference_stream.py index 6143ea7..bd7d88d 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -99,22 +99,20 @@ class GlobalVars: @staticmethod def deepcopy_outputs(): - return deepcopy(GlobalVars.deepcopy_outputs) + return deepcopy(GlobalVars.speech_dialogue_outputs) @dataclass class PCMStatChunk: - size: int status: str # sl(start listen), cl(continue listen), el(end listen) data: np.ndarray # a numpy array of dtype np.float32 def __post_init__(self): - super().__post_init__() if self.status not in ["sl", "cl", "el"]: raise ValueError("status must be one of: 'sl','cl','el'") def __str__(self): - return f"size:{self.size} status:{self.status} numpy data:{self.data.shape} {super().__str__()}" + return f"status:{self.status} numpy data:{self.data.shape}" class PCMListener: @@ -176,9 +174,7 @@ def send(self, pcm_items: np.ndarray, status: str): padded_chunk[: len(chunk)] = chunk chunk = padded_chunk # 将数据标准化到[-1,1]范围并发送到队列 - item = PCMStatChunk( - size=chunk_size, status=status, data=chunk.astype(np.float32) / 32768.0 - ) + item = PCMStatChunk(status=status, data=chunk.astype(np.float32) / 32768.0) self.pcm_stat_chunk_queue.put(item) def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: @@ -205,7 +201,8 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs if status == "sl": # Satge1: start listen # stat will be auto set to 'cl' after Stage1 - outputs = self.pipeline.speech_dialogue(feature, **outputs) + outputs = self.pipeline.speech_dialogue(torch.tensor(feature.numpy().tolist()), **outputs) + print(f"sl --> output stat {outputs['stat']}") if status == "el": print("status end listen. Detect vad time out") @@ -214,8 +211,8 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected print("output stat continue listen") - outputs = self.pipeline.speech_dialogue(feature, **outputs) - print(f"speech_dialogue --> output stat {outputs['stat']}") + outputs = self.pipeline.speech_dialogue(torch.tensor(feature.numpy().tolist()), **outputs) + print(f"cl --> output stat {outputs['stat']}") if is_first_pack: outputs["stat"] = "cl" if outputs["stat"] == "el": @@ -254,7 +251,7 @@ def listen(self): if "hidden_state" in outputs: del outputs["hidden_state"] - for i in range(feature_last_chunk): + for i in range(len(feature_last_chunk)): if i == 0: outputs = self.llm_prefill( "sl", feature_last_chunk[i], outputs, is_first_pack=True @@ -523,7 +520,8 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) fs = 16000 - chunk_size = list.audio_processor.get_chunk_size() + # like io_uring + chunk_size = listener.audio_processor.get_chunk_size() wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) wav_input[: wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): diff --git a/bin/server.py b/bin/server.py index b038372..8d61382 100644 --- a/bin/server.py +++ b/bin/server.py @@ -225,7 +225,7 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): # Satge1: start listen # stat will be auto set to 'cl' after Stage1 outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( - torch.tensor(data['feature']), + data['feature'], **outputs) print(f"sl -> speech_dialogue outputs stat: {outputs['stat']}") @@ -238,7 +238,7 @@ def llm_prefill(data, outputs, sid, is_first_pack=False): # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected outputs = connected_users[sid][1].pipeline_obj.pipeline_proc.speech_dialogue( - torch.tensor(data['feature']), + data['feature'], **outputs) print(f"cl -> speech_dialogue outputs stat: {outputs['stat']}") if is_first_pack: diff --git a/web/vad.py b/web/vad.py index 8ad8a3e..a869ac5 100644 --- a/web/vad.py +++ b/web/vad.py @@ -134,12 +134,12 @@ def predict(self, # return dict if return_dict['status'] == 'sl': # copy last chunk size chunks - return_dict['feature_last_chunk'] = self.history[-self.last_chunk_size:].unsqueeze(1).numpy().tolist() - return_dict['feature'] = self.input_chunk.numpy().tolist() - return_dict['history_feature'] = self.history.numpy().tolist() + return_dict['feature_last_chunk'] = self.history[-self.last_chunk_size:].unsqueeze(1) + return_dict['feature'] = self.input_chunk + return_dict['history_feature'] = self.history elif return_dict['status'] == 'cl' or return_dict['status'] == 'el': return_dict['feature_last_chunk'] = None - return_dict['feature'] = self.input_chunk.numpy().tolist() - return_dict['history_feature'] = self.history.numpy().tolist() + return_dict['feature'] = self.input_chunk + return_dict['history_feature'] = self.history return return_dict From ebe608aa8f2da506e617ff7e7268e7a03af106f1 Mon Sep 17 00:00:00 2001 From: weedge Date: Wed, 18 Dec 2024 21:45:07 +0800 Subject: [PATCH 11/15] add .gitignore Signed-off-by: weedge --- .gitignore | 4 ++++ bin/inference.py | 12 ------------ 2 files changed, 4 insertions(+), 12 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..749ccda --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class diff --git a/bin/inference.py b/bin/inference.py index 34ac6bb..67d2f14 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -1,14 +1,7 @@ from __future__ import print_function import argparse -import os -import json -import queue import torch -import yaml -import threading -import struct -import time import torchaudio import datetime import builtins @@ -16,15 +9,10 @@ import soundfile as sf import numpy as np -import torch.nn.functional as F import torchaudio.compliance.kaldi as k -from torch.utils.data import DataLoader - from models.pipeline import inferencePipeline from models.decoder.llm2tts import llm2TTS -from web.parms import GlobalParams -from web.pool import TTSObjectPool def get_args(): parser = argparse.ArgumentParser(description='Freeze-Omni') From 3c8760266b924ca417990033f3c8cd77455ebe1c Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 19 Dec 2024 14:50:53 +0800 Subject: [PATCH 12/15] fix: inference Signed-off-by: weedge --- bin/inference.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index 67d2f14..b024902 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -105,11 +105,18 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, Returns: - None """ - wav, fs = sf.read(configs.input_wav) - wav = torch.tensor(wav) + wav, fs = torchaudio.load(configs.input_wav) if fs != 16000: - wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav) fs = 16000 + wav = wav.reshape(-1) + + #wav, fs = sf.read(configs.input_wav) + #wav = torch.tensor(wav) + #if fs != 16000: + # wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float()) + # fs = 16000 + print("--->",wav.shape) codec_chunk_size = 40 codec_padding_size = 10 @@ -126,12 +133,14 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) wav_input[:wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): + print("--->",wav_input.shape, wav.shape,wav_input[i:i+chunk_size].shape) fbank = audio_processor.process(wav_input[i:i+chunk_size]) outputs = pipeline.speech_dialogue(fbank, **outputs) print(f"speech_dialogue outputs stat: {outputs['stat']}") outputs['stat'] = 'cl' audio_processor.reset() + print("listen",outputs.keys()) outputs['adapter_cache'] = None outputs['encoder_cache'] = None outputs['pe_index'] = 0 From feb2d48705f81a636f5cf64275ee38d9310807f7 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 19 Dec 2024 20:21:55 +0800 Subject: [PATCH 13/15] fix: add GenTTSFrame and el stat Signed-off-by: weedge --- bin/inference_stream.py | 72 ++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/bin/inference_stream.py b/bin/inference_stream.py index bd7d88d..3831e42 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -95,10 +95,14 @@ def process(self, audio: np.ndarray) -> torch.Tensor: class GlobalVars: + """ + multi turn conversation interal speech dialogue outputs for listen and speak + """ + speech_dialogue_outputs = {} @staticmethod - def deepcopy_outputs(): + def deepcopy_outputs(): # cp for write return deepcopy(GlobalVars.speech_dialogue_outputs) @@ -112,7 +116,16 @@ def __post_init__(self): raise ValueError("status must be one of: 'sl','cl','el'") def __str__(self): - return f"status:{self.status} numpy data:{self.data.shape}" + return f"status:{self.status} numpy data:{self.data.shape if self.data else None}" + + +@dataclass +class GenTTSFrame: + text: str + data: np.ndarray # a numpy array of dtype np.float32 + + def __str__(self): + return f"text:{self.text} numpy data:{self.data.shape if self.data else None}" class PCMListener: @@ -155,13 +168,17 @@ def stop(self): self.stop_listen = True self.listen_thread.join(timeout=3) - def send(self, pcm_items: np.ndarray, status: str): + def send(self, pcm_items: np.ndarray | None, status: str): """ 将float32(<1) numpy数组按chunk_size大小切分并发送到FIFO队列缓冲区 Args: pcm_items: 输入的PCM数据数组 """ + if pcm_items is None: + item = PCMStatChunk(status=status, data=None) + self.pcm_stat_chunk_queue.put(item) + # 获取音频处理器的块大小 chunk_size = self.audio_processor.get_chunk_size() @@ -201,17 +218,23 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs if status == "sl": # Satge1: start listen # stat will be auto set to 'cl' after Stage1 - outputs = self.pipeline.speech_dialogue(torch.tensor(feature.numpy().tolist()), **outputs) + outputs = self.pipeline.speech_dialogue( + torch.tensor(feature.numpy().tolist()), **outputs + ) print(f"sl --> output stat {outputs['stat']}") + return outputs if status == "el": - print("status end listen. Detect vad time out") + print("status end listen. start to speak") + return outputs if status == "cl": if outputs["stat"] == "cl": # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected print("output stat continue listen") - outputs = self.pipeline.speech_dialogue(torch.tensor(feature.numpy().tolist()), **outputs) + outputs = self.pipeline.speech_dialogue( + torch.tensor(feature.numpy().tolist()), **outputs + ) print(f"cl --> output stat {outputs['stat']}") if is_first_pack: outputs["stat"] = "cl" @@ -237,8 +260,8 @@ def listen(self): continue print(f"Received PCM stat chunk: {stat_chunk}") - fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) if stat_chunk.status == "sl": + fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) outputs = GlobalVars.deepcopy_outputs() feature_last_chunk = self.history_buffering_strategy(fbank_feature) outputs["adapter_cache"] = None @@ -260,10 +283,21 @@ def listen(self): outputs = self.llm_prefill( "cl", feature_last_chunk[i], outputs, is_first_pack=True ) - outputs = self.llm_prefill("cl", fbank_feature, outputs) - - elif stat_chunk.status == "cl" or stat_chunk.status == "el": - outputs = self.llm_prefill(stat_chunk.status, fbank_feature, outputs) + GlobalVars.speech_dialogue_outputs = self.llm_prefill("cl", fbank_feature, outputs) + + elif stat_chunk.status == "cl": + fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) + GlobalVars.speech_dialogue_outputs = self.llm_prefill( + stat_chunk.status, fbank_feature, GlobalVars.deepcopy_outputs() + ) + elif stat_chunk.status == "el": + outputs = GlobalVars.deepcopy_outputs() + outputs["adapter_cache"] = None + outputs["encoder_cache"] = None + outputs["pe_index"] = 0 + outputs["stat"] = "ss" + outputs["last_id"] = None + self.outputs_queue.put(outputs) @dataclass @@ -384,10 +418,11 @@ def decoder( self.tts_data.clear() self.whole_text = "" break - self.tts_data.put(seg.squeeze().float().cpu().numpy() * 32768) + frame = GenTTSFrame(text=cur_text, data=seg.squeeze().float().cpu().numpy() * 32768) + self.tts_data.put(frame) return generate_num - def get_tts_data(self) -> Generator[np.ndarray, None, None]: + def get_tts_data(self) -> Generator[GenTTSFrame, None, None]: """ get tts bytes data """ @@ -529,14 +564,15 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): status = "sl" if i > 0: status = "cl" - listener.send(wav_input[i : i + chunk_size].numpy(), status=status) + listener.send(wav_input[i : i + chunk_size].numpy(), status) + listener.send(None, "el") wavs = [] # get tts speak data - for data in speaker.get_tts_data(): - if data: - print(data.shape) - wavs.append(torch.tensor(data)) + for item in speaker.get_tts_data(): + if item: + print(item) + wavs.append(torch.tensor(item.data)) sf.write(configs.output_wav, torch.cat(wav, -1).squeeze().float().cpu().numpy(), 24000) print(f"write to {configs.output_wav}") From 6882170fefaf22dc019dfc07cac6973366b8d467 Mon Sep 17 00:00:00 2001 From: weedge Date: Thu, 19 Dec 2024 22:16:35 +0800 Subject: [PATCH 14/15] fix: get tts speak data is none sleep 0.01 to yield thread from cpu Signed-off-by: weedge --- bin/inference_stream.py | 56 ++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/bin/inference_stream.py b/bin/inference_stream.py index 3831e42..bae8be7 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -116,7 +116,9 @@ def __post_init__(self): raise ValueError("status must be one of: 'sl','cl','el'") def __str__(self): - return f"status:{self.status} numpy data:{self.data.shape if self.data else None}" + return ( + f"status:{self.status} numpy data:{self.data.shape if self.data is not None else None}" + ) @dataclass @@ -125,7 +127,7 @@ class GenTTSFrame: data: np.ndarray # a numpy array of dtype np.float32 def __str__(self): - return f"text:{self.text} numpy data:{self.data.shape if self.data else None}" + return f"text:{self.text} numpy data:{self.data.shape if self.data is not None else None}" class PCMListener: @@ -178,21 +180,21 @@ def send(self, pcm_items: np.ndarray | None, status: str): if pcm_items is None: item = PCMStatChunk(status=status, data=None) self.pcm_stat_chunk_queue.put(item) - - # 获取音频处理器的块大小 - chunk_size = self.audio_processor.get_chunk_size() - - # 按chunk_size大小切分数据 - for i in range(0, len(pcm_items), chunk_size): - chunk = pcm_items[i : i + chunk_size] - # 如果最后一块数据大小不足,则用0填充 - if len(chunk) < chunk_size: - padded_chunk = np.zeros(chunk_size, dtype=np.float32) - padded_chunk[: len(chunk)] = chunk - chunk = padded_chunk - # 将数据标准化到[-1,1]范围并发送到队列 - item = PCMStatChunk(status=status, data=chunk.astype(np.float32) / 32768.0) - self.pcm_stat_chunk_queue.put(item) + else: + # 获取音频处理器的块大小 + chunk_size = self.audio_processor.get_chunk_size() + + # 按chunk_size大小切分数据 + for i in range(0, len(pcm_items), chunk_size): + chunk = pcm_items[i : i + chunk_size] + # 如果最后一块数据大小不足,则用0填充 + if len(chunk) < chunk_size: + padded_chunk = np.zeros(chunk_size, dtype=np.float32) + padded_chunk[: len(chunk)] = chunk + chunk = padded_chunk + # 将数据标准化到[-1,1]范围并发送到队列 + item = PCMStatChunk(status=status, data=(chunk.astype(np.float32) / 32768.0)) + self.pcm_stat_chunk_queue.put(item) def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: # cache fbank feature (input_chunk) @@ -297,6 +299,7 @@ def listen(self): outputs["pe_index"] = 0 outputs["stat"] = "ss" outputs["last_id"] = None + print("end listen put outputs") self.outputs_queue.put(outputs) @@ -417,8 +420,9 @@ def decoder( if self.tts_over: self.tts_data.clear() self.whole_text = "" + self.tts_data.put(GenTTSFrame(text="", data=None)) break - frame = GenTTSFrame(text=cur_text, data=seg.squeeze().float().cpu().numpy() * 32768) + frame = GenTTSFrame(text=cur_text, data=seg.squeeze().float().cpu().numpy()) self.tts_data.put(frame) return generate_num @@ -435,8 +439,8 @@ def get_tts_data(self) -> Generator[GenTTSFrame, None, None]: # print("Get TTS data") # yield output_data.astype(np.int16).tobytes() yield output_data - - yield None + else: + yield None def speak(self): """ @@ -514,6 +518,7 @@ def speak(self): outputs["stat"] = "sl" outputs["last_id"] = None print(self.whole_text) + self.tts_data.put(GenTTSFrame(text="", data=None)) def interrupt(self): self.stop_speak = True @@ -572,9 +577,14 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): for item in speaker.get_tts_data(): if item: print(item) - wavs.append(torch.tensor(item.data)) - - sf.write(configs.output_wav, torch.cat(wav, -1).squeeze().float().cpu().numpy(), 24000) + if item.data is None: + break + else: + wavs.append(item.data) + else: + time.sleep(0.01) # yield thread + + sf.write(configs.output_wav, np.concatenate(wavs, -1), 24000) print(f"write to {configs.output_wav}") listener.stop() From 50021e25bf3fdc8c2694e8c9cd81592da00330a0 Mon Sep 17 00:00:00 2001 From: weedge Date: Fri, 20 Dec 2024 15:41:53 +0800 Subject: [PATCH 15/15] fix: PCMListener prefill Signed-off-by: weedge --- bin/inference.py | 17 ++- bin/inference_stream.py | 227 +++++++++++++++------------------------- models/pipeline.py | 2 +- models/utils.py | 15 ++- 4 files changed, 115 insertions(+), 146 deletions(-) diff --git a/bin/inference.py b/bin/inference.py index b024902..9a2673d 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -11,6 +11,7 @@ import numpy as np import torchaudio.compliance.kaldi as k +from models.utils import print_outputs from models.pipeline import inferencePipeline from models.decoder.llm2tts import llm2TTS @@ -62,14 +63,14 @@ def chunk_data_shift(self, xs): self.input_chunk[:, self.chunk_overlap:, :] = xs.squeeze(0) def process(self, - audio: np.ndarray): + audio: torch.Tensor): """ # 1. Converts the input audio tensor to the appropriate format. # 2. Computes the filter bank features (fbank) for the audio. # 3. Updates the input chunk and history based on the new audio segment. """ with torch.no_grad(): - sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768 + sample_data = audio.clone().reshape(1, -1, 1)[:, :, :1] * 32768 self.fbank_shift(sample_data) # use kaldi api to compute fbank xs = k.fbank(waveform = self.input_sample.squeeze(-1), dither=0, @@ -126,6 +127,7 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, # set system role, stat will be set to 'sl' stat = 'pre' outputs = pipeline.speech_dialogue(None, stat=stat, role="You are a helpful assistant.") + print(f"pre-> outputs:[{print_outputs(outputs)}]") chunk_size = audio_processor.get_chunk_size() # Satge1: start listen @@ -134,20 +136,28 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, wav_input[:wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): print("--->",wav_input.shape, wav.shape,wav_input[i:i+chunk_size].shape) + print(f"cl in-> outputs:{print_outputs(outputs)}") + if outputs['stat'] =="sl": + print(f"stat_chunk data:{wav_input[i:i+chunk_size]}") fbank = audio_processor.process(wav_input[i:i+chunk_size]) + if outputs['stat'] =="sl": + print(f"fbank:{fbank}") outputs = pipeline.speech_dialogue(fbank, **outputs) - print(f"speech_dialogue outputs stat: {outputs['stat']}") + print(f"cl out-> outputs:{print_outputs(outputs)}") outputs['stat'] = 'cl' audio_processor.reset() print("listen",outputs.keys()) + print(f"listen-> outputs:[{print_outputs(outputs)}]") outputs['adapter_cache'] = None outputs['encoder_cache'] = None outputs['pe_index'] = 0 outputs['stat'] = 'ss' + print(f"speak get-> outputs:[{print_outputs(outputs)}]") # Stage3: start speak outputs = pipeline.speech_dialogue(None, **outputs) + print(f"ss-> outputs:[{print_outputs(outputs)}]") cur_hidden_state = [] cur_hidden_state.append(outputs['hidden_state']) @@ -166,6 +176,7 @@ def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, del outputs['text'] del outputs['hidden_state'] outputs = pipeline.speech_dialogue(None, **outputs) + print(f"sc-> outputs:[{print_outputs(outputs)}]") if outputs['stat'] == 'cs': cur_hidden_state.append(outputs['hidden_state']) whole_text += outputs['text'][len(last_text):] diff --git a/bin/inference_stream.py b/bin/inference_stream.py index bae8be7..1c76c52 100644 --- a/bin/inference_stream.py +++ b/bin/inference_stream.py @@ -16,6 +16,8 @@ import torchaudio import torchaudio.compliance.kaldi as k +from bin.inference import audioEncoderProcessor +from models.utils import print_outputs from web.queue import PCMQueue, ThreadSafeQueue from models.pipeline import inferencePipeline from models.decoder.llm2tts import llm2TTS @@ -41,59 +43,6 @@ def get_args(): return args -class audioEncoderProcessor: - def __init__(self, chunk_size=16): - self.chunk_size = 16 - self.chunk_overlap = 3 - self.feat_dim = 80 - self.frame_size = 400 - self.frame_shift = 160 - self.frame_overlap = self.frame_size - self.frame_shift - self.CHUNK = self.frame_shift * self.chunk_size - self.reset() - - def get_chunk_size(self): - return self.CHUNK - - def reset(self): - self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim]) - self.input_sample = torch.zeros([1, self.CHUNK + self.frame_overlap, 1]) - - def fbank_shift(self, sample_data): - # fbank feature shift - self.input_sample[:, : self.frame_overlap, :] = self.input_sample[ - :, -self.frame_overlap :, : - ].clone() - self.input_sample[:, self.frame_overlap :, :] = sample_data - - def chunk_data_shift(self, xs): - # chunk feature shift - self.input_chunk[:, : self.chunk_overlap, :] = self.input_chunk[ - :, -self.chunk_overlap :, : - ].clone() - self.input_chunk[:, self.chunk_overlap :, :] = xs.squeeze(0) - - def process(self, audio: np.ndarray) -> torch.Tensor: - """ - # 1. Converts the input audio tensor to the appropriate format. - # 2. Computes the filter bank features (fbank) for the audio. - # 3. Updates the input chunk and history based on the new audio segment. - """ - with torch.no_grad(): - sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768 - self.fbank_shift(sample_data) - # use kaldi api to compute fbank - xs = k.fbank( - waveform=self.input_sample.squeeze(-1), - dither=0, - frame_length=25, - frame_shift=10, - num_mel_bins=self.feat_dim, - ) - self.chunk_data_shift(xs) - return self.input_chunk.clone() - - class GlobalVars: """ multi turn conversation interal speech dialogue outputs for listen and speak @@ -109,7 +58,7 @@ def deepcopy_outputs(): # cp for write @dataclass class PCMStatChunk: status: str # sl(start listen), cl(continue listen), el(end listen) - data: np.ndarray # a numpy array of dtype np.float32 + data: torch.Tensor # chunk tensor def __post_init__(self): if self.status not in ["sl", "cl", "el"]: @@ -117,10 +66,22 @@ def __post_init__(self): def __str__(self): return ( - f"status:{self.status} numpy data:{self.data.shape if self.data is not None else None}" + f"status:{self.status} data shape:{self.data.shape if self.data is not None else None}" ) +@dataclass +class Listen2SpeakFrame: + """ + outputs have stat, so need pass it from listen to speak + """ + + outputs: dict + + def __str__(self): + return f"outputs:{print_outputs(self.outputs)}" + + @dataclass class GenTTSFrame: text: str @@ -148,9 +109,9 @@ def __init__( self.pipeline = pipeline # pre status system prompt, outputs stat: pre -> sl - GlobalVars.speech_dialogue_outputs = pipeline.speech_dialogue( - None, stat="pre", role=system_prompt - ) + self.outputs = self.pipeline.speech_dialogue(None, stat="pre", role=system_prompt) + # print(f"pre-> outputs:[{print_outputs(self.outputs)}]") + GlobalVars.speech_dialogue_outputs = self.outputs # chunck feat history cache also use ring buffer to do :) self.history = torch.zeros( @@ -170,31 +131,13 @@ def stop(self): self.stop_listen = True self.listen_thread.join(timeout=3) - def send(self, pcm_items: np.ndarray | None, status: str): - """ - 将float32(<1) numpy数组按chunk_size大小切分并发送到FIFO队列缓冲区 - - Args: - pcm_items: 输入的PCM数据数组 - """ + def send(self, pcm_items: torch.Tensor | None, status: str): if pcm_items is None: item = PCMStatChunk(status=status, data=None) self.pcm_stat_chunk_queue.put(item) else: - # 获取音频处理器的块大小 - chunk_size = self.audio_processor.get_chunk_size() - - # 按chunk_size大小切分数据 - for i in range(0, len(pcm_items), chunk_size): - chunk = pcm_items[i : i + chunk_size] - # 如果最后一块数据大小不足,则用0填充 - if len(chunk) < chunk_size: - padded_chunk = np.zeros(chunk_size, dtype=np.float32) - padded_chunk[: len(chunk)] = chunk - chunk = padded_chunk - # 将数据标准化到[-1,1]范围并发送到队列 - item = PCMStatChunk(status=status, data=(chunk.astype(np.float32) / 32768.0)) - self.pcm_stat_chunk_queue.put(item) + item = PCMStatChunk(status=status, data=pcm_items) + self.pcm_stat_chunk_queue.put(item) def history_buffering_strategy(self, input_chunk: torch.Tensor) -> torch.Tensor: # cache fbank feature (input_chunk) @@ -220,31 +163,25 @@ def llm_prefill(self, status: str, feature: torch.Tensor, outputs: dict, is_firs if status == "sl": # Satge1: start listen # stat will be auto set to 'cl' after Stage1 - outputs = self.pipeline.speech_dialogue( - torch.tensor(feature.numpy().tolist()), **outputs - ) - print(f"sl --> output stat {outputs['stat']}") + outputs = self.pipeline.speech_dialogue(feature, **outputs) return outputs if status == "el": print("status end listen. start to speak") return outputs if status == "cl": - if outputs["stat"] == "cl": + if outputs["stat"] == "cl" or outputs["stat"] == "sl": # Stage2: continue listen # stat will be auto set to 'ss' when endpoint is detected - print("output stat continue listen") - outputs = self.pipeline.speech_dialogue( - torch.tensor(feature.numpy().tolist()), **outputs - ) - print(f"cl --> output stat {outputs['stat']}") + # print("output stat continue listen") + outputs = self.pipeline.speech_dialogue(feature, **outputs) if is_first_pack: outputs["stat"] = "cl" if outputs["stat"] == "el": print("output stat end listen. Detect invalid break") if outputs["stat"] == "ss": - print("output stat start speak. send outputs to queue") - self.outputs_queue.put(outputs) + # print(f"start speak. start to speak") + pass return outputs def listen(self): @@ -260,47 +197,54 @@ def listen(self): if stat_chunk is None: time.sleep(0.01) continue - print(f"Received PCM stat chunk: {stat_chunk}") + # print(f"Received PCM stat chunk: {stat_chunk}") + + # if self.outputs['stat'] =="sl": + # print(f"stat_chunk data:{stat_chunk.data}") if stat_chunk.status == "sl": - fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) - outputs = GlobalVars.deepcopy_outputs() - feature_last_chunk = self.history_buffering_strategy(fbank_feature) - outputs["adapter_cache"] = None - outputs["encoder_cache"] = None - outputs["pe_index"] = 0 - outputs["stat"] = "sl" - outputs["last_id"] = None - if "text" in outputs: - del outputs["text"] - if "hidden_state" in outputs: - del outputs["hidden_state"] - - for i in range(len(feature_last_chunk)): - if i == 0: - outputs = self.llm_prefill( - "sl", feature_last_chunk[i], outputs, is_first_pack=True - ) - else: - outputs = self.llm_prefill( - "cl", feature_last_chunk[i], outputs, is_first_pack=True - ) - GlobalVars.speech_dialogue_outputs = self.llm_prefill("cl", fbank_feature, outputs) + fbank_feature = self.audio_processor.process(stat_chunk.data) + self.outputs = ( + GlobalVars.deepcopy_outputs() + ) # for next turn conversation outputs from speak where in gloab var + # print(f"sl-> outputs:{print_outputs(self.outputs)}") + self.outputs["adapter_cache"] = None + self.outputs["encoder_cache"] = None + self.outputs["pe_index"] = 0 + self.outputs["stat"] = "sl" + self.outputs["last_id"] = None + if "text" in self.outputs: + del self.outputs["text"] + if "hidden_state" in self.outputs: + del self.outputs["hidden_state"] + + self.outputs = self.llm_prefill("sl", fbank_feature, self.outputs) elif stat_chunk.status == "cl": - fbank_feature = self.audio_processor.process(np.float32(stat_chunk.data)) - GlobalVars.speech_dialogue_outputs = self.llm_prefill( - stat_chunk.status, fbank_feature, GlobalVars.deepcopy_outputs() - ) + # print(f"cl in-> outputs:{print_outputs(self.outputs)}") + fbank_feature = self.audio_processor.process(stat_chunk.data) + # if self.outputs['stat'] =="sl": + # print(f"fbank:{fbank_feature}") + self.outputs = self.llm_prefill(stat_chunk.status, fbank_feature, self.outputs) + # print(f"cl out-> outputs:{print_outputs(self.outputs)}") + if self.outputs["stat"] == "ss": + pass + else: + self.outputs["stat"] = "cl" elif stat_chunk.status == "el": - outputs = GlobalVars.deepcopy_outputs() - outputs["adapter_cache"] = None - outputs["encoder_cache"] = None - outputs["pe_index"] = 0 - outputs["stat"] = "ss" - outputs["last_id"] = None - print("end listen put outputs") - self.outputs_queue.put(outputs) + self.audio_processor.reset() + + # print(f"el-> outputs:[{print_outputs(self.outputs)}]") + self.outputs["adapter_cache"] = None + self.outputs["encoder_cache"] = None + self.outputs["pe_index"] = 0 + self.outputs["last_id"] = None + self.outputs["stat"] = "ss" + + if self.outputs["stat"] == "ss": + frame = Listen2SpeakFrame(self.outputs) + # print(f"start to speak, send frame:[{frame}]") + self.outputs_queue.put(frame) @dataclass @@ -450,15 +394,17 @@ def speak(self): if self.stop_speak: print("Stop speak") break - outputs = self.outputs_in_queue.get() - if outputs is None: + frame: Listen2SpeakFrame = self.outputs_in_queue.get() + if frame is None: time.sleep(0.01) continue + + outputs = frame.outputs + # print(f"speak get-> outputs:[{print_outputs(outputs)}]") # Stage3: start speak self.is_generate = True outputs = self.pipeline.speech_dialogue(None, **outputs) - # outputs dict need change, so deepcopy - GlobalVars.speech_dialogue_outputs = deepcopy(outputs) + # print(f"ss-> outputs:[{print_outputs(outputs)}]") cur_hidden_state = [] cur_hidden_state.append(outputs["hidden_state"]) @@ -479,7 +425,7 @@ def speak(self): del outputs["text"] del outputs["hidden_state"] outputs = self.pipeline.speech_dialogue(None, **outputs) - GlobalVars.speech_dialogue_outputs = deepcopy(outputs) + # print(f"sc-> outputs:[{print_outputs(outputs)}]") if outputs["stat"] == "cs": cur_hidden_state.append(outputs["hidden_state"]) if "�" in outputs["text"][len(last_text) :]: @@ -517,6 +463,7 @@ def speak(self): self.is_generate = False outputs["stat"] = "sl" outputs["last_id"] = None + GlobalVars.speech_dialogue_outputs = deepcopy(outputs) print(self.whole_text) self.tts_data.put(GenTTSFrame(text="", data=None)) @@ -565,11 +512,12 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size) wav_input[: wav.shape[0]] = wav for i in range(0, wav_input.shape[0], chunk_size): - # send numpy ndarray pcm data with status - status = "sl" - if i > 0: - status = "cl" - listener.send(wav_input[i : i + chunk_size].numpy(), status) + print("--->", wav_input.shape, wav.shape, wav_input[i : i + chunk_size].shape) + status = "cl" + if i == 0: + status = "sl" + # send pcm data with status + listener.send(wav_input[i : i + chunk_size], status) listener.send(None, "el") wavs = [] @@ -582,8 +530,8 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): else: wavs.append(item.data) else: - time.sleep(0.01) # yield thread - + time.sleep(0.01) # yield thread + sf.write(configs.output_wav, np.concatenate(wavs, -1), 24000) print(f"write to {configs.output_wav}") @@ -602,7 +550,6 @@ def inference_stream(listener: PCMListener, speaker: TTSSpeaker, configs): pipeline = inferencePipeline(configs) # decoder tts = llm2TTS(configs.model_path) - # listen -> gen_queue -> speak gen_queue = ThreadSafeQueue() # listen diff --git a/models/pipeline.py b/models/pipeline.py index 4f0fad3..f17d17b 100644 --- a/models/pipeline.py +++ b/models/pipeline.py @@ -23,7 +23,7 @@ def __init__(self, args): self.model.eval() def speech_dialogue(self, - audio: tuple, + audio: torch.Tensor, role: str=None, stat: str='sl', past_key_values=None, diff --git a/models/utils.py b/models/utils.py index 28bd6aa..6e11596 100644 --- a/models/utils.py +++ b/models/utils.py @@ -1,11 +1,11 @@ -import torch import re import os import yaml +import torch +import numpy as np from models.audioLLM import AudioLLM - from models.encoder.cmvn import GlobalCMVN, load_cmvn from models.encoder.encoder import speechEncoder @@ -55,3 +55,14 @@ def init_encoder_llm(configs): model = AudioLLM(encoder=encoder, **configs['model_conf']) return model + +def print_outputs(outputs: dict): + print_str = "" + for key, item in outputs.items(): + if isinstance(item, (torch.Tensor, np.ndarray)): + print_str += f"{key} shape:{item.shape} " + if isinstance(item, (str, int)): + print_str += f"{key}:{item} " + if isinstance(item, (list, tuple)): + print_str += f"{key} len:{len(item)} " + return print_str \ No newline at end of file