diff --git a/cosyvoice/bin/average_model.py b/cosyvoice/bin/average_model.py index d095dcd99..81886b2da 100644 --- a/cosyvoice/bin/average_model.py +++ b/cosyvoice/bin/average_model.py @@ -18,7 +18,7 @@ import glob import yaml -import torch +from torch import load,device,true_divide,save def get_args(): @@ -73,7 +73,7 @@ def main(): assert num == len(path_list) for path in path_list: print('Processing {}'.format(path)) - states = torch.load(path, map_location=torch.device('cpu')) + states = load(path, map_location=device('cpu')) for k in states.keys(): if k not in avg.keys(): avg[k] = states[k].clone() @@ -83,9 +83,9 @@ def main(): for k in avg.keys(): if avg[k] is not None: # pytorch 1.6 use true_divide instead of /= - avg[k] = torch.true_divide(avg[k], num) + avg[k] = true_divide(avg[k], num) print('Saving to {}'.format(args.dst_model)) - torch.save(avg, args.dst_model) + save(avg, args.dst_model) if __name__ == '__main__': diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index ddd486e97..d78e936fb 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -19,7 +19,7 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) import os import sys -import torch +from torch import jit,_C ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) @@ -38,12 +38,12 @@ def get_args(): def get_optimized_script(model, preserved_attrs=[]): - script = torch.jit.script(model) + script = jit.script(model) if preserved_attrs != []: - script = torch.jit.freeze(script, preserved_attrs=preserved_attrs) + script = jit.freeze(script, preserved_attrs=preserved_attrs) else: - script = torch.jit.freeze(script) - script = torch.jit.optimize_for_inference(script) + script = jit.freeze(script) + script = jit.optimize_for_inference(script) return script @@ -52,9 +52,9 @@ def main(): logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') - torch._C._jit_set_fusion_strategy([('STATIC', 1)]) - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) + _C._jit_set_fusion_strategy([('STATIC', 1)]) + _C._jit_set_profiling_mode(False) + _C._jit_set_profiling_executor(False) try: model = CosyVoice(args.model_dir) diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index 9ddd35894..ab11f1afa 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -22,7 +22,7 @@ import sys import onnxruntime import random -import torch +from torch import rand,ones,onnx,float32,cuda,testing,from_numpy from tqdm import tqdm ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) @@ -31,12 +31,12 @@ def get_dummy_input(batch_size, seq_len, out_channels, device): - x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) - mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) - mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) - t = torch.rand((batch_size), dtype=torch.float32, device=device) - spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) - cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + x = rand((batch_size, out_channels, seq_len), dtype=float32, device=device) + mask = ones((batch_size, 1, seq_len), dtype=float32, device=device) + mu = rand((batch_size, out_channels, seq_len), dtype=float32, device=device) + t = rand((batch_size), dtype=float32, device=device) + spks = rand((batch_size, out_channels), dtype=float32, device=device) + cond = rand((batch_size, out_channels, seq_len), dtype=float32, device=device) return x, mask, mu, t, spks, cond @@ -71,7 +71,7 @@ def main(): batch_size, seq_len = 2, 256 out_channels = model.model.flow.decoder.estimator.out_channels x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - torch.onnx.export( + onnx.export( estimator, (x, mask, mu, t, spks, cond), '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), @@ -93,7 +93,7 @@ def main(): option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + providers = ['CUDAExecutionProvider' if cuda.is_available() else 'CPUExecutionProvider'] estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers) @@ -109,7 +109,7 @@ def main(): 'cond': cond.cpu().numpy() } output_onnx = estimator_onnx.run(None, ort_inputs)[0] - torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + testing.assert_allclose(output_pytorch, from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) if __name__ == "__main__": diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py index 2cb831a5f..445aca1e9 100644 --- a/cosyvoice/bin/inference.py +++ b/cosyvoice/bin/inference.py @@ -18,7 +18,8 @@ import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) import os -import torch +from torch import no_grad,concat +from torch.cuda import is_available as cuda_is_available from torch.utils.data import DataLoader import torchaudio from hyperpyyaml import load_hyperpyyaml @@ -57,8 +58,8 @@ def main(): os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Init cosyvoice models from configs - use_cuda = args.gpu >= 0 and torch.cuda.is_available() - device = torch.device('cuda' if use_cuda else 'cpu') + use_cuda = args.gpu >= 0 and cuda_is_available() + device = device('cuda' if use_cuda else 'cpu') with open(args.config, 'r') as f: configs = load_hyperpyyaml(f) @@ -73,7 +74,7 @@ def main(): os.makedirs(args.result_dir, exist_ok=True) fn = os.path.join(args.result_dir, 'wav.scp') f = open(fn, 'w') - with torch.no_grad(): + with no_grad(): for _, batch in tqdm(enumerate(test_data_loader)): utts = batch["utts"] assert len(utts) == 1, "inference mode only support batchsize 1" @@ -101,7 +102,7 @@ def main(): tts_speeches = [] for model_output in model.tts(**model_input): tts_speeches.append(model_output['tts_speech']) - tts_speeches = torch.concat(tts_speeches, dim=1) + tts_speeches = concat(tts_speeches, dim=1) tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e228..2fedaae16 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -17,7 +17,7 @@ from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download -import torch +from torch.cuda import is_available as cuda_is_available from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.utils.file_utils import logging @@ -42,7 +42,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): + if cuda_is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16) @@ -142,7 +142,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): + if cuda_is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6e10f00fe..ab9c26ad8 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -15,15 +15,16 @@ from typing import Generator import json import onnxruntime -import torch +from torch import tensor,load,device,int32 +from torch.cuda import is_available import numpy as np import whisper from typing import Callable -import torchaudio.compliance.kaldi as kaldi -import torchaudio +from torchaudio.compliance import kaldi +from torchaudio.transforms import Resample import os import re -import inflect +from inflect import engine try: import ttsfrd use_ttsfrd = True @@ -44,19 +45,21 @@ def __init__(self, campplus_model: str, speech_tokenizer_model: str, spk2info: str = '', - allowed_special: str = 'all'): + allowed_special: str = 'all', + refresh_fst_cache: bool = False): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device('cuda' if is_available() else 'cpu') option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, - providers=["CUDAExecutionProvider" if torch.cuda.is_available() else + providers=["CUDAExecutionProvider" if is_available() else "CPUExecutionProvider"]) + self.refresh_fst_cache = refresh_fst_cache if os.path.exists(spk2info): - self.spk2info = torch.load(spk2info, map_location=self.device) + self.spk2info = load(spk2info, map_location=self.device,weights_only=True) else: self.spk2info = {} self.allowed_special = allowed_special @@ -68,19 +71,19 @@ def __init__(self, 'failed to initialize ttsfrd resource' self.frd.set_lang_type('pinyinvg') else: - self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=self.refresh_fst_cache) self.en_tn_model = EnNormalizer() - self.inflect_parser = inflect.engine() + self.inflect_parser = engine() # from inflect def _extract_text_token(self, text): if isinstance(text, Generator): logging.info('get tts_text generator, will return _extract_text_token_generator!') # NOTE add a dummy text_token_len for compatibility - return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device) + return self._extract_text_token_generator(text), tensor([0], dtype=int32).to(self.device) else: text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) - text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) - text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) + text_token = tensor([text_token], dtype=int32).to(self.device) + text_token_len = tensor([text_token.shape[1]], dtype=int32).to(self.device) return text_token, text_token_len def _extract_text_token_generator(self, text_generator): @@ -97,8 +100,8 @@ def _extract_speech_token(self, speech): feat.detach().cpu().numpy(), self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() - speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) - speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + speech_token = tensor([speech_token], dtype=int32).to(self.device) + speech_token_len = tensor([speech_token.shape[1]], dtype=int32).to(self.device) return speech_token, speech_token_len def _extract_spk_embedding(self, speech): @@ -109,13 +112,13 @@ def _extract_spk_embedding(self, speech): feat = feat - feat.mean(dim=0, keepdim=True) embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - embedding = torch.tensor([embedding]).to(self.device) + embedding = tensor([embedding]).to(self.device) return embedding def _extract_speech_feat(self, speech): speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) speech_feat = speech_feat.unsqueeze(dim=0) - speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) + speech_feat_len = tensor([speech_feat.shape[1]], dtype=int32).to(self.device) return speech_feat, speech_feat_len def text_normalize(self, text, split=True, text_frontend=True): @@ -157,7 +160,7 @@ def frontend_sft(self, tts_text, spk_id): def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate): tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text) - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) + prompt_speech_resample = Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample) speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) if resample_rate == 24000: @@ -200,7 +203,7 @@ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resampl def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate): prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) + prompt_speech_resample = Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample) embedding = self._extract_spk_embedding(prompt_speech_16k) source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb0c..393deb027 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -13,7 +13,8 @@ # limitations under the License. import os from typing import Generator -import torch +from torch import device,cuda,concat,tensor,zeros,load,jit,int32 +from torch.nn import Module as nnModule import numpy as np import threading import time @@ -27,11 +28,11 @@ class CosyVoiceModel: def __init__(self, - llm: torch.nn.Module, - flow: torch.nn.Module, - hift: torch.nn.Module, + llm: nnModule, + flow: nnModule, + hift: nnModule, fp16: bool): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device('cuda' if cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift @@ -57,7 +58,7 @@ def __init__(self, # rtf and decoding related self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.llm_context = cuda.stream(cuda.Stream(self.device)) if cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -67,25 +68,25 @@ def __init__(self, self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): - self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) + self.llm.load_state_dict(load(llm_model, map_location=self.device,weights_only=True), strict=True) self.llm.to(self.device).eval() - self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.load_state_dict(load(flow_model, map_location=self.device,weights_only=True), strict=True) self.flow.to(self.device).eval() # in case hift_model is a hifigan model - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + hift_state_dict = {k.replace('generator.', ''): v for k, v in load(hift_model, map_location=self.device,weights_only=True).items()} self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.to(self.device).eval() def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): - llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) + llm_text_encoder = jit.load(llm_text_encoder_model, map_location=self.device) self.llm.text_encoder = llm_text_encoder - llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) + llm_llm = jit.load(llm_llm_model, map_location=self.device) self.llm.llm = llm_llm - flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + flow_encoder = jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + assert cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model): convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) if os.path.getsize(flow_decoder_estimator_model) == 0: @@ -104,29 +105,29 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' for i in self.llm.inference_bistream(text=text, prompt_text=prompt_text.to(self.device), - prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_text_len=tensor([prompt_text.shape[1]], dtype=int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token_len=tensor([llm_prompt_speech_token.shape[1]], dtype=int32).to(self.device), embedding=llm_embedding.to(self.device)): self.tts_speech_token_dict[uuid].append(i) else: for i in self.llm.inference(text=text.to(self.device), - text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + text_len=tensor([text.shape[1]], dtype=int32).to(self.device), prompt_text=prompt_text.to(self.device), - prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_text_len=tensor([prompt_text.shape[1]], dtype=int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token_len=tensor([llm_prompt_speech_token.shape[1]], dtype=int32).to(self.device), embedding=llm_embedding.to(self.device)): self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + token_len=tensor([token.shape[1]], dtype=int32).to(self.device), prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token_len=tensor([prompt_token.shape[1]], dtype=int32).to(self.device), prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat_len=tensor([prompt_feat.shape[1]], dtype=int32).to(self.device), embedding=embedding.to(self.device), flow_cache=self.flow_cache_dict[uuid]) self.flow_cache_dict[uuid] = flow_cache @@ -137,9 +138,9 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize= # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + tts_mel = concat([hift_cache_mel, tts_mel], dim=2) else: - hift_cache_source = torch.zeros(1, 1, 0) + hift_cache_source = zeros(1, 1, 0) # keep overlap mel and hift cache if finalize is False: self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] @@ -160,18 +161,18 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize= tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech - def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), - prompt_text=torch.zeros(1, 0, dtype=torch.int32), - llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + def tts(self, text, flow_embedding, llm_embedding=zeros(0, 192), + prompt_text=zeros(1, 0, dtype=int32), + llm_prompt_speech_token=zeros(1, 0, dtype=int32), + flow_prompt_speech_token=zeros(1, 0, dtype=int32), + prompt_speech_feat=zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) - self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) + self.mel_overlap_dict[this_uuid] = zeros(1, 80, 0) + self.flow_cache_dict[this_uuid] = zeros(1, 80, 0, 2) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() if stream is True: @@ -179,7 +180,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), while True: time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ .unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, @@ -196,7 +197,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -207,7 +208,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), else: # deal with all tokens p.join() - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -222,7 +223,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + cuda.empty_cache() def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread @@ -230,13 +231,13 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True self.hift_cache_dict[this_uuid] = None - self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) - self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) + self.mel_overlap_dict[this_uuid] = zeros(1, 80, 0) + self.flow_cache_dict[this_uuid] = zeros(1, 80, 0, 2) if stream is True: token_hop_len = self.token_min_hop_len while True: if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ .unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, @@ -252,7 +253,7 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: break # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -262,7 +263,7 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -276,17 +277,17 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + cuda.empty_cache() class CosyVoice2Model(CosyVoiceModel): def __init__(self, - llm: torch.nn.Module, - flow: torch.nn.Module, - hift: torch.nn.Module, + llm: nnModule, + flow: nnModule, + hift: nnModule, fp16: bool): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device('cuda' if cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift @@ -307,7 +308,7 @@ def __init__(self, self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.stream_scale_factor = 1 - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.llm_context = cuda.stream(cuda.Stream(self.device)) if cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -315,25 +316,25 @@ def __init__(self, self.hift_cache_dict = {} def load_jit(self, flow_encoder_model): - flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + flow_encoder = jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + token_len=tensor([token.shape[1]], dtype=int32).to(self.device), prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token_len=tensor([prompt_token.shape[1]], dtype=int32).to(self.device), prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat_len=tensor([prompt_feat.shape[1]], dtype=int32).to(self.device), embedding=embedding.to(self.device), finalize=finalize) tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + tts_mel = concat([hift_cache_mel, tts_mel], dim=2) else: - hift_cache_source = torch.zeros(1, 1, 0) + hift_cache_source = zeros(1, 1, 0) # keep overlap mel and hift cache if finalize is False: tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) @@ -352,11 +353,11 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_off tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech - def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), - prompt_text=torch.zeros(1, 0, dtype=torch.int32), - llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + def tts(self, text, flow_embedding, llm_embedding=zeros(0, 192), + prompt_text=zeros(1, 0, dtype=int32), + llm_prompt_speech_token=zeros(1, 0, dtype=int32), + flow_prompt_speech_token=zeros(1, 0, dtype=int32), + prompt_speech_feat=zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: @@ -369,7 +370,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), while True: time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -383,7 +384,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -395,7 +396,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), else: # deal with all tokens p.join() - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech_token = tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -408,4 +409,4 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) - torch.cuda.empty_cache() + cuda.empty_cache() diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index bbd3305da..65b2c4fb3 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Callable, List, Generator -import torch -from torch import nn +from torch import nn,device,Tensor,concat,tril,tensor,ones,zeros,int32,inference_mode +from torch import bool as torch_bool import torch.nn.functional as F from transformers import Qwen2ForCausalLM from torch.nn.utils.rnn import pad_sequence, unpad_sequence @@ -23,7 +23,7 @@ from cosyvoice.utils.file_utils import logging -class TransformerLM(torch.nn.Module): +class TransformerLM(nn.Module): def __init__( self, text_encoder_input_size: int, @@ -31,8 +31,8 @@ def __init__( llm_output_size: int, text_token_size: int, speech_token_size: int, - text_encoder: torch.nn.Module, - llm: torch.nn.Module, + text_encoder: nn.Module, + llm: nn.Module, sampling: Callable, length_normalized_loss: bool = True, lsm_weight: float = 0.0, @@ -42,7 +42,7 @@ def __init__( self.llm_input_size = llm_input_size self.speech_token_size = speech_token_size # 1. build text token inputs related modules - self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + self.text_embedding = nn.Embedding(text_token_size, text_encoder_input_size) self.text_encoder = text_encoder self.text_encoder_affine_layer = nn.Linear( self.text_encoder.output_size(), @@ -52,7 +52,7 @@ def __init__( # 2. build speech token language model related modules self.sos_eos = 0 self.task_id = 1 - self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm_embedding = nn.Embedding(2, llm_input_size) self.llm = llm self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) self.criterion_ce = LabelSmoothingLoss( @@ -63,16 +63,16 @@ def __init__( ) # 3. [Optional] build speech token related modules - self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) - self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) + self.speech_embedding = nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = nn.Linear(spk_embed_dim, llm_input_size) # 4. sampling method self.sampling = sampling def encode( self, - text: torch.Tensor, - text_lengths: torch.Tensor, + text: Tensor, + text_lengths: Tensor, ): encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) @@ -82,17 +82,17 @@ def encode( def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) - lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) + lm_input = [concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))] - lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input_len = tensor([i.size(0) for i in lm_input], dtype=int32) lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) return lm_input, lm_input_len def forward( self, batch: dict, - device: torch.device, - ) -> Dict[str, Optional[torch.Tensor]]: + device: device, + ) -> Dict[str, Optional[Tensor]]: """ Args: text: (B, L, D) @@ -107,7 +107,7 @@ def forward( embedding = batch['embedding'].to(device) # 1. prepare llm_target - lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + lm_target = [tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))] lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) @@ -140,7 +140,7 @@ def forward( def sampling_ids( self, - weighted_scores: torch.Tensor, + weighted_scores: Tensor, decoded_tokens: List, sampling: int, ignore_eos: bool = True, @@ -155,25 +155,25 @@ def sampling_ids( raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) return top_ids - @torch.inference_mode() + @inference_mode() def inference( self, - text: torch.Tensor, - text_len: torch.Tensor, - prompt_text: torch.Tensor, - prompt_text_len: torch.Tensor, - prompt_speech_token: torch.Tensor, - prompt_speech_token_len: torch.Tensor, - embedding: torch.Tensor, + text: Tensor, + text_len: Tensor, + prompt_text: Tensor, + prompt_text_len: Tensor, + prompt_speech_token: Tensor, + prompt_speech_token_len: Tensor, + embedding: Tensor, sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, - ) -> Generator[torch.Tensor, None, None]: + ) -> Generator[Tensor, None, None]: if self.fp16 is True: embedding = embedding.half() device = text.device - text = torch.concat([prompt_text, text], dim=1) + text = concat([prompt_text, text], dim=1) text_len += prompt_text_len text = self.text_embedding(text) @@ -186,7 +186,7 @@ def inference( embedding = self.spk_embed_affine_layer(embedding) embedding = embedding.unsqueeze(dim=1) else: - embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) + embedding = zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) # 3. concat llm_input sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) @@ -194,8 +194,8 @@ def inference( if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: - prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + prompt_speech_token_emb = zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) @@ -204,12 +204,12 @@ def inference( # 5. step by step decode out_tokens = [] offset = 0 - att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) + att_cache, cnn_cache = zeros((0, 0, 0, 0), device=lm_input.device), zeros((0, 0, 0, 0), device=lm_input.device) for i in range(max_len): y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, - att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), - device=lm_input.device)).to(torch.bool)) + att_mask=tril(ones((1, lm_input.shape[1], lm_input.shape[1]), + device=lm_input.device)).to(torch_bool)) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) # force continue decode first token if i == 0: @@ -224,7 +224,7 @@ def inference( lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) -class Qwen2Encoder(torch.nn.Module): +class Qwen2Encoder(nn.Module): def __init__(self, pretrain_path): super().__init__() self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) @@ -250,13 +250,13 @@ def __init__( llm_input_size: int, llm_output_size: int, speech_token_size: int, - llm: torch.nn.Module, + llm: nn.Module, sampling: Callable, length_normalized_loss: bool = True, lsm_weight: float = 0.0, mix_ratio: List[int] = [5, 15], ): - torch.nn.Module.__init__(self) + nn.Module.__init__(self) self.llm_input_size = llm_input_size self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size @@ -266,7 +266,7 @@ def __init__( self.task_id = 1 self.fill_token = 2 - self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm_embedding = nn.Embedding(2, llm_input_size) self.llm = llm self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3) self.criterion_ce = LabelSmoothingLoss( @@ -277,28 +277,28 @@ def __init__( ) # 3. [Optional] build speech token related modules - self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size) + self.speech_embedding = nn.Embedding(speech_token_size + 3, llm_input_size) # 4. sampling method self.sampling = sampling self.mix_ratio = mix_ratio - @torch.inference_mode() + @inference_mode() def inference( self, - text: torch.Tensor, - text_len: torch.Tensor, - prompt_text: torch.Tensor, - prompt_text_len: torch.Tensor, - prompt_speech_token: torch.Tensor, - prompt_speech_token_len: torch.Tensor, - embedding: torch.Tensor, + text: Tensor, + text_len: Tensor, + prompt_text: Tensor, + prompt_text_len: Tensor, + prompt_speech_token: Tensor, + prompt_speech_token_len: Tensor, + embedding: Tensor, sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, - ) -> Generator[torch.Tensor, None, None]: + ) -> Generator[Tensor, None, None]: device = text.device - text = torch.concat([prompt_text, text], dim=1) + text = concat([prompt_text, text], dim=1) text_len += prompt_text_len text = self.llm.model.model.embed_tokens(text) @@ -308,8 +308,8 @@ def inference( if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: - prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + prompt_speech_token_emb = zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) @@ -320,7 +320,7 @@ def inference( cache = None for i in range(max_len): y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + masks=tril(ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch_bool), cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() @@ -333,19 +333,19 @@ def inference( out_tokens.append(top_ids) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) - @torch.inference_mode() + @inference_mode() def inference_bistream( self, text: Generator, - prompt_text: torch.Tensor, - prompt_text_len: torch.Tensor, - prompt_speech_token: torch.Tensor, - prompt_speech_token_len: torch.Tensor, - embedding: torch.Tensor, + prompt_text: Tensor, + prompt_text_len: Tensor, + prompt_speech_token: Tensor, + prompt_speech_token_len: Tensor, + embedding: Tensor, sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, - ) -> Generator[torch.Tensor, None, None]: + ) -> Generator[Tensor, None, None]: device = prompt_text.device # 1. prepare input @@ -354,8 +354,8 @@ def inference_bistream( if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: - prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb], dim=1) + prompt_speech_token_emb = zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = concat([sos_eos_emb], dim=1) # 2. iterate text out_tokens = [] @@ -364,13 +364,13 @@ def inference_bistream( text_cache = self.llm.model.model.embed_tokens(prompt_text) next_fill_index = -1 for this_text in text: - text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + text_cache = concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) # prompt_speech_token_emb not empty, try append to lm_input while prompt_speech_token_emb.size(1) != 0: if text_cache.size(1) >= self.mix_ratio[0]: lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) - lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + lm_input = concat([lm_input, lm_input_text, lm_input_speech], dim=1) text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] else: logging.info('not enough text token to decode, wait for more') @@ -385,7 +385,7 @@ def inference_bistream( if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: lm_input = lm_input_text else: - lm_input = torch.concat([lm_input, lm_input_text], dim=1) + lm_input = concat([lm_input, lm_input_text], dim=1) text_cache = text_cache[:, self.mix_ratio[0]:] else: logging.info('not enough text token to decode, wait for more') @@ -393,7 +393,7 @@ def inference_bistream( while True: seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + masks=tril(ones((1, seq_len, seq_len), device=lm_input.device)).to(torch_bool), cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) if next_fill_index != -1 and len(out_tokens) == next_fill_index: @@ -414,12 +414,12 @@ def inference_bistream( lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) # 3. final decode - lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + lm_input = concat([lm_input, text_cache, task_id_emb], dim=1) logging.info('no more text token, decode until met eos') while True: seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + masks=tril(ones((1, seq_len, seq_len), device=lm_input.device)).to(torch_bool), cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py index c49de00c8..2abdd777d 100644 --- a/cosyvoice/utils/class_utils.py +++ b/cosyvoice/utils/class_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from torch import nn from cosyvoice.transformer.activation import Swish from cosyvoice.transformer.subsampling import ( @@ -39,12 +39,12 @@ COSYVOICE_ACTIVATION_CLASSES = { - "hardtanh": torch.nn.Hardtanh, - "tanh": torch.nn.Tanh, - "relu": torch.nn.ReLU, - "selu": torch.nn.SELU, - "swish": getattr(torch.nn, "SiLU", Swish), - "gelu": torch.nn.GELU, + "hardtanh": nn.Hardtanh, + "tanh": nn.Tanh, + "relu": nn.ReLU, + "selu": nn.SELU, + "swish": getattr(nn, "SiLU", Swish), + "gelu": nn.GELU, } COSYVOICE_SUBSAMPLE_CLASSES = { @@ -55,7 +55,7 @@ "conv2d": Conv2dSubsampling4, "conv2d6": Conv2dSubsampling6, "conv2d8": Conv2dSubsampling8, - 'paraformer_dummy': torch.nn.Identity + 'paraformer_dummy': nn.Identity } COSYVOICE_EMB_CLASSES = {