1616import io
1717from comfy .utils import ProgressBar
1818from cached_path import cached_path
19- sys .path .append (Install .f5TTSPath )
20- from model import DiT ,UNetT # noqa E402
21- from model .utils_infer import ( # noqa E402
19+ sys .path .append (os . path . join ( Install .f5TTSPath , "src" ) )
20+ from f5_tts . model import DiT ,UNetT # noqa E402
21+ from f5_tts . infer .utils_infer import ( # noqa E402
2222 load_model ,
23+ load_vocoder ,
2324 preprocess_ref_audio_text ,
2425 infer_process ,
2526)
2829
2930class F5TTSCreate :
3031 voice_reg = re .compile (r"\{([^\}]+)\}" )
31- model_types = ["F5" , "E2" ]
32+ model_types = ["F5" , "F5-JP" , "F5-FR" , "E2" ]
33+ vocoder_types = ["vocos" , "bigvgan" ]
3234 tooltip_seed = "Seed. -1 = random"
3335
36+ def get_model_types ():
37+ model_types = F5TTSCreate .model_types [:]
38+ models_path = folder_paths .get_folder_paths ("checkpoints" )
39+ for model_path in models_path :
40+ f5_model_path = os .path .join (model_path , 'F5-TTS' )
41+ if os .path .isdir (f5_model_path ):
42+ for file in os .listdir (f5_model_path ):
43+ p = Path (file )
44+ if (
45+ p .suffix in folder_paths .supported_pt_extensions
46+ and os .path .isfile (os .path .join (f5_model_path , file ))
47+ ):
48+ txtFile = F5TTSCreate .get_txt_file_path (
49+ os .path .join (f5_model_path , file )
50+ )
51+
52+ if (
53+ os .path .isfile (txtFile )
54+ ):
55+ model_types .append ("model://" + file )
56+ return model_types
57+
58+ @staticmethod
59+ def get_txt_file_path (file ):
60+ p = Path (file )
61+ return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
62+
3463 def is_voice_name (self , word ):
3564 return self .voice_reg .match (word .strip ())
3665
@@ -55,50 +84,118 @@ def load_voice(ref_audio, ref_text):
5584 )
5685 return main_voice
5786
58- def load_model (self , model ):
59- models = {
87+ def get_model_funcs (self ):
88+ return {
6089 "F5" : self .load_f5_model ,
90+ "F5-JP" : self .load_f5_model_jp ,
91+ "F5-FR" : self .load_f5_model_fr ,
6192 "E2" : self .load_e2_model ,
6293 }
63- return models [model ]()
94+
95+ def get_vocoder (self , vocoder_name ):
96+ if vocoder_name == "vocos" :
97+ os .path .join (Install .f5TTSPath , "checkpoints/vocos-mel-24khz" )
98+ elif vocoder_name == "bigvgan" :
99+ os .path .join (Install .f5TTSPath , "checkpoints/bigvgan_v2_24khz_100band_256x" ) # noqa E501
100+
101+ def load_vocoder (self , vocoder_name ):
102+ return load_vocoder (vocoder_name = vocoder_name )
103+
104+ def load_model (self , model , vocoder_name ):
105+ model_funcs = self .get_model_funcs ()
106+ if model in model_funcs :
107+ return model_funcs [model ](vocoder_name )
108+ else :
109+ return self .load_f5_model_url (model , vocoder_name )
64110
65111 def get_vocab_file (self ):
66112 return os .path .join (
67113 Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
68114 )
69115
70- def load_e2_model (self ):
116+ def load_e2_model (self , vocoder ):
71117 model_cls = UNetT
72118 model_cfg = dict (dim = 1024 , depth = 24 , heads = 16 , ff_mult = 4 )
73119 repo_name = "E2-TTS"
74120 exp_name = "E2TTS_Base"
75121 ckpt_step = 1200000
76122 ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
77123 vocab_file = self .get_vocab_file ()
124+ vocoder_name = "vocos"
78125 ema_model = load_model (
79126 model_cls , model_cfg ,
80- ckpt_file , vocab_file
127+ ckpt_file , vocab_file = vocab_file ,
128+ mel_spec_type = vocoder_name ,
129+ )
130+ vocoder = self .load_vocoder (vocoder_name )
131+ return (ema_model , vocoder , vocoder_name )
132+
133+ def load_f5_model (self , vocoder ):
134+ repo_name = "F5-TTS"
135+ if vocoder == "bigvgan" :
136+ exp_name = "F5TTS_Base_bigvgan"
137+ ckpt_step = 1250000
138+ else :
139+ exp_name = "F5TTS_Base"
140+ ckpt_step = 1200000
141+ return self .load_f5_model_url (
142+ f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" , # noqa E501
143+ vocoder ,
144+ )
145+
146+ def load_f5_model_jp (self , vocoder ):
147+ return self .load_f5_model_url (
148+ "hf://Jmica/F5TTS/JA_8500000/model_8499660.pt" ,
149+ vocoder ,
150+ "hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt"
81151 )
82- return ema_model
83152
84- def load_f5_model (self ):
153+ def load_f5_model_fr (self , vocoder ):
154+ return self .load_f5_model_url (
155+ "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_1374000.pt" , # noqa E501
156+ vocoder ,
157+ "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt" # noqa E501
158+ )
159+
160+ def cached_path (self , url ):
161+ if url .startswith ("model:" ):
162+ path = re .sub ("^model:/*" , "" , url )
163+ models_path = folder_paths .get_folder_paths ("checkpoints" )
164+ for model_path in models_path :
165+ f5_model_path = os .path .join (model_path , 'F5-TTS' )
166+ model_file = os .path .join (f5_model_path , path )
167+ if os .path .isfile (model_file ):
168+ return model_file
169+ raise FileNotFoundError ("No model found: " + url )
170+ return None
171+ return str (cached_path (url )) # noqa E501
172+
173+ def load_f5_model_url (self , url , vocoder_name , vocab_url = None ):
174+ vocoder = self .load_vocoder (vocoder_name )
85175 model_cls = DiT
86176 model_cfg = dict (
87177 dim = 1024 , depth = 22 , heads = 16 ,
88178 ff_mult = 2 , text_dim = 512 , conv_layers = 4
89179 )
90- repo_name = "F5-TTS"
91- exp_name = "F5TTS_Base"
92- ckpt_step = 1200000
93- ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
94- vocab_file = self .get_vocab_file ()
180+ ckpt_file = str (self .cached_path (url )) # noqa E501
181+
182+ if vocab_url is None :
183+ if url .startswith ("model:" ):
184+ vocab_file = F5TTSCreate .get_txt_file_path (ckpt_file )
185+ else :
186+ vocab_file = self .get_vocab_file ()
187+ else :
188+ vocab_file = str (self .cached_path (vocab_url ))
95189 ema_model = load_model (
96190 model_cls , model_cfg ,
97- ckpt_file , vocab_file
191+ ckpt_file , vocab_file = vocab_file ,
192+ mel_spec_type = vocoder_name ,
98193 )
99- return ema_model
194+ return ( ema_model , vocoder , vocoder_name )
100195
101- def generate_audio (self , voices , model_obj , chunks , seed ):
196+ def generate_audio (
197+ self , voices , model_obj , chunks , seed , vocoder , mel_spec_type
198+ ):
102199 if seed >= 0 :
103200 torch .manual_seed (seed )
104201 else :
@@ -127,7 +224,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
127224 print (f"Voice: { voice } " )
128225 print ("text:" + text )
129226 audio , final_sample_rate , spectragram = infer_process (
130- ref_audio , ref_text , gen_text , model_obj
227+ ref_audio , ref_text , gen_text , model_obj ,
228+ vocoder = vocoder , mel_spec_type = mel_spec_type
131229 )
132230 generated_audio_segments .append (audio )
133231 frame_rate = final_sample_rate
@@ -147,9 +245,20 @@ def generate_audio(self, voices, model_obj, chunks, seed):
147245 os .unlink (wave_file .name )
148246 return audio
149247
150- def create (self , voices , chunks , seed = - 1 , model = "F5" ):
151- model_obj = self .load_model (model )
152- return self .generate_audio (voices , model_obj , chunks , seed )
248+ def create (
249+ self , voices , chunks , seed = - 1 , model = "F5" , vocoder_name = "vocos"
250+ ):
251+ (
252+ model_obj ,
253+ vocoder ,
254+ mel_spec_type
255+ ) = self .load_model (model , vocoder_name )
256+ return self .generate_audio (
257+ voices ,
258+ model_obj ,
259+ chunks , seed ,
260+ vocoder , mel_spec_type = mel_spec_type ,
261+ )
153262
154263
155264class F5TTSAudioInputs :
@@ -158,6 +267,7 @@ def __init__(self):
158267
159268 @classmethod
160269 def INPUT_TYPES (s ):
270+ model_types = F5TTSCreate .get_model_types ()
161271 return {
162272 "required" : {
163273 "sample_audio" : ("AUDIO" ,),
@@ -171,7 +281,8 @@ def INPUT_TYPES(s):
171281 "default" : 1 , "min" : - 1 ,
172282 "tooltip" : F5TTSCreate .tooltip_seed ,
173283 }),
174- "model" : (F5TTSCreate .model_types ,),
284+ "model" : (model_types ,),
285+ # "vocoder": (F5TTSCreate.vocoder_types,),
175286 },
176287 }
177288
@@ -213,7 +324,10 @@ def remove_wave_file(self):
213324 print ("F5TTS: Cannot remove? " + self .wave_file_name )
214325 print (e )
215326
216- def create (self , sample_audio , sample_text , speech , seed = - 1 , model = "F5" ):
327+ def create (
328+ self , sample_audio , sample_text , speech , seed = - 1 , model = "F5"
329+ ):
330+ vocoder = "vocos"
217331 try :
218332 main_voice = self .load_voice_from_input (sample_audio , sample_text )
219333
@@ -223,7 +337,9 @@ def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
223337 chunks = f5ttsCreate .split_text (speech )
224338 voices ['main' ] = main_voice
225339
226- audio = f5ttsCreate .create (voices , chunks , seed , model )
340+ audio = f5ttsCreate .create (
341+ voices , chunks , seed , model , vocoder
342+ )
227343 finally :
228344 self .remove_wave_file ()
229345 return (audio , )
@@ -243,11 +359,6 @@ class F5TTSAudio:
243359 def __init__ (self ):
244360 self .use_cli = False
245361
246- @staticmethod
247- def get_txt_file_path (file ):
248- p = Path (file )
249- return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
250-
251362 @classmethod
252363 def INPUT_TYPES (s ):
253364 input_dir = folder_paths .get_input_directory ()
@@ -256,11 +367,13 @@ def INPUT_TYPES(s):
256367 )
257368 filesWithTxt = []
258369 for file in files :
259- txtFile = F5TTSAudio .get_txt_file_path (file )
370+ txtFile = F5TTSCreate .get_txt_file_path (file )
260371 if os .path .isfile (os .path .join (input_dir , txtFile )):
261372 filesWithTxt .append (file )
262373 filesWithTxt = sorted (filesWithTxt )
263374
375+ model_types = F5TTSCreate .get_model_types ()
376+
264377 return {
265378 "required" : {
266379 "sample" : (filesWithTxt , {"audio_upload" : True }),
@@ -273,7 +386,8 @@ def INPUT_TYPES(s):
273386 "default" : 1 , "min" : - 1 ,
274387 "tooltip" : F5TTSCreate .tooltip_seed ,
275388 }),
276- "model" : (F5TTSCreate .model_types ,),
389+ "model" : (model_types ,),
390+ # "vocoder": (F5TTSCreate.vocoder_types,),
277391 }
278392 }
279393
@@ -304,12 +418,14 @@ def load_voice_from_file(self, sample):
304418 input_dir = folder_paths .get_input_directory ()
305419 txt_file = os .path .join (
306420 input_dir ,
307- F5TTSAudio .get_txt_file_path (sample )
421+ F5TTSCreate .get_txt_file_path (sample )
308422 )
309423 audio_text = ''
310- with open (txt_file , 'r' ) as file :
424+ with open (txt_file , 'r' , encoding = 'utf-8' ) as file :
311425 audio_text = file .read ()
312426 audio_path = folder_paths .get_annotated_filepath (sample )
427+ print ("audio_text" )
428+ print (audio_text )
313429 return F5TTSCreate .load_voice (audio_path , audio_text )
314430
315431 def load_voices_from_files (self , sample , voice_names ):
@@ -330,7 +446,8 @@ def load_voices_from_files(self, sample, voice_names):
330446 voices [voice_name ] = self .load_voice_from_file (sample_file )
331447 return voices
332448
333- def create (self , sample , speech , seed = - 1 , model = "F5" ):
449+ def create (self , sample , speech , seed = - 2 , model = "F5" ):
450+ vocoder = "vocos"
334451 # Install.check_install()
335452 main_voice = self .load_voice_from_file (sample )
336453
@@ -350,14 +467,14 @@ def create(self, sample, speech, seed=-1, model="F5"):
350467 voices = self .load_voices_from_files (sample , voice_names )
351468 voices ['main' ] = main_voice
352469
353- audio = f5ttsCreate .create (voices , chunks , seed , model )
470+ audio = f5ttsCreate .create (voices , chunks , seed , model , vocoder )
354471 return (audio , )
355472
356473 @classmethod
357474 def IS_CHANGED (s , sample , speech , seed , model ):
358475 m = hashlib .sha256 ()
359476 audio_path = folder_paths .get_annotated_filepath (sample )
360- audio_txt_path = F5TTSAudio .get_txt_file_path (audio_path )
477+ audio_txt_path = F5TTSCreate .get_txt_file_path (audio_path )
361478 last_modified_timestamp = os .path .getmtime (audio_path )
362479 txt_last_modified_timestamp = os .path .getmtime (audio_txt_path )
363480 m .update (audio_path )
0 commit comments