Skip to content

Commit 7c3a819

Browse files
committed
Added support for VIP models
1 parent 3c23878 commit 7c3a819

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

audio_separator/separator/architectures/mdxc_separator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def load_model(self):
6767

6868
try:
6969
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
70-
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
70+
self.model_run.load_state_dict(torch.load(self.model_path, map_location=self.torch_device))
7171
self.model_run.to(self.torch_device).eval()
7272
except RuntimeError as e:
7373
self.logger.error(f"Error: {e}")
@@ -121,7 +121,7 @@ def separate(self, audio_file_path):
121121

122122
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
123123
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
124-
124+
125125
if not isinstance(self.primary_source, np.ndarray):
126126
self.primary_source = source.T
127127

audio_separator/separator/separator.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def __init__(
132132
self.onnx_execution_provider = None
133133
self.model_instance = None
134134

135+
self.model_is_uvr_vip = False
136+
self.model_friendly_name = None
137+
135138
self.setup_accelerated_inferencing_device()
136139

137140
def setup_accelerated_inferencing_device(self):
@@ -347,35 +350,48 @@ def list_supported_model_files(self):
347350
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
348351
model_files_grouped_by_type = {
349352
"VR": model_downloads_list["vr_download_list"],
350-
"MDX": model_downloads_list["mdx_download_list"],
353+
"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]},
351354
"Demucs": filtered_demucs_v4,
352-
"MDXC": model_downloads_list["mdx23c_download_list"],
353-
# "MDX23": model_downloads_list["mdx23_download_list"],
355+
"MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"]},
354356
}
355357
return model_files_grouped_by_type
356358

359+
def print_uvr_vip_message(self):
360+
"""
361+
This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
362+
"""
363+
if self.model_is_uvr_vip:
364+
self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
365+
self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
366+
357367
def download_model_files(self, model_filename):
358368
"""
359369
This method downloads the model files for a given model filename, if they are not already present.
360370
"""
361371
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
362372

363373
supported_model_files_grouped = self.list_supported_model_files()
364-
model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
374+
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
375+
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
365376

366377
yaml_config_filename = None
367378

368379
self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
369380
for model_type, model_list in supported_model_files_grouped.items():
370381
for model_friendly_name, model_download_list in model_list.items():
382+
self.model_is_uvr_vip = "VIP" in model_friendly_name
383+
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
384+
371385
# If model_download_list is a string, this model only requires a single file so we can just download it
372386
if isinstance(model_download_list, str) and model_download_list == model_filename:
373387
self.logger.debug(f"Single file model identified: {model_friendly_name}")
388+
self.model_friendly_name = model_friendly_name
374389

375390
self.download_file_if_not_exists(f"{model_repo_url_prefix}/{model_filename}", model_path)
391+
self.print_uvr_vip_message()
376392

377393
self.logger.debug(f"Returning path for single model file: {model_path}")
378-
return model_type, model_friendly_name, model_path, yaml_config_filename
394+
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
379395

380396
# If it's a dict, iterate through each entry check if any of them match model_filename
381397
# If the value is a full URL, download it from that URL.
@@ -389,6 +405,8 @@ def download_model_files(self, model_filename):
389405

390406
if this_model_matches_input_filename:
391407
self.logger.debug(f"Multi-file model identified: {model_friendly_name}, iterating through files to download")
408+
self.model_friendly_name = model_friendly_name
409+
self.print_uvr_vip_message()
392410

393411
for config_key, config_value in model_download_list.items():
394412
self.logger.debug(f"Attempting to identify download URL for config pair: {config_key} -> {config_value}")
@@ -403,6 +421,14 @@ def download_model_files(self, model_filename):
403421
download_url = f"{model_repo_url_prefix}/{config_key}"
404422
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
405423

424+
# In case the user specified the YAML filename as the model input instead of the model filename, correct that
425+
if model_filename.endswith(".yaml"):
426+
self.logger.warning(f"The model name you've specified, {model_filename} is actually a model config file, not a model file itself.")
427+
self.logger.warning(f"We found a model matching this config file: {config_key} so we'll use that model file for this run.")
428+
self.logger.warning("To prevent confusing / inconsistent behaviour in future, specify an actual model filename instead.")
429+
model_filename = config_key
430+
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
431+
406432
# For MDXC models, the config_value is the YAML file which needs to be downloaded separately from the application_data repo
407433
yaml_config_filename = config_value
408434
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
@@ -419,7 +445,7 @@ def download_model_files(self, model_filename):
419445
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_value))
420446

421447
self.logger.debug(f"All files downloaded for model {model_friendly_name}, returning initial path {model_path}")
422-
return model_type, model_friendly_name, model_path, yaml_config_filename
448+
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
423449

424450
raise ValueError(f"Model file {model_filename} not found in supported model files")
425451

@@ -562,8 +588,8 @@ def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx"):
562588
load_model_start_time = time.perf_counter()
563589

564590
# Setting up the model path
591+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
565592
model_name = model_filename.split(".")[0]
566-
model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
567593
self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
568594

569595
if model_path.lower().endswith(".yaml"):
@@ -639,6 +665,9 @@ def separate(self, audio_file_path):
639665
# Unset more separation params to prevent accidentally re-using the wrong source files or output paths
640666
self.model_instance.clear_file_specific_paths()
641667

668+
# Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
669+
self.print_uvr_vip_message()
670+
642671
# Log the completion of the separation process
643672
self.logger.debug("Separation process completed.")
644673
self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "audio-separator"
7-
version = "0.16.1"
7+
version = "0.16.2"
88
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
99
authors = ["Andrew Beveridge <[email protected]>"]
1010
license = "MIT"

0 commit comments

Comments
 (0)