diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e22..91c36e1d 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -145,6 +145,8 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') + if torch.mps.is_available() is True: + load_jit = True self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6e10f00f..f8f77d3f 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -36,6 +36,14 @@ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation +def get_torch_device(): + if torch.backends.mps.is_available(): + return torch.device('mps') + elif torch.cuda.is_available(): + return torch.device('cuda') + else: + return torch.device('cpu') + class CosyVoiceFrontEnd: def __init__(self, @@ -47,10 +55,12 @@ def __init__(self, allowed_special: str = 'all'): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = get_torch_device() option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 + if self.device.type == "mps": + logging.warning("ONNXRuntime does not support MPS. ONNX models will run on CPU.") self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb0..be5f8153 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -22,6 +22,7 @@ import uuid from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt +from cosyvoice.cli.frontend import get_torch_device class CosyVoiceModel: @@ -31,7 +32,7 @@ def __init__(self, flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = get_torch_device() self.llm = llm self.flow = flow self.hift = hift @@ -57,7 +58,10 @@ def __init__(self, # rtf and decoding related self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + if torch.cuda.is_available(): + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) + else: + self.llm_context = nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -222,7 +226,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + torch.mps.empty_cache() def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread @@ -276,7 +280,13 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + self.empty_cache() + + def empty_cache(self): + if torch.mps.is_available(): + torch.mps.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() class CosyVoice2Model(CosyVoiceModel): @@ -286,7 +296,7 @@ def __init__(self, flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = torch.device('mps' if torch.mps.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift @@ -307,7 +317,10 @@ def __init__(self, self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.stream_scale_factor = 1 - self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + if torch.cuda.is_available(): + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) + else: + self.llm_context = nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -408,4 +421,4 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) - torch.cuda.empty_cache() + self.empty_cache()