diff --git a/src/constants.ts b/src/constants.ts index 4a998ec69..98fdbd9f1 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -181,6 +181,14 @@ export enum TorchMirrorUrl { NightlyCpu = 'https://download.pytorch.org/whl/nightly/cpu', } +export type TorchUpdatePolicy = 'auto' | 'defer' | 'pinned'; + +export type TorchPinnedPackages = { + torch?: string; + torchaudio?: string; + torchvision?: string; +}; + /** Legacy NVIDIA torch mirror used by older installs (CUDA 12.9). */ export const LEGACY_NVIDIA_TORCH_MIRROR = 'https://download.pytorch.org/whl/cu129'; @@ -216,6 +224,10 @@ export const NVIDIA_TORCH_PACKAGES: string[] = [ `torchaudio==${NVIDIA_TORCH_VERSION}`, `torchvision==${NVIDIA_TORCHVISION_VERSION}`, ]; +/** Minimum NVIDIA driver version recommended for the current CUDA torch build. */ +export const NVIDIA_DRIVER_MIN_VERSION = '580'; +/** Recommended NVIDIA torch package set key (torch/torchaudio/torchvision). */ +export const NVIDIA_TORCH_RECOMMENDED_VERSION = `${NVIDIA_TORCH_VERSION}|${NVIDIA_TORCHVISION_VERSION}`; /** The log files used by the desktop process. */ export enum LogFile { diff --git a/src/install/installationManager.ts b/src/install/installationManager.ts index 1f951d751..b7a8737f2 100644 --- a/src/install/installationManager.ts +++ b/src/install/installationManager.ts @@ -5,7 +5,17 @@ import { promisify } from 'node:util'; import { strictIpcMain as ipcMain } from '@/infrastructure/ipcChannels'; -import { IPC_CHANNELS, InstallStage, ProgressStatus } from '../constants'; +import { useComfySettings } from '../config/comfySettings'; +import { + IPC_CHANNELS, + InstallStage, + NVIDIA_DRIVER_MIN_VERSION, + NVIDIA_TORCHVISION_VERSION, + NVIDIA_TORCH_RECOMMENDED_VERSION, + NVIDIA_TORCH_VERSION, + ProgressStatus, + TorchMirrorUrl, +} from '../constants'; import { PythonImportVerificationError } from '../infrastructure/pythonImportVerificationError'; import { useAppState } from '../main-process/appState'; import type { AppWindow } from '../main-process/appWindow'; @@ -23,7 +33,8 @@ import { InstallWizard } from './installWizard'; import { Troubleshooting } from './troubleshooting'; const execAsync = promisify(exec); -const NVIDIA_DRIVER_MIN_VERSION = '580'; +const TORCH_MIRROR_CUDA_PATH = new URL(TorchMirrorUrl.Cuda).pathname; +const TORCH_MIRROR_NIGHTLY_CUDA_PATH = new URL(TorchMirrorUrl.NightlyCuda).pathname; /** * Extracts the NVIDIA driver version from `nvidia-smi` output. @@ -91,7 +102,7 @@ export class InstallationManager implements HasTelemetry { // Convert from old format if (state === 'upgraded') installation.upgradeConfig(); - // Install updated manager requirements + // Install updated requirements if (installation.needsRequirementsUpdate) await this.updatePackages(installation); // Resolve issues and re-run validation @@ -382,7 +393,7 @@ export class InstallationManager implements HasTelemetry { await installation.virtualEnvironment.installComfyUIRequirements(callbacks); await installation.virtualEnvironment.installComfyUIManagerRequirements(callbacks); await this.warnIfNvidiaDriverTooOld(installation); - await installation.virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks); + await this.maybeUpdateNvidiaTorch(installation, callbacks); await installation.validate(); } catch (error) { log.error('Error auto-updating packages:', error); @@ -390,6 +401,172 @@ export class InstallationManager implements HasTelemetry { } } + private async maybeUpdateNvidiaTorch(installation: ComfyInstallation, callbacks: ProcessCallbacks): Promise { + const virtualEnvironment = installation.virtualEnvironment; + if (virtualEnvironment.selectedDevice !== 'nvidia') return; + + const config = useDesktopConfig(); + const updatePolicy = config.get('torchUpdatePolicy'); + const recommendedVersion = NVIDIA_TORCH_RECOMMENDED_VERSION; + const lastPromptedVersion = config.get('torchLastPromptedVersion'); + if (updatePolicy === 'pinned' && lastPromptedVersion === recommendedVersion) { + log.info('Skipping NVIDIA PyTorch update because updates are pinned for this version.'); + return; + } + + const installedVersions = await virtualEnvironment.getInstalledTorchPackageVersions(); + if (!installedVersions) { + log.warn('Skipping NVIDIA PyTorch update because installed versions could not be read.'); + return; + } + + const isOutOfDate = await virtualEnvironment.isNvidiaTorchOutOfDate(installedVersions); + if (!isOutOfDate) return; + + if (config.get('torchOutOfDateRecommendedVersion') !== recommendedVersion) { + config.set('torchOutOfDateRecommendedVersion', recommendedVersion); + config.set('torchOutOfDatePackages', installedVersions); + } + + if (updatePolicy === 'defer' && lastPromptedVersion === recommendedVersion) { + log.info('Skipping NVIDIA PyTorch update because updates are deferred for this version.'); + return; + } + + const updateApproved = updatePolicy === 'auto' && lastPromptedVersion === recommendedVersion; + let shouldAttemptUpdate = updateApproved; + + if (!updateApproved) { + const currentTorch = installedVersions.torch ?? 'unknown'; + const currentTorchaudio = installedVersions.torchaudio ?? 'unknown'; + const currentTorchvision = installedVersions.torchvision ?? 'unknown'; + + const { response } = await this.appWindow.showMessageBox({ + type: 'question', + title: 'Update PyTorch?', + message: + 'Your NVIDIA PyTorch build is out of date. We can update it to the recommended build for improved performance. This update may affect memory usage and compatibility with some custom nodes.', + detail: [ + `Current: torch ${currentTorch}, torchaudio ${currentTorchaudio}, torchvision ${currentTorchvision}`, + `Recommended: torch ${NVIDIA_TORCH_VERSION}, torchaudio ${NVIDIA_TORCH_VERSION}, torchvision ${NVIDIA_TORCHVISION_VERSION}`, + ].join('\n'), + buttons: ['Update PyTorch', 'Ask again later', 'Silence until next version', 'Silence forever'], + defaultId: 0, + cancelId: 1, + }); + + switch (response) { + case 1: + log.info('Deferring NVIDIA PyTorch update prompt.'); + return; + case 2: + config.set('torchLastPromptedVersion', recommendedVersion); + config.set('torchUpdatePolicy', 'defer'); + config.delete('torchPinnedPackages'); + virtualEnvironment.updateTorchUpdatePolicy('defer', undefined, recommendedVersion); + return; + case 3: + config.set('torchLastPromptedVersion', recommendedVersion); + config.set('torchUpdatePolicy', 'pinned'); + config.set('torchPinnedPackages', installedVersions); + virtualEnvironment.updateTorchUpdatePolicy('pinned', installedVersions, recommendedVersion); + return; + default: + config.set('torchLastPromptedVersion', recommendedVersion); + config.set('torchUpdatePolicy', 'auto'); + config.delete('torchPinnedPackages'); + config.delete('torchUpdateFailureSilencedVersion'); + virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion); + shouldAttemptUpdate = true; + } + } else { + virtualEnvironment.updateTorchUpdatePolicy('auto', undefined, recommendedVersion); + } + + if (!shouldAttemptUpdate) return; + + const torchMirrorOverride = await this.updateTorchMirrorForRecommendedVersion(); + + try { + await virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks, torchMirrorOverride); + config.delete('torchUpdateFailureSilencedVersion'); + } catch (error) { + log.error('Error updating NVIDIA PyTorch packages:', error); + if (config.get('torchUpdateFailureSilencedVersion') === recommendedVersion) return; + + const { response } = await this.appWindow.showMessageBox({ + type: 'warning', + title: 'PyTorch update failed', + message: + 'We could not install the recommended NVIDIA PyTorch build. This may be because your configured torch mirror does not provide it.', + detail: 'We will retry the update on each startup.', + buttons: ['OK', "Don't show again"], + defaultId: 0, + cancelId: 0, + }); + + if (response === 1) { + config.set('torchUpdateFailureSilencedVersion', recommendedVersion); + } + } + } + + private async updateTorchMirrorForRecommendedVersion(): Promise { + let settings; + try { + settings = useComfySettings(); + } catch (error) { + log.warn('Unable to access Comfy settings to update torch mirror.', error); + return undefined; + } + + const currentMirror = settings.get('Comfy-Desktop.UV.TorchInstallMirror'); + const updatedMirror = this.getRecommendedTorchMirror(currentMirror); + if (!updatedMirror || updatedMirror === currentMirror) return updatedMirror ?? currentMirror; + + settings.set('Comfy-Desktop.UV.TorchInstallMirror', updatedMirror); + try { + await settings.saveSettings(); + } catch (error) { + log.warn('Failed to persist torch mirror update.', error); + } + + return updatedMirror; + } + + private getRecommendedTorchMirror(mirror: string | undefined): string | undefined { + const defaultTorchMirror = String(TorchMirrorUrl.Default); + if (!mirror?.trim() || mirror === defaultTorchMirror) return TorchMirrorUrl.Cuda; + + let parsedMirror: URL; + try { + parsedMirror = new URL(mirror); + } catch (error) { + log.warn('Unable to parse torch mirror URL for normalization.', error); + return mirror; + } + + const path = parsedMirror.pathname; + if (!path.includes('/whl/')) return mirror; + + let updatedPath = path; + const nightlyCudaPattern = /\/whl\/nightly\/cu\d+/i; + const cudaPattern = /\/whl\/cu\d+/i; + + if (nightlyCudaPattern.test(updatedPath)) { + updatedPath = updatedPath.replace(nightlyCudaPattern, TORCH_MIRROR_NIGHTLY_CUDA_PATH); + } else if (cudaPattern.test(updatedPath)) { + updatedPath = updatedPath.replace(cudaPattern, TORCH_MIRROR_CUDA_PATH); + } else { + return mirror; + } + + if (updatedPath === path) return mirror; + parsedMirror.pathname = updatedPath; + + return parsedMirror.toString(); + } + /** * Warns the user if their NVIDIA driver is too old for the required CUDA build. * @param installation The current installation. diff --git a/src/main-process/comfyInstallation.ts b/src/main-process/comfyInstallation.ts index d2e5e9627..27e09ef2c 100644 --- a/src/main-process/comfyInstallation.ts +++ b/src/main-process/comfyInstallation.ts @@ -72,6 +72,9 @@ export class ComfyInstallation { pythonMirror: useComfySettings().get('Comfy-Desktop.UV.PythonInstallMirror'), pypiMirror: useComfySettings().get('Comfy-Desktop.UV.PypiInstallMirror'), torchMirror: useComfySettings().get('Comfy-Desktop.UV.TorchInstallMirror'), + torchUpdatePolicy: useDesktopConfig().get('torchUpdatePolicy'), + torchPinnedPackages: useDesktopConfig().get('torchPinnedPackages'), + torchUpdateDecisionVersion: useDesktopConfig().get('torchLastPromptedVersion'), }); } diff --git a/src/store/desktopSettings.ts b/src/store/desktopSettings.ts index 6010253f5..0b44be570 100644 --- a/src/store/desktopSettings.ts +++ b/src/store/desktopSettings.ts @@ -1,3 +1,4 @@ +import type { TorchPinnedPackages, TorchUpdatePolicy } from '../constants'; import type { GpuType, TorchDeviceType } from '../preload'; export type DesktopInstallState = 'started' | 'installed' | 'upgraded'; @@ -35,4 +36,16 @@ export type DesktopSettings = { versionConsentedMetrics?: string; /** Whether the user has generated an image successfully. */ hasGeneratedSuccessfully?: boolean; + /** How to handle NVIDIA PyTorch updates. */ + torchUpdatePolicy?: TorchUpdatePolicy; + /** The pinned NVIDIA torch package versions when updates are disabled. */ + torchPinnedPackages?: TorchPinnedPackages; + /** The recommended NVIDIA torch version tied to the current update decision. */ + torchLastPromptedVersion?: string; + /** The recommended NVIDIA torch version whose update failure prompt is suppressed. */ + torchUpdateFailureSilencedVersion?: string; + /** The recommended NVIDIA torch version recorded when we first detected an out-of-date torch install. */ + torchOutOfDateRecommendedVersion?: string; + /** The torch package versions recorded when we first detected an out-of-date torch install. */ + torchOutOfDatePackages?: TorchPinnedPackages; }; diff --git a/src/virtualEnvironment.ts b/src/virtualEnvironment.ts index c626cc04c..c2cd2b6c0 100644 --- a/src/virtualEnvironment.ts +++ b/src/virtualEnvironment.ts @@ -14,9 +14,12 @@ import { LEGACY_NVIDIA_TORCH_MIRROR, NVIDIA_TORCHVISION_VERSION, NVIDIA_TORCH_PACKAGES, + NVIDIA_TORCH_RECOMMENDED_VERSION, NVIDIA_TORCH_VERSION, PYPI_FALLBACK_INDEX_URLS, TorchMirrorUrl, + TorchPinnedPackages, + TorchUpdatePolicy, } from './constants'; import { PythonImportVerificationError } from './infrastructure/pythonImportVerificationError'; import { useAppState } from './main-process/appState'; @@ -92,13 +95,7 @@ export function getPipInstallArgs(config: PipInstallConfig): string[] { return installArgs; } -/** - * Returns the default torch mirror for the given device. - * @param device The device type - * @returns The default torch mirror - */ -function getDefaultTorchMirror(device: TorchDeviceType): string { - log.debug('Falling back to default torch mirror'); +function getDeviceDefaultTorchMirror(device: TorchDeviceType): string { switch (device) { case 'mps': return TorchMirrorUrl.NightlyCpu; @@ -143,6 +140,9 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { readonly pythonMirror?: string; readonly pypiMirror?: string; readonly torchMirror?: string; + torchUpdatePolicy?: TorchUpdatePolicy; + torchPinnedPackages?: TorchPinnedPackages; + torchUpdateDecisionVersion?: string; uvPty: pty.IPty | undefined; /** The environment variables to set for uv. */ @@ -196,6 +196,9 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { pythonMirror, pypiMirror, torchMirror, + torchUpdatePolicy, + torchPinnedPackages, + torchUpdateDecisionVersion, }: { telemetry: ITelemetry; selectedDevice?: TorchDeviceType; @@ -203,6 +206,9 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { pythonMirror?: string; pypiMirror?: string; torchMirror?: string; + torchUpdatePolicy?: TorchUpdatePolicy; + torchPinnedPackages?: TorchPinnedPackages; + torchUpdateDecisionVersion?: string; } ) { this.basePath = basePath; @@ -212,6 +218,9 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { this.pythonMirror = pythonMirror; this.pypiMirror = pypiMirror; this.torchMirror = fixDeviceMirrorMismatch(selectedDevice!, torchMirror); + this.torchUpdatePolicy = torchUpdatePolicy; + this.torchPinnedPackages = torchPinnedPackages; + this.torchUpdateDecisionVersion = torchUpdateDecisionVersion; // uv defaults to .venv this.venvPath = path.join(basePath, '.venv'); @@ -275,6 +284,35 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { return primary; } + updateTorchUpdatePolicy( + policy: TorchUpdatePolicy | undefined, + pinnedPackages?: TorchPinnedPackages, + decisionVersion?: string + ) { + this.torchUpdatePolicy = policy; + if (pinnedPackages !== undefined) { + this.torchPinnedPackages = pinnedPackages; + } else if (policy !== 'pinned') { + this.torchPinnedPackages = undefined; + } + if (decisionVersion !== undefined) { + this.torchUpdateDecisionVersion = decisionVersion; + } + } + + isUsingRecommendedTorchMirror(): boolean { + if (!this.torchMirror) return true; + return this.torchMirror === getDeviceDefaultTorchMirror(this.selectedDevice); + } + + private shouldSkipNvidiaTorchUpgrade(): boolean { + if (this.torchUpdatePolicy === 'pinned' && this.torchUpdateDecisionVersion === NVIDIA_TORCH_RECOMMENDED_VERSION) + return true; + if (this.torchUpdatePolicy === 'defer' && this.torchUpdateDecisionVersion === NVIDIA_TORCH_RECOMMENDED_VERSION) + return true; + return false; + } + public async create(callbacks?: ProcessCallbacks): Promise { try { await this.createEnvironment(callbacks); @@ -642,7 +680,11 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { return; } - const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice); + let torchMirror = this.torchMirror; + if (!torchMirror) { + log.info('Falling back to default torch mirror'); + torchMirror = getDeviceDefaultTorchMirror(this.selectedDevice); + } const config: PipInstallConfig = { packages: ['torch', 'torchvision', 'torchaudio'], indexUrl: torchMirror, @@ -663,8 +705,12 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { * Ensures NVIDIA installs use the recommended PyTorch packages. * @param callbacks The callbacks to use for the command. */ - async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise { + async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks, torchMirrorOverride?: string): Promise { if (this.selectedDevice !== 'nvidia') return; + if (this.shouldSkipNvidiaTorchUpgrade()) { + log.info('Skipping NVIDIA PyTorch upgrade due to pinned policy or deferred updates.'); + return; + } const installedVersions = await this.getInstalledTorchPackageVersions(); if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) { @@ -672,7 +718,11 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { return; } - const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice); + let torchMirror = torchMirrorOverride ?? this.torchMirror; + if (!torchMirror) { + log.info('Falling back to default torch mirror'); + torchMirror = getDeviceDefaultTorchMirror(this.selectedDevice); + } const config: PipInstallConfig = { packages: NVIDIA_TORCH_PACKAGES, indexUrl: torchMirror, @@ -708,7 +758,7 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { * Reads installed torch package versions using `uv pip list --format=json`. * @returns The torch package versions when available, otherwise `undefined`. */ - private async getInstalledTorchPackageVersions(): Promise { + async getInstalledTorchPackageVersions(): Promise { let stdout = ''; let stderr = ''; const callbacks: ProcessCallbacks = { @@ -771,6 +821,15 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { return versions; } + async isNvidiaTorchOutOfDate(installedVersions?: TorchPackageVersions): Promise { + if (this.selectedDevice !== 'nvidia') return false; + + const resolvedVersions = installedVersions ?? (await this.getInstalledTorchPackageVersions()); + if (!resolvedVersions) return false; + + return !this.meetsMinimumNvidiaTorchVersions(resolvedVersions); + } + /** * Installs AMD ROCm SDK packages on Windows. * @param callbacks The callbacks to use for the command. @@ -997,6 +1056,7 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { */ private async needsNvidiaTorchUpgrade(): Promise { if (this.selectedDevice !== 'nvidia') return false; + if (this.shouldSkipNvidiaTorchUpgrade()) return false; const installedVersions = await this.getInstalledTorchPackageVersions(); if (!installedVersions) { diff --git a/tests/unit/install/installationManager.test.ts b/tests/unit/install/installationManager.test.ts index af9f92969..f48a8109c 100644 --- a/tests/unit/install/installationManager.test.ts +++ b/tests/unit/install/installationManager.test.ts @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { ComfyServerConfig } from '@/config/comfyServerConfig'; import { ComfySettings } from '@/config/comfySettings'; -import { IPC_CHANNELS } from '@/constants'; +import { IPC_CHANNELS, TorchMirrorUrl } from '@/constants'; import { InstallationManager, isNvidiaDriverBelowMinimum, @@ -156,6 +156,37 @@ describe('InstallationManager', () => { await Promise.resolve(); }); + describe('getRecommendedTorchMirror', () => { + it('returns the CUDA mirror when the mirror is empty or default', () => { + const helper = manager as unknown as { getRecommendedTorchMirror: (mirror?: string) => string | undefined }; + + expect(helper.getRecommendedTorchMirror()).toBe(TorchMirrorUrl.Cuda); + expect(helper.getRecommendedTorchMirror('')).toBe(TorchMirrorUrl.Cuda); + expect(helper.getRecommendedTorchMirror(TorchMirrorUrl.Default)).toBe(TorchMirrorUrl.Cuda); + }); + + it('normalizes CUDA mirrors to the recommended CUDA path', () => { + const helper = manager as unknown as { getRecommendedTorchMirror: (mirror?: string) => string | undefined }; + + expect(helper.getRecommendedTorchMirror('https://download.pytorch.org/whl/cu128')).toBe(TorchMirrorUrl.Cuda); + }); + + it('normalizes nightly CUDA mirrors to the recommended CUDA path', () => { + const helper = manager as unknown as { getRecommendedTorchMirror: (mirror?: string) => string | undefined }; + + expect(helper.getRecommendedTorchMirror('https://download.pytorch.org/whl/nightly/cu118')).toBe( + TorchMirrorUrl.NightlyCuda + ); + }); + + it('returns the original mirror when it cannot be normalized', () => { + const helper = manager as unknown as { getRecommendedTorchMirror: (mirror?: string) => string | undefined }; + + expect(helper.getRecommendedTorchMirror('https://example.com/simple')).toBe('https://example.com/simple'); + expect(helper.getRecommendedTorchMirror('not a url')).toBe('not a url'); + }); + }); + describe('ensureInstalled', () => { beforeEach(() => { vi.spyOn(ComfyInstallation, 'fromConfig').mockImplementation(() => diff --git a/tests/unit/virtualEnvironment.test.ts b/tests/unit/virtualEnvironment.test.ts index 2e6af3a4d..d45fcc24e 100644 --- a/tests/unit/virtualEnvironment.test.ts +++ b/tests/unit/virtualEnvironment.test.ts @@ -3,7 +3,12 @@ import { type ChildProcess, spawn } from 'node:child_process'; import path from 'node:path'; import { test as baseTest, describe, expect, vi } from 'vitest'; -import { TorchMirrorUrl } from '@/constants'; +import { + NVIDIA_TORCHVISION_VERSION, + NVIDIA_TORCH_RECOMMENDED_VERSION, + NVIDIA_TORCH_VERSION, + TorchMirrorUrl, +} from '@/constants'; import type { ITelemetry } from '@/services/telemetry'; import { VirtualEnvironment, getPipInstallArgs } from '@/virtualEnvironment'; @@ -313,4 +318,242 @@ describe('VirtualEnvironment', () => { expect(uvEnv.UV_PYTHON_INSTALL_MIRROR).toBeUndefined(); }); }); + + describe('isUsingRecommendedTorchMirror', () => { + test('returns true when using default mirror for NVIDIA', () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchMirror: TorchMirrorUrl.Cuda, + }); + + expect(env.isUsingRecommendedTorchMirror()).toBe(true); + }); + + test('returns false when using a custom mirror', () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchMirror: 'https://download.pytorch.org/whl/cu128', + }); + + expect(env.isUsingRecommendedTorchMirror()).toBe(false); + }); + }); + + describe('updateTorchUpdatePolicy', () => { + test('clears pinned packages when policy is not pinned', () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchUpdatePolicy: 'pinned', + torchPinnedPackages: { torch: '2.8.0+cu130' }, + }); + + env.updateTorchUpdatePolicy('auto'); + + expect(env.torchUpdatePolicy).toBe('auto'); + expect(env.torchPinnedPackages).toBeUndefined(); + }); + + test('stores pinned packages and decision version when provided', () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + }); + + env.updateTorchUpdatePolicy('pinned', { torch: NVIDIA_TORCH_VERSION }, 'decision'); + + expect(env.torchUpdatePolicy).toBe('pinned'); + expect(env.torchPinnedPackages).toEqual({ torch: NVIDIA_TORCH_VERSION }); + expect(env.torchUpdateDecisionVersion).toBe('decision'); + }); + }); + + describe('ensureRecommendedNvidiaTorch', () => { + test('skips upgrade when updates are pinned for the current recommended version', async () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchUpdatePolicy: 'pinned', + torchUpdateDecisionVersion: NVIDIA_TORCH_RECOMMENDED_VERSION, + }); + + const versionsSpy = vi.spyOn(env, 'getInstalledTorchPackageVersions'); + + await env.ensureRecommendedNvidiaTorch(); + + expect(versionsSpy).not.toHaveBeenCalled(); + }); + + test('does not skip upgrade when pinned decision version differs', async () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchUpdatePolicy: 'pinned', + torchUpdateDecisionVersion: '2.8.0+cu130|0.23.0+cu130', + }); + + const versionsSpy = vi.spyOn(env, 'getInstalledTorchPackageVersions').mockResolvedValue({ + torch: NVIDIA_TORCH_VERSION, + torchaudio: NVIDIA_TORCH_VERSION, + torchvision: NVIDIA_TORCHVISION_VERSION, + }); + + await env.ensureRecommendedNvidiaTorch(); + + expect(versionsSpy).toHaveBeenCalled(); + }); + + test('skips upgrade when updates are deferred for the recommended version', async () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + torchUpdatePolicy: 'defer', + torchUpdateDecisionVersion: NVIDIA_TORCH_RECOMMENDED_VERSION, + }); + + const versionsSpy = vi.spyOn(env, 'getInstalledTorchPackageVersions'); + + await env.ensureRecommendedNvidiaTorch(); + + expect(versionsSpy).not.toHaveBeenCalled(); + }); + }); + + describe('isNvidiaTorchOutOfDate', () => { + test('returns false when device is not NVIDIA', async ({ virtualEnv }) => { + const versionsSpy = vi.spyOn(virtualEnv, 'getInstalledTorchPackageVersions'); + + await expect(virtualEnv.isNvidiaTorchOutOfDate()).resolves.toBe(false); + expect(versionsSpy).not.toHaveBeenCalled(); + }); + + test('returns true when installed versions are below recommended', async () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + }); + + mockSpawnOutputOnce( + JSON.stringify([ + { name: 'torch', version: '2.8.0+cu130' }, + { name: 'torchaudio', version: '2.8.0+cu130' }, + { name: 'torchvision', version: '0.23.0+cu130' }, + ]) + ); + + await expect(env.isNvidiaTorchOutOfDate()).resolves.toBe(true); + }); + + test('returns false when installed versions meet the recommended minimums', async () => { + vi.stubGlobal('process', { + ...process, + resourcesPath: '/test/resources', + }); + + const env = new VirtualEnvironment('/mock/venv', { + telemetry: mockTelemetry, + selectedDevice: 'nvidia', + pythonVersion: '3.12', + }); + + mockSpawnOutputOnce( + JSON.stringify([ + { name: 'torch', version: NVIDIA_TORCH_VERSION }, + { name: 'torchaudio', version: NVIDIA_TORCH_VERSION }, + { name: 'torchvision', version: NVIDIA_TORCHVISION_VERSION }, + ]) + ); + + await expect(env.isNvidiaTorchOutOfDate()).resolves.toBe(false); + }); + }); + + describe('getInstalledTorchPackageVersions', () => { + test('returns parsed torch versions from uv json output', async ({ virtualEnv }) => { + const output = `[{"name":"aiohappyeyeballs","version":"2.6.1"},{"name":"aiohttp","version":"3.13.3"},{"name":"aiosignal","version":"1.4.0"},{"name":"alembic","version":"1.18.1"},{"name":"annotated-types","version":"0.7.0"},{"name":"attrs","version":"25.4.0"},{"name":"av","version":"16.1.0"},{"name":"certifi","version":"2026.1.4"},{"name":"charset-normalizer","version":"3.4.4"},{"name":"comfy-kitchen","version":"0.2.6"},{"name":"comfyui-embedded-docs","version":"0.4.0"},{"name":"comfyui-frontend-package","version":"1.36.14"},{"name":"comfyui-workflow-templates","version":"0.8.4"},{"name":"comfyui-workflow-templates-core","version":"0.3.88"},{"name":"comfyui-workflow-templates-media-api","version":"0.3.39"},{"name":"comfyui-workflow-templates-media-image","version":"0.3.55"},{"name":"comfyui-workflow-templates-media-other","version":"0.3.80"},{"name":"comfyui-workflow-templates-media-video","version":"0.3.38"},{"name":"einops","version":"0.8.1"},{"name":"filelock","version":"3.20.3"},{"name":"frozenlist","version":"1.8.0"},{"name":"fsspec","version":"2026.1.0"},{"name":"greenlet","version":"3.3.0"},{"name":"hf-xet","version":"1.2.0"},{"name":"huggingface-hub","version":"0.36.0"},{"name":"idna","version":"3.11"},{"name":"jinja2","version":"3.1.6"},{"name":"kornia","version":"0.8.2"},{"name":"kornia-rs","version":"0.1.10"},{"name":"mako","version":"1.3.10"},{"name":"markupsafe","version":"3.0.3"},{"name":"mpmath","version":"1.3.0"},{"name":"multidict","version":"6.7.0"},{"name":"networkx","version":"3.6.1"},{"name":"numpy","version":"2.4.1"},{"name":"nvidia-cublas","version":"13.0.0.19"},{"name":"nvidia-cuda-cupti","version":"13.0.48"},{"name":"nvidia-cuda-nvrtc","version":"13.0.48"},{"name":"nvidia-cuda-runtime","version":"13.0.48"},{"name":"nvidia-cudnn-cu13","version":"9.13.0.50"},{"name":"nvidia-cufft","version":"12.0.0.15"},{"name":"nvidia-cufile","version":"1.15.0.42"},{"name":"nvidia-curand","version":"10.4.0.35"},{"name":"nvidia-cusolver","version":"12.0.3.29"},{"name":"nvidia-cusparse","version":"12.6.2.49"},{"name":"nvidia-cusparselt-cu13","version":"0.8.0"},{"name":"nvidia-nccl-cu13","version":"2.27.7"},{"name":"nvidia-nvjitlink","version":"13.0.39"},{"name":"nvidia-nvshmem-cu13","version":"3.3.24"},{"name":"nvidia-nvtx","version":"13.0.39"},{"name":"packaging","version":"25.0"},{"name":"pillow","version":"12.1.0"},{"name":"pip","version":"24.0"},{"name":"propcache","version":"0.4.1"},{"name":"psutil","version":"7.2.1"},{"name":"pydantic","version":"2.12.5"},{"name":"pydantic-core","version":"2.41.5"},{"name":"pydantic-settings","version":"2.12.0"},{"name":"python-dotenv","version":"1.2.1"},{"name":"pyyaml","version":"6.0.3"},{"name":"regex","version":"2026.1.15"},{"name":"requests","version":"2.32.5"},{"name":"safetensors","version":"0.7.0"},{"name":"scipy","version":"1.17.0"},{"name":"sentencepiece","version":"0.2.1"},{"name":"setuptools","version":"80.9.0"},{"name":"spandrel","version":"0.4.1"},{"name":"sqlalchemy","version":"2.0.45"},{"name":"sympy","version":"1.14.0"},{"name":"tokenizers","version":"0.22.2"},{"name":"torch","version":"2.9.1+cu130"},{"name":"torchaudio","version":"2.9.1+cu130"},{"name":"torchsde","version":"0.2.6"},{"name":"torchvision","version":"0.24.1+cu130"},{"name":"tqdm","version":"4.67.1"},{"name":"trampoline","version":"0.1.2"},{"name":"transformers","version":"4.57.5"},{"name":"triton","version":"3.5.1"},{"name":"typing-extensions","version":"4.15.0"},{"name":"typing-inspection","version":"0.4.2"},{"name":"urllib3","version":"2.6.3"},{"name":"yarl","version":"1.22.0"}]`; + + mockSpawnOutputOnce(output); + + await expect(virtualEnv.getInstalledTorchPackageVersions()).resolves.toEqual({ + torch: '2.9.1+cu130', + torchaudio: '2.9.1+cu130', + torchvision: '0.24.1+cu130', + }); + }); + + test('returns undefined when uv output contains no torch packages', async ({ virtualEnv }) => { + const output = JSON.stringify([{ name: 'numpy', version: '2.1.0' }]); + mockSpawnOutputOnce(output); + + await expect(virtualEnv.getInstalledTorchPackageVersions()).resolves.toBeUndefined(); + }); + + test('returns undefined when uv output is not JSON', async ({ virtualEnv }) => { + mockSpawnOutputOnce('not json'); + + await expect(virtualEnv.getInstalledTorchPackageVersions()).resolves.toBeUndefined(); + }); + + test('returns undefined when uv output is not an array', async ({ virtualEnv }) => { + mockSpawnOutputOnce(JSON.stringify({ name: 'torch', version: NVIDIA_TORCH_VERSION })); + + await expect(virtualEnv.getInstalledTorchPackageVersions()).resolves.toBeUndefined(); + }); + + test('returns undefined when uv exits with non-zero code', async ({ virtualEnv }) => { + mockSpawnOutputOnce(JSON.stringify([]), 1); + + await expect(virtualEnv.getInstalledTorchPackageVersions()).resolves.toBeUndefined(); + }); + }); });