diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index da24d8802bb..b51ccc65fc2 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -252,6 +252,58 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The high resolution fix img2img strength used in the upscale pass.", ) + hrf_scale: Optional[float] = InputField( + default=None, + description="The high resolution fix upscale factor.", + ) + hrf_latent_interpolation_mode: Optional[str] = InputField( + default=None, + description="The latent interpolation mode used in the high resolution fix upscale pass.", + ) + hrf_upscale_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The Spandrel upscale model used in the high resolution fix upscale pass.", + ) + hrf_tile_controlnet_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The tile ControlNet model used in the high resolution fix upscale pass.", + ) + hrf_structure: Optional[float] = InputField( + default=None, + description="Legacy high resolution fix tile ControlNet structure value.", + ) + hrf_tile_control_weight: Optional[float] = InputField( + default=None, + description="The high resolution fix tile ControlNet control weight.", + ) + hrf_tile_control_end: Optional[float] = InputField( + default=None, + description="The high resolution fix tile ControlNet end step percentage.", + ) + hrf_tile_size: Optional[int] = InputField( + default=None, + description="The high resolution fix tiled processing tile size.", + ) + hrf_tile_overlap: Optional[int] = InputField( + default=None, + description="The high resolution fix tiled processing tile overlap.", + ) + hrf_steps: Optional[int] = InputField( + default=None, + description="The number of steps used for the high resolution fix refinement pass.", + ) + hrf_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The optional model override used for the high resolution fix refinement pass.", + ) + hrf_lora_mode: Optional[str] = InputField( + default=None, + description="The LoRA mode used for the high resolution fix refinement pass.", + ) + hrf_loras: Optional[list[LoRAMetadataField]] = InputField( + default=None, + description="The dedicated LoRAs used for the high resolution fix refinement pass.", + ) # SDXL positive_style_prompt: Optional[str] = InputField( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 05edf886890..da515108716 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -350,13 +350,48 @@ "options_withCount_other": "{{count}} options" }, "hrf": { - "hrf": "High Resolution Fix", - "enableHrf": "Enable High Resolution Fix", + "hrf": "Upscale", + "enableHrf": "Enable", + "enableUpscale": "Enable Upscale", + "scale": "Scale", + "strength": "Denoise Strength", "upscaleMethod": "Upscale Method", + "latentInterpolationMode": "Latent Interpolation", + "tileControlWeight": "Tile Control Weight", + "tileControlEnd": "Tile Control Duration", + "steps": "Refinement Steps", + "model": "Refinement Model", + "reuseGenerateModel": "Reuse Base Model", + "loraMode": "Refinement LoRAs", + "reuseGenerateLoras": "Reuse", + "noLoras": "None", + "dedicatedLoras": "LoRAs", + "addDedicatedLora": "Add LoRA", + "noCompatibleModels": "No Compatible Models", + "latent": "Latent", + "upscaleModelMethod": "Upscale Model", + "nearest": "Nearest", + "bilinear": "Bilinear", + "bicubic": "Bicubic", + "area": "Area", + "nearestExact": "Nearest Exact", "metadata": { "enabled": "High Resolution Fix Enabled", "strength": "High Resolution Fix Strength", - "method": "High Resolution Fix Method" + "method": "High Resolution Fix Method", + "scale": "High Resolution Fix Scale", + "latentInterpolationMode": "High Resolution Fix Latent Interpolation", + "upscaleModel": "High Resolution Fix Upscale Model", + "tileControlNetModel": "High Resolution Fix Tile ControlNet", + "legacyStructure": "Legacy High Resolution Fix Structure", + "tileControlWeight": "High Resolution Fix Tile Control Weight", + "tileControlEnd": "High Resolution Fix Tile Control Duration", + "tileSize": "High Resolution Fix Tile Size", + "tileOverlap": "High Resolution Fix Tile Overlap", + "steps": "High Resolution Fix Refinement Steps", + "model": "High Resolution Fix Refinement Model", + "loraMode": "High Resolution Fix LoRA Mode", + "loras": "High Resolution Fix Dedicated LoRAs" } }, "prompt": { @@ -1683,6 +1718,15 @@ "modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}", "fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time", "incompatibleLoRAs": "Incompatible LoRA(s) added", + "hrfExternalModelUnsupported": "High Resolution Fix is not supported for external models", + "hrfRefinerUnsupported": "High Resolution Fix is not supported when SDXL Refiner is enabled", + "hrfUpscaleModelBaseUnsupported": "High Resolution Fix with an upscale model is supported for SD1.5 and SDXL models only", + "hrfModelOverrideMethodUnsupported": "High Resolution Fix model override is only supported with upscale-model HRF", + "hrfModelOverrideExternalUnsupported": "High Resolution Fix model override does not support external models", + "hrfModelOverrideBaseUnsupported": "High Resolution Fix model override is supported for SD1.5 and SDXL models only", + "hrfModelOverrideBaseMismatch": "High Resolution Fix model override must use the same base as the Generate model", + "hrfUpscaleModelMissing": "High Resolution Fix needs an upscale model", + "hrfTileControlNetModelMissing": "High Resolution Fix needs a tile ControlNet model", "canvasIsFiltering": "Canvas is busy (filtering)", "canvasIsTransforming": "Canvas is busy (transforming)", "canvasIsRasterizing": "Canvas is busy (rasterizing)", @@ -2200,6 +2244,118 @@ "Generate high quality images at a larger resolution than optimal for the model. Generally used to prevent duplication in the generated image." ] }, + "hrfEnabled": { + "heading": "Enable Upscale", + "paragraphs": [ + "Adds an optional second pass that upscales the base image and refines details at the larger size.", + "When disabled, the Upscale settings are saved but are not applied to new generations." + ] + }, + "hrfMethod": { + "heading": "Upscale Method", + "paragraphs": [ + "Latent resizes the image in latent space before refinement. It is faster and does not require an upscale model.", + "Upscale Model enlarges the image with a selected upscale model before refinement, which can preserve more texture and detail." + ] + }, + "hrfScale": { + "heading": "Scale", + "paragraphs": [ + "Controls final output size as a multiplier of the base Width and Height.", + "Higher scale makes a larger image and uses more time and VRAM. Lower scale stays closer to the base size." + ] + }, + "hrfDenoisingStrength": { + "heading": "Denoise Strength", + "paragraphs": [ + "Controls how much the refinement pass can change the upscaled image.", + "Lower strength preserves more of the upscaled image. Higher strength allows more new detail and more prompt influence." + ] + }, + "hrfLatentInterpolation": { + "heading": "Latent Interpolation", + "paragraphs": [ + "Controls how Latent Upscale resizes the latent image before refinement.", + "Smoother modes such as Bicubic and Bilinear can soften transitions. Nearest modes preserve harder edges but can look blockier." + ] + }, + "hrfUpscaleModel": { + "heading": "Upscale Model", + "paragraphs": [ + "Selects the image-to-image upscale model used to enlarge the base image before refinement.", + "Different upscale models may work better for different image types, such as photos, illustration, or line art." + ] + }, + "hrfTileControl": { + "heading": "Tile Control", + "paragraphs": [ + "Selects a compatible Tile or Union ControlNet model used during Upscale Model refinement.", + "Only same-base Tile-style ControlNet models are shown. Tile guidance helps preserve structure and reduce drift after the image is enlarged." + ] + }, + "hrfTileControlWeight": { + "heading": "Tile Control Weight", + "paragraphs": [ + "Controls how strongly Tile Control guides the refinement pass.", + "Higher weight preserves structure more strongly. Lower weight gives the model and prompt more freedom to change details." + ] + }, + "hrfTileControlEnd": { + "heading": "Tile Control Duration", + "paragraphs": [ + "Controls how long Tile Control remains active during the refinement pass.", + "Higher values keep tile guidance active for more of the pass. Lower values use tile guidance early, then let refinement finish more freely." + ] + }, + "hrfTileSize": { + "heading": "Tile Size", + "paragraphs": [ + "Controls the tile size used during Upscale Model refinement.", + "Larger tiles can improve consistency but use more VRAM. Smaller tiles use less VRAM and may be safer on memory-limited systems." + ] + }, + "hrfTileOverlap": { + "heading": "Tile Overlap", + "paragraphs": [ + "Controls how much adjacent tiles overlap during Upscale Model refinement.", + "Higher overlap can reduce visible seams but uses more VRAM and time. Lower overlap is faster but may show more tile boundaries." + ] + }, + "hrfModel": { + "heading": "Refinement Model", + "paragraphs": [ + "Selects an optional model used only for the Upscale refinement pass.", + "Reuse Base Model keeps the base generation model for refinement. Choosing another model lets you use a compatible refinement model for the Upscale pass." + ] + }, + "hrfLoraMode": { + "heading": "Refinement LoRAs", + "paragraphs": [ + "Controls which LoRAs are applied during the Upscale refinement pass.", + "Reuse applies the base generation LoRAs again. None disables LoRAs for refinement. Dedicated lets you select LoRAs used only for refinement." + ] + }, + "hrfLoraSelect": { + "heading": "Dedicated Refinement LoRAs", + "paragraphs": [ + "Selects LoRAs that apply only during the Upscale refinement pass.", + "Dedicated refinement LoRAs can add detail or style after upscaling without affecting the base image generation." + ] + }, + "hrfLoraWeight": { + "heading": "Refinement LoRA Weight", + "paragraphs": [ + "Controls how strongly this dedicated LoRA affects the Upscale refinement pass.", + "Higher weight increases the LoRA's impact on refined details. Lower weight makes the effect more subtle." + ] + }, + "hrfSteps": { + "heading": "Refinement Steps", + "paragraphs": [ + "Controls the number of denoising steps used for the Upscale refinement pass.", + "When custom refinement steps are off, Upscale reuses the Generate step count. Higher values can add detail but take longer." + ] + }, "paramIterations": { "heading": "Iterations", "paragraphs": [ diff --git a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts index 95fa75cfa32..93b581aaa00 100644 --- a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts +++ b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts @@ -6,7 +6,23 @@ export type Feature = | 'fluxDypePreset' | 'fluxDypeScale' | 'fluxDypeExponent' + | 'hrfDenoisingStrength' + | 'hrfEnabled' | 'hrf' + | 'hrfLatentInterpolation' + | 'hrfLoraMode' + | 'hrfLoraSelect' + | 'hrfLoraWeight' + | 'hrfMethod' + | 'hrfModel' + | 'hrfScale' + | 'hrfSteps' + | 'hrfTileControl' + | 'hrfTileControlEnd' + | 'hrfTileControlWeight' + | 'hrfTileOverlap' + | 'hrfTileSize' + | 'hrfUpscaleModel' | 'paramNegativeConditioning' | 'paramPositiveConditioning' | 'paramScheduler' diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index 24dee85a66a..4d24fc16309 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -8,15 +8,20 @@ import type { import { describe, expect, it } from 'vitest'; import { + modelChanged, paramsSliceConfig, + selectHrfFinalDimensions, selectModelSupportsDimensions, selectModelSupportsGuidance, + selectModelSupportsHrf, + selectModelSupportsHrfUpscaleModel, selectModelSupportsNegativePrompt, selectModelSupportsRefImages, selectModelSupportsSeed, selectModelSupportsSteps, + setHrfMethod, } from './paramsSlice'; -import { getInitialParamsState } from './types'; +import { getInitialParamsState, type ParamsState } from './types'; const buildExternalModelIdentifier = (config: ExternalApiModelConfig) => ({ @@ -132,6 +137,254 @@ describe('paramsSlice selectors for external models', () => { expect(selectModelSupportsSteps.resultFunc(model)).toBe(false); expect(selectModelSupportsDimensions.resultFunc(model, config)).toBe(true); }); + + it('supports HRF only for SD1.5 and SDXL models', () => { + expect( + selectModelSupportsHrf.resultFunc({ + key: 'sd1', + hash: 'h', + name: 'SD1', + base: 'sd-1', + type: 'main', + }) + ).toBe(true); + expect( + selectModelSupportsHrf.resultFunc({ + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + }) + ).toBe(true); + + const unsupportedBases = ['sd-2', 'sd-3', 'flux', 'flux2', 'anima', 'cogview4', 'qwen-image', 'z-image'] as const; + unsupportedBases.forEach((base) => { + expect( + selectModelSupportsHrf.resultFunc({ + key: base, + hash: 'h', + name: base, + base, + type: 'main', + }) + ).toBe(false); + }); + + const config = createExternalConfig({ modes: ['txt2img'], supports_reference_images: false }); + expect(selectModelSupportsHrf.resultFunc(buildExternalModelIdentifier(config))).toBe(false); + }); +}); + +describe('paramsSlice HRF selectors', () => { + it('rounds final dimensions down to the model grid', () => { + expect(selectHrfFinalDimensions.resultFunc(513, 512, 1.5, 'flux')).toEqual({ width: 768, height: 768 }); + }); + + it('supports upscale-model HRF only for SD1.5 and SDXL models', () => { + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'sd1', + hash: 'h', + name: 'SD1', + base: 'sd-1', + type: 'main', + }) + ).toBe(true); + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + }) + ).toBe(true); + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'flux', + hash: 'h', + name: 'FLUX', + base: 'flux', + type: 'main', + }) + ).toBe(false); + }); +}); + +describe('paramsSlice HRF reducers', () => { + it('clears upscale-model-only refinement overrides when switching to latent HRF', () => { + const state = getInitialParamsState(); + state.hrfMethod = 'upscale_model'; + state.hrfSteps = 12; + state.hrfModel = { key: 'hrf-model', hash: 'h', name: 'HRF Model', base: 'sdxl', type: 'main' }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'lora', + isEnabled: true, + model: { key: 'lora', hash: 'h', name: 'LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const nextState = paramsSliceConfig.slice.reducer(state, setHrfMethod('latent')); + + expect(nextState.hrfMethod).toBe('latent'); + expect(nextState.hrfSteps).toBeNull(); + expect(nextState.hrfModel).toBeNull(); + expect(nextState.hrfLoraMode).toBe('reuse_generate'); + expect(nextState.hrfLoras).toEqual([]); + }); + + it('filters dedicated HRF LoRAs when the Generate model base changes', () => { + const state = getInitialParamsState(); + const previousModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const nextModel = { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' } as const; + state.model = previousModel; + state.hrfMethod = 'upscale_model'; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + { + id: 'sd1-lora', + isEnabled: true, + model: { key: 'sd1-lora', hash: 'h', name: 'SD1 LoRA', base: 'sd-1', type: 'lora' }, + weight: 0.7, + }, + ] as ParamsState['hrfLoras']; + + const nextState = paramsSliceConfig.slice.reducer(state, modelChanged({ model: nextModel, previousModel })); + + expect(nextState.hrfLoraMode).toBe('dedicated'); + expect(nextState.hrfLoras).toHaveLength(1); + expect(nextState.hrfLoras[0]?.model.key).toBe('sd1-lora'); + }); + + it('preserves HRF settings when switching through an unsupported base and back', () => { + const state = getInitialParamsState(); + const sdxlModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const externalModel = buildExternalModelIdentifier( + createExternalConfig({ modes: ['txt2img'], supports_reference_images: false }) + ); + state.model = sdxlModel; + state.hrfEnabled = true; + state.hrfMethod = 'upscale_model'; + state.hrfModel = { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }; + state.hrfTileControlNetModel = { + key: 'tile-sdxl', + hash: 'h', + name: 'Tile SDXL', + base: 'sdxl', + type: 'controlnet', + }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const unsupportedState = paramsSliceConfig.slice.reducer( + state, + modelChanged({ model: externalModel, previousModel: sdxlModel }) + ); + const returnedState = paramsSliceConfig.slice.reducer( + unsupportedState, + modelChanged({ model: sdxlModel, previousModel: externalModel }) + ); + + expect(unsupportedState.hrfEnabled).toBe(true); + expect(unsupportedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(unsupportedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(unsupportedState.hrfLoraMode).toBe('dedicated'); + expect(unsupportedState.hrfLoras).toHaveLength(1); + expect(returnedState.hrfEnabled).toBe(true); + expect(returnedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(returnedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(returnedState.hrfLoraMode).toBe('dedicated'); + expect(returnedState.hrfLoras[0]?.model.key).toBe('sdxl-lora'); + }); + + it('cleans preserved HRF settings when returning to a different supported base', () => { + const state = getInitialParamsState(); + const sdxlModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const sd1Model = { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' } as const; + const fluxModel = { key: 'flux', hash: 'h', name: 'FLUX', base: 'flux', type: 'main' } as const; + state.model = sdxlModel; + state.hrfEnabled = true; + state.hrfMethod = 'upscale_model'; + state.hrfModel = { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }; + state.hrfTileControlNetModel = { + key: 'tile-sdxl', + hash: 'h', + name: 'Tile SDXL', + base: 'sdxl', + type: 'controlnet', + }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const unsupportedState = paramsSliceConfig.slice.reducer( + state, + modelChanged({ model: fluxModel, previousModel: sdxlModel }) + ); + const sd1State = paramsSliceConfig.slice.reducer( + unsupportedState, + modelChanged({ model: sd1Model, previousModel: fluxModel }) + ); + + expect(unsupportedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(unsupportedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(unsupportedState.hrfLoras).toHaveLength(1); + expect(sd1State.hrfEnabled).toBe(true); + expect(sd1State.hrfModel).toBeNull(); + expect(sd1State.hrfTileControlNetModel).toBeNull(); + expect(sd1State.hrfLoraMode).toBe('dedicated'); + expect(sd1State.hrfLoras).toEqual([]); + }); +}); + +describe('paramsSlice HRF migrations', () => { + it('migrates legacy HRF Structure to explicit Tile Control Weight', () => { + const v5State = { + ...getInitialParamsState(), + _version: 5, + hrfStructure: 0, + } as unknown as Record; + delete v5State.hrfTileControlWeight; + delete v5State.hrfSteps; + delete v5State.hrfModel; + delete v5State.hrfLoraMode; + delete v5State.hrfLoras; + + const migrated = paramsSliceConfig.persistConfig!.migrate(v5State); + + expect(migrated).toMatchObject({ + _version: 7, + hrfTileControlWeight: 0.625, + hrfSteps: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], + }); + expect('hrfStructure' in migrated).toBe(false); + }); }); describe('paramsSliceConfig persisted state migration', () => { @@ -155,7 +408,7 @@ describe('paramsSliceConfig persisted state migration', () => { const result = migrate?.(v2State) as ReturnType; - expect(result._version).toBe(3); + expect(result._version).toBe(7); expect(result.qwenImageVaeModel).toBeNull(); expect(result.qwenImageQwenVLEncoderModel).toBeNull(); // Existing params should be preserved diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index a5200ef1ff8..59576a298af 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -7,7 +7,16 @@ import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMul import { isPlainObject } from 'es-toolkit'; import { clamp } from 'es-toolkit/compat'; import { logout } from 'features/auth/store/authSlice'; -import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types'; +import type { + AspectRatioID, + HrfLatentInterpolationMode, + HrfLoraMode, + HrfMethod, + InfillMethod, + LoRA, + ParamsState, + RgbaColor, +} from 'features/controlLayers/store/types'; import { ASPECT_RATIO_MAP, DEFAULT_ASPECT_RATIO_CONFIG, @@ -21,7 +30,7 @@ import { SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS, SUPPORTS_REF_IMAGES_BASE_MODELS, } from 'features/modelManagerV2/models'; -import type { BaseModelType } from 'features/nodes/types/common'; +import { type BaseModelType, type ModelIdentifierField, zModelIdentifierField } from 'features/nodes/types/common'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -39,15 +48,21 @@ import type { ParameterPrecision, ParameterScheduler, ParameterSDXLRefinerModel, + ParameterSpandrelImageToImageModel, ParameterT5EncoderModel, ParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; import { getExternalPanelControl, hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema'; import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; -import type { AnyModelConfigWithExternal } from 'services/api/types'; +import type { AnyModelConfigWithExternal, ControlNetModelConfig, LoRAModelConfig } from 'services/api/types'; import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; +import { v4 as uuidv4 } from 'uuid'; + +const DEFAULT_HRF_LORA_WEIGHT = 0.75; + +const selectHrfLoRA = (state: ParamsState, id: string) => state.hrfLoras.find((lora) => lora.id === id); const slice = createSlice({ name: 'params', @@ -115,6 +130,102 @@ const slice = createSlice({ setOptimizedDenoisingEnabled: (state, action: PayloadAction) => { state.optimizedDenoisingEnabled = action.payload; }, + setHrfEnabled: (state, action: PayloadAction) => { + state.hrfEnabled = action.payload && !state.refinerModel; + }, + setHrfMethod: (state, action: PayloadAction) => { + state.hrfMethod = action.payload; + if (action.payload === 'latent') { + state.hrfSteps = null; + state.hrfModel = null; + state.hrfLoraMode = 'reuse_generate'; + state.hrfLoras = []; + } + }, + setHrfScale: (state, action: PayloadAction) => { + state.hrfScale = action.payload; + }, + setHrfStrength: (state, action: PayloadAction) => { + state.hrfStrength = action.payload; + }, + setHrfLatentInterpolationMode: (state, action: PayloadAction) => { + state.hrfLatentInterpolationMode = action.payload; + }, + setHrfUpscaleModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfUpscaleModel.safeParse(action.payload); + if (result.success) { + state.hrfUpscaleModel = result.data; + } + }, + setHrfTileControlNetModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfTileControlNetModel.safeParse(action.payload); + if (result.success) { + state.hrfTileControlNetModel = result.data; + } + }, + setHrfTileControlWeight: (state, action: PayloadAction) => { + state.hrfTileControlWeight = action.payload; + }, + setHrfTileControlEnd: (state, action: PayloadAction) => { + state.hrfTileControlEnd = action.payload; + }, + setHrfTileSize: (state, action: PayloadAction) => { + state.hrfTileSize = action.payload; + }, + setHrfTileOverlap: (state, action: PayloadAction) => { + state.hrfTileOverlap = action.payload; + }, + setHrfSteps: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfSteps.safeParse(action.payload); + if (result.success) { + state.hrfSteps = result.data; + } + }, + setHrfModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfModel.safeParse(action.payload); + if (result.success) { + state.hrfModel = result.data; + } + }, + setHrfLoraMode: (state, action: PayloadAction) => { + state.hrfLoraMode = action.payload; + }, + setHrfLoras: (state, action: PayloadAction) => { + state.hrfLoras = action.payload; + }, + hrfLoraAdded: { + reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { + const { model, id } = action.payload; + const parsedModel = zModelIdentifierField.parse(model); + const defaultLoRAConfig: Pick = { + weight: model.default_settings?.weight ?? DEFAULT_HRF_LORA_WEIGHT, + isEnabled: true, + }; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.key !== parsedModel.key); + state.hrfLoras.push({ ...defaultLoRAConfig, model: parsedModel, id }); + }, + prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }), + }, + hrfLoraDeleted: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.id !== id); + }, + hrfLoraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { + const { id, weight } = action.payload; + const lora = selectHrfLoRA(state, id); + if (!lora) { + return; + } + lora.weight = weight; + }, + hrfLoraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => { + const { id, isEnabled } = action.payload; + const lora = selectHrfLoRA(state, id); + if (!lora) { + return; + } + lora.isEnabled = isEnabled; + }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -141,6 +252,17 @@ const slice = createSlice({ return; } + if (isHrfSupportedBase(model.base)) { + if (state.hrfModel?.base !== model.base) { + state.hrfModel = null; + } + if (state.hrfTileControlNetModel?.base !== model.base) { + state.hrfTileControlNetModel = null; + } + const effectiveHrfBase = state.hrfModel?.base ?? model.base; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.base === effectiveHrfBase); + } + applyClipSkip(state, model, state.clipSkip); }, vaeSelected: (state, action: PayloadAction) => { @@ -333,6 +455,9 @@ const slice = createSlice({ return; } state.refinerModel = result.data; + if (state.refinerModel) { + state.hrfEnabled = false; + } }, setRefinerSteps: (state, action: PayloadAction) => { state.refinerSteps = action.payload; @@ -642,6 +767,25 @@ export const { setSeed, setImg2imgStrength, setOptimizedDenoisingEnabled, + setHrfEnabled, + setHrfMethod, + setHrfScale, + setHrfStrength, + setHrfLatentInterpolationMode, + setHrfUpscaleModel, + setHrfTileControlNetModel, + setHrfTileControlWeight, + setHrfTileControlEnd, + setHrfTileSize, + setHrfTileOverlap, + setHrfSteps, + setHrfModel, + setHrfLoraMode, + setHrfLoras, + hrfLoraAdded, + hrfLoraDeleted, + hrfLoraWeightChanged, + hrfLoraIsEnabledChanged, setSeamlessXAxis, setSeamlessYAxis, setShouldRandomizeSeed, @@ -734,6 +878,52 @@ export const paramsSliceConfig: SliceConfig = { state.qwenImageQwenVLEncoderModel = null; } + if (state._version === 3) { + // v3 -> v4, add Generate tab high resolution fix settings + state._version = 4; + state.hrfEnabled = false; + state.hrfScale = 2; + state.hrfStrength = 0.45; + state.hrfLatentInterpolationMode = 'bicubic'; + } + + if (state._version === 4) { + // v4 -> v5, add Generate tab upscale-model high resolution fix settings + state._version = 5; + state.hrfMethod = 'latent'; + state.hrfUpscaleModel = null; + state.hrfTileControlNetModel = null; + state.hrfStructure = 0; + state.hrfTileSize = 1024; + state.hrfTileOverlap = 128; + } + + if (state._version === 5) { + // v5 -> v6, add explicit Generate tab HRF Tile ControlNet timing + state._version = 6; + state.hrfTileControlEnd = 0.2; + } + + if (state._version === 6) { + // v6 -> v7, replace the Invoke Upscale "Structure" abstraction with explicit Generate HRF controls + state._version = 7; + if (!('qwenImageVaeModel' in state)) { + state.qwenImageVaeModel = null; + } + if (!('qwenImageQwenVLEncoderModel' in state)) { + state.qwenImageQwenVLEncoderModel = null; + } + if (!('hrfTileControlWeight' in state)) { + const legacyHrfStructure = typeof state.hrfStructure === 'number' ? state.hrfStructure : 0; + state.hrfTileControlWeight = (legacyHrfStructure + 10) * 0.0325 + 0.3; + state.hrfSteps = null; + state.hrfModel = null; + state.hrfLoraMode = 'reuse_generate'; + state.hrfLoras = []; + } + delete state.hrfStructure; + } + return zParamsState.parse(state); }, }, @@ -802,6 +992,25 @@ export const selectInfillPatchmatchDownscaleSize = createParamsSelector( export const selectInfillColorValue = createParamsSelector((params) => params.infillColorValue); export const selectImg2imgStrength = createParamsSelector((params) => params.img2imgStrength); export const selectOptimizedDenoisingEnabled = createParamsSelector((params) => params.optimizedDenoisingEnabled); +export const selectHrfEnabled = createParamsSelector((params) => params.hrfEnabled); +export const selectHrfMethod = createParamsSelector((params) => params.hrfMethod); +export const selectHrfScale = createParamsSelector((params) => params.hrfScale); +export const selectHrfStrength = createParamsSelector((params) => params.hrfStrength); +export const selectHrfLatentInterpolationMode = createParamsSelector((params) => params.hrfLatentInterpolationMode); +export const selectHrfUpscaleModel = createParamsSelector((params) => params.hrfUpscaleModel); +export const selectHrfTileControlNetModel = createParamsSelector((params) => params.hrfTileControlNetModel); +export const selectHrfTileControlWeight = createParamsSelector((params) => params.hrfTileControlWeight); +export const selectHrfTileControlEnd = createParamsSelector((params) => params.hrfTileControlEnd); +export const selectHrfTileSize = createParamsSelector((params) => params.hrfTileSize); +export const selectHrfTileOverlap = createParamsSelector((params) => params.hrfTileOverlap); +export const selectHrfSteps = createParamsSelector((params) => params.hrfSteps); +export const selectHrfModel = createParamsSelector((params) => params.hrfModel); +export const selectHrfLoraMode = createParamsSelector((params) => params.hrfLoraMode); +export const selectHrfLoras = createParamsSelector((params) => params.hrfLoras); +export const buildSelectHrfLoRA = (id: string) => + createSelector([selectParamsSlice], (params) => { + return selectHrfLoRA(params, id); + }); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? ''); @@ -884,6 +1093,21 @@ export const selectModelSupportsDimensions = createSelector(selectModel, selectM } return true; }); +export const isHrfSupportedBase = (base: BaseModelType | null | undefined): boolean => + base === 'sd-1' || base === 'sdxl'; + +export const selectModelSupportsHrf = createSelector(selectModel, (model) => { + if (!model) { + return false; + } + return isHrfSupportedBase(model.base); +}); +export const selectModelSupportsHrfUpscaleModel = createSelector(selectModel, (model) => { + if (!model) { + return false; + } + return isHrfSupportedBase(model.base); +}); export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => { if (modelConfig && isExternalApiModelConfig(modelConfig)) { return getExternalPanelControl(modelConfig, 'image', 'seed'); @@ -930,6 +1154,19 @@ export const selectRefinerSteps = createParamsSelector((params) => params.refine export const selectWidth = createParamsSelector((params) => params.dimensions.width); export const selectHeight = createParamsSelector((params) => params.dimensions.height); +export const selectHrfFinalDimensions = createSelector( + selectWidth, + selectHeight, + selectHrfScale, + selectBase, + (width, height, hrfScale, base) => { + const gridSize = getGridSize(base as BaseModelType | undefined); + return { + width: Math.max(roundDownToMultiple(width * hrfScale, gridSize), 64), + height: Math.max(roundDownToMultiple(height * hrfScale, gridSize), 64), + }; + } +); export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id); export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value); export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7a7ebeade71..7bb72dbefb2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -24,6 +24,7 @@ import { zParameterScheduler, zParameterSDXLRefinerModel, zParameterSeed, + zParameterSpandrelImageToImageModel, zParameterSteps, zParameterStrength, zParameterT5EncoderModel, @@ -749,8 +750,17 @@ const zPositivePromptHistory = z export const zInfillMethod = z.enum(['patchmatch', 'lama', 'cv2', 'color', 'tile']); export type InfillMethod = z.infer; +export const zHrfLatentInterpolationMode = z.enum(['nearest', 'bilinear', 'bicubic', 'area', 'nearest-exact']); +export type HrfLatentInterpolationMode = z.infer; + +export const zHrfMethod = z.enum(['latent', 'upscale_model']); +export type HrfMethod = z.infer; + +export const zHrfLoraMode = z.enum(['reuse_generate', 'none', 'dedicated']); +export type HrfLoraMode = z.infer; + export const zParamsState = z.object({ - _version: z.literal(3), + _version: z.literal(7), maskBlur: z.number(), maskBlurMethod: zParameterMaskBlurMethod, canvasCoherenceMode: zParameterCanvasCoherenceMode, @@ -765,6 +775,21 @@ export const zParamsState = z.object({ guidance: zParameterGuidance, img2imgStrength: zParameterStrength, optimizedDenoisingEnabled: z.boolean(), + hrfEnabled: z.boolean(), + hrfMethod: zHrfMethod, + hrfScale: z.number().min(1).max(8), + hrfStrength: zParameterStrength, + hrfLatentInterpolationMode: zHrfLatentInterpolationMode, + hrfUpscaleModel: zParameterSpandrelImageToImageModel.nullable(), + hrfTileControlNetModel: zModelIdentifierField.nullable(), + hrfTileControlWeight: z.number().min(0).max(2), + hrfTileControlEnd: z.number().min(0).max(1), + hrfTileSize: z.number().int().min(8), + hrfTileOverlap: z.number().int().min(8), + hrfSteps: z.number().int().min(1).nullable(), + hrfModel: zParameterModel.nullable(), + hrfLoraMode: zHrfLoraMode, + hrfLoras: z.array(zLoRA), iterations: z.number(), scheduler: zParameterScheduler, fluxScheduler: zParameterFluxScheduler, @@ -839,7 +864,7 @@ export const zParamsState = z.object({ }); export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ - _version: 3, + _version: 7, maskBlur: 16, maskBlurMethod: 'box', canvasCoherenceMode: 'Gaussian Blur', @@ -854,6 +879,21 @@ export const getInitialParamsState = (): ParamsState => ({ guidance: 4, img2imgStrength: 0.75, optimizedDenoisingEnabled: true, + hrfEnabled: false, + hrfMethod: 'latent', + hrfScale: 2, + hrfStrength: 0.45, + hrfLatentInterpolationMode: 'bicubic', + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + hrfTileControlWeight: 0.625, + hrfTileControlEnd: 0.2, + hrfTileSize: 1024, + hrfTileOverlap: 128, + hrfSteps: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], iterations: 1, scheduler: 'dpmpp_3m_k', fluxScheduler: 'euler', diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx index c7433153825..362fc05d511 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx @@ -21,4 +21,25 @@ describe('ImageMetadataActions', () => { expect(handlers).toContain(ImageMetadataHandlers.QwenImageQuantization); expect(handlers).toContain(ImageMetadataHandlers.QwenImageShift); }); + + it('includes HRF refinement metadata handlers in the recall parameters UI', () => { + const element = (ImageMetadataActions as unknown as { type: (props: { metadata: unknown }) => unknown }).type({ + metadata: { model: { key: 'test' } }, + }) as { + props: { + children: Array<{ props?: { handler?: unknown } }>; + }; + }; + + const handlers = element.props.children + .map((child) => child.props?.handler) + .filter((handler): handler is unknown => handler !== undefined); + + expect(handlers).toContain(ImageMetadataHandlers.HrfTileControlWeight); + expect(handlers).toContain(ImageMetadataHandlers.HrfTileControlEnd); + expect(handlers).toContain(ImageMetadataHandlers.HrfSteps); + expect(handlers).toContain(ImageMetadataHandlers.HrfModel); + expect(handlers).toContain(ImageMetadataHandlers.HrfLoraMode); + expect(handlers).toContain(ImageMetadataHandlers.HrfLoRAs); + }); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 105ad3dfd67..d3bc3ca0da9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -49,6 +49,22 @@ export const ImageMetadataActions = memo((props: Props) => { + + + + + + + + + + + + + + + + diff --git a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx index bb295303273..1a775fb95a2 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx @@ -1,6 +1,6 @@ import type { AppStore } from 'app/store/store'; import type * as paramsSliceModule from 'features/controlLayers/store/paramsSlice'; -import { ImageMetadataHandlers } from 'features/metadata/parsing'; +import { ImageMetadataHandlers, MetadataUtils } from 'features/metadata/parsing'; import type * as modelsApiModule from 'services/api/endpoints/models'; import { beforeEach, describe, expect, it, vi } from 'vitest'; @@ -22,7 +22,7 @@ vi.mock('features/controlLayers/store/paramsSlice', async (importOriginal) => { return { ...mod, selectBase: () => currentBase }; }); -const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({ +const fakeModel = (type: 'main' | 'vae' | 'qwen3_encoder' | 'lora', base: string) => ({ key: `${type}-key`, hash: 'hash', name: `Some ${type}`, @@ -31,6 +31,7 @@ const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({ }); let nextResolved: ReturnType = fakeModel('vae', 'flux2'); +let resolvedModels: Record> = {}; vi.mock('services/api/endpoints/models', async (importOriginal) => { const mod = await importOriginal(); @@ -48,15 +49,22 @@ vi.mock('services/api/endpoints/models', async (importOriginal) => { const makeStore = (): AppStore => ({ - dispatch: vi.fn(() => ({ - unwrap: () => Promise.resolve(nextResolved), - })), + dispatch: vi.fn((action) => { + if (action?.type === 'generation/modelSelected') { + currentBase = action.payload.base; + return action; + } + return { + unwrap: () => Promise.resolve(resolvedModels[action?.key] ?? nextResolved), + }; + }), getState: () => ({}), }) as unknown as AppStore; beforeEach(() => { currentBase = 'flux2'; nextResolved = fakeModel('vae', 'flux2'); + resolvedModels = {}; }); describe('ImageMetadataHandlers — Klein recall gating', () => { @@ -171,4 +179,89 @@ describe('ImageMetadataHandlers — Klein recall gating', () => { expect(parsed).toBe(3.5); }); }); + + describe('HRF LoRAs', () => { + it('parses dedicated HRF LoRAs against the image metadata model base', async () => { + currentBase = 'sd-1'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.HrfLoRAs.parse( + { + model: mainModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + store + ); + + expect(parsed).toEqual([ + expect.objectContaining({ model: expect.objectContaining({ key: hrfLora.key }), weight: 0.6 }), + ]); + }); + + it('filters HRF LoRAs that do not match the metadata HRF model base', async () => { + currentBase = 'sdxl'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfModel = fakeModel('main', 'sd-1'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.HrfLoRAs.parse( + { + model: mainModel, + hrf_model: hrfModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + store + ); + + expect(parsed).toEqual([]); + }); + + it('recalls dedicated HRF LoRAs after the recalled main model changes base', async () => { + currentBase = 'sd-1'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [mainModel.key]: mainModel, + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const recalled = await MetadataUtils.recallByHandlers({ + metadata: { + model: mainModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + handlers: [ImageMetadataHandlers.HrfLoRAs, ImageMetadataHandlers.MainModel], + store, + silent: true, + }); + + expect(Object.keys(ImageMetadataHandlers).indexOf('HrfLoRAs')).toBeGreaterThan( + Object.keys(ImageMetadataHandlers).indexOf('MainModel') + ); + expect(recalled.has(ImageMetadataHandlers.MainModel)).toBe(true); + expect(recalled.has(ImageMetadataHandlers.HrfLoRAs)).toBe(true); + expect(store.dispatch).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'generation/modelSelected', + payload: expect.objectContaining({ base: 'sdxl' }), + }) + ); + expect(store.dispatch).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'params/setHrfLoras', + payload: [expect.objectContaining({ model: expect.objectContaining({ key: hrfLora.key }), weight: 0.6 })], + }) + ); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index c5a31d03a34..0e2f341aedd 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -40,6 +40,21 @@ import { setFluxDypeScale, setFluxScheduler, setGuidance, + setHrfEnabled, + setHrfLatentInterpolationMode, + setHrfLoraMode, + setHrfLoras, + setHrfMethod, + setHrfModel, + setHrfScale, + setHrfSteps, + setHrfStrength, + setHrfTileControlEnd, + setHrfTileControlNetModel, + setHrfTileControlWeight, + setHrfTileOverlap, + setHrfTileSize, + setHrfUpscaleModel, setImg2imgStrength, setRefinerCFGScale, setRefinerNegativeAestheticScore, @@ -64,8 +79,22 @@ import { zImageVaeModelSelected, } from 'features/controlLayers/store/paramsSlice'; import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; -import type { CanvasMetadata, LoRA, RefImageState } from 'features/controlLayers/store/types'; -import { zCanvasMetadata, zCanvasReferenceImageState_OLD, zRefImageState } from 'features/controlLayers/store/types'; +import type { + CanvasMetadata, + HrfLatentInterpolationMode, + HrfLoraMode, + HrfMethod as HrfMethodType, + LoRA, + RefImageState, +} from 'features/controlLayers/store/types'; +import { + zCanvasMetadata, + zCanvasReferenceImageState_OLD, + zHrfLatentInterpolationMode, + zHrfLoraMode, + zHrfMethod, + zRefImageState, +} from 'features/controlLayers/store/types'; import type { ModelIdentifierField, ModelType } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifier } from 'features/nodes/types/v2/common'; @@ -621,6 +650,253 @@ const DenoisingStrength: SingleMetadataHandler = { }; //#endregion DenoisingStrength +//#region High Resolution Fix +const HrfEnabled: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfEnabled', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_enabled'); + const parsed = z.boolean().parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfEnabled(value)); + }, + i18nKey: 'hrf.metadata.enabled', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfMethod: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfMethod', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_method'); + const parsed = zHrfMethod.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfMethod(value)); + }, + i18nKey: 'hrf.metadata.method', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfStrength: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfStrength', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_strength'); + const parsed = zParameterStrength.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfStrength(value)); + }, + i18nKey: 'hrf.metadata.strength', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfScale: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfScale', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_scale'); + const parsed = z.number().min(1).max(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfScale(value)); + }, + i18nKey: 'hrf.metadata.scale', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfLatentInterpolationModeMetadata: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfLatentInterpolationMode', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_latent_interpolation_mode'); + const parsed = zHrfLatentInterpolationMode.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfLatentInterpolationMode(value)); + }, + i18nKey: 'hrf.metadata.latentInterpolationMode', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfUpscaleModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfUpscaleModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_upscale_model'); + return parseModelIdentifier(raw, store, 'spandrel_image_to_image'); + }, + recall: (value, store) => { + store.dispatch(setHrfUpscaleModel(value)); + }, + i18nKey: 'hrf.metadata.upscaleModel', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfTileControlNetModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlNetModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_tile_controlnet_model'); + return parseModelIdentifier(raw, store, 'controlnet'); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlNetModel(value)); + }, + i18nKey: 'hrf.metadata.tileControlNetModel', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfStructure: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfStructure', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_structure'); + const parsed = z.number().min(-10).max(10).parse(raw); + return Promise.resolve((parsed + 10) * 0.0325 + 0.3); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlWeight(value)); + }, + i18nKey: 'hrf.metadata.legacyStructure', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileControlWeight: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlWeight', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_control_weight'); + const parsed = z.number().min(0).max(2).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlWeight(value)); + }, + i18nKey: 'hrf.metadata.tileControlWeight', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileControlEnd: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlEnd', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_control_end'); + const parsed = z.number().min(0).max(1).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlEnd(value)); + }, + i18nKey: 'hrf.metadata.tileControlEnd', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileSize: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileSize', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_size'); + const parsed = z.number().int().min(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileSize(value)); + }, + i18nKey: 'hrf.metadata.tileSize', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileOverlap: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileOverlap', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_overlap'); + const parsed = z.number().int().min(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileOverlap(value)); + }, + i18nKey: 'hrf.metadata.tileOverlap', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfSteps: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfSteps', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_steps'); + const parsed = zParameterSteps.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfSteps(value)); + }, + i18nKey: 'hrf.metadata.steps', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_model'); + return parseModelIdentifier(raw, store, 'main'); + }, + recall: (value, store) => { + store.dispatch(setHrfModel(value as ParameterModel)); + }, + i18nKey: 'hrf.metadata.model', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfLoraModeMetadata: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfLoraMode', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_lora_mode'); + const parsed = zHrfLoraMode.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfLoraMode(value)); + }, + i18nKey: 'hrf.metadata.loraMode', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; +//#endregion High Resolution Fix + //#region SeamlessX const SeamlessX: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1310,6 +1586,58 @@ const LoRAs: CollectionMetadataHandler = { }; //#endregion LoRAs +//#region HRF LoRAs +const HrfLoRAs: CollectionMetadataHandler = { + [CollectionMetadataKey]: true, + type: 'HrfLoRAs', + parse: async (metadata, store) => { + const rawArray = getProperty(metadata, 'hrf_loras'); + + if (!rawArray) { + return []; + } + + assert(isArray(rawArray)); + + const loras: LoRA[] = []; + const effectiveHrfBase = getHrfLoRACompatibilityBase(metadata, store); + + for (const rawItem of rawArray) { + try { + const rawIdentifier = getProperty(rawItem, 'model'); + const identifier = await parseModelIdentifier(rawIdentifier, store, 'lora'); + assert(identifier.type === 'lora'); + assert(isCompatibleWithBase(identifier, effectiveHrfBase)); + + const weight = getProperty(rawItem, 'weight'); + + loras.push({ + id: getPrefixedId('hrf_lora'), + model: identifier, + weight: zLoRAWeight.parse(weight), + isEnabled: true, + }); + } catch { + continue; + } + } + + return loras; + }, + recallOne: (value, store) => { + store.dispatch(setHrfLoras([value])); + }, + recall: (values, store) => { + store.dispatch(setHrfLoras(values)); + }, + i18nKey: 'hrf.metadata.loras', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: CollectionMetadataValueProps) => ( + + ), +}; +//#endregion HRF LoRAs + //#region CanvasLayers const CanvasLayers: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1620,6 +1948,21 @@ export const ImageMetadataHandlers = { Seed, Steps, DenoisingStrength, + HrfEnabled, + HrfMethod, + HrfStrength, + HrfScale, + HrfLatentInterpolationMode: HrfLatentInterpolationModeMetadata, + HrfUpscaleModel, + HrfTileControlNetModel, + HrfStructure, + HrfTileControlWeight, + HrfTileControlEnd, + HrfTileSize, + HrfTileOverlap, + HrfSteps, + HrfModel, + HrfLoraMode: HrfLoraModeMetadata, SeamlessX, SeamlessY, RefinerModel, @@ -1651,6 +1994,7 @@ export const ImageMetadataHandlers = { QwenImageShift, ZImageShift, LoRAs, + HrfLoRAs, CanvasLayers, RefImages, ImageSize, @@ -1737,10 +2081,8 @@ const recallByHandlers = async (arg: { (handler) => !skip.some((skippedHandler) => skippedHandler.type === handler.type) ); - // It's possible for some metadata item's recall to clobber the recall of another. For example, the model recall - // may change the width and height. If we are also recalling the width and height directly, we need to ensure that the - // model is recalled first, so it doesn't accidentally override the width and height. This is the only known case - // where the order of recall matters. + // It's possible for some metadata item's recall to clobber another or to affect compatibility checks. For example, + // model recall may change dimensions, and LoRA-like handlers validate against the current model base. const sortedHandlers = filteredHandlers.sort((a, b) => { if (a === ImageMetadataHandlers.MainModel) { return -1; // MainModel should be recalled first @@ -1988,13 +2330,43 @@ const parseModelIdentifier = async (raw: unknown, store: AppStore, type: ModelTy }; const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppStore) => { - const base = selectBase(store.getState()); + return isCompatibleWithBase(candidate, selectBase(store.getState())); +}; + +const isCompatibleWithBase = (candidate: ModelIdentifierField, base: string | null | undefined) => { if (!base) { return true; } return candidate.base === base; }; +const getMetadataModelBase = (metadata: unknown, path: string): string | null => { + const raw = getProperty(metadata, path); + if (!raw) { + return null; + } + + const identifierResult = zModelIdentifierField.safeParse(raw); + if (identifierResult.success) { + return identifierResult.data.base; + } + + const oldIdentifierResult = zModelIdentifier.safeParse(raw); + if (oldIdentifierResult.success) { + return oldIdentifierResult.data.base_model; + } + + return null; +}; + +const getHrfLoRACompatibilityBase = (metadata: unknown, store: AppStore): string | null | undefined => { + return ( + getMetadataModelBase(metadata, 'hrf_model') ?? + getMetadataModelBase(metadata, 'model') ?? + selectBase(store.getState()) + ); +}; + const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise => { try { await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap(); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts new file mode 100644 index 00000000000..24acc9d9f1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -0,0 +1,859 @@ +import type { RootState } from 'app/store/store'; +import type { HrfMethod } from 'features/controlLayers/store/types'; +import { describe, expect, it } from 'vitest'; + +import { addHighResFix } from './addHighResFix'; +import { Graph } from './Graph'; + +const buildState = (overrides?: { + base?: string; + hrfEnabled?: boolean; + hrfMethod?: HrfMethod; + hrfSteps?: number | null; + hrfModel?: unknown; + hrfLoraMode?: 'reuse_generate' | 'none' | 'dedicated'; + hrfLoras?: unknown[]; + hrfTileSize?: number; + loras?: unknown[]; + refinerModel?: unknown; +}): RootState => + ({ + ui: { activeTab: 'generate' }, + loras: { loras: overrides?.loras ?? [] }, + params: { + model: { key: 'model', hash: 'model-hash', name: 'model', base: overrides?.base ?? 'sdxl', type: 'main' }, + dimensions: { width: 512, height: 512 }, + hrfEnabled: overrides?.hrfEnabled ?? true, + hrfMethod: overrides?.hrfMethod ?? 'latent', + hrfScale: 2, + hrfStrength: 0.35, + hrfLatentInterpolationMode: 'bilinear', + hrfUpscaleModel: { + key: 'upscale', + hash: 'upscale-hash', + name: 'upscale', + base: 'any', + type: 'spandrel_image_to_image', + }, + hrfTileControlNetModel: { + key: 'tile', + hash: 'tile-hash', + name: 'tile', + base: overrides?.base ?? 'sdxl', + type: 'controlnet', + }, + hrfTileControlWeight: 0.625, + hrfTileControlEnd: 0.2, + hrfTileSize: overrides?.hrfTileSize ?? 1024, + hrfTileOverlap: 128, + hrfSteps: overrides?.hrfSteps ?? null, + hrfModel: overrides?.hrfModel ?? null, + hrfLoraMode: overrides?.hrfLoraMode ?? 'reuse_generate', + hrfLoras: overrides?.hrfLoras ?? [], + optimizedDenoisingEnabled: true, + refinerModel: overrides?.refinerModel ?? null, + }, + }) as unknown as RootState; + +const buildClassicGraph = () => { + const g = new Graph('test_graph'); + const seed = g.addNode({ id: 'seed', type: 'integer' }); + const noise = g.addNode({ id: 'noise', type: 'noise', use_cpu: true, width: 512, height: 512 }); + const modelLoader = g.addNode({ + id: 'model_loader', + type: 'sdxl_model_loader', + model: { key: 'model', hash: 'model-hash', name: 'model', base: 'sdxl', type: 'main' }, + }); + const denoise = g.addNode({ + id: 'denoise', + type: 'denoise_latents', + cfg_scale: 7.5, + scheduler: 'euler', + steps: 30, + }); + const l2i = g.addNode({ id: 'l2i', type: 'l2i', fp32: true }); + + g.addEdge(seed, 'value', noise, 'seed'); + g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(modelLoader, 'unet', denoise, 'unet'); + g.addEdge(modelLoader, 'vae', l2i, 'vae'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + return { g, seed, noise, modelLoader, denoise, l2i }; +}; + +const addSDXLConditioning = (g: Graph, denoise: ReturnType['denoise']) => { + const posCond = g.addNode({ + id: 'pos_cond', + type: 'sdxl_compel_prompt', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + const posCollect = g.addNode({ id: 'pos_collect', type: 'collect' }); + const negCond = g.addNode({ + id: 'neg_cond', + type: 'sdxl_compel_prompt', + prompt: 'negative', + style: 'negative', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + const negCollect = g.addNode({ id: 'neg_collect', type: 'collect' }); + + g.addEdge(posCond, 'conditioning', posCollect, 'item'); + g.addEdge(posCollect, 'collection', denoise, 'positive_conditioning'); + g.addEdge(negCond, 'conditioning', negCollect, 'item'); + g.addEdge(negCollect, 'collection', denoise, 'negative_conditioning'); + + return { posCond, posCollect, negCond, negCollect }; +}; + +const addSDXLRegionalConditioning = (g: Graph, conditioning: ReturnType) => { + const mask = g.addNode({ + id: 'region_mask', + type: 'alpha_mask_to_tensor', + image: { image_name: 'region-mask' }, + }); + const regionPrompt = g.addNode({ id: 'region_prompt', type: 'string', value: 'regional prompt' }); + const regionPosCond = g.addNode({ + id: 'region_pos_cond', + type: 'sdxl_compel_prompt', + style: 'regional positive style', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + const regionNegCond = g.addNode({ + id: 'region_neg_cond', + type: 'sdxl_compel_prompt', + prompt: 'regional negative', + style: 'regional negative style', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + + g.addEdge(regionPrompt, 'value', regionPosCond, 'prompt'); + g.addEdgeFromObj({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: regionPosCond.id, field: 'mask' }, + }); + g.addEdgeFromObj({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: regionNegCond.id, field: 'mask' }, + }); + g.addEdge(regionPosCond, 'conditioning', conditioning.posCollect, 'item'); + g.addEdge(regionNegCond, 'conditioning', conditioning.negCollect, 'item'); + + return { mask, regionPrompt, regionPosCond, regionNegCond }; +}; + +const addSeamlessToClassicGraph = ({ + g, + modelLoader, + denoise, +}: Pick, 'g' | 'modelLoader' | 'denoise'>) => { + const seamless = g.addNode({ + id: 'seamless', + type: 'seamless', + seamless_x: true, + seamless_y: false, + }); + + g.deleteEdgesTo(denoise, ['unet']); + g.addEdge(modelLoader, 'unet', seamless, 'unet'); + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + g.addEdge(seamless, 'unet', denoise, 'unet'); + + return seamless; +}; + +const buildTransformerGraph = () => { + const g = new Graph('test_transformer_graph'); + const seed = g.addNode({ id: 'seed', type: 'integer' }); + const denoise = g.addNode({ + id: 'sd3_denoise', + type: 'sd3_denoise', + cfg_scale: 4, + width: 512, + height: 512, + steps: 20, + denoising_start: 0, + denoising_end: 1, + }); + const l2i = g.addNode({ id: 'sd3_l2i', type: 'sd3_l2i' }); + + g.addEdge(seed, 'value', denoise, 'seed'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + return { g, seed, denoise, l2i }; +}; + +describe('addHighResFix', () => { + it('reroutes classic txt2img graphs through latent resize and a second denoise pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ g, state: buildState(), generationMode: 'txt2img', denoise, l2i, noise, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const resize = nodes.find((node) => node.type === 'lresize'); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfNoise = nodes.find((node) => node.id.startsWith('hrf_noise')); + + expect(resize).toMatchObject({ type: 'lresize', width: 1024, height: 1024, mode: 'bilinear' }); + expect(hrfDenoise).toMatchObject({ type: 'denoise_latents', denoising_start: 0.65, denoising_end: 1 }); + expect(hrfNoise).toMatchObject({ type: 'noise', width: 1024, height: 1024, use_cpu: true }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + width: 512, + height: 512, + hrf_enabled: true, + hrf_method: 'latent', + hrf_strength: 0.35, + hrf_scale: 2, + hrf_latent_interpolation_mode: 'bilinear', + }); + }); + + it('ignores stale upscale-model-only settings when latent HRF is selected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'latent', + hrfSteps: 12, + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'dedicated', + hrfTileSize: 1536, + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const nodes = Object.values(g.getGraph().nodes); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + + expect(nodes.some((node) => node.type === 'lresize')).toBe(true); + expect(nodes.some((node) => node.type === 'spandrel_image_to_image_autoscale')).toBe(false); + expect(nodes.some((node) => node.id.startsWith('hrf_sdxl_model_loader'))).toBe(false); + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(hrfDenoise).toMatchObject({ steps: 30 }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); + expect(g.getMetadataNode()).toMatchObject({ hrf_method: 'latent' }); + expect(g.getMetadataNode().hrf_steps).toBeUndefined(); + }); + + it('preserves the original graph and writes disabled metadata when HRF is off', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ + g, + state: buildState({ hrfEnabled: false }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + expect(Object.values(graph.nodes).some((node) => node.type === 'lresize')).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ hrf_enabled: false }); + }); + + it('applies custom HRF steps only to the second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model', hrfSteps: 12 }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const hrfDenoise = Object.values(g.getGraph().nodes).find( + (node) => node.type === 'tiled_multi_diffusion_denoise_latents' + ); + + expect(denoise.steps).toBe(30); + expect(hrfDenoise).toMatchObject({ steps: 12 }); + expect(g.getMetadataNode()).toMatchObject({ hrf_steps: 12 }); + }); + + it('preserves unsupported model-family graphs and writes disabled metadata', () => { + const { g, seed, denoise, l2i } = buildTransformerGraph(); + + addHighResFix({ g, state: buildState({ base: 'sd-3' }), generationMode: 'txt2img', denoise, l2i, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + expect(nodes.some((node) => node.type === 'lresize')).toBe(false); + expect(nodes.some((node) => node.id.startsWith('hrf_noise'))).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: 'sd3_denoise', field: 'latents' }, + destination: { node_id: 'sd3_l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ hrf_enabled: false }); + }); + + it('clones SDXL conditioning with final HRF dimensions for the second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const { posCond, posCollect } = addSDXLConditioning(g, denoise); + + addHighResFix({ g, state: buildState(), generationMode: 'txt2img', denoise, l2i, noise, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const hrfPosCollect = nodes.find((node) => node.id.startsWith('hrf_pos_collect_conditioning_collect')); + + if (!hrfDenoise || !hrfPosCond || !hrfPosCollect) { + throw new Error('Expected HRF denoise and cloned SDXL conditioning nodes'); + } + + expect(posCond).toMatchObject({ original_width: 512, original_height: 512 }); + expect(hrfPosCond).toMatchObject({ + type: 'sdxl_compel_prompt', + original_width: 1024, + original_height: 1024, + target_width: 1024, + target_height: 1024, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPosCond.id, field: 'conditioning' }, + destination: { node_id: hrfPosCollect.id, field: 'item' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPosCollect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field: 'positive_conditioning' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: posCollect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field: 'positive_conditioning' }, + }); + }); + + it('reroutes SDXL txt2img graphs through an upscale model, tiled encode, and tiled second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model' }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const intermediateL2i = nodes.find((node) => node.id.startsWith('hrf_intermediate_l2i')); + const spandrel = nodes.find((node) => node.type === 'spandrel_image_to_image_autoscale'); + const unsharp = nodes.find((node) => node.type === 'unsharp_mask'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const tileControlNet = nodes.find((node) => node.id.startsWith('hrf_controlnet')); + + if (!intermediateL2i || !spandrel || !unsharp || !i2l || !tiledDenoise || !tileControlNet) { + throw new Error('Expected upscale-model HRF nodes'); + } + + expect(nodes.some((node) => node.type === 'lresize')).toBe(false); + expect(intermediateL2i).toMatchObject({ type: 'l2i', is_intermediate: true }); + expect(spandrel).toMatchObject({ + type: 'spandrel_image_to_image_autoscale', + image_to_image_model: { key: 'upscale' }, + scale: 2, + tile_size: 1024, + fit_to_multiple_of_8: true, + }); + expect(i2l).toMatchObject({ type: 'i2l', tiled: true, tile_size: 1024 }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); + expect(tiledDenoise).toMatchObject({ + type: 'tiled_multi_diffusion_denoise_latents', + tile_height: 1024, + tile_width: 1024, + tile_overlap: 128, + denoising_start: 0.65, + denoising_end: 1, + }); + expect(tileControlNet).toMatchObject({ + type: 'controlnet', + begin_step_percent: 0, + end_step_percent: 0.2, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: tileControlNet.id, field: 'control' }, + destination: { node_id: tiledDenoise.id, field: 'control' }, + }); + const positiveConditioningEdge = graph.edges.find( + (edge) => edge.destination.node_id === tiledDenoise.id && edge.destination.field === 'positive_conditioning' + ); + if (!positiveConditioningEdge) { + throw new Error('Expected positive conditioning edge'); + } + expect(graph.nodes[positiveConditioningEdge.source.node_id]?.type).not.toBe('collect'); + expect(graph.edges).not.toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: intermediateL2i.id, field: 'latents' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: tiledDenoise.id, field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + const metadataNode = g.getMetadataNode(); + expect(graph.edges).not.toContainEqual({ + source: { node_id: spandrel.id, field: 'width' }, + destination: { node_id: metadataNode.id, field: 'width' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: spandrel.id, field: 'height' }, + destination: { node_id: metadataNode.id, field: 'height' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + width: 512, + height: 512, + hrf_enabled: true, + hrf_method: 'upscale_model', + hrf_upscale_model: { key: 'upscale' }, + hrf_tile_controlnet_model: { key: 'tile' }, + hrf_tile_control_weight: 0.625, + hrf_tile_control_end: 0.2, + hrf_tile_size: 1024, + hrf_tile_overlap: 128, + hrf_lora_mode: 'reuse_generate', + }); + }); + + it('uses a dedicated SDXL HRF model and custom steps for the second pass only', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfSteps: 14, + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfModelLoader = nodes.find((node) => node.id.startsWith('hrf_sdxl_model_loader')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + + if (!hrfModelLoader || !hrfPosCond || !tiledDenoise || !i2l) { + throw new Error('Expected dedicated HRF model graph nodes'); + } + + expect(hrfModelLoader).toMatchObject({ type: 'sdxl_model_loader', model: { key: 'hrf-model' } }); + expect(tiledDenoise).toMatchObject({ steps: 14 }); + expect(denoise.steps).toBe(30); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'clip' }, + destination: { node_id: hrfPosCond.id, field: 'clip' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: i2l.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: 'l2i', field: 'vae' }, + }); + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(g.getMetadataNode()).toMatchObject({ + hrf_steps: 14, + hrf_model: { key: 'hrf-model' }, + hrf_lora_mode: 'none', + }); + }); + + it('preserves regional SDXL conditioning when a dedicated HRF model uses classic fallback', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const conditioning = addSDXLConditioning(g, denoise); + const { mask, regionPrompt } = addSDXLRegionalConditioning(g, conditioning); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfRegionPosCond = nodes.find((node) => node.id.startsWith('hrf_region_pos_cond')); + const hrfPositiveCollect = nodes.find((node) => node.id.startsWith('hrf_positive_conditioning_collect')); + + if (!classicHrfDenoise || !hrfRegionPosCond || !hrfPositiveCollect) { + throw new Error('Expected classic HRF denoise and cloned regional conditioning nodes'); + } + + expect(nodes.some((node) => node.type === 'tiled_multi_diffusion_denoise_latents')).toBe(false); + expect(hrfRegionPosCond).toMatchObject({ + type: 'sdxl_compel_prompt', + original_width: 1024, + original_height: 1024, + target_width: 1024, + target_height: 1024, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: regionPrompt.id, field: 'value' }, + destination: { node_id: hrfRegionPosCond.id, field: 'prompt' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: hrfRegionPosCond.id, field: 'mask' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfRegionPosCond.id, field: 'conditioning' }, + destination: { node_id: hrfPositiveCollect.id, field: 'item' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPositiveCollect.id, field: 'collection' }, + destination: { node_id: classicHrfDenoise.id, field: 'positive_conditioning' }, + }); + }); + + it('applies dedicated HRF LoRA CLIP outputs to all regional conditioning clones', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const conditioning = addSDXLConditioning(g, denoise); + addSDXLRegionalConditioning(g, conditioning); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const hrfRegionPosCond = nodes.find((node) => node.id.startsWith('hrf_region_pos_cond')); + const hrfRegionNegCond = nodes.find((node) => node.id.startsWith('hrf_region_neg_cond')); + + if (!loraLoader || !classicHrfDenoise || !hrfPosCond || !hrfRegionPosCond || !hrfRegionNegCond) { + throw new Error('Expected dedicated HRF LoRA and cloned regional conditioning nodes'); + } + + for (const cond of [hrfPosCond, hrfRegionPosCond, hrfRegionNegCond]) { + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'clip' }, + destination: { node_id: cond.id, field: 'clip' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'clip2' }, + destination: { node_id: cond.id, field: 'clip2' }, + }); + } + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: classicHrfDenoise.id, field: 'unet' }, + }); + }); + + it('applies only dedicated HRF LoRAs when dedicated LoRA mode is selected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + + if (!loraLoader || !tiledDenoise) { + throw new Error('Expected dedicated HRF LoRA graph nodes'); + } + + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + hrf_lora_mode: 'dedicated', + hrf_loras: [{ model: { key: 'hrf-lora' }, weight: 0.6 }], + }); + }); + + it('ignores incompatible stale dedicated HRF LoRAs in the graph and metadata', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'stale-sd1-lora-id', + isEnabled: true, + model: { key: 'stale-sd1-lora', hash: 'stale-hash', name: 'Stale SD1 LoRA', base: 'sd-1', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const nodes = Object.values(g.getGraph().nodes); + + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(nodes.some((node) => node.type === 'lora_selector')).toBe(false); + expect(g.getMetadataNode()).toMatchObject({ hrf_lora_mode: 'dedicated' }); + expect(g.getMetadataNode().hrf_loras).toEqual([]); + }); + + it('preserves seamless routing when a dedicated HRF model is selected', () => { + const { g, seed, noise, modelLoader, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + addSeamlessToClassicGraph({ g, modelLoader, denoise }); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfModelLoader = nodes.find((node) => node.id.startsWith('hrf_sdxl_model_loader')); + const hrfSeamless = nodes.find((node) => node.id.startsWith('hrf_seamless')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + + if (!hrfModelLoader || !hrfSeamless || !tiledDenoise || !i2l) { + throw new Error('Expected dedicated HRF seamless graph nodes'); + } + + expect(hrfSeamless).toMatchObject({ type: 'seamless', seamless_x: true, seamless_y: false }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'unet' }, + destination: { node_id: hrfSeamless.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: hrfSeamless.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'vae' }, + destination: { node_id: i2l.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'vae' }, + destination: { node_id: 'l2i', field: 'vae' }, + }); + }); + + it('preserves seamless routing before dedicated HRF LoRAs', () => { + const { g, seed, noise, modelLoader, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + addSeamlessToClassicGraph({ g, modelLoader, denoise }); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfSeamless = nodes.find((node) => node.id.startsWith('hrf_seamless')); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + + if (!hrfSeamless || !loraLoader || !tiledDenoise) { + throw new Error('Expected HRF seamless and dedicated LoRA graph nodes'); + } + + expect(graph.edges).toContainEqual({ + source: { node_id: 'model_loader', field: 'vae' }, + destination: { node_id: hrfSeamless.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: loraLoader.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + }); + + it('uses a regular second denoise for upscale-model HRF when reference image adapters are connected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + const ipAdapter = g.addNode({ + id: 'ip_adapter', + type: 'ip_adapter', + weight: 1, + method: 'full', + ip_adapter_model: { key: 'ip', hash: 'ip-hash', name: 'ip', base: 'sdxl', type: 'ip_adapter' }, + clip_vision_model: 'ViT-H', + begin_step_percent: 0, + end_step_percent: 1, + image: { image_name: 'test' }, + }); + const ipAdapterCollector = g.addNode({ id: 'ip_adapter_collector', type: 'collect' }); + g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollector, 'item'); + g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model' }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + + if (!classicHrfDenoise) { + throw new Error('Expected classic HRF denoise node'); + } + + expect(nodes.some((node) => node.type === 'tiled_multi_diffusion_denoise_latents')).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: ipAdapterCollector.id, field: 'collection' }, + destination: { node_id: classicHrfDenoise.id, field: 'ip_adapter' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: classicHrfDenoise.id, field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts new file mode 100644 index 00000000000..f7dbb860a47 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -0,0 +1,855 @@ +import type { RootState } from 'app/store/store'; +import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import { isHrfSupportedBase, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import type { GenerationMode, LoRA } from 'features/controlLayers/store/types'; +import type { BaseModelType } from 'features/nodes/types/common'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs'; +import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { DenoiseLatentsNodes, LatentToImageNodes } from 'features/nodes/util/graph/types'; +import { getGridSize } from 'features/parameters/util/optimalDimension'; +import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { AnyInvocation, AnyInvocationInputField, AnyInvocationOutputField, Invocation } from 'services/api/types'; +import { assert } from 'tsafe'; + +type AddHighResFixArg = { + g: Graph; + state: RootState; + generationMode: GenerationMode; + denoise: Invocation; + l2i: Invocation; + noise?: Invocation<'noise'>; + seed: Invocation<'integer'>; +}; + +type HrfModelLoader = Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>; +type HrfVaeSource = HrfModelLoader | Invocation<'seamless'>; +type FreshHrfModelInputs = { + modelLoader: HrfModelLoader; + vaeSource: HrfVaeSource; +}; +type ConditioningSource = { + node_id: string; + field: AnyInvocationOutputField; +}; + +const SKIPPED_DENOISE_INPUT_FIELDS = new Set(['latents', 'noise']); +const FRESH_HRF_MODEL_INPUT_FIELDS = new Set(['unet', 'positive_conditioning', 'negative_conditioning']); +const TILED_DENOISE_INPUT_FIELDS = new Set(['positive_conditioning', 'negative_conditioning', 'unet']); +const SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS = new Set(['latents', 'metadata']); +const LATENT_HRF_L2I_TILE_SIZE = 1024; + +const getHighResFixFinalDimensions = (state: RootState) => { + const params = selectParamsSlice(state); + const gridSize = getGridSize(params.model?.base as BaseModelType | undefined); + + return { + width: Math.max(roundDownToMultiple(params.dimensions.width * params.hrfScale, gridSize), 64), + height: Math.max(roundDownToMultiple(params.dimensions.height * params.hrfScale, gridSize), 64), + }; +}; + +const getHighResFixDenoisingStartAndEnd = (state: RootState): { denoising_start: number; denoising_end: number } => { + const params = selectParamsSlice(state); + const model = params.model; + const { hrfStrength, optimizedDenoisingEnabled } = params; + + switch (model?.base) { + case 'sd-3': + case 'flux': + case 'flux2': { + if (model.base === 'flux' && 'variant' in model && model.variant === 'dev_fill') { + return { denoising_start: 0, denoising_end: 1 }; + } + + const exponent = optimizedDenoisingEnabled ? 0.2 : 1; + return { denoising_start: 1 - hrfStrength ** exponent, denoising_end: 1 }; + } + case 'anima': + case 'sd-1': + case 'sd-2': + case 'sdxl': + case 'cogview4': + case 'qwen-image': + case 'z-image': { + return { denoising_start: 1 - hrfStrength, denoising_end: 1 }; + } + default: { + assert(false, `Unsupported base for high resolution fix: ${model?.base}`); + } + } +}; + +const shouldApplyHighResFix = (state: RootState, generationMode: GenerationMode) => { + const params = selectParamsSlice(state); + const model = params.model; + + return ( + selectActiveTab(state) === 'generate' && + generationMode === 'txt2img' && + params.hrfEnabled && + model !== null && + isHrfSupportedBase(model.base) && + !params.refinerModel + ); +}; + +const shouldWriteDisabledMetadata = (state: RootState, generationMode: GenerationMode) => { + return selectActiveTab(state) === 'generate' && generationMode === 'txt2img'; +}; + +const cloneSDXLCompelPromptForFinalDimensions = ( + g: Graph, + node: Invocation<'sdxl_compel_prompt'>, + finalDimensions: { width: number; height: number } +) => { + const clone = g.addNode({ + ...node, + id: getPrefixedId(`hrf_${node.id.split(':')[0]}`), + original_width: finalDimensions.width, + original_height: finalDimensions.height, + target_width: finalDimensions.width, + target_height: finalDimensions.height, + }); + + for (const edge of g.getEdgesTo(node)) { + g.addEdgeFromObj({ + source: { ...edge.source }, + destination: { node_id: clone.id, field: edge.destination.field }, + }); + } + + return clone; +}; + +const cloneSDXLConditioningForFinalDimensions = ( + g: Graph, + sourceNodeId: string, + finalDimensions: { width: number; height: number }, + outputType: 'collection' | 'single' +) => { + const sourceNode = g.getNode(sourceNodeId); + + if (sourceNode.type === 'sdxl_compel_prompt') { + return { + nodeId: cloneSDXLCompelPromptForFinalDimensions(g, sourceNode, finalDimensions).id, + field: 'conditioning', + }; + } + + if (sourceNode.type !== 'collect') { + return null; + } + + const itemEdges = g.getEdgesTo(sourceNode).filter((edge) => edge.destination.field === 'item'); + const hasSDXLConditioning = itemEdges.some((edge) => g.getNode(edge.source.node_id).type === 'sdxl_compel_prompt'); + + if (!hasSDXLConditioning) { + if (outputType === 'single' && itemEdges.length === 1) { + return { nodeId: itemEdges[0]!.source.node_id, field: itemEdges[0]!.source.field }; + } + + return null; + } + + if (outputType === 'single') { + if (itemEdges.length !== 1) { + return null; + } + + const edge = itemEdges[0]!; + const itemNode = g.getNode(edge.source.node_id); + const source = + itemNode.type === 'sdxl_compel_prompt' + ? cloneSDXLCompelPromptForFinalDimensions(g, itemNode, finalDimensions) + : itemNode; + + return { nodeId: source.id, field: edge.source.field }; + } + + const collect = g.addNode({ + type: 'collect', + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}_conditioning_collect`), + }); + + for (const edge of itemEdges) { + const itemNode = g.getNode(edge.source.node_id); + const source = + itemNode.type === 'sdxl_compel_prompt' + ? cloneSDXLCompelPromptForFinalDimensions(g, itemNode, finalDimensions) + : itemNode; + + g.addEdgeFromObj({ + source: { node_id: source.id, field: edge.source.field }, + destination: { node_id: collect.id, field: 'item' }, + }); + } + + return { nodeId: collect.id, field: 'collection' }; +}; + +const copyDenoiseInputs = ( + g: Graph, + from: AnyInvocation, + to: AnyInvocation, + finalDimensions: { width: number; height: number }, + allowedInputFields?: Set, + skippedInputFields?: Set +) => { + for (const edge of g.getEdgesTo(from)) { + if (SKIPPED_DENOISE_INPUT_FIELDS.has(edge.destination.field)) { + continue; + } + if (skippedInputFields?.has(edge.destination.field)) { + continue; + } + if (allowedInputFields && !allowedInputFields.has(edge.destination.field)) { + continue; + } + + const finalSizeConditioning = ['positive_conditioning', 'negative_conditioning'].includes(edge.destination.field) + ? cloneSDXLConditioningForFinalDimensions( + g, + edge.source.node_id, + finalDimensions, + to.type === 'tiled_multi_diffusion_denoise_latents' ? 'single' : 'collection' + ) + : null; + + g.addEdgeFromObj({ + source: finalSizeConditioning + ? { + node_id: finalSizeConditioning.nodeId, + field: finalSizeConditioning.field as AnyInvocationOutputField, + } + : { ...edge.source }, + destination: { node_id: to.id, field: edge.destination.field as AnyInvocationInputField }, + }); + } +}; + +const copyInputEdges = (g: Graph, from: AnyInvocation, to: AnyInvocation, skippedInputFields: Set) => { + for (const edge of g.getEdgesTo(from)) { + if (skippedInputFields.has(edge.destination.field)) { + continue; + } + g.addEdgeFromObj({ + source: { ...edge.source }, + destination: { node_id: to.id, field: edge.destination.field }, + }); + } +}; + +const hasUnsupportedTiledDenoiseInputs = (g: Graph, denoise: Invocation<'denoise_latents'>) => { + return g.getEdgesTo(denoise).some((edge) => { + const field = edge.destination.field; + if (SKIPPED_DENOISE_INPUT_FIELDS.has(field)) { + return false; + } + + if (!TILED_DENOISE_INPUT_FIELDS.has(field)) { + return true; + } + + if (['positive_conditioning', 'negative_conditioning'].includes(field)) { + const sourceNode = g.getNode(edge.source.node_id); + if (sourceNode.type === 'collect') { + const itemEdges = g.getEdgesTo(sourceNode).filter((itemEdge) => itemEdge.destination.field === 'item'); + return itemEdges.length !== 1; + } + } + + return false; + }); +}; + +const addTileControlNet = ( + g: Graph, + hrfDenoise: AnyInvocation, + imageSource: Invocation<'unsharp_mask'>, + tileControlNetModel: NonNullable['hrfTileControlNetModel']>, + tileControlWeight: number, + tileControlEnd: number +) => { + const controlNet = g.addNode({ + id: getPrefixedId('hrf_controlnet'), + type: 'controlnet', + control_model: tileControlNetModel, + control_mode: 'balanced', + resize_mode: 'just_resize', + control_weight: tileControlWeight, + begin_step_percent: 0, + end_step_percent: tileControlEnd, + }); + + g.addEdge(imageSource, 'image', controlNet, 'image'); + g.addEdgeFromObj({ + source: { node_id: controlNet.id, field: 'control' }, + destination: { node_id: hrfDenoise.id, field: 'control' }, + }); +}; + +const getConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning' +): ConditioningSource[] => { + const edge = g.getEdgesTo(denoise).find((edge) => edge.destination.field === field); + assert(edge, `Missing ${field} edge for HRF second pass`); + + const sourceNode = g.getNode(edge.source.node_id); + if (sourceNode.type !== 'collect') { + return [{ node_id: edge.source.node_id, field: edge.source.field as AnyInvocationOutputField }]; + } + + const itemEdges = g.getEdgesTo(sourceNode).filter((edge) => edge.destination.field === 'item'); + assert(itemEdges.length > 0, `HRF second pass expects at least one ${field} conditioning source`); + + return itemEdges.map((edge) => ({ + node_id: edge.source.node_id, + field: edge.source.field as AnyInvocationOutputField, + })); +}; + +const connectConditioningToDenoise = ( + g: Graph, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning', + conditioning: Array | Invocation<'sdxl_compel_prompt'>> +) => { + assert(conditioning.length > 0, `HRF second pass expects at least one ${field} conditioning node`); + + if (conditioning.length === 1) { + g.addEdgeFromObj({ + source: { node_id: conditioning[0]!.id, field: 'conditioning' }, + destination: { node_id: hrfDenoise.id, field }, + }); + return; + } + + assert(hrfDenoise.type === 'denoise_latents', 'Tiled HRF second pass supports only single conditioning inputs'); + + const collect = g.addNode({ + type: 'collect', + id: getPrefixedId(`hrf_${field}_collect`), + }); + + for (const cond of conditioning) { + g.addEdge(cond, 'conditioning', collect, 'item'); + } + + g.addEdgeFromObj({ + source: { node_id: collect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field }, + }); +}; + +const cloneSD1ConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning' +) => { + return getConditioningSources(g, denoise, field).map((source) => { + const sourceNode = g.getNode(source.node_id); + assert(sourceNode.type === 'compel', 'SD1 HRF refinement requires SD1 conditioning'); + + const clone = g.addNode({ + ...sourceNode, + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}`), + }); + copyInputEdges(g, sourceNode, clone, new Set(['clip'])); + return clone; + }); +}; + +const cloneSDXLConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning', + finalDimensions: { width: number; height: number } +) => { + return getConditioningSources(g, denoise, field).map((source) => { + const sourceNode = g.getNode(source.node_id); + assert(sourceNode.type === 'sdxl_compel_prompt', 'SDXL HRF refinement requires SDXL conditioning'); + + const clone = g.addNode({ + ...sourceNode, + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}`), + original_width: finalDimensions.width, + original_height: finalDimensions.height, + target_width: finalDimensions.width, + target_height: finalDimensions.height, + }); + copyInputEdges(g, sourceNode, clone, new Set(['clip', 'clip2'])); + return clone; + }); +}; + +const findOriginalSeamlessNode = (g: Graph, denoise: Invocation<'denoise_latents'>) => { + let current: AnyInvocation = denoise; + const visited = new Set(); + + while (!visited.has(current.id)) { + visited.add(current.id); + const unetEdge = g.getEdgesTo(current).find((edge) => edge.destination.field === 'unet'); + if (!unetEdge) { + return null; + } + + const sourceNode = g.getNode(unetEdge.source.node_id); + if (sourceNode.type === 'seamless') { + return sourceNode; + } + + current = sourceNode; + } + + return null; +}; + +const addFreshHrfSeamless = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + modelLoader: HrfModelLoader, + useHrfModelVae: boolean +) => { + const originalSeamless = findOriginalSeamlessNode(g, denoise); + if (!originalSeamless) { + return null; + } + + const seamless = g.addNode({ + ...originalSeamless, + id: getPrefixedId('hrf_seamless'), + }); + + g.addEdge(modelLoader, 'unet', seamless, 'unet'); + + if (useHrfModelVae) { + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + } else { + const originalVaeEdge = g.getEdgesTo(originalSeamless).find((edge) => edge.destination.field === 'vae'); + if (originalVaeEdge) { + g.addEdgeFromObj({ + source: { ...originalVaeEdge.source }, + destination: { node_id: seamless.id, field: 'vae' }, + }); + } else { + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + } + } + + g.addEdge(seamless, 'unet', hrfDenoise, 'unet'); + return seamless; +}; + +const getEnabledHrfLoRAs = (state: RootState): LoRA[] | null => { + const params = selectParamsSlice(state); + + if (params.hrfLoraMode === 'none') { + return null; + } + + if (params.hrfLoraMode === 'dedicated') { + return getEnabledDedicatedHrfLoRAs(state); + } + + return state.loras.loras.filter((lora) => lora.isEnabled); +}; + +const getEnabledDedicatedHrfLoRAs = (state: RootState): LoRA[] => { + const params = selectParamsSlice(state); + const effectiveHrfBase = params.hrfModel?.base ?? params.model?.base; + if (!effectiveHrfBase) { + return []; + } + return params.hrfLoras.filter((lora) => lora.isEnabled && lora.model.base === effectiveHrfBase); +}; + +const getHrfLoRAMetadata = (state: RootState) => { + const params = selectParamsSlice(state); + if (params.hrfLoraMode !== 'dedicated') { + return undefined; + } + + return getEnabledDedicatedHrfLoRAs(state).map((lora) => ({ + model: zModelIdentifierField.parse(lora.model), + weight: lora.weight, + })); +}; + +const addFreshSD1HrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'> +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + const hrfModel = params.hrfModel ?? params.model; + assert(hrfModel?.base === 'sd-1', 'SD1 HRF refinement model must be an SD1 model'); + + const modelLoader = g.addNode({ + type: 'main_model_loader', + id: getPrefixedId('hrf_sd1_model_loader'), + model: hrfModel, + }); + const clipSkip = g.addNode({ + type: 'clip_skip', + id: getPrefixedId('hrf_clip_skip'), + skipped_layers: params.clipSkip, + }); + + const positiveConditioning = cloneSD1ConditioningSources(g, denoise, 'positive_conditioning'); + const negativeConditioning = cloneSD1ConditioningSources(g, denoise, 'negative_conditioning'); + const hrfSeamless = addFreshHrfSeamless(g, denoise, hrfDenoise, modelLoader, params.hrfModel !== null); + + g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); + for (const cond of positiveConditioning) { + g.addEdge(clipSkip, 'clip', cond, 'clip'); + } + for (const cond of negativeConditioning) { + g.addEdge(clipSkip, 'clip', cond, 'clip'); + } + if (!hrfSeamless) { + g.addEdgeFromObj({ + source: { node_id: modelLoader.id, field: 'unet' }, + destination: { node_id: hrfDenoise.id, field: 'unet' }, + }); + } + connectConditioningToDenoise(g, hrfDenoise, 'positive_conditioning', positiveConditioning); + connectConditioningToDenoise(g, hrfDenoise, 'negative_conditioning', negativeConditioning); + + const enabledLoRAs = getEnabledHrfLoRAs(state); + if (enabledLoRAs?.length) { + addLoRAs( + state, + g, + hrfDenoise, + modelLoader, + hrfSeamless, + clipSkip, + positiveConditioning[0]!, + negativeConditioning[0]!, + { + loras: enabledLoRAs, + idPrefix: 'hrf', + metadataKey: params.hrfLoraMode === 'dedicated' ? 'hrf_loras' : 'loras', + extraPositiveConditioning: positiveConditioning.slice(1), + extraNegativeConditioning: negativeConditioning.slice(1), + } + ); + } + + return { modelLoader, vaeSource: hrfSeamless ?? modelLoader }; +}; + +const addFreshSDXLHrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + finalDimensions: { width: number; height: number } +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + const hrfModel = params.hrfModel ?? params.model; + assert(hrfModel?.base === 'sdxl', 'SDXL HRF refinement model must be an SDXL model'); + + const modelLoader = g.addNode({ + type: 'sdxl_model_loader', + id: getPrefixedId('hrf_sdxl_model_loader'), + model: hrfModel, + }); + + const positiveConditioning = cloneSDXLConditioningSources(g, denoise, 'positive_conditioning', finalDimensions); + const negativeConditioning = cloneSDXLConditioningSources(g, denoise, 'negative_conditioning', finalDimensions); + const hrfSeamless = addFreshHrfSeamless(g, denoise, hrfDenoise, modelLoader, params.hrfModel !== null); + + for (const cond of positiveConditioning) { + g.addEdge(modelLoader, 'clip', cond, 'clip'); + g.addEdge(modelLoader, 'clip2', cond, 'clip2'); + } + for (const cond of negativeConditioning) { + g.addEdge(modelLoader, 'clip', cond, 'clip'); + g.addEdge(modelLoader, 'clip2', cond, 'clip2'); + } + if (!hrfSeamless) { + g.addEdgeFromObj({ + source: { node_id: modelLoader.id, field: 'unet' }, + destination: { node_id: hrfDenoise.id, field: 'unet' }, + }); + } + connectConditioningToDenoise(g, hrfDenoise, 'positive_conditioning', positiveConditioning); + connectConditioningToDenoise(g, hrfDenoise, 'negative_conditioning', negativeConditioning); + + const enabledLoRAs = getEnabledHrfLoRAs(state); + if (enabledLoRAs?.length) { + addSDXLLoRAs(state, g, hrfDenoise, modelLoader, hrfSeamless, positiveConditioning[0]!, negativeConditioning[0]!, { + loras: enabledLoRAs, + idPrefix: 'hrf', + metadataKey: params.hrfLoraMode === 'dedicated' ? 'hrf_loras' : 'loras', + extraPositiveConditioning: positiveConditioning.slice(1), + extraNegativeConditioning: negativeConditioning.slice(1), + }); + } + + return { modelLoader, vaeSource: hrfSeamless ?? modelLoader }; +}; + +const addFreshHrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + finalDimensions: { width: number; height: number } +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + + if (params.model?.base === 'sdxl') { + return addFreshSDXLHrfModelInputs(state, g, denoise, hrfDenoise, finalDimensions); + } + + return addFreshSD1HrfModelInputs(state, g, denoise, hrfDenoise); +}; + +const addLatentHighResFix = ({ + g, + state, + denoise, + l2i, + noise, + seed, +}: AddHighResFixArg): Invocation => { + const params = selectParamsSlice(state); + const finalDimensions = getHighResFixFinalDimensions(state); + const { denoising_start, denoising_end } = getHighResFixDenoisingStartAndEnd(state); + + const resizeLatents = g.addNode({ + id: getPrefixedId('hrf_resize_latents'), + type: 'lresize', + ...finalDimensions, + mode: params.hrfLatentInterpolationMode, + antialias: false, + }); + + const hrfDenoise = g.addNode({ + ...denoise, + id: getPrefixedId(`hrf_${denoise.type}`), + denoising_start, + denoising_end, + ...(denoise.type === 'denoise_latents' ? {} : finalDimensions), + } as Invocation); + + copyDenoiseInputs(g, denoise, hrfDenoise, finalDimensions); + + if (denoise.type === 'denoise_latents') { + assert(noise, 'SD1.5/SD2/SDXL high resolution fix graphs require a noise node'); + const classicHrfDenoise = hrfDenoise as Invocation<'denoise_latents'>; + + const hrfNoise = g.addNode({ + type: 'noise', + id: getPrefixedId('hrf_noise'), + use_cpu: noise.use_cpu, + ...finalDimensions, + }); + g.addEdge(seed, 'value', hrfNoise, 'seed'); + g.addEdge(hrfNoise, 'noise', classicHrfDenoise, 'noise'); + } + + g.deleteEdgesTo(l2i, ['latents']); + g.addEdge(denoise, 'latents', resizeLatents, 'latents'); + g.addEdge(resizeLatents, 'latents', hrfDenoise, 'latents'); + g.addEdge(hrfDenoise, 'latents', l2i, 'latents'); + if (l2i.type === 'l2i') { + g.updateNode(l2i, { tile_size: LATENT_HRF_L2I_TILE_SIZE, tiled: true }); + } + + g.upsertMetadata({ + width: params.dimensions.width, + height: params.dimensions.height, + hrf_enabled: true, + hrf_method: 'latent', + hrf_strength: params.hrfStrength, + hrf_scale: params.hrfScale, + hrf_latent_interpolation_mode: params.hrfLatentInterpolationMode, + }); + + return l2i; +}; + +const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddHighResFixArg): Invocation<'l2i'> => { + const params = selectParamsSlice(state); + const finalDimensions = getHighResFixFinalDimensions(state); + const { denoising_start, denoising_end } = getHighResFixDenoisingStartAndEnd(state); + + assert(params.model?.base === 'sd-1' || params.model?.base === 'sdxl', 'Upscale model HRF supports SD1.5 and SDXL'); + assert(params.hrfUpscaleModel, 'Upscale model HRF requires a Spandrel upscale model'); + assert(params.hrfTileControlNetModel, 'Upscale model HRF requires a tile ControlNet model'); + assert(denoise.type === 'denoise_latents', 'Upscale model HRF requires classic SD denoise latents'); + assert(l2i.type === 'l2i', 'Upscale model HRF requires classic SD latents-to-image'); + assert(noise, 'Upscale model HRF requires a noise node'); + + const intermediateL2i = g.addNode({ + ...l2i, + id: getPrefixedId('hrf_intermediate_l2i'), + is_intermediate: true, + board: undefined, + metadata: undefined, + }); + copyInputEdges(g, l2i, intermediateL2i, SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS); + + const spandrelAutoscale = g.addNode({ + type: 'spandrel_image_to_image_autoscale', + id: getPrefixedId('hrf_spandrel_autoscale'), + image_to_image_model: params.hrfUpscaleModel, + fit_to_multiple_of_8: true, + scale: params.hrfScale, + tile_size: params.hrfTileSize, + }); + + const unsharpMask = g.addNode({ + type: 'unsharp_mask', + id: getPrefixedId('hrf_unsharp_2'), + radius: 2, + strength: 60, + }); + + const i2l = g.addNode({ + type: 'i2l', + id: getPrefixedId('hrf_i2l'), + fp32: l2i.fp32, + tile_size: params.hrfTileSize, + tiled: true, + }); + copyInputEdges(g, l2i, i2l, SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS); + g.updateNode(l2i, { tile_size: params.hrfTileSize, tiled: true }); + + const hrfNoise = g.addNode({ + type: 'noise', + id: getPrefixedId('hrf_noise'), + use_cpu: noise.use_cpu, + }); + + const useClassicDenoise = hasUnsupportedTiledDenoiseInputs(g, denoise); + const hrfSteps = params.hrfSteps ?? denoise.steps; + const needsFreshHrfModelInputs = params.hrfModel !== null || params.hrfLoraMode !== 'reuse_generate'; + const hrfDenoise = useClassicDenoise + ? g.addNode({ + ...denoise, + id: getPrefixedId('hrf_denoise_latents'), + steps: hrfSteps, + denoising_start, + denoising_end, + }) + : g.addNode({ + type: 'tiled_multi_diffusion_denoise_latents', + id: getPrefixedId('hrf_tiled_multidiffusion_denoise_latents'), + tile_height: params.hrfTileSize, + tile_width: params.hrfTileSize, + tile_overlap: params.hrfTileOverlap, + steps: hrfSteps, + cfg_scale: denoise.cfg_scale, + cfg_rescale_multiplier: denoise.cfg_rescale_multiplier, + scheduler: denoise.scheduler, + denoising_start, + denoising_end, + }); + + copyDenoiseInputs( + g, + denoise, + hrfDenoise, + finalDimensions, + useClassicDenoise ? undefined : TILED_DENOISE_INPUT_FIELDS, + needsFreshHrfModelInputs ? FRESH_HRF_MODEL_INPUT_FIELDS : undefined + ); + + if (needsFreshHrfModelInputs) { + if (params.hrfModel) { + g.deleteEdgesTo(i2l, ['vae']); + g.deleteEdgesTo(l2i, ['vae']); + } + + const { vaeSource: hrfVaeSource } = addFreshHrfModelInputs(state, g, denoise, hrfDenoise, finalDimensions); + if (params.hrfModel) { + g.addEdge(hrfVaeSource, 'vae', i2l, 'vae'); + g.addEdge(hrfVaeSource, 'vae', l2i, 'vae'); + } + } + addTileControlNet( + g, + hrfDenoise, + unsharpMask, + params.hrfTileControlNetModel, + params.hrfTileControlWeight, + params.hrfTileControlEnd + ); + + g.deleteEdgesTo(l2i, ['latents']); + g.addEdge(denoise, 'latents', intermediateL2i, 'latents'); + g.addEdge(intermediateL2i, 'image', spandrelAutoscale, 'image'); + g.addEdge(spandrelAutoscale, 'image', unsharpMask, 'image'); + g.addEdge(unsharpMask, 'image', i2l, 'image'); + g.addEdge(seed, 'value', hrfNoise, 'seed'); + g.addEdgeFromObj({ + source: { node_id: unsharpMask.id, field: 'width' }, + destination: { node_id: hrfNoise.id, field: 'width' }, + }); + g.addEdgeFromObj({ + source: { node_id: unsharpMask.id, field: 'height' }, + destination: { node_id: hrfNoise.id, field: 'height' }, + }); + g.addEdgeFromObj({ + source: { node_id: hrfNoise.id, field: 'noise' }, + destination: { node_id: hrfDenoise.id, field: 'noise' }, + }); + g.addEdgeFromObj({ + source: { node_id: i2l.id, field: 'latents' }, + destination: { node_id: hrfDenoise.id, field: 'latents' }, + }); + g.addEdgeFromObj({ + source: { node_id: hrfDenoise.id, field: 'latents' }, + destination: { node_id: l2i.id, field: 'latents' }, + }); + + g.upsertMetadata({ + width: params.dimensions.width, + height: params.dimensions.height, + hrf_enabled: true, + hrf_method: 'upscale_model', + hrf_strength: params.hrfStrength, + hrf_scale: params.hrfScale, + hrf_upscale_model: params.hrfUpscaleModel, + hrf_tile_controlnet_model: params.hrfTileControlNetModel, + hrf_tile_control_weight: params.hrfTileControlWeight, + hrf_tile_control_end: params.hrfTileControlEnd, + hrf_tile_size: params.hrfTileSize, + hrf_tile_overlap: params.hrfTileOverlap, + hrf_steps: params.hrfSteps ?? undefined, + hrf_model: params.hrfModel ?? undefined, + hrf_lora_mode: params.hrfLoraMode, + hrf_loras: getHrfLoRAMetadata(state), + }); + + return l2i; +}; + +export const addHighResFix = (arg: AddHighResFixArg): Invocation => { + const { g, state, generationMode, l2i } = arg; + const params = selectParamsSlice(state); + + if (!shouldApplyHighResFix(state, generationMode)) { + if (shouldWriteDisabledMetadata(state, generationMode)) { + g.upsertMetadata({ hrf_enabled: false }); + } + return l2i; + } + + if (params.hrfMethod === 'upscale_model') { + return addUpscaleModelHighResFix(arg); + } + + return addLatentHighResFix(arg); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index 79a8521efba..15b29e120f4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -1,9 +1,18 @@ import type { RootState } from 'app/store/store'; import { getPrefixedId } from 'features/controlLayers/konva/util'; +import type { LoRA } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; +type AddLoRAsOptions = { + loras?: LoRA[]; + metadataKey?: 'loras' | 'hrf_loras'; + idPrefix?: string; + extraPositiveConditioning?: Invocation<'compel'>[]; + extraNegativeConditioning?: Invocation<'compel'>[]; +}; + export const addLoRAs = ( state: RootState, g: Graph, @@ -12,12 +21,15 @@ export const addLoRAs = ( seamless: Invocation<'seamless'> | null, clipSkip: Invocation<'clip_skip'>, posCond: Invocation<'compel'>, - negCond: Invocation<'compel'> + negCond: Invocation<'compel'>, + options?: AddLoRAsOptions ): void => { - const enabledLoRAs = state.loras.loras.filter( + const enabledLoRAs = (options?.loras ?? state.loras.loras).filter( (l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2') ); const loraCount = enabledLoRAs.length; + const positiveConditioning = [posCond, ...(options?.extraPositiveConditioning ?? [])]; + const negativeConditioning = [negCond, ...(options?.extraNegativeConditioning ?? [])]; if (loraCount === 0) { return; @@ -29,11 +41,11 @@ export const addLoRAs = ( // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ type: 'collect', - id: getPrefixedId('lora_collector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collector` : 'lora_collector'), }); const loraCollectionLoader = g.addNode({ type: 'lora_collection_loader', - id: getPrefixedId('lora_collection_loader'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collection_loader` : 'lora_collection_loader'), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -42,11 +54,17 @@ export const addLoRAs = ( g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip'); // Reroute UNet & CLIP connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['unet']); - g.deleteEdgesTo(posCond, ['clip']); - g.deleteEdgesTo(negCond, ['clip']); g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); - g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); + + for (const cond of positiveConditioning) { + g.deleteEdgesTo(cond, ['clip']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + } + + for (const cond of negativeConditioning) { + g.deleteEdgesTo(cond, ['clip']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + } for (const lora of enabledLoRAs) { const { weight } = lora; @@ -54,7 +72,7 @@ export const addLoRAs = ( const loraSelector = g.addNode({ type: 'lora_selector', - id: getPrefixedId('lora_selector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_selector` : 'lora_selector'), lora: parsedModel, weight, }); @@ -67,5 +85,9 @@ export const addLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - g.upsertMetadata({ loras: loraMetadata }); + if (options?.metadataKey === 'hrf_loras') { + g.upsertMetadata({ hrf_loras: loraMetadata }); + } else { + g.upsertMetadata({ loras: loraMetadata }); + } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index a38c9757cea..1fd7bc448c0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -1,9 +1,18 @@ import type { RootState } from 'app/store/store'; import { getPrefixedId } from 'features/controlLayers/konva/util'; +import type { LoRA } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; +type AddSDXLLoRAsOptions = { + loras?: LoRA[]; + metadataKey?: 'loras' | 'hrf_loras'; + idPrefix?: string; + extraPositiveConditioning?: Invocation<'sdxl_compel_prompt'>[]; + extraNegativeConditioning?: Invocation<'sdxl_compel_prompt'>[]; +}; + export const addSDXLLoRAs = ( state: RootState, g: Graph, @@ -11,10 +20,13 @@ export const addSDXLLoRAs = ( modelLoader: Invocation<'sdxl_model_loader'>, seamless: Invocation<'seamless'> | null, posCond: Invocation<'sdxl_compel_prompt'>, - negCond: Invocation<'sdxl_compel_prompt'> + negCond: Invocation<'sdxl_compel_prompt'>, + options?: AddSDXLLoRAsOptions ): void => { - const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); + const enabledLoRAs = (options?.loras ?? state.loras.loras).filter((l) => l.isEnabled && l.model.base === 'sdxl'); const loraCount = enabledLoRAs.length; + const positiveConditioning = [posCond, ...(options?.extraPositiveConditioning ?? [])]; + const negativeConditioning = [negCond, ...(options?.extraNegativeConditioning ?? [])]; if (loraCount === 0) { return; @@ -25,12 +37,14 @@ export const addSDXLLoRAs = ( // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ - id: getPrefixedId('lora_collector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collector` : 'lora_collector'), type: 'collect', }); const loraCollectionLoader = g.addNode({ type: 'sdxl_lora_collection_loader', - id: getPrefixedId('sdxl_lora_collection_loader'), + id: getPrefixedId( + options?.idPrefix ? `${options.idPrefix}_sdxl_lora_collection_loader` : 'sdxl_lora_collection_loader' + ), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -40,13 +54,19 @@ export const addSDXLLoRAs = ( g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2'); // Reroute UNet & CLIP connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['unet']); - g.deleteEdgesTo(posCond, ['clip', 'clip2']); - g.deleteEdgesTo(negCond, ['clip', 'clip2']); g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); - g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2'); - g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2'); + + for (const cond of positiveConditioning) { + g.deleteEdgesTo(cond, ['clip', 'clip2']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip2', cond, 'clip2'); + } + + for (const cond of negativeConditioning) { + g.deleteEdgesTo(cond, ['clip', 'clip2']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip2', cond, 'clip2'); + } for (const lora of enabledLoRAs) { const { weight } = lora; @@ -54,7 +74,7 @@ export const addSDXLLoRAs = ( const loraSelector = g.addNode({ type: 'lora_selector', - id: getPrefixedId('lora_selector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_selector` : 'lora_selector'), lora: parsedModel, weight, }); @@ -67,5 +87,9 @@ export const addSDXLLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - g.upsertMetadata({ loras: loraMetadata }); + if (options?.metadataKey === 'hrf_loras') { + g.upsertMetadata({ hrf_loras: loraMetadata }); + } else { + g.upsertMetadata({ loras: loraMetadata }); + } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts index 0afe04ed13b..bc05a75802f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts @@ -34,6 +34,7 @@ const defaultParams: { let params = { ...defaultParams }; vi.mock('features/controlLayers/store/paramsSlice', () => ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => model), selectParamsSlice: vi.fn(() => params), selectAnimaVaeModel: vi.fn(() => animaVaeModel), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts index 0ab76ccefb6..b018f9871a1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts @@ -11,6 +11,7 @@ import { import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { addAnimaLoRAs } from 'features/nodes/util/graph/generation/addAnimaLoRAs'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -251,6 +252,17 @@ export const buildAnimaGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts index 6adee057545..45c34e605b5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts @@ -3,6 +3,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata } from 'features/controlLayers/store/selectors'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -167,6 +168,17 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts index 5b9f3d0a468..f2f1c65d082 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -112,6 +112,7 @@ const mockParams = { }; vi.mock('features/controlLayers/store/paramsSlice', () => ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => currentModel), selectParamsSlice: vi.fn(() => mockParams), selectKleinVaeModel: vi.fn(() => currentKleinVae), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index dafcd9310ec..e310942502c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -16,6 +16,7 @@ import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -582,6 +583,17 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => model), selectParamsSlice: vi.fn(() => params), })); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.ts index 0d92d325afd..06fe43c55e9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.ts @@ -7,6 +7,7 @@ import { isQwenImageReferenceImageConfig } from 'features/controlLayers/store/ty import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { zImageField } from 'features/nodes/types/common'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -283,6 +284,17 @@ export const buildQwenImageGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index eae07532011..c8d28fd9c9a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -4,9 +4,9 @@ import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; -// import { addHRF } from 'features/nodes/util/graph/generation/addHRF'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -305,6 +305,18 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 9d65076a70d..93c432add10 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -4,6 +4,7 @@ import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; @@ -312,6 +313,18 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts index 632006050e6..41ba098ac17 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -12,7 +12,7 @@ vi.mock('i18next', () => ({ import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; -import type { MainModelConfig } from 'services/api/types'; +import type { MainModelConfig, MainOrExternalModelConfig } from 'services/api/types'; import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; @@ -50,6 +50,39 @@ const flux2GGUF9BModel = { const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; +const externalModel = { + key: 'external', + hash: 'h', + name: 'External', + base: 'external', + type: 'external_image_generator', + format: 'external_api', +} as unknown as MainOrExternalModelConfig; + +const sdxlModel = { + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + format: 'checkpoint', +} as unknown as MainModelConfig; + +const upscaleModel = { + key: 'upscale', + hash: 'h', + name: 'Upscale', + base: 'any', + type: 'spandrel_image_to_image', +}; + +const tileControlNetModel = { + key: 'tile', + hash: 'h', + name: 'Tile ControlNet', + base: 'sdxl', + type: 'controlnet', +}; const baseDynamicPrompts: DynamicPromptsState = { _version: 1, @@ -71,14 +104,28 @@ const baseParams = { positivePrompt: 'test', kleinVaeModel: null, kleinQwen3EncoderModel: null, + hrfEnabled: false, + hrfMethod: 'latent', + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], + refinerModel: null, } as unknown as ParamsState; // --- Helpers --- const buildGenerateTabArg = (overrides: { - model?: MainModelConfig | null; + model?: MainOrExternalModelConfig | null; kleinVaeModel?: unknown; kleinQwen3EncoderModel?: unknown; + hrfEnabled?: boolean; + hrfMethod?: ParamsState['hrfMethod']; + hrfUpscaleModel?: unknown; + hrfTileControlNetModel?: unknown; + hrfModel?: unknown; + refinerModel?: unknown; hasFlux2DiffusersVaeSource?: boolean; hasFlux2DiffusersQwen3Source?: boolean; }) => ({ @@ -88,6 +135,12 @@ const buildGenerateTabArg = (overrides: { ...baseParams, kleinVaeModel: overrides.kleinVaeModel ?? null, kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + hrfEnabled: overrides.hrfEnabled ?? false, + hrfMethod: overrides.hrfMethod ?? 'latent', + hrfUpscaleModel: overrides.hrfUpscaleModel ?? null, + hrfTileControlNetModel: overrides.hrfTileControlNetModel ?? null, + hrfModel: overrides.hrfModel ?? null, + refinerModel: overrides.refinerModel ?? null, } as unknown as ParamsState, refImages: baseRefImages, loras: [], @@ -139,6 +192,24 @@ const hasFlux2VaeReason = (reasons: { content: string }[]) => const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); +const hasHrfExternalReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfExternalModelUnsupported')); + +const hasHrfRefinerReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfRefinerUnsupported')); + +const hasHrfUpscaleModelBaseReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfUpscaleModelBaseUnsupported')); + +const hasHrfUpscaleModelMissingReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfUpscaleModelMissing')); + +const hasHrfTileControlNetMissingReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfTileControlNetModelMissing')); + +const hasHrfModelOverrideBaseMismatchReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfModelOverrideBaseMismatch')); + // --- Tests --- describe('FLUX.2 Klein readiness checks – generate tab', () => { @@ -221,6 +292,107 @@ describe('FLUX.2 Klein readiness checks – generate tab', () => { }); }); +describe('High Resolution Fix readiness checks - generate tab', () => { + it('ignores stale HRF state for external models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: externalModel, hrfEnabled: true }) + ); + expect(hasHrfExternalReason(reasons)).toBe(false); + }); + + it('errors when HRF is enabled with SDXL Refiner', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: sdxlModel, hrfEnabled: true, refinerModel: { key: 'refiner' } }) + ); + expect(hasHrfRefinerReason(reasons)).toBe(true); + }); + + it('ignores stale HRF state for unsupported model bases', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2DiffusersModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + hrfModel: { key: 'anima', hash: 'h', name: 'Anima', base: 'anima', type: 'main' }, + }) + ); + expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(false); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(false); + }); + + it('errors when upscale-model HRF is missing required models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: sdxlModel, hrfEnabled: true, hrfMethod: 'upscale_model' }) + ); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(true); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(true); + }); + + it('does not error when upscale-model HRF has required SDXL models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: tileControlNetModel, + }) + ); + expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(false); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + }); + + it('does not apply stale upscale-model-only readiness checks to latent HRF', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'latent', + hrfModel: { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' }, + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + }) + ); + + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(false); + }); + + it('errors when dedicated HRF model base differs from the Generate model base', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: tileControlNetModel, + hrfModel: { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' }, + }) + ); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(true); + }); + + it('validates Tile ControlNet against the dedicated HRF model base', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: { ...tileControlNetModel, base: 'sd-1' }, + hrfModel: { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }, + }) + ); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(true); + }); +}); + describe('FLUX.2 Klein readiness checks – canvas tab', () => { it('no errors when main model is diffusers', () => { const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 230fa3348d6..2a4116d5f9d 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -8,7 +8,7 @@ import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; import { debounce, groupBy, upperFirst } from 'es-toolkit/compat'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice'; -import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { isHrfSupportedBase, selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import type { CanvasState, LoRA, ParamsState, RefImagesState } from 'features/controlLayers/store/types'; @@ -362,6 +362,33 @@ export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { }); } + if (params.hrfEnabled && model && !isExternalApiModelConfig(model) && isHrfSupportedBase(model.base)) { + if (params.refinerModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfRefinerUnsupported') }); + } + if (params.hrfMethod === 'upscale_model') { + const hrfBase = + params.hrfModel?.base ?? (model && !isExternalApiModelConfig(model) ? model.base : params.model?.base); + if (params.hrfModel) { + if (params.hrfModel.base === 'external') { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideExternalUnsupported') }); + } else if (!isHrfSupportedBase(params.hrfModel.base)) { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseUnsupported') }); + } else if (model && !isExternalApiModelConfig(model) && params.hrfModel.base !== model.base) { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseMismatch') }); + } + } + if (!params.hrfUpscaleModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelMissing') }); + } + if (!params.hrfTileControlNetModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); + } else if (hrfBase && params.hrfTileControlNetModel.base !== hrfBase) { + reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); + } + } + } + return reasons; }; const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: { diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx new file mode 100644 index 00000000000..1a2a9fb3973 --- /dev/null +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -0,0 +1,1116 @@ +import type { ComboboxOnChange, FormLabelProps } from '@invoke-ai/ui-library'; +import { + Box, + Button, + ButtonGroup, + Card, + CardBody, + CardHeader, + Combobox, + CompositeNumberInput, + CompositeSlider, + Expander, + Flex, + FormControl, + FormControlGroup, + FormLabel, + IconButton, + StandaloneAccordion, + Switch, + Text, + Tooltip, +} from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { useModelCombobox } from 'common/hooks/useModelCombobox'; +import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice'; +import { + buildSelectHrfLoRA, + hrfLoraAdded, + hrfLoraDeleted, + hrfLoraIsEnabledChanged, + hrfLoraWeightChanged, + selectBase, + selectHrfEnabled, + selectHrfFinalDimensions, + selectHrfLatentInterpolationMode, + selectHrfLoraMode, + selectHrfLoras, + selectHrfMethod, + selectHrfModel, + selectHrfScale, + selectHrfSteps, + selectHrfStrength, + selectHrfTileControlEnd, + selectHrfTileControlNetModel, + selectHrfTileControlWeight, + selectHrfTileOverlap, + selectHrfTileSize, + selectHrfUpscaleModel, + selectIsRefinerModelSelected, + selectModelSupportsHrf, + selectSteps, + setHrfEnabled, + setHrfLatentInterpolationMode, + setHrfLoraMode, + setHrfMethod, + setHrfModel, + setHrfScale, + setHrfSteps, + setHrfStrength, + setHrfTileControlEnd, + setHrfTileControlNetModel, + setHrfTileControlWeight, + setHrfTileOverlap, + setHrfTileSize, + setHrfUpscaleModel, +} from 'features/controlLayers/store/paramsSlice'; +import type { LoRA } from 'features/controlLayers/store/types'; +import { zHrfLatentInterpolationMode, zHrfMethod } from 'features/controlLayers/store/types'; +import { CONSTRAINTS as STEPS_CONSTRAINTS } from 'features/parameters/components/Core/ParamSteps'; +import { ModelPicker } from 'features/parameters/components/ModelPicker'; +import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; +import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { + modelConfigsAdapterSelectors, + selectModelConfigsQuery, + useGetModelConfigQuery, +} from 'services/api/endpoints/models'; +import { + useControlNetModels, + useLoRAModels, + useMainModels, + useSpandrelImageToImageModels, +} from 'services/api/hooks/modelsByType'; +import { + type ControlNetModelConfig, + isControlNetModelConfig, + isExternalApiModelConfig, + isMainOrExternalModelConfig, + type LoRAModelConfig, + type MainOrExternalModelConfig, + type SpandrelImageToImageModelConfig, +} from 'services/api/types'; + +const SCALE_CONSTRAINTS = { + initial: 2, + sliderMin: 1, + sliderMax: 4, + numberInputMin: 1, + numberInputMax: 8, + coarseStep: 0.05, + fineStep: 0.01, +}; + +const STRENGTH_CONSTRAINTS = { + initial: 0.45, + sliderMin: 0, + sliderMax: 1, + numberInputMin: 0, + numberInputMax: 1, + coarseStep: 0.01, + fineStep: 0.01, +}; + +const TILE_CONTROL_WEIGHT_CONSTRAINTS = { + initial: 0.625, + sliderMin: 0, + sliderMax: 1.5, + numberInputMin: 0, + numberInputMax: 2, + coarseStep: 0.025, + fineStep: 0.005, +}; + +const TILE_CONTROL_END_CONSTRAINTS = { + initial: 0.2, + sliderMin: 0, + sliderMax: 1, + numberInputMin: 0, + numberInputMax: 1, + coarseStep: 0.01, + fineStep: 0.01, +}; + +const TILE_SIZE_CONSTRAINTS = { + initial: 1024, + sliderMin: 512, + sliderMax: 1536, + numberInputMin: 512, + numberInputMax: 1536, + coarseStep: 64, + fineStep: 64, +}; + +const TILE_OVERLAP_CONSTRAINTS = { + initial: 128, + sliderMin: 32, + sliderMax: 256, + numberInputMin: 16, + numberInputMax: 512, + coarseStep: 16, + fineStep: 8, +}; + +const formLabelProps: FormLabelProps = { + m: 0, + w: '10.5rem', + minW: '10.5rem', + maxW: '10.5rem', + flexShrink: 0, + whiteSpace: 'normal', + lineHeight: 1.2, + overflowWrap: 'break-word', +}; + +const formControlProps = { + alignItems: 'center', + gap: 3, + minW: 0, + w: 'full', +}; + +type DisabledProps = { + isDisabled?: boolean; +}; + +const selectHrfTileControlNetModelConfig = createSelector( + selectModelConfigsQuery, + selectHrfTileControlNetModel, + (modelConfigs, modelIdentifierField) => { + if (!modelConfigs.data || !modelIdentifierField) { + return null; + } + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key); + if (!modelConfig || !isControlNetModelConfig(modelConfig)) { + return null; + } + return modelConfig; + } +); + +const selectHrfModelConfig = createSelector(selectModelConfigsQuery, selectHrfModel, (modelConfigs, model) => { + if (!modelConfigs.data || !model) { + return null; + } + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key); + if (!modelConfig || !isMainOrExternalModelConfig(modelConfig) || isExternalApiModelConfig(modelConfig)) { + return null; + } + return modelConfig; +}); + +const selectHrfLoRAIds = createMemoizedSelector(selectHrfLoras, (loras) => loras.map(({ id }) => id)); +const selectHrfLoRAModelKeys = createMemoizedSelector(selectHrfLoras, (loras) => loras.map(({ model }) => model.key)); + +const selectBadges = createMemoizedSelector( + [selectHrfEnabled, selectHrfMethod, selectHrfScale, selectHrfStrength, selectHrfFinalDimensions], + (enabled, method, scale, strength, finalDimensions) => { + if (!enabled) { + return EMPTY_ARRAY; + } + + const methodBadge = method === 'upscale_model' ? 'Model' : 'Latent'; + return [ + methodBadge, + `${scale}x`, + `${Math.round(strength * 100)}%`, + `${finalDimensions.width}x${finalDimensions.height}`, + ]; + } +); + +const ParamHrfEnabled = memo(() => { + const dispatch = useAppDispatch(); + const enabled = useAppSelector(selectHrfEnabled); + const { t } = useTranslation(); + + const onChange = useCallback( + (event: ChangeEvent) => { + dispatch(setHrfEnabled(event.target.checked)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.enableHrf')} + + + + ); +}); + +ParamHrfEnabled.displayName = 'ParamHrfEnabled'; + +const ParamHrfMethod = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const method = useAppSelector(selectHrfMethod); + const { t } = useTranslation(); + + const onClickLatent = useCallback(() => { + dispatch(setHrfMethod('latent')); + }, [dispatch]); + + const onClickUpscaleModel = useCallback(() => { + dispatch(setHrfMethod('upscale_model')); + }, [dispatch]); + + return ( + + + {t('hrf.upscaleMethod')} + + + + + + + ); +}); + +ParamHrfMethod.displayName = 'ParamHrfMethod'; + +const ParamHrfScale = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const scale = useAppSelector(selectHrfScale); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfScale(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.scale')} + + + + + ); +}); + +ParamHrfScale.displayName = 'ParamHrfScale'; + +const ParamHrfStrength = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const strength = useAppSelector(selectHrfStrength); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfStrength(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.strength')} + + + + + ); +}); + +ParamHrfStrength.displayName = 'ParamHrfStrength'; + +const ParamHrfLatentInterpolationMode = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const mode = useAppSelector(selectHrfLatentInterpolationMode); + const { t } = useTranslation(); + + const options = useMemo( + () => [ + { label: t('hrf.bilinear'), value: 'bilinear' }, + { label: t('hrf.bicubic'), value: 'bicubic' }, + { label: t('hrf.nearest'), value: 'nearest' }, + { label: t('hrf.nearestExact'), value: 'nearest-exact' }, + { label: t('hrf.area'), value: 'area' }, + ], + [t] + ); + + const value = useMemo(() => options.find((o) => o.value === mode), [mode, options]); + + const onChange = useCallback( + (v) => { + const result = zHrfLatentInterpolationMode.safeParse(v?.value); + if (!result.success) { + return; + } + dispatch(setHrfLatentInterpolationMode(result.data)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.latentInterpolationMode')} + + + + ); +}); + +ParamHrfLatentInterpolationMode.displayName = 'ParamHrfLatentInterpolationMode'; + +const ParamHrfUpscaleModel = memo(({ isDisabled = false }: DisabledProps) => { + const { t } = useTranslation(); + const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels(); + const model = useAppSelector(selectHrfUpscaleModel); + const dispatch = useAppDispatch(); + + const tooltipLabel = useMemo(() => { + if (!modelConfigs.length || !model) { + return; + } + return modelConfigs.find((m) => m.key === model.key)?.description; + }, [modelConfigs, model]); + + const _onChange = useCallback( + (v: SpandrelImageToImageModelConfig | null) => { + dispatch(setHrfUpscaleModel(v)); + }, + [dispatch] + ); + + const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({ + modelConfigs, + onChange: _onChange, + selectedModel: model, + isLoading, + }); + + return ( + + + {t('upscaling.upscaleModel')} + + + + + + + + + + ); +}); + +ParamHrfUpscaleModel.displayName = 'ParamHrfUpscaleModel'; + +const ParamHrfTileControlNetModel = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const tileControlNetModel = useAppSelector(selectHrfTileControlNetModelConfig); + const generateBaseModel = useAppSelector(selectBase); + const hrfModel = useAppSelector(selectHrfModel); + const currentBaseModel = hrfModel?.base ?? generateBaseModel; + const [modelConfigs, { isLoading }] = useControlNetModels(); + + const onChange = useCallback( + (controlNetModel: ControlNetModelConfig) => { + dispatch(setHrfTileControlNetModel(controlNetModel)); + }, + [dispatch] + ); + + const filteredModelConfigs = useMemo(() => { + if (!currentBaseModel || !['sd-1', 'sdxl'].includes(currentBaseModel)) { + return []; + } + return modelConfigs.filter((model) => { + const isCompatible = model.base === currentBaseModel; + const isTileOrMultiModel = + model.name.toLowerCase().includes('tile') || model.name.toLowerCase().includes('union'); + return isCompatible && isTileOrMultiModel; + }); + }, [modelConfigs, currentBaseModel]); + + const getIsOptionDisabled = useCallback( + (model: ControlNetModelConfig): boolean => { + return currentBaseModel !== model.base; + }, + [currentBaseModel] + ); + const isMissingModel = !filteredModelConfigs.length; + const isInvalid = !isDisabled && isMissingModel; + + return ( + + + {t('upscaling.tileControl')} + + + + ); +}); + +ParamHrfTileControlNetModel.displayName = 'ParamHrfTileControlNetModel'; + +const ParamHrfTileControlWeight = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const tileControlWeight = useAppSelector(selectHrfTileControlWeight); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileControlWeight(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.tileControlWeight')} + + + + + ); +}); + +ParamHrfTileControlWeight.displayName = 'ParamHrfTileControlWeight'; + +const ParamHrfTileControlEnd = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const tileControlEnd = useAppSelector(selectHrfTileControlEnd); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileControlEnd(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.tileControlEnd')} + + + + + ); +}); + +ParamHrfTileControlEnd.displayName = 'ParamHrfTileControlEnd'; + +const ParamHrfTileSize = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const tileSize = useAppSelector(selectHrfTileSize); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileSize(v)); + }, + [dispatch] + ); + + return ( + + + {t('upscaling.tileSize')} + + + + + ); +}); + +ParamHrfTileSize.displayName = 'ParamHrfTileSize'; + +const ParamHrfTileOverlap = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const tileOverlap = useAppSelector(selectHrfTileOverlap); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileOverlap(v)); + }, + [dispatch] + ); + + return ( + + + {t('upscaling.tileOverlap')} + + + + + ); +}); + +ParamHrfTileOverlap.displayName = 'ParamHrfTileOverlap'; + +const ParamHrfSteps = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const hrfSteps = useAppSelector(selectHrfSteps); + const generateSteps = useAppSelector(selectSteps); + const { t } = useTranslation(); + + const onToggle = useCallback( + (event: ChangeEvent) => { + dispatch(setHrfSteps(event.target.checked ? generateSteps : null)); + }, + [dispatch, generateSteps] + ); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfSteps(v)); + }, + [dispatch] + ); + + const isCustom = hrfSteps !== null; + + return ( + + + {t('hrf.steps')} + + + + + + + ); +}); + +ParamHrfSteps.displayName = 'ParamHrfSteps'; + +const ParamHrfModel = memo(({ isDisabled = false }: DisabledProps) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const selectedModelConfig = useAppSelector(selectHrfModelConfig); + const currentBaseModel = useAppSelector(selectBase); + + const filter = useCallback( + (model: MainOrExternalModelConfig) => { + return ( + !isExternalApiModelConfig(model) && model.base === currentBaseModel && ['sd-1', 'sdxl'].includes(model.base) + ); + }, + [currentBaseModel] + ); + const [modelConfigs, { isLoading }] = useMainModels(filter); + + const onChange = useCallback( + (model: MainOrExternalModelConfig | null) => { + dispatch(setHrfModel(model && !isExternalApiModelConfig(model) ? model : null)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.model')} + + + + ); +}); + +ParamHrfModel.displayName = 'ParamHrfModel'; + +const ParamHrfLoraMode = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const mode = useAppSelector(selectHrfLoraMode); + const { t } = useTranslation(); + + const onClickReuseGenerate = useCallback(() => { + dispatch(setHrfLoraMode('reuse_generate')); + }, [dispatch]); + + const onClickNone = useCallback(() => { + dispatch(setHrfLoraMode('none')); + }, [dispatch]); + + const onClickDedicated = useCallback(() => { + dispatch(setHrfLoraMode('dedicated')); + }, [dispatch]); + + return ( + + + {t('hrf.loraMode')} + + + + + + + + ); +}); + +ParamHrfLoraMode.displayName = 'ParamHrfLoraMode'; + +const ParamHrfLoraSelect = memo(({ isDisabled = false }: DisabledProps) => { + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useLoRAModels(); + const { t } = useTranslation(); + const addedLoRAModelKeys = useAppSelector(selectHrfLoRAModelKeys); + const currentBaseModel = useAppSelector(selectBase); + const hrfModel = useAppSelector(selectHrfModel); + const hrfBase = hrfModel?.base ?? currentBaseModel; + + const compatibleLoRAs = useMemo(() => { + if (!hrfBase) { + return EMPTY_ARRAY; + } + return modelConfigs.filter((model) => model.base === hrfBase); + }, [hrfBase, modelConfigs]); + + const getIsDisabled = useCallback( + (model: LoRAModelConfig): boolean => { + return addedLoRAModelKeys.includes(model.key); + }, + [addedLoRAModelKeys] + ); + + const onChange = useCallback( + (model: LoRAModelConfig | null) => { + if (!model) { + return; + } + dispatch(hrfLoraAdded({ model })); + }, + [dispatch] + ); + + const placeholder = useMemo(() => { + if (isLoading) { + return t('common.loading'); + } + if (compatibleLoRAs.length === 0) { + return hrfBase ? t('models.noCompatibleLoRAs') : t('models.selectModel'); + } + return t('hrf.addDedicatedLora'); + }, [compatibleLoRAs.length, hrfBase, isLoading, t]); + + return ( + + + {t('hrf.dedicatedLoras')} + + + + ); +}); + +ParamHrfLoraSelect.displayName = 'ParamHrfLoraSelect'; + +const HrfLoRAList = memo(({ isDisabled = false }: DisabledProps) => { + const ids = useAppSelector(selectHrfLoRAIds); + + if (!ids.length) { + return null; + } + + return ( + + {ids.map((id) => ( + + ))} + + ); +}); + +HrfLoRAList.displayName = 'HrfLoRAList'; + +const HrfLoRACard = memo((props: { id: string } & DisabledProps) => { + const selectLoRA = useMemo(() => buildSelectHrfLoRA(props.id), [props.id]); + const lora = useAppSelector(selectLoRA); + + if (!lora) { + return null; + } + return ; +}); + +HrfLoRACard.displayName = 'HrfLoRACard'; + +const HrfLoRAContent = memo(({ lora, isDisabled = false }: { lora: LoRA } & DisabledProps) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const { data: loraConfig } = useGetModelConfigQuery(lora.model.key); + const isWeightDisabled = isDisabled || !lora.isEnabled; + + const onChange = useCallback( + (v: number) => { + dispatch(hrfLoraWeightChanged({ id: lora.id, weight: v })); + }, + [dispatch, lora.id] + ); + + const onToggle = useCallback(() => { + dispatch(hrfLoraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.id, lora.isEnabled]); + + const onRemove = useCallback(() => { + dispatch(hrfLoraDeleted({ id: lora.id })); + }, [dispatch, lora.id]); + + return ( + + + + + {loraConfig?.name ?? lora.model.key.substring(0, 8)} + + + + } + isDisabled={isDisabled} + /> + + + + + + + + + + + ); +}); + +HrfLoRAContent.displayName = 'HrfLoRAContent'; + +export const HighResFixSettingsAccordion = memo(() => { + const { t } = useTranslation(); + const badges = useAppSelector(selectBadges); + const hrfEnabled = useAppSelector(selectHrfEnabled); + const method = useAppSelector(selectHrfMethod); + const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); + const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); + const hrfLoraMode = useAppSelector(selectHrfLoraMode); + const { isOpen, onToggle } = useStandaloneAccordionToggle({ + id: 'high-res-fix-settings-generate-tab', + defaultIsOpen: false, + }); + const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ + id: 'high-res-fix-settings-generate-tab-advanced', + defaultIsOpen: false, + }); + + const parsedMethod = zHrfMethod.parse(method); + const isDisabled = !hrfEnabled; + + if (!modelSupportsHrf || isRefinerModelSelected) { + return null; + } + + return ( + + + + + + + + + + {parsedMethod === 'upscale_model' && ( + <> + + + + )} + + + + + + {parsedMethod === 'latent' && } + {parsedMethod === 'upscale_model' && ( + <> + + + + + + + + + )} + + {parsedMethod === 'upscale_model' && hrfLoraMode === 'dedicated' && ( + <> + + + + )} + + + + ); +}); + +HighResFixSettingsAccordion.displayName = 'HighResFixSettingsAccordion'; diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx index 06a122cc4ef..8a14b55413d 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx @@ -7,6 +7,7 @@ import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion'; import { ExternalSettingsAccordion } from 'features/settingsAccordions/components/ExternalSettingsAccordion/ExternalSettingsAccordion'; import { GenerationSettingsAccordion } from 'features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion'; +import { HighResFixSettingsAccordion } from 'features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion'; import { GenerateTabImageSettingsAccordion } from 'features/settingsAccordions/components/ImageSettingsAccordion/GenerateTabImageSettingsAccordion'; import { RefinerSettingsAccordion } from 'features/settingsAccordions/components/RefinerSettingsAccordion/RefinerSettingsAccordion'; import { StylePresetMenu } from 'features/stylePresets/components/StylePresetMenu'; @@ -44,6 +45,7 @@ export const ParametersPanelGenerate = memo(() => { + {isSDXL && } {!isCogview4 && !isExternal && } {isExternal && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 62070fcbbbe..e9f4ea682de 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7691,6 +7691,84 @@ export type components = { * @default null */ hrf_strength?: number | null; + /** + * Hrf Scale + * @description The high resolution fix upscale factor. + * @default null + */ + hrf_scale?: number | null; + /** + * Hrf Latent Interpolation Mode + * @description The latent interpolation mode used in the high resolution fix upscale pass. + * @default null + */ + hrf_latent_interpolation_mode?: string | null; + /** + * Hrf Upscale Model + * @description The Spandrel upscale model used in the high resolution fix upscale pass. + * @default null + */ + hrf_upscale_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Tile Controlnet Model + * @description The tile ControlNet model used in the high resolution fix upscale pass. + * @default null + */ + hrf_tile_controlnet_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Structure + * @description Legacy high resolution fix tile ControlNet structure value. + * @default null + */ + hrf_structure?: number | null; + /** + * Hrf Tile Control Weight + * @description The high resolution fix tile ControlNet control weight. + * @default null + */ + hrf_tile_control_weight?: number | null; + /** + * Hrf Tile Control End + * @description The high resolution fix tile ControlNet end step percentage. + * @default null + */ + hrf_tile_control_end?: number | null; + /** + * Hrf Tile Size + * @description The high resolution fix tiled processing tile size. + * @default null + */ + hrf_tile_size?: number | null; + /** + * Hrf Tile Overlap + * @description The high resolution fix tiled processing tile overlap. + * @default null + */ + hrf_tile_overlap?: number | null; + /** + * Hrf Steps + * @description The number of steps used for the high resolution fix refinement pass. + * @default null + */ + hrf_steps?: number | null; + /** + * Hrf Model + * @description The optional model override used for the high resolution fix refinement pass. + * @default null + */ + hrf_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Lora Mode + * @description The LoRA mode used for the high resolution fix refinement pass. + * @default null + */ + hrf_lora_mode?: string | null; + /** + * Hrf Loras + * @description The dedicated LoRAs used for the high resolution fix refinement pass. + * @default null + */ + hrf_loras?: components["schemas"]["LoRAMetadataField"][] | null; /** * Positive Style Prompt * @description The positive style prompt parameter