@@ -132,6 +132,9 @@ def __init__(
132
132
self .onnx_execution_provider = None
133
133
self .model_instance = None
134
134
135
+ self .model_is_uvr_vip = False
136
+ self .model_friendly_name = None
137
+
135
138
self .setup_accelerated_inferencing_device ()
136
139
137
140
def setup_accelerated_inferencing_device (self ):
@@ -347,35 +350,48 @@ def list_supported_model_files(self):
347
350
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
348
351
model_files_grouped_by_type = {
349
352
"VR" : model_downloads_list ["vr_download_list" ],
350
- "MDX" : model_downloads_list ["mdx_download_list" ],
353
+ "MDX" : { ** model_downloads_list ["mdx_download_list" ], ** model_downloads_list [ "mdx_download_vip_list" ]} ,
351
354
"Demucs" : filtered_demucs_v4 ,
352
- "MDXC" : model_downloads_list ["mdx23c_download_list" ],
353
- # "MDX23": model_downloads_list["mdx23_download_list"],
355
+ "MDXC" : {** model_downloads_list ["mdx23c_download_list" ], ** model_downloads_list ["mdx23c_download_vip_list" ]},
354
356
}
355
357
return model_files_grouped_by_type
356
358
359
+ def print_uvr_vip_message (self ):
360
+ """
361
+ This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
362
+ """
363
+ if self .model_is_uvr_vip :
364
+ self .logger .warning (f"The model: '{ self .model_friendly_name } ' is a VIP model, intended by Anjok07 for access by paying subscribers only." )
365
+ self .logger .warning ("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr" )
366
+
357
367
def download_model_files (self , model_filename ):
358
368
"""
359
369
This method downloads the model files for a given model filename, if they are not already present.
360
370
"""
361
371
model_path = os .path .join (self .model_file_dir , f"{ model_filename } " )
362
372
363
373
supported_model_files_grouped = self .list_supported_model_files ()
364
- model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
374
+ public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
375
+ vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
365
376
366
377
yaml_config_filename = None
367
378
368
379
self .logger .debug (f"Searching for model_filename { model_filename } in supported_model_files_grouped" )
369
380
for model_type , model_list in supported_model_files_grouped .items ():
370
381
for model_friendly_name , model_download_list in model_list .items ():
382
+ self .model_is_uvr_vip = "VIP" in model_friendly_name
383
+ model_repo_url_prefix = vip_model_repo_url_prefix if self .model_is_uvr_vip else public_model_repo_url_prefix
384
+
371
385
# If model_download_list is a string, this model only requires a single file so we can just download it
372
386
if isinstance (model_download_list , str ) and model_download_list == model_filename :
373
387
self .logger .debug (f"Single file model identified: { model_friendly_name } " )
388
+ self .model_friendly_name = model_friendly_name
374
389
375
390
self .download_file_if_not_exists (f"{ model_repo_url_prefix } /{ model_filename } " , model_path )
391
+ self .print_uvr_vip_message ()
376
392
377
393
self .logger .debug (f"Returning path for single model file: { model_path } " )
378
- return model_type , model_friendly_name , model_path , yaml_config_filename
394
+ return model_filename , model_type , model_friendly_name , model_path , yaml_config_filename
379
395
380
396
# If it's a dict, iterate through each entry check if any of them match model_filename
381
397
# If the value is a full URL, download it from that URL.
@@ -389,6 +405,8 @@ def download_model_files(self, model_filename):
389
405
390
406
if this_model_matches_input_filename :
391
407
self .logger .debug (f"Multi-file model identified: { model_friendly_name } , iterating through files to download" )
408
+ self .model_friendly_name = model_friendly_name
409
+ self .print_uvr_vip_message ()
392
410
393
411
for config_key , config_value in model_download_list .items ():
394
412
self .logger .debug (f"Attempting to identify download URL for config pair: { config_key } -> { config_value } " )
@@ -403,6 +421,14 @@ def download_model_files(self, model_filename):
403
421
download_url = f"{ model_repo_url_prefix } /{ config_key } "
404
422
self .download_file_if_not_exists (download_url , os .path .join (self .model_file_dir , config_key ))
405
423
424
+ # In case the user specified the YAML filename as the model input instead of the model filename, correct that
425
+ if model_filename .endswith (".yaml" ):
426
+ self .logger .warning (f"The model name you've specified, { model_filename } is actually a model config file, not a model file itself." )
427
+ self .logger .warning (f"We found a model matching this config file: { config_key } so we'll use that model file for this run." )
428
+ self .logger .warning ("To prevent confusing / inconsistent behaviour in future, specify an actual model filename instead." )
429
+ model_filename = config_key
430
+ model_path = os .path .join (self .model_file_dir , f"{ model_filename } " )
431
+
406
432
# For MDXC models, the config_value is the YAML file which needs to be downloaded separately from the application_data repo
407
433
yaml_config_filename = config_value
408
434
yaml_config_filepath = os .path .join (self .model_file_dir , yaml_config_filename )
@@ -419,7 +445,7 @@ def download_model_files(self, model_filename):
419
445
self .download_file_if_not_exists (download_url , os .path .join (self .model_file_dir , config_value ))
420
446
421
447
self .logger .debug (f"All files downloaded for model { model_friendly_name } , returning initial path { model_path } " )
422
- return model_type , model_friendly_name , model_path , yaml_config_filename
448
+ return model_filename , model_type , model_friendly_name , model_path , yaml_config_filename
423
449
424
450
raise ValueError (f"Model file { model_filename } not found in supported model files" )
425
451
@@ -562,8 +588,8 @@ def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx"):
562
588
load_model_start_time = time .perf_counter ()
563
589
564
590
# Setting up the model path
591
+ model_filename , model_type , model_friendly_name , model_path , yaml_config_filename = self .download_model_files (model_filename )
565
592
model_name = model_filename .split ("." )[0 ]
566
- model_type , model_friendly_name , model_path , yaml_config_filename = self .download_model_files (model_filename )
567
593
self .logger .debug (f"Model downloaded, friendly name: { model_friendly_name } , model_path: { model_path } " )
568
594
569
595
if model_path .lower ().endswith (".yaml" ):
@@ -639,6 +665,9 @@ def separate(self, audio_file_path):
639
665
# Unset more separation params to prevent accidentally re-using the wrong source files or output paths
640
666
self .model_instance .clear_file_specific_paths ()
641
667
668
+ # Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
669
+ self .print_uvr_vip_message ()
670
+
642
671
# Log the completion of the separation process
643
672
self .logger .debug ("Separation process completed." )
644
673
self .logger .info (f'Separation duration: { time .strftime ("%H:%M:%S" , time .gmtime (int (time .perf_counter () - separate_start_time )))} ' )
0 commit comments