1717from comfy .utils import ProgressBar
1818from cached_path import cached_path
1919sys .path .append (Install .f5TTSPath )
20- from model import DiT # noqa E402
20+ from model import DiT , UNetT # noqa E402
2121from model .utils_infer import ( # noqa E402
2222 load_model ,
2323 preprocess_ref_audio_text ,
2828
2929class F5TTSCreate :
3030 voice_reg = re .compile (r"\{(\w+)\}" )
31+ model_types = ["F5" , "E2" ]
3132 tooltip_seed = "Seed. -1 = random"
3233
3334 def is_voice_name (self , word ):
@@ -54,7 +55,33 @@ def load_voice(ref_audio, ref_text):
5455 )
5556 return main_voice
5657
57- def load_model (self ):
58+ def load_model (self , model ):
59+ models = {
60+ "F5" : self .load_f5_model ,
61+ "E2" : self .load_e2_model ,
62+ }
63+ return models [model ]()
64+
65+ def get_vocab_file (self ):
66+ return os .path .join (
67+ Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
68+ )
69+
70+ def load_e2_model (self ):
71+ model_cls = UNetT
72+ model_cfg = dict (dim = 1024 , depth = 24 , heads = 16 , ff_mult = 4 )
73+ repo_name = "E2-TTS"
74+ exp_name = "E2TTS_Base"
75+ ckpt_step = 1200000
76+ ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
77+ vocab_file = self .get_vocab_file ()
78+ ema_model = load_model (
79+ model_cls , model_cfg ,
80+ ckpt_file , vocab_file
81+ )
82+ return ema_model
83+
84+ def load_f5_model (self ):
5885 model_cls = DiT
5986 model_cfg = dict (
6087 dim = 1024 , depth = 22 , heads = 16 ,
@@ -64,10 +91,11 @@ def load_model(self):
6491 exp_name = "F5TTS_Base"
6592 ckpt_step = 1200000
6693 ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
67- vocab_file = os .path .join (
68- Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
94+ vocab_file = self .get_vocab_file ()
95+ ema_model = load_model (
96+ model_cls , model_cfg ,
97+ ckpt_file , vocab_file
6998 )
70- ema_model = load_model (model_cls , model_cfg , ckpt_file , vocab_file )
7199 return ema_model
72100
73101 def generate_audio (self , voices , model_obj , chunks , seed ):
@@ -117,8 +145,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
117145 os .unlink (wave_file .name )
118146 return audio
119147
120- def create (self , voices , chunks , seed = - 1 ):
121- model_obj = self .load_model ()
148+ def create (self , voices , chunks , seed = - 1 , model = "F5" ):
149+ model_obj = self .load_model (model )
122150 return self .generate_audio (voices , model_obj , chunks , seed )
123151
124152
@@ -141,6 +169,7 @@ def INPUT_TYPES(s):
141169 "default" : 1 , "min" : - 1 ,
142170 "tooltip" : F5TTSCreate .tooltip_seed ,
143171 }),
172+ "model" : (F5TTSCreate .model_types ,),
144173 },
145174 }
146175
@@ -174,7 +203,7 @@ def remove_wave_file(self):
174203 print ("F5TTS: Cannot remove? " + self .wave_file .name )
175204 print (e )
176205
177- def create (self , sample_audio , sample_text , speech , seed = - 1 ):
206+ def create (self , sample_audio , sample_text , speech , seed = - 1 , model = "F5" ):
178207 try :
179208 main_voice = self .load_voice_from_input (sample_audio , sample_text )
180209
@@ -184,7 +213,7 @@ def create(self, sample_audio, sample_text, speech, seed=-1):
184213 chunks = f5ttsCreate .split_text (speech )
185214 voices ['main' ] = main_voice
186215
187- audio = f5ttsCreate .create (voices , chunks , seed )
216+ audio = f5ttsCreate .create (voices , chunks , seed , model )
188217 finally :
189218 self .remove_wave_file ()
190219 return (audio , )
@@ -233,6 +262,7 @@ def INPUT_TYPES(s):
233262 "default" : 1 , "min" : - 1 ,
234263 "tooltip" : F5TTSCreate .tooltip_seed ,
235264 }),
265+ "model" : (F5TTSCreate .model_types ,),
236266 }
237267 }
238268
@@ -289,7 +319,7 @@ def load_voices_from_files(self, sample, voice_names):
289319 voices [voice_name ] = self .load_voice_from_file (sample_file )
290320 return voices
291321
292- def create (self , sample , speech , seed = - 1 ):
322+ def create (self , sample , speech , seed = - 1 , model = "F5" ):
293323 # Install.check_install()
294324 main_voice = self .load_voice_from_file (sample )
295325
@@ -309,7 +339,7 @@ def create(self, sample, speech, seed=-1):
309339 voices = self .load_voices_from_files (sample , voice_names )
310340 voices ['main' ] = main_voice
311341
312- audio = f5ttsCreate .create (voices , chunks , seed )
342+ audio = f5ttsCreate .create (voices , chunks , seed , model )
313343 return (audio , )
314344
315345 @classmethod
0 commit comments