1717import time
1818import typing as tp
1919import warnings
20+ from tqdm import tqdm
2021from audiocraft .models import MusicGen
2122from audiocraft .data .audio import audio_write
2223from audiocraft .data .audio_utils import apply_fade , apply_tafade , apply_splice_effect
4849os .environ ['USE_FLASH_ATTENTION' ] = '1'
4950os .environ ['XFORMERS_FORCE_DISABLE_TRITON' ]= '1'
5051
52+
5153def interrupt_callback ():
5254 return INTERRUPTED
5355
@@ -162,7 +164,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model):
162164
163165 return gr .update (value = melody_name ), gr .update (maximum = MAX_PROMPT_INDEX , value = 0 ), gr .update (value = assigned_model , interactive = True )
164166
165- def predict (model , text , melody_filepath , duration , dimension , topk , topp , temperature , cfg_coef , background , title , settings_font , settings_font_color , seed , overlap = 1 , prompt_index = 0 , include_title = True , include_settings = True , harmony_only = False ):
167+ def predict (model , text , melody_filepath , duration , dimension , topk , topp , temperature , cfg_coef , background , title , settings_font , settings_font_color , seed , overlap = 1 , prompt_index = 0 , include_title = True , include_settings = True , harmony_only = False , profile = gr . OAuthProfile , progress = gr . Progress ( track_tqdm = True ) ):
166168 global MODEL , INTERRUPTED , INTERRUPTING , MOVE_TO_CPU
167169 output_segments = None
168170 melody_name = "Not Used"
@@ -228,14 +230,16 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
228230 cfg_coef = cfg_coef ,
229231 duration = segment_duration ,
230232 two_step_cfg = False ,
233+ extend_stride = 10 ,
231234 rep_penalty = 0.5
232235 )
236+ MODEL .set_custom_progress_callback (gr .Progress (track_tqdm = True ))
233237
234238 try :
235239 if melody :
236240 # return excess duration, load next model and continue in loop structure building up output_segments
237241 if duration > MODEL .lm .cfg .dataset .segment_duration :
238- output_segments , duration = generate_music_segments (text , melody , seed , MODEL , duration , overlap , MODEL .lm .cfg .dataset .segment_duration , prompt_index , harmony_only = False )
242+ output_segments , duration = generate_music_segments (text , melody , seed , MODEL , duration , overlap , MODEL .lm .cfg .dataset .segment_duration , prompt_index , harmony_only = False , progress = gr . Progress ( track_tqdm = True ) )
239243 else :
240244 # pure original code
241245 sr , melody = melody [0 ], torch .from_numpy (melody [1 ]).to (MODEL .device ).float ().t ().unsqueeze (0 )
@@ -247,20 +251,20 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
247251 descriptions = [text ],
248252 melody_wavs = melody ,
249253 melody_sample_rate = sr ,
250- progress = True
254+ progress = True , progress_callback = gr . Progress ( track_tqdm = True )
251255 )
252256 # All output_segments are populated, so we can break the loop or set duration to 0
253257 break
254258 else :
255259 #output = MODEL.generate(descriptions=[text], progress=False)
256260 if not output_segments :
257- next_segment = MODEL .generate (descriptions = [text ], progress = True )
261+ next_segment = MODEL .generate (descriptions = [text ], progress = True , progress_callback = gr . Progress ( track_tqdm = True ) )
258262 duration -= segment_duration
259263 else :
260264 last_chunk = output_segments [- 1 ][:, :, - overlap * MODEL .sample_rate :]
261- next_segment = MODEL .generate_continuation (last_chunk , MODEL .sample_rate , descriptions = [text ], progress = True )
265+ next_segment = MODEL .generate_continuation (last_chunk , MODEL .sample_rate , descriptions = [text ], progress = True , progress_callback = gr . Progress ( track_tqdm = True ) )
262266 duration -= segment_duration - overlap
263- if next_segment != None :
267+ if next_segment != None :
264268 output_segments .append (next_segment )
265269 except Exception as e :
266270 print (f"Error generating audio: { e } " )
@@ -312,7 +316,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
312316 return None , None , seed
313317 else :
314318 output = output .detach ().cpu ().float ()[0 ]
315- profile : gr . OAuthProfile | None = None
319+
316320 title_file_name = convert_title_to_filename (title )
317321 with NamedTemporaryFile ("wb" , suffix = ".wav" , delete = False , prefix = title_file_name ) as file :
318322 video_description = f"{ text } \n Duration: { str (initial_duration )} Dimension: { dimension } \n Top-k:{ topk } Top-p:{ topp } \n Randomness:{ temperature } \n cfg:{ cfg_coef } overlap: { overlap } \n Seed: { seed } \n Model: { model } \n Melody Condition:{ melody_name } \n Sample Segment: { prompt_index } "
@@ -357,7 +361,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
357361 "background" : background ,
358362 "include_title" : include_title ,
359363 "include_settings" : include_settings ,
360- "profile" : profile ,
364+ "profile" : "Satoshi Nakamoto" if profile . value is None else profile . value . username ,
361365 "commit" : commit_hash (),
362366 "tag" : git_tag (),
363367 "version" : gr .__version__ ,
@@ -396,11 +400,11 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
396400
397401 if waveform_video_path :
398402 modules .user_history .save_file (
399- profile = profile ,
403+ profile = profile . value ,
400404 image = background ,
401- audio = file ,
405+ audio = file . name ,
402406 video = waveform_video_path ,
403- label = text ,
407+ label = title ,
404408 metadata = metadata ,
405409 )
406410
@@ -413,9 +417,9 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
413417 torch .cuda .ipc_collect ()
414418 return waveform_video_path , file .name , seed
415419
416- gr .set_static_paths (paths = ["fonts/" ,"assets/" ])
420+ gr .set_static_paths (paths = ["fonts/" ,"assets/" , "images/" ])
417421def ui (** kwargs ):
418- with gr .Blocks (title = "UnlimitedMusicGen" ,css_paths = "style_20250331.css" , theme = 'Surn/beeuty' ) as interface :
422+ with gr .Blocks (title = "UnlimitedMusicGen" ,css_paths = "style_20250331.css" , theme = 'Surn/beeuty' ) as demo :
419423 with gr .Tab ("UnlimitedMusicGen" ):
420424 gr .Markdown (
421425 """
@@ -482,12 +486,12 @@ def ui(**kwargs):
482486 with gr .Column () as c :
483487 output = gr .Video (label = "Generated Music" )
484488 wave_file = gr .File (label = ".wav file" , elem_id = "output_wavefile" , interactive = True )
485- seed_used = gr .Number (label = 'Seed used' , value = - 1 , interactive = False )
489+ seed_used = gr .Number (label = 'Seed used' , value = - 1 , interactive = False )
486490
487491 radio .change (toggle_audio_src , radio , [melody_filepath ], queue = False , show_progress = False )
488492 melody_filepath .change (load_melody_filepath , inputs = [melody_filepath , title , model ], outputs = [title , prompt_index , model ], api_name = "melody_filepath_change" , queue = False )
489493 reuse_seed .click (fn = lambda x : x , inputs = [seed_used ], outputs = [seed ], queue = False , api_name = "reuse_seed" )
490- submit . click ( predict , inputs = [ model , text , melody_filepath , duration , dimension , topk , topp , temperature , cfg_coef , background , title , settings_font , settings_font_color , seed , overlap , prompt_index , include_title , include_settings , harmony_only ], outputs = [ output , wave_file , seed_used ], api_name = "submit" )
494+
491495 gr .Examples (
492496 examples = [
493497 [
@@ -524,9 +528,24 @@ def ui(**kwargs):
524528 inputs = [text , melody_filepath , model , title ],
525529 outputs = [output ]
526530 )
527- gr . HTML ( value = versions_html (), visible = True , elem_id = "versions" )
531+
528532 with gr .Tab ("User History" ) as history_tab :
529533 modules .user_history .render ()
534+ user_profile = gr .State (None )
535+
536+ with gr .Row ("Versions" ) as versions_row :
537+ gr .HTML (value = versions_html (), visible = True , elem_id = "versions" )
538+
539+ submit .click (
540+ modules .user_history .get_profile ,
541+ inputs = [],
542+ outputs = [user_profile ],
543+ queue = True ,
544+ api_name = "submit"
545+ ).then (
546+ predict ,
547+ inputs = [model , text ,melody_filepath , duration , dimension , topk , topp , temperature , cfg_coef , background , title , settings_font , settings_font_color , seed , overlap , prompt_index , include_title , include_settings , harmony_only , user_profile ],
548+ outputs = [output , wave_file , seed_used ])
530549
531550 # Show the interface
532551 launch_kwargs = {}
0 commit comments