Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 54 additions & 39 deletions nodes/tools/installers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""
This module provides an advanced utility node for installing the Nunchaku Python wheel.
It dynamically fetches available versions from GitHub and Hugging Face, mirrors the data
for ModelScope, allows the user to select an installer backend (pip or uv), and
automatically finds the most compatible wheel. The installation status is displayed
directly on the node UI.
It dynamically fetches available versions from GitHub, Hugging Face, and ModelScope,
allows the user to select an installer backend (pip or uv), and automatically finds
the most compatible wheel. The installation status is displayed directly on the node UI.
"""

import copy # # NEW: Added for deepcopying release data.
import importlib.metadata
import json
import platform
Expand All @@ -20,9 +18,14 @@
from packaging.version import parse as parse_version

# --- Helper Functions ---
## CHANGED: Defined separate URLs for each source.

# CHANGE: Defined separate, direct API URLs for each source.
GITHUB_API_URL = "https://api.github.com/repos/nunchaku-tech/nunchaku"
HF_API_URL = "https://huggingface.co/api/models/mit-han-lab/nunchaku/tree/main"
# CHANGE: Added the direct ModelScope API URL, replacing the previous mirror logic.
MODEL_SCOPE_API_URL = (
"https://modelscope.cn/api/v1/models/nunchaku-tech/nunchaku/repo/files?Revision=master&PageSize=500"
)


def is_nunchaku_installed() -> bool:
Expand All @@ -35,18 +38,19 @@ def is_nunchaku_installed() -> bool:

def _get_json_from_url(url: str) -> List[Dict] | Dict:
try:
## CHANGED: Added a User-Agent header for API requests.
headers = {"User-Agent": "ComfyUI-Nunchaku-InstallerNode"}
req = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(req) as response:
# Added a timeout for network robustness.
with urllib.request.urlopen(req, timeout=10) as response:
if response.status == 200:
return json.loads(response.read())
print(f"Warning: Received status code {response.status} from {url}")
return []
except Exception:
except Exception as e:
print(f"Error fetching data from {url}: {e}")
return []


## NEW: Function dedicated to fetching from GitHub.
def get_nunchaku_releases_from_github() -> List[Dict]:
releases = _get_json_from_url(f"{GITHUB_API_URL}/releases")
if isinstance(releases, list):
Expand All @@ -56,10 +60,14 @@ def get_nunchaku_releases_from_github() -> List[Dict]:
return []


## NEW: Helper to parse wheel files from sources like Hugging Face.
# CHANGE: Uses a more robust regex to extract the version.
def _parse_wheels_from_file_list(file_list: List[Dict], source_name: str, path_key: str, url_prefix: str) -> List[Dict]:
releases = {}
wheel_regex = re.compile(r"nunchaku-(.+?)\+")

# This new regex `nunchaku-([^-+]+)` stops at the first `-` or `+`,
# making it more reliable for various wheel filename formats.
wheel_regex = re.compile(r"nunchaku-([^-+]+)")

for file_info in file_list:
filename = file_info.get(path_key)
if filename and filename.endswith(".whl"):
Expand All @@ -77,10 +85,10 @@ def _parse_wheels_from_file_list(file_list: List[Dict], source_name: str, path_k
releases[tag_name]["assets"].append(
{"name": filename, "browser_download_url": f"{url_prefix}{filename}"}
)

return list(releases.values())


## NEW: Function dedicated to fetching from Hugging Face.
def get_nunchaku_releases_from_huggingface() -> List[Dict]:
api_response = _get_json_from_url(HF_API_URL)
if not isinstance(api_response, list):
Expand All @@ -90,30 +98,49 @@ def get_nunchaku_releases_from_huggingface() -> List[Dict]:
)


## NEW: Major function to fetch from all sources at once and mirror HF to ModelScope.
# NEW: Function to fetch releases directly from the ModelScope API.
def get_nunchaku_releases_from_modelscope() -> List[Dict]:
api_response = _get_json_from_url(MODEL_SCOPE_API_URL)

# Navigate the specific ModelScope API response structure: Data -> Files
if isinstance(api_response, dict):
inner_data = api_response.get("Data", {})
file_list = inner_data.get("Files") if isinstance(inner_data, dict) else None
else:
file_list = None

if not isinstance(file_list, list):
return []

# Use the "Name" key (capitalized) to get the filename.
return _parse_wheels_from_file_list(
file_list,
"modelscope",
"Name",
"https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/",
)


def fetch_and_structure_all_releases() -> Dict[str, Dict[str, Dict]]:
structured_releases = {"github": {}, "huggingface": {}, "modelscope": {}}
source_map = {"github": get_nunchaku_releases_from_github, "huggingface": get_nunchaku_releases_from_huggingface}

source_map = {
"github": get_nunchaku_releases_from_github,
"huggingface": get_nunchaku_releases_from_huggingface,
"modelscope": get_nunchaku_releases_from_modelscope, # Calls the new ModelScope function.
}

for source_name, fetch_func in source_map.items():
for release in fetch_func():
if tag := release.get("tag_name"):
structured_releases[source_name][tag] = release

if structured_releases["huggingface"]:
structured_releases["modelscope"] = copy.deepcopy(structured_releases["huggingface"])
modelscope_prefix = "https://modelscope.cn/models/nunchaku-tech/nunchaku/resolve/master/"
for release_data in structured_releases["modelscope"].values():
release_data["source"] = "modelscope"
for asset in release_data.get("assets", []):
asset["browser_download_url"] = f"{modelscope_prefix}{asset['name']}"

if not any(structured_releases.values()):
return {"github": {"latest": {"tag_name": "latest"}}}

return structured_releases


## NEW: Separates official and dev versions for the UI.
def prepare_version_lists(structured_data: Dict[str, Dict[str, Dict]]) -> Tuple[List[str], List[str]]:
official_tags, dev_tags = set(), set()
for source_data in structured_data.values():
Expand Down Expand Up @@ -175,7 +202,6 @@ def find_compatible_wheel(assets: List[Dict], sys_info: Dict[str, str]) -> Optio


def install_wheel(wheel_url: str, backend: str) -> str:
## CHANGED: Corrected installation command logic with an explicit if/else.
if backend == "uv":
command = [sys.executable, "-m", "uv", "pip", "install", wheel_url]
else: # Default to pip
Expand All @@ -187,7 +213,7 @@ def install_wheel(wheel_url: str, backend: str) -> str:
)
output_log = []
for line in iter(process.stdout.readline, ""):
print(line, end="") # This prints the live output from the installer
print(line, end="")
output_log.append(line)
process.wait()
full_log = "".join(output_log)
Expand All @@ -202,7 +228,7 @@ def install_wheel(wheel_url: str, backend: str) -> str:

# --- ComfyUI Node Definition ---

## NEW: Pre-fetch all release data on startup to improve performance.
# Pre-fetch all release data on startup to improve performance.
ALL_RELEASES_DATA = fetch_and_structure_all_releases()
OFFICIAL_VERSIONS, DEV_VERSIONS = prepare_version_lists(ALL_RELEASES_DATA)
DEV_CHOICES = ["None"] + DEV_VERSIONS
Expand All @@ -222,7 +248,6 @@ def IS_CHANGED(cls, **kwargs):

@classmethod
def INPUT_TYPES(cls):
## CHANGED: Inputs now include a source selector and a separate dropdown for dev versions.
return {
"required": {
"source": (["github", "huggingface", "modelscope"], {}),
Expand All @@ -237,12 +262,10 @@ def INPUT_TYPES(cls):

def run(self, source: str, version: str, dev_version_github: str, backend: str):
try:
# MODIFICATION START: Automatic uninstall if nunchaku is detected
# CHANGE: Added automatic uninstallation of any pre-existing nunchaku version.
if is_nunchaku_installed():
print("An existing version of Nunchaku was detected. Attempting to uninstall automatically...")
# Command to uninstall without user confirmation (-y)
uninstall_command = [sys.executable, "-m", "pip", "uninstall", "nunchaku", "-y"]

process = subprocess.Popen(
uninstall_command,
stdout=subprocess.PIPE,
Expand All @@ -251,28 +274,22 @@ def run(self, source: str, version: str, dev_version_github: str, backend: str):
encoding="utf-8",
errors="replace",
)

# Capture and print output for logging
output_log = []
for line in iter(process.stdout.readline, ""):
print(line, end="")
output_log.append(line)
process.wait()

if process.returncode != 0:
full_log = "".join(output_log)
raise subprocess.CalledProcessError(process.returncode, uninstall_command, output=full_log)

# If uninstall is successful, inform the user and stop execution.
status_message = (
"✅ An existing version of Nunchaku was detected and uninstalled.\n\n"
"**Please restart ComfyUI completely.**\n\n"
"Then, run this node again to install the desired version."
)
return (status_message,)
# MODIFICATION END

## NEW: Logic to prioritize dev version selection.
if dev_version_github != "None":
final_version_tag = f"v{dev_version_github}"
source = "github"
Expand All @@ -285,15 +302,13 @@ def run(self, source: str, version: str, dev_version_github: str, backend: str):

source_versions = ALL_RELEASES_DATA.get(source, {})

## CHANGED: 'latest' is now resolved dynamically based on the selected source.
if final_version_tag == "latest":
official_tags = [v.lstrip("v") for v in source_versions.keys() if "dev" not in v]
if not official_tags:
raise RuntimeError(f"No official versions found on source '{source}'.")
final_version_tag = f"v{sorted(official_tags, key=parse_version, reverse=True)[0]}"

release_data = source_versions.get(final_version_tag)
## NEW: Better error handling if a version isn't on the selected source.
if not release_data:
available_on = [s for s, data in ALL_RELEASES_DATA.items() if final_version_tag in data]
msg = f"Version '{final_version_tag}' not available from '{source}'."
Expand Down
Loading