1212import sys
1313import numpy as np
1414import re
15+ import io
1516from comfy .utils import ProgressBar
1617from cached_path import cached_path
1718sys .path .append (Install .f5TTSPath )
2425sys .path .pop ()
2526
2627
27- class F5TTSAudio :
28-
29- def __init__ (self ):
30- self .use_cli = False
31- self .voice_reg = re .compile (r"\{(\w+)\}" )
28+ class F5TTSCreate :
29+ voice_reg = re .compile (r"\{(\w+)\}" )
3230
33- @staticmethod
34- def get_txt_file_path (file ):
35- p = Path (file )
36- return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
31+ def is_voice_name (self , word ):
32+ return self .voice_reg .match (word .strip ())
3733
38- @classmethod
39- def INPUT_TYPES (s ):
40- input_dir = folder_paths .get_input_directory ()
41- files = folder_paths .filter_files_content_types (
42- os .listdir (input_dir ), ["audio" , "video" ]
43- )
44- filesWithTxt = []
45- for file in files :
46- txtFile = F5TTSAudio .get_txt_file_path (file )
47- if os .path .isfile (os .path .join (input_dir , txtFile )):
48- filesWithTxt .append (file )
49- return {
50- "required" : {
51- "sample" : (sorted (filesWithTxt ), {"audio_upload" : True }),
52- "speech" : ("STRING" , {
53- "multiline" : True ,
54- "default" : "Hello World"
55- }),
56- }
57- }
34+ def get_voice_names (self , chunks ):
35+ voice_names = {}
36+ for text in chunks :
37+ match = self .is_voice_name (text )
38+ if match :
39+ voice_names [match [1 ]] = True
40+ return voice_names
5841
59- CATEGORY = "audio"
42+ def split_text (self , speech ):
43+ reg1 = r"(?=\{\w+\})"
44+ return re .split (reg1 , speech )
6045
61- RETURN_TYPES = ("AUDIO" , )
62- FUNCTION = "create"
46+ @staticmethod
47+ def load_voice (ref_audio , ref_text ):
48+ main_voice = {"ref_audio" : ref_audio , "ref_text" : ref_text }
6349
64- def create_with_cli (self , audio_path , audio_text , speech , output_dir ):
65- subprocess .run (
66- [
67- "python" , "inference-cli.py" , "--model" , "F5-TTS" ,
68- "--ref_audio" , audio_path , "--ref_text" , audio_text ,
69- "--gen_text" , speech ,
70- "--output_dir" , output_dir
71- ],
72- cwd = Install .f5TTSPath
50+ main_voice ["ref_audio" ], main_voice ["ref_text" ] = preprocess_ref_audio_text ( # noqa E501
51+ ref_audio , ref_text
7352 )
74- output_audio = os .path .join (output_dir , "out.wav" )
75- with wave .open (output_audio , "rb" ) as wave_file :
76- frame_rate = wave_file .getframerate ()
77-
78- waveform , sample_rate = torchaudio .load (output_audio )
79- audio = {"waveform" : waveform .unsqueeze (0 ), "sample_rate" : frame_rate }
80- return audio
53+ return main_voice
8154
8255 def load_model (self ):
8356 model_cls = DiT
@@ -95,29 +68,6 @@ def load_model(self):
9568 ema_model = load_model (model_cls , model_cfg , ckpt_file , vocab_file )
9669 return ema_model
9770
98- def load_voice (self , ref_audio , ref_text ):
99- main_voice = {"ref_audio" : ref_audio , "ref_text" : ref_text }
100-
101- main_voice ["ref_audio" ], main_voice ["ref_text" ] = preprocess_ref_audio_text ( # noqa E501
102- ref_audio , ref_text
103- )
104- return main_voice
105-
106- def is_voice_name (self , word ):
107- return self .voice_reg .match (word .strip ())
108-
109- def get_voice_names (self , chunks ):
110- voice_names = {}
111- for text in chunks :
112- match = self .is_voice_name (text )
113- if match :
114- voice_names [match [1 ]] = True
115- return voice_names
116-
117- def split_text (self , speech ):
118- reg1 = r"(?=\{\w+\})"
119- return re .split (reg1 , speech )
120-
12171 def generate_audio (self , voices , model_obj , chunks ):
12272 frame_rate = 44100
12373 generated_audio_segments = []
@@ -133,7 +83,7 @@ def generate_audio(self, voices, model_obj, chunks):
13383 if voice not in voices :
13484 print (f"Voice { voice } not found, using main." )
13585 voice = "main"
136- text = self .voice_reg .sub ("" , text )
86+ text = F5TTSCreate .voice_reg .sub ("" , text )
13787 gen_text = text .strip ()
13888 ref_audio = voices [voice ]["ref_audio" ]
13989 ref_text = voices [voice ]["ref_text" ]
@@ -160,6 +110,137 @@ def generate_audio(self, voices, model_obj, chunks):
160110 os .unlink (wave_file .name )
161111 return audio
162112
113+ def create (self , voices , chunks ):
114+ model_obj = self .load_model ()
115+ return self .generate_audio (voices , model_obj , chunks )
116+
117+
118+ class F5TTSAudioInputs :
119+ def __init__ (self ):
120+ self .wave_file = None
121+
122+ @classmethod
123+ def INPUT_TYPES (s ):
124+ return {
125+ "required" : {
126+ "sample_audio" : ("AUDIO" ,),
127+ "sample_text" : ("STRING" , {"default" : "Text of sample_audio" }),
128+ "speech" : ("STRING" , {
129+ "multiline" : True ,
130+ "default" : "This is what I want to say"
131+ }),
132+ },
133+ }
134+
135+ CATEGORY = "audio"
136+
137+ RETURN_TYPES = ("AUDIO" , )
138+ FUNCTION = "create"
139+
140+ def load_voice_from_input (self , sample_audio , sample_text ):
141+ self .wave_file = tempfile .NamedTemporaryFile (
142+ suffix = ".wav" , delete = False
143+ )
144+ for (batch_number , waveform ) in enumerate (
145+ sample_audio ["waveform" ].cpu ()):
146+ buff = io .BytesIO ()
147+ torchaudio .save (
148+ buff , waveform , sample_audio ["sample_rate" ], format = "WAV"
149+ )
150+ with open (self .wave_file .name , 'wb' ) as f :
151+ f .write (buff .getbuffer ())
152+ break
153+ r = F5TTSCreate .load_voice (self .wave_file .name , sample_text )
154+ return r
155+
156+ def remove_wave_file (self ):
157+ if self .wave_file is not None :
158+ try :
159+ os .unlink (self .wave_file .name )
160+ self .wave_file = None
161+ except Exception as e :
162+ print ("F5TTS: Cannot remove? " + self .wave_file .name )
163+ print (e )
164+
165+ def create (self , sample_audio , sample_text , speech ):
166+ try :
167+ main_voice = self .load_voice_from_input (sample_audio , sample_text )
168+
169+ f5ttsCreate = F5TTSCreate ()
170+
171+ voices = {}
172+ chunks = f5ttsCreate .split_text (speech )
173+ voices ['main' ] = main_voice
174+
175+ audio = f5ttsCreate .create (voices , chunks )
176+ finally :
177+ self .remove_wave_file ()
178+ return (audio , )
179+
180+ @classmethod
181+ def IS_CHANGED (s , sample_audio , sample_text , speech ):
182+ m = hashlib .sha256 ()
183+ m .update (sample_text )
184+ m .update (sample_audio )
185+ m .update (speech )
186+ return m .digest ().hex ()
187+
188+
189+ class F5TTSAudio :
190+ def __init__ (self ):
191+ self .use_cli = False
192+
193+ @staticmethod
194+ def get_txt_file_path (file ):
195+ p = Path (file )
196+ return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
197+
198+ @classmethod
199+ def INPUT_TYPES (s ):
200+ input_dir = folder_paths .get_input_directory ()
201+ files = folder_paths .filter_files_content_types (
202+ os .listdir (input_dir ), ["audio" , "video" ]
203+ )
204+ filesWithTxt = []
205+ for file in files :
206+ txtFile = F5TTSAudio .get_txt_file_path (file )
207+ if os .path .isfile (os .path .join (input_dir , txtFile )):
208+ filesWithTxt .append (file )
209+ filesWithTxt = sorted (filesWithTxt )
210+
211+ return {
212+ "required" : {
213+ "sample" : (filesWithTxt , {"audio_upload" : True }),
214+ "speech" : ("STRING" , {
215+ "multiline" : True ,
216+ "default" : "This is what I want to say"
217+ }),
218+ }
219+ }
220+
221+ CATEGORY = "audio"
222+
223+ RETURN_TYPES = ("AUDIO" , )
224+ FUNCTION = "create"
225+
226+ def create_with_cli (self , audio_path , audio_text , speech , output_dir ):
227+ subprocess .run (
228+ [
229+ "python" , "inference-cli.py" , "--model" , "F5-TTS" ,
230+ "--ref_audio" , audio_path , "--ref_text" , audio_text ,
231+ "--gen_text" , speech ,
232+ "--output_dir" , output_dir
233+ ],
234+ cwd = Install .f5TTSPath
235+ )
236+ output_audio = os .path .join (output_dir , "out.wav" )
237+ with wave .open (output_audio , "rb" ) as wave_file :
238+ frame_rate = wave_file .getframerate ()
239+
240+ waveform , sample_rate = torchaudio .load (output_audio )
241+ audio = {"waveform" : waveform .unsqueeze (0 ), "sample_rate" : frame_rate }
242+ return audio
243+
163244 def load_voice_from_file (self , sample ):
164245 input_dir = folder_paths .get_input_directory ()
165246 txt_file = os .path .join (
@@ -170,7 +251,7 @@ def load_voice_from_file(self, sample):
170251 with open (txt_file , 'r' ) as file :
171252 audio_text = file .read ()
172253 audio_path = folder_paths .get_annotated_filepath (sample )
173- return self .load_voice (audio_path , audio_text )
254+ return F5TTSCreate .load_voice (audio_path , audio_text )
174255
175256 def load_voices_from_files (self , sample , voice_names ):
176257 voices = {}
@@ -194,6 +275,7 @@ def create(self, sample, speech):
194275 # Install.check_install()
195276 main_voice = self .load_voice_from_file (sample )
196277
278+ f5ttsCreate = F5TTSCreate ()
197279 if self .use_cli :
198280 # working...
199281 output_dir = tempfile .mkdtemp ()
@@ -204,21 +286,23 @@ def create(self, sample, speech):
204286 )
205287 shutil .rmtree (output_dir )
206288 else :
207- model_obj = self .load_model ()
208- chunks = self .split_text (speech )
209- voice_names = self .get_voice_names (chunks )
289+ chunks = f5ttsCreate .split_text (speech )
290+ voice_names = f5ttsCreate .get_voice_names (chunks )
210291 voices = self .load_voices_from_files (sample , voice_names )
211292 voices ['main' ] = main_voice
212293
213- audio = self . generate_audio (voices , model_obj , chunks )
294+ audio = f5ttsCreate . create (voices , chunks )
214295 return (audio , )
215296
216297 @classmethod
217298 def IS_CHANGED (s , sample , speech ):
218299 m = hashlib .sha256 ()
219300 audio_path = folder_paths .get_annotated_filepath (sample )
301+ audio_txt_path = F5TTSAudio .get_txt_file_path (audio_path )
220302 last_modified_timestamp = os .path .getmtime (audio_path )
303+ txt_last_modified_timestamp = os .path .getmtime (audio_txt_path )
221304 m .update (audio_path )
222305 m .update (str (last_modified_timestamp ))
306+ m .update (str (txt_last_modified_timestamp ))
223307 m .update (speech )
224308 return m .digest ().hex ()
0 commit comments