Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
67 changes: 49 additions & 18 deletions bin/inference.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
from __future__ import print_function

import argparse
import os
import json
import queue
import torch
import yaml
import threading
import struct
import time
import torchaudio
import datetime
import builtins
import math

import soundfile as sf
import numpy as np
import torch.nn.functional as F
import torchaudio.compliance.kaldi as k

from torch.utils.data import DataLoader

from models.utils import print_outputs
from models.pipeline import inferencePipeline
from models.decoder.llm2tts import llm2TTS
from web.parms import GlobalParams
from web.pool import TTSObjectPool

def get_args():
parser = argparse.ArgumentParser(description='Freeze-Omni')
Expand All @@ -40,6 +29,11 @@ def get_args():
print(args)
return args

def custom_print(*args, **kwargs):
current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
original_print(f'[{current_time}]', *args, **kwargs)


class audioEncoderProcessor:
def __init__(self, chunk_size = 16):
self.chunk_size = 16
Expand Down Expand Up @@ -70,8 +64,13 @@ def chunk_data_shift(self, xs):

def process(self,
audio: torch.Tensor):
"""
# 1. Converts the input audio tensor to the appropriate format.
# 2. Computes the filter bank features (fbank) for the audio.
# 3. Updates the input chunk and history based on the new audio segment.
"""
with torch.no_grad():
sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768
sample_data = audio.clone().reshape(1, -1, 1)[:, :, :1] * 32768
self.fbank_shift(sample_data)
# use kaldi api to compute fbank
xs = k.fbank(waveform = self.input_sample.squeeze(-1), dither=0,
Expand All @@ -80,6 +79,9 @@ def process(self,
return self.input_chunk.clone()

def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav):
"""
Decodes the current hidden state and text to generate audio segments using speech decoder.
"""
hidden_state_output = torch.cat(cur_hidden_state).squeeze(1)
cur_text_procced = pipeline.post_process(cur_text)
print("Synthesis: ", [cur_text_procced])
Expand All @@ -91,7 +93,7 @@ def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_p
codec_chunk_size, codec_padding_size):
wav.append(seg)

def inference(pipeline, audio_processor, tts, configs):
def inference(pipeline:inferencePipeline, audio_processor:audioEncoderProcessor, tts:llm2TTS, configs):
"""
Perform inference for a speech dialogue system.

Expand All @@ -104,11 +106,18 @@ def inference(pipeline, audio_processor, tts, configs):
Returns:
- None
"""
wav, fs = sf.read(configs.input_wav)
wav = torch.tensor(wav)
wav, fs = torchaudio.load(configs.input_wav)
if fs != 16000:
wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float())
wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav)
fs = 16000
wav = wav.reshape(-1)

#wav, fs = sf.read(configs.input_wav)
#wav = torch.tensor(wav)
#if fs != 16000:
# wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float())
# fs = 16000
print("--->",wav.shape)

codec_chunk_size = 40
codec_padding_size = 10
Expand All @@ -118,25 +127,37 @@ def inference(pipeline, audio_processor, tts, configs):
# set system role, stat will be set to 'sl'
stat = 'pre'
outputs = pipeline.speech_dialogue(None, stat=stat, role="You are a helpful assistant.")
print(f"pre-> outputs:[{print_outputs(outputs)}]")
chunk_size = audio_processor.get_chunk_size()

# Satge1: start listen
# stat will be auto set to 'cl' after Stage1
wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size)
wav_input[:wav.shape[0]] = wav
for i in range(0, wav_input.shape[0], chunk_size):
print("--->",wav_input.shape, wav.shape,wav_input[i:i+chunk_size].shape)
print(f"cl in-> outputs:{print_outputs(outputs)}")
if outputs['stat'] =="sl":
print(f"stat_chunk data:{wav_input[i:i+chunk_size]}")
fbank = audio_processor.process(wav_input[i:i+chunk_size])
if outputs['stat'] =="sl":
print(f"fbank:{fbank}")
outputs = pipeline.speech_dialogue(fbank, **outputs)
print(f"cl out-> outputs:{print_outputs(outputs)}")
outputs['stat'] = 'cl'
audio_processor.reset()

print("listen",outputs.keys())
print(f"listen-> outputs:[{print_outputs(outputs)}]")
outputs['adapter_cache'] = None
outputs['encoder_cache'] = None
outputs['pe_index'] = 0
outputs['stat'] = 'ss'
print(f"speak get-> outputs:[{print_outputs(outputs)}]")

# Stage3: start speak
outputs = pipeline.speech_dialogue(None, **outputs)
print(f"ss-> outputs:[{print_outputs(outputs)}]")
cur_hidden_state = []
cur_hidden_state.append(outputs['hidden_state'])

Expand All @@ -155,6 +176,7 @@ def inference(pipeline, audio_processor, tts, configs):
del outputs['text']
del outputs['hidden_state']
outputs = pipeline.speech_dialogue(None, **outputs)
print(f"sc-> outputs:[{print_outputs(outputs)}]")
if outputs['stat'] == 'cs':
cur_hidden_state.append(outputs['hidden_state'])
whole_text += outputs['text'][len(last_text):]
Expand All @@ -168,12 +190,14 @@ def inference(pipeline, audio_processor, tts, configs):
decoder(cur_hidden_state, pipeline, cur_text, tts,
codec_chunk_size, codec_padding_size, decoder_topk, wav)
cur_hidden_state = []
print(f"cur_text:{cur_text}")
cur_text = ""
if outputs['stat'] == 'sl':
break
# print(outputs['text'])
#print(outputs['text'])
last_text = outputs['text']
if len(cur_hidden_state) != 0:
print(f"cur_text:{cur_text}")
decoder(cur_hidden_state, pipeline, cur_text, tts,
codec_chunk_size, codec_padding_size, decoder_topk, wav)

Expand All @@ -183,8 +207,15 @@ def inference(pipeline, audio_processor, tts, configs):
print(whole_text)

if __name__ == '__main__':
# change print function to add time stamp
original_print = builtins.print
builtins.print = custom_print

configs = get_args()
# encoder and audio llm
pipeline = inferencePipeline(configs)
# decoder
tts = llm2TTS(configs.model_path)
# stream chunk to encoder
audio_processor = audioEncoderProcessor()
inference(pipeline, audio_processor, tts, configs)
Loading