From ad925c8e5a60ba7ea0279b641fa740820b5353c3 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Mon, 4 May 2026 14:40:22 +0200 Subject: [PATCH 1/9] feat: add generate tab highres fix --- invokeai/app/invocations/metadata.py | 8 + invokeai/frontend/web/public/locales/en.json | 15 +- .../controlLayers/store/paramsSlice.test.ts | 18 ++ .../controlLayers/store/paramsSlice.ts | 66 ++++- .../src/features/controlLayers/store/types.ts | 15 +- .../ImageMetadataActions.tsx | 5 + .../web/src/features/metadata/parsing.tsx | 104 +++++++- .../graph/generation/addHighResFix.test.ts | 197 ++++++++++++++ .../util/graph/generation/addHighResFix.ts | 251 ++++++++++++++++++ .../util/graph/generation/buildAnimaGraph.ts | 12 + .../graph/generation/buildCogView4Graph.ts | 12 + .../util/graph/generation/buildFLUXGraph.ts | 12 + .../graph/generation/buildQwenImageGraph.ts | 12 + .../util/graph/generation/buildSD1Graph.ts | 14 +- .../util/graph/generation/buildSD3Graph.ts | 12 + .../util/graph/generation/buildSDXLGraph.ts | 13 + .../util/graph/generation/buildZImageGraph.ts | 12 + .../features/queue/store/readiness.test.ts | 40 ++- .../web/src/features/queue/store/readiness.ts | 9 + .../HighResFixSettingsAccordion.tsx | 247 +++++++++++++++++ .../ParametersPanelGenerate.tsx | 2 + .../frontend/web/src/services/api/schema.ts | 12 + 22 files changed, 1079 insertions(+), 9 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts create mode 100644 invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index da24d8802bb..545d277e823 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -252,6 +252,14 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The high resolution fix img2img strength used in the upscale pass.", ) + hrf_scale: Optional[float] = InputField( + default=None, + description="The high resolution fix latent upscale factor.", + ) + hrf_latent_interpolation_mode: Optional[str] = InputField( + default=None, + description="The latent interpolation mode used in the high resolution fix upscale pass.", + ) # SDXL positive_style_prompt: Optional[str] = InputField( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index c62122db222..7d952d9e11d 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -352,11 +352,22 @@ "hrf": { "hrf": "High Resolution Fix", "enableHrf": "Enable High Resolution Fix", + "scale": "Scale", + "strength": "Denoise Strength", "upscaleMethod": "Upscale Method", + "latentInterpolationMode": "Latent Interpolation", + "latent": "Latent", + "nearest": "Nearest", + "bilinear": "Bilinear", + "bicubic": "Bicubic", + "area": "Area", + "nearestExact": "Nearest Exact", "metadata": { "enabled": "High Resolution Fix Enabled", "strength": "High Resolution Fix Strength", - "method": "High Resolution Fix Method" + "method": "High Resolution Fix Method", + "scale": "High Resolution Fix Scale", + "latentInterpolationMode": "High Resolution Fix Latent Interpolation" } }, "prompt": { @@ -1676,6 +1687,8 @@ "modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}", "fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time", "incompatibleLoRAs": "Incompatible LoRA(s) added", + "hrfExternalModelUnsupported": "High Resolution Fix is not supported for external models", + "hrfRefinerUnsupported": "High Resolution Fix is not supported when SDXL Refiner is enabled", "canvasIsFiltering": "Canvas is busy (filtering)", "canvasIsTransforming": "Canvas is busy (transforming)", "canvasIsRasterizing": "Canvas is busy (rasterizing)", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index 7d665a38185..c8719482a21 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -8,8 +8,10 @@ import type { import { describe, expect, it } from 'vitest'; import { + selectHrfFinalDimensions, selectModelSupportsDimensions, selectModelSupportsGuidance, + selectModelSupportsHrf, selectModelSupportsNegativePrompt, selectModelSupportsRefImages, selectModelSupportsSeed, @@ -130,4 +132,20 @@ describe('paramsSlice selectors for external models', () => { expect(selectModelSupportsSteps.resultFunc(model)).toBe(false); expect(selectModelSupportsDimensions.resultFunc(model, config)).toBe(true); }); + + it('returns false for HRF support on external models', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_reference_images: false, + }); + const model = buildExternalModelIdentifier(config); + + expect(selectModelSupportsHrf.resultFunc(model)).toBe(false); + }); +}); + +describe('paramsSlice HRF selectors', () => { + it('rounds final dimensions down to the model grid', () => { + expect(selectHrfFinalDimensions.resultFunc(513, 512, 1.5, 'flux')).toEqual({ width: 768, height: 768 }); + }); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index e7466f61e12..867295569e6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -7,7 +7,13 @@ import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMul import { isPlainObject } from 'es-toolkit'; import { clamp } from 'es-toolkit/compat'; import { logout } from 'features/auth/store/authSlice'; -import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types'; +import type { + AspectRatioID, + HrfLatentInterpolationMode, + InfillMethod, + ParamsState, + RgbaColor, +} from 'features/controlLayers/store/types'; import { ASPECT_RATIO_MAP, DEFAULT_ASPECT_RATIO_CONFIG, @@ -115,6 +121,18 @@ const slice = createSlice({ setOptimizedDenoisingEnabled: (state, action: PayloadAction) => { state.optimizedDenoisingEnabled = action.payload; }, + setHrfEnabled: (state, action: PayloadAction) => { + state.hrfEnabled = action.payload && !state.refinerModel; + }, + setHrfScale: (state, action: PayloadAction) => { + state.hrfScale = action.payload; + }, + setHrfStrength: (state, action: PayloadAction) => { + state.hrfStrength = action.payload; + }, + setHrfLatentInterpolationMode: (state, action: PayloadAction) => { + state.hrfLatentInterpolationMode = action.payload; + }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -136,6 +154,10 @@ const slice = createSlice({ const model = result.data; state.model = model; + if (model?.base === 'external') { + state.hrfEnabled = false; + } + // If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things if (model === null || previousModel?.base === model.base) { return; @@ -313,6 +335,9 @@ const slice = createSlice({ return; } state.refinerModel = result.data; + if (state.refinerModel) { + state.hrfEnabled = false; + } }, setRefinerSteps: (state, action: PayloadAction) => { state.refinerSteps = action.payload; @@ -614,6 +639,10 @@ export const { setSeed, setImg2imgStrength, setOptimizedDenoisingEnabled, + setHrfEnabled, + setHrfScale, + setHrfStrength, + setHrfLatentInterpolationMode, setSeamlessXAxis, setSeamlessYAxis, setShouldRandomizeSeed, @@ -695,6 +724,15 @@ export const paramsSliceConfig: SliceConfig = { state.positivePromptHistory = []; } + if (state._version === 2) { + // v2 -> v3, add Generate tab high resolution fix settings + state._version = 3; + state.hrfEnabled = false; + state.hrfScale = 2; + state.hrfStrength = 0.45; + state.hrfLatentInterpolationMode = 'bicubic'; + } + return zParamsState.parse(state); }, }, @@ -761,6 +799,10 @@ export const selectInfillPatchmatchDownscaleSize = createParamsSelector( export const selectInfillColorValue = createParamsSelector((params) => params.infillColorValue); export const selectImg2imgStrength = createParamsSelector((params) => params.img2imgStrength); export const selectOptimizedDenoisingEnabled = createParamsSelector((params) => params.optimizedDenoisingEnabled); +export const selectHrfEnabled = createParamsSelector((params) => params.hrfEnabled); +export const selectHrfScale = createParamsSelector((params) => params.hrfScale); +export const selectHrfStrength = createParamsSelector((params) => params.hrfStrength); +export const selectHrfLatentInterpolationMode = createParamsSelector((params) => params.hrfLatentInterpolationMode); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? ''); @@ -843,6 +885,15 @@ export const selectModelSupportsDimensions = createSelector(selectModel, selectM } return true; }); +export const selectModelSupportsHrf = createSelector(selectModel, (model) => { + if (!model) { + return false; + } + if (model.base === 'external') { + return false; + } + return true; +}); export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => { if (modelConfig && isExternalApiModelConfig(modelConfig)) { return getExternalPanelControl(modelConfig, 'image', 'seed'); @@ -889,6 +940,19 @@ export const selectRefinerSteps = createParamsSelector((params) => params.refine export const selectWidth = createParamsSelector((params) => params.dimensions.width); export const selectHeight = createParamsSelector((params) => params.dimensions.height); +export const selectHrfFinalDimensions = createSelector( + selectWidth, + selectHeight, + selectHrfScale, + selectBase, + (width, height, hrfScale, base) => { + const gridSize = getGridSize(base as BaseModelType | undefined); + return { + width: Math.max(roundDownToMultiple(width * hrfScale, gridSize), 64), + height: Math.max(roundDownToMultiple(height * hrfScale, gridSize), 64), + }; + } +); export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id); export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value); export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index baf8cf04fd0..e647e8999a8 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -748,8 +748,11 @@ const zPositivePromptHistory = z export const zInfillMethod = z.enum(['patchmatch', 'lama', 'cv2', 'color', 'tile']); export type InfillMethod = z.infer; +export const zHrfLatentInterpolationMode = z.enum(['nearest', 'bilinear', 'bicubic', 'area', 'nearest-exact']); +export type HrfLatentInterpolationMode = z.infer; + export const zParamsState = z.object({ - _version: z.literal(2), + _version: z.literal(3), maskBlur: z.number(), maskBlurMethod: zParameterMaskBlurMethod, canvasCoherenceMode: zParameterCanvasCoherenceMode, @@ -764,6 +767,10 @@ export const zParamsState = z.object({ guidance: zParameterGuidance, img2imgStrength: zParameterStrength, optimizedDenoisingEnabled: z.boolean(), + hrfEnabled: z.boolean(), + hrfScale: z.number().min(1).max(8), + hrfStrength: zParameterStrength, + hrfLatentInterpolationMode: zHrfLatentInterpolationMode, iterations: z.number(), scheduler: zParameterScheduler, fluxScheduler: zParameterFluxScheduler, @@ -833,7 +840,7 @@ export const zParamsState = z.object({ }); export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ - _version: 2, + _version: 3, maskBlur: 16, maskBlurMethod: 'box', canvasCoherenceMode: 'Gaussian Blur', @@ -848,6 +855,10 @@ export const getInitialParamsState = (): ParamsState => ({ guidance: 4, img2imgStrength: 0.75, optimizedDenoisingEnabled: true, + hrfEnabled: false, + hrfScale: 2, + hrfStrength: 0.45, + hrfLatentInterpolationMode: 'bicubic', iterations: 1, scheduler: 'dpmpp_3m_k', fluxScheduler: 'euler', diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 105ad3dfd67..3ef175459ca 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -49,6 +49,11 @@ export const ImageMetadataActions = memo((props: Props) => { + + + + + diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index a8ad7373218..dd6d580fecf 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -36,6 +36,10 @@ import { setFluxDypeScale, setFluxScheduler, setGuidance, + setHrfEnabled, + setHrfLatentInterpolationMode, + setHrfScale, + setHrfStrength, setImg2imgStrength, setRefinerCFGScale, setRefinerNegativeAestheticScore, @@ -60,8 +64,18 @@ import { zImageVaeModelSelected, } from 'features/controlLayers/store/paramsSlice'; import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; -import type { CanvasMetadata, LoRA, RefImageState } from 'features/controlLayers/store/types'; -import { zCanvasMetadata, zCanvasReferenceImageState_OLD, zRefImageState } from 'features/controlLayers/store/types'; +import type { + CanvasMetadata, + HrfLatentInterpolationMode, + LoRA, + RefImageState, +} from 'features/controlLayers/store/types'; +import { + zCanvasMetadata, + zCanvasReferenceImageState_OLD, + zHrfLatentInterpolationMode, + zRefImageState, +} from 'features/controlLayers/store/types'; import type { ModelIdentifierField, ModelType } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifier } from 'features/nodes/types/v2/common'; @@ -610,6 +624,87 @@ const DenoisingStrength: SingleMetadataHandler = { }; //#endregion DenoisingStrength +//#region High Resolution Fix +const HrfEnabled: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfEnabled', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_enabled'); + const parsed = z.boolean().parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfEnabled(value)); + }, + i18nKey: 'hrf.metadata.enabled', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfMethod: UnrecallableMetadataHandler = { + [UnrecallableMetadataKey]: true, + type: 'HrfMethod', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_method'); + const parsed = z.string().parse(raw); + return Promise.resolve(parsed); + }, + i18nKey: 'hrf.metadata.method', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: UnrecallableMetadataValueProps) => , +}; + +const HrfStrength: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfStrength', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_strength'); + const parsed = zParameterStrength.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfStrength(value)); + }, + i18nKey: 'hrf.metadata.strength', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfScale: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfScale', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_scale'); + const parsed = z.number().min(1).max(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfScale(value)); + }, + i18nKey: 'hrf.metadata.scale', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfLatentInterpolationModeMetadata: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfLatentInterpolationMode', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_latent_interpolation_mode'); + const parsed = zHrfLatentInterpolationMode.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfLatentInterpolationMode(value)); + }, + i18nKey: 'hrf.metadata.latentInterpolationMode', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; +//#endregion High Resolution Fix + //#region SeamlessX const SeamlessX: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1521,6 +1616,11 @@ export const ImageMetadataHandlers = { Seed, Steps, DenoisingStrength, + HrfEnabled, + HrfMethod, + HrfStrength, + HrfScale, + HrfLatentInterpolationMode: HrfLatentInterpolationModeMetadata, SeamlessX, SeamlessY, RefinerModel, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts new file mode 100644 index 00000000000..68f56fd77d7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -0,0 +1,197 @@ +import type { RootState } from 'app/store/store'; +import { describe, expect, it } from 'vitest'; + +import { addHighResFix } from './addHighResFix'; +import { Graph } from './Graph'; + +const buildState = (overrides?: { base?: string; hrfEnabled?: boolean; refinerModel?: unknown }): RootState => + ({ + ui: { activeTab: 'generate' }, + params: { + model: { key: 'model', name: 'model', base: overrides?.base ?? 'sdxl', type: 'main' }, + dimensions: { width: 512, height: 512 }, + hrfEnabled: overrides?.hrfEnabled ?? true, + hrfScale: 2, + hrfStrength: 0.35, + hrfLatentInterpolationMode: 'bilinear', + optimizedDenoisingEnabled: true, + refinerModel: overrides?.refinerModel ?? null, + }, + }) as unknown as RootState; + +const buildClassicGraph = () => { + const g = new Graph('test_graph'); + const seed = g.addNode({ id: 'seed', type: 'integer' }); + const noise = g.addNode({ id: 'noise', type: 'noise', use_cpu: true, width: 512, height: 512 }); + const denoise = g.addNode({ + id: 'denoise', + type: 'denoise_latents', + cfg_scale: 7.5, + scheduler: 'euler', + steps: 30, + }); + const l2i = g.addNode({ id: 'l2i', type: 'l2i', fp32: true }); + + g.addEdge(seed, 'value', noise, 'seed'); + g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + return { g, seed, noise, denoise, l2i }; +}; + +const addSDXLConditioning = (g: Graph, denoise: ReturnType['denoise']) => { + const posCond = g.addNode({ + id: 'pos_cond', + type: 'sdxl_compel_prompt', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + const posCollect = g.addNode({ id: 'pos_collect', type: 'collect' }); + + g.addEdge(posCond, 'conditioning', posCollect, 'item'); + g.addEdge(posCollect, 'collection', denoise, 'positive_conditioning'); + + return { posCond, posCollect }; +}; + +const buildTransformerGraph = () => { + const g = new Graph('test_transformer_graph'); + const seed = g.addNode({ id: 'seed', type: 'integer' }); + const denoise = g.addNode({ + id: 'sd3_denoise', + type: 'sd3_denoise', + cfg_scale: 4, + width: 512, + height: 512, + steps: 20, + denoising_start: 0, + denoising_end: 1, + }); + const l2i = g.addNode({ id: 'sd3_l2i', type: 'sd3_l2i' }); + + g.addEdge(seed, 'value', denoise, 'seed'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + return { g, seed, denoise, l2i }; +}; + +describe('addHighResFix', () => { + it('reroutes classic txt2img graphs through latent resize and a second denoise pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ g, state: buildState(), generationMode: 'txt2img', denoise, l2i, noise, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const resize = nodes.find((node) => node.type === 'lresize'); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfNoise = nodes.find((node) => node.id.startsWith('hrf_noise')); + + expect(resize).toMatchObject({ type: 'lresize', width: 1024, height: 1024, mode: 'bilinear' }); + expect(hrfDenoise).toMatchObject({ type: 'denoise_latents', denoising_start: 0.65, denoising_end: 1 }); + expect(hrfNoise).toMatchObject({ type: 'noise', width: 1024, height: 1024, use_cpu: true }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + width: 1024, + height: 1024, + hrf_enabled: true, + hrf_method: 'latent', + hrf_strength: 0.35, + hrf_scale: 2, + hrf_latent_interpolation_mode: 'bilinear', + }); + }); + + it('preserves the original graph and writes disabled metadata when HRF is off', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ + g, + state: buildState({ hrfEnabled: false }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + expect(Object.values(graph.nodes).some((node) => node.type === 'lresize')).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ hrf_enabled: false }); + }); + + it('reroutes transformer txt2img graphs through latent resize and a final-size second denoise pass', () => { + const { g, seed, denoise, l2i } = buildTransformerGraph(); + + addHighResFix({ g, state: buildState({ base: 'sd-3' }), generationMode: 'txt2img', denoise, l2i, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const resize = nodes.find((node) => node.type === 'lresize'); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_sd3_denoise')); + + if (!hrfDenoise) { + throw new Error('Expected HRF SD3 denoise node'); + } + + expect(resize).toMatchObject({ type: 'lresize', width: 1024, height: 1024, mode: 'bilinear' }); + expect(hrfDenoise).toMatchObject({ type: 'sd3_denoise', width: 1024, height: 1024, denoising_end: 1 }); + expect((hrfDenoise as { denoising_start: number }).denoising_start).toBeCloseTo(1 - 0.35 ** 0.2); + expect(nodes.some((node) => node.id.startsWith('hrf_noise'))).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: 'seed', field: 'value' }, + destination: { node_id: hrfDenoise.id, field: 'seed' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: 'sd3_denoise', field: 'latents' }, + destination: { node_id: 'sd3_l2i', field: 'latents' }, + }); + }); + + it('clones SDXL conditioning with final HRF dimensions for the second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const { posCond, posCollect } = addSDXLConditioning(g, denoise); + + addHighResFix({ g, state: buildState(), generationMode: 'txt2img', denoise, l2i, noise, seed }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const hrfPosCollect = nodes.find((node) => node.id.startsWith('hrf_sdxl_conditioning_collect')); + + if (!hrfDenoise || !hrfPosCond || !hrfPosCollect) { + throw new Error('Expected HRF denoise and cloned SDXL conditioning nodes'); + } + + expect(posCond).toMatchObject({ original_width: 512, original_height: 512 }); + expect(hrfPosCond).toMatchObject({ + type: 'sdxl_compel_prompt', + original_width: 1024, + original_height: 1024, + target_width: 1024, + target_height: 1024, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPosCond.id, field: 'conditioning' }, + destination: { node_id: hrfPosCollect.id, field: 'item' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPosCollect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field: 'positive_conditioning' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: posCollect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field: 'positive_conditioning' }, + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts new file mode 100644 index 00000000000..92760ad7e1d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -0,0 +1,251 @@ +import type { RootState } from 'app/store/store'; +import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import type { GenerationMode } from 'features/controlLayers/store/types'; +import type { BaseModelType } from 'features/nodes/types/common'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { DenoiseLatentsNodes, LatentToImageNodes } from 'features/nodes/util/graph/types'; +import { getGridSize } from 'features/parameters/util/optimalDimension'; +import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { AnyInvocationInputField, AnyInvocationOutputField, Invocation } from 'services/api/types'; +import { assert } from 'tsafe'; + +type AddHighResFixArg = { + g: Graph; + state: RootState; + generationMode: GenerationMode; + denoise: Invocation; + l2i: Invocation; + noise?: Invocation<'noise'>; + seed: Invocation<'integer'>; +}; + +const SKIPPED_DENOISE_INPUT_FIELDS = new Set(['latents', 'noise']); + +const getHighResFixFinalDimensions = (state: RootState) => { + const params = selectParamsSlice(state); + const gridSize = getGridSize(params.model?.base as BaseModelType | undefined); + + return { + width: Math.max(roundDownToMultiple(params.dimensions.width * params.hrfScale, gridSize), 64), + height: Math.max(roundDownToMultiple(params.dimensions.height * params.hrfScale, gridSize), 64), + }; +}; + +const getHighResFixDenoisingStartAndEnd = (state: RootState): { denoising_start: number; denoising_end: number } => { + const params = selectParamsSlice(state); + const model = params.model; + const { hrfStrength, optimizedDenoisingEnabled } = params; + + switch (model?.base) { + case 'sd-3': + case 'flux': + case 'flux2': { + if (model.base === 'flux' && 'variant' in model && model.variant === 'dev_fill') { + return { denoising_start: 0, denoising_end: 1 }; + } + + const exponent = optimizedDenoisingEnabled ? 0.2 : 1; + return { denoising_start: 1 - hrfStrength ** exponent, denoising_end: 1 }; + } + case 'anima': + case 'sd-1': + case 'sd-2': + case 'sdxl': + case 'cogview4': + case 'qwen-image': + case 'z-image': { + return { denoising_start: 1 - hrfStrength, denoising_end: 1 }; + } + default: { + assert(false, `Unsupported base for high resolution fix: ${model?.base}`); + } + } +}; + +const shouldApplyHighResFix = (state: RootState, generationMode: GenerationMode) => { + const params = selectParamsSlice(state); + const model = params.model; + + return ( + selectActiveTab(state) === 'generate' && + generationMode === 'txt2img' && + params.hrfEnabled && + model !== null && + model.base !== 'external' && + !params.refinerModel + ); +}; + +const shouldWriteDisabledMetadata = (state: RootState, generationMode: GenerationMode) => { + return selectActiveTab(state) === 'generate' && generationMode === 'txt2img'; +}; + +const cloneSDXLCompelPromptForFinalDimensions = ( + g: Graph, + node: Invocation<'sdxl_compel_prompt'>, + finalDimensions: { width: number; height: number } +) => { + const clone = g.addNode({ + ...node, + id: getPrefixedId(`hrf_${node.id.split(':')[0]}`), + original_width: finalDimensions.width, + original_height: finalDimensions.height, + target_width: finalDimensions.width, + target_height: finalDimensions.height, + }); + + for (const edge of g.getEdgesTo(node)) { + g.addEdgeFromObj({ + source: { ...edge.source }, + destination: { node_id: clone.id, field: edge.destination.field }, + }); + } + + return clone; +}; + +const cloneSDXLConditioningForFinalDimensions = ( + g: Graph, + sourceNodeId: string, + finalDimensions: { width: number; height: number } +) => { + const sourceNode = g.getNode(sourceNodeId); + + if (sourceNode.type === 'sdxl_compel_prompt') { + return { + nodeId: cloneSDXLCompelPromptForFinalDimensions(g, sourceNode, finalDimensions).id, + field: 'conditioning', + }; + } + + if (sourceNode.type !== 'collect') { + return null; + } + + const itemEdges = g.getEdgesTo(sourceNode).filter((edge) => edge.destination.field === 'item'); + const hasSDXLConditioning = itemEdges.some((edge) => g.getNode(edge.source.node_id).type === 'sdxl_compel_prompt'); + + if (!hasSDXLConditioning) { + return null; + } + + const collect = g.addNode({ + type: 'collect', + id: getPrefixedId('hrf_sdxl_conditioning_collect'), + }); + + for (const edge of itemEdges) { + const itemNode = g.getNode(edge.source.node_id); + const source = + itemNode.type === 'sdxl_compel_prompt' + ? cloneSDXLCompelPromptForFinalDimensions(g, itemNode, finalDimensions) + : itemNode; + + g.addEdgeFromObj({ + source: { node_id: source.id, field: edge.source.field }, + destination: { node_id: collect.id, field: 'item' }, + }); + } + + return { nodeId: collect.id, field: 'collection' }; +}; + +const copyDenoiseInputs = ( + g: Graph, + from: Invocation, + to: Invocation, + finalDimensions: { width: number; height: number } +) => { + for (const edge of g.getEdgesTo(from)) { + if (SKIPPED_DENOISE_INPUT_FIELDS.has(edge.destination.field)) { + continue; + } + + const finalSizeConditioning = ['positive_conditioning', 'negative_conditioning'].includes(edge.destination.field) + ? cloneSDXLConditioningForFinalDimensions(g, edge.source.node_id, finalDimensions) + : null; + + g.addEdgeFromObj({ + source: finalSizeConditioning + ? { + node_id: finalSizeConditioning.nodeId, + field: finalSizeConditioning.field as AnyInvocationOutputField, + } + : { ...edge.source }, + destination: { node_id: to.id, field: edge.destination.field as AnyInvocationInputField }, + }); + } +}; + +export const addHighResFix = ({ + g, + state, + generationMode, + denoise, + l2i, + noise, + seed, +}: AddHighResFixArg): Invocation => { + const params = selectParamsSlice(state); + + if (!shouldApplyHighResFix(state, generationMode)) { + if (shouldWriteDisabledMetadata(state, generationMode)) { + g.upsertMetadata({ hrf_enabled: false }); + } + return l2i; + } + + const finalDimensions = getHighResFixFinalDimensions(state); + const { denoising_start, denoising_end } = getHighResFixDenoisingStartAndEnd(state); + + const resizeLatents = g.addNode({ + id: getPrefixedId('hrf_resize_latents'), + type: 'lresize', + ...finalDimensions, + mode: params.hrfLatentInterpolationMode, + antialias: false, + }); + + const hrfDenoise = g.addNode({ + ...denoise, + id: getPrefixedId(`hrf_${denoise.type}`), + denoising_start, + denoising_end, + ...(denoise.type === 'denoise_latents' ? {} : finalDimensions), + } as Invocation); + + copyDenoiseInputs(g, denoise, hrfDenoise, finalDimensions); + + if (denoise.type === 'denoise_latents') { + assert(noise, 'SD1.5/SD2/SDXL high resolution fix graphs require a noise node'); + const classicHrfDenoise = hrfDenoise as Invocation<'denoise_latents'>; + + const hrfNoise = g.addNode({ + type: 'noise', + id: getPrefixedId('hrf_noise'), + use_cpu: noise.use_cpu, + ...finalDimensions, + }); + g.addEdge(seed, 'value', hrfNoise, 'seed'); + g.addEdge(hrfNoise, 'noise', classicHrfDenoise, 'noise'); + } + + g.deleteEdgesTo(l2i, ['latents']); + g.addEdge(denoise, 'latents', resizeLatents, 'latents'); + g.addEdge(resizeLatents, 'latents', hrfDenoise, 'latents'); + g.addEdge(hrfDenoise, 'latents', l2i, 'latents'); + + g.upsertMetadata({ + width: finalDimensions.width, + height: finalDimensions.height, + hrf_enabled: true, + hrf_method: 'latent', + hrf_strength: params.hrfStrength, + hrf_scale: params.hrfScale, + hrf_latent_interpolation_mode: params.hrfLatentInterpolationMode, + }); + + return l2i; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts index 0ab76ccefb6..b018f9871a1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.ts @@ -11,6 +11,7 @@ import { import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { addAnimaLoRAs } from 'features/nodes/util/graph/generation/addAnimaLoRAs'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -251,6 +252,17 @@ export const buildAnimaGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts index 6adee057545..45c34e605b5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts @@ -3,6 +3,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata } from 'features/controlLayers/store/selectors'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -167,6 +168,17 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index dafcd9310ec..e310942502c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -16,6 +16,7 @@ import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -582,6 +583,17 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index eae07532011..c8d28fd9c9a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -4,9 +4,9 @@ import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; -// import { addHRF } from 'features/nodes/util/graph/generation/addHRF'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -305,6 +305,18 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 9d65076a70d..93c432add10 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -4,6 +4,7 @@ import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addHighResFix } from 'features/nodes/util/graph/generation/addHighResFix'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; @@ -312,6 +313,18 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise>(false); } + if (generationMode === 'txt2img' && selectActiveTab(state) === 'generate') { + canvasOutput = addHighResFix({ + g, + state, + generationMode, + denoise, + l2i, + seed, + }); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts index 632006050e6..10afa43f8df 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -12,7 +12,7 @@ vi.mock('i18next', () => ({ import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; -import type { MainModelConfig } from 'services/api/types'; +import type { MainModelConfig, MainOrExternalModelConfig } from 'services/api/types'; import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; @@ -50,6 +50,14 @@ const flux2GGUF9BModel = { const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; +const externalModel = { + key: 'external', + hash: 'h', + name: 'External', + base: 'external', + type: 'external_image_generator', + format: 'external_api', +} as unknown as MainOrExternalModelConfig; const baseDynamicPrompts: DynamicPromptsState = { _version: 1, @@ -71,14 +79,18 @@ const baseParams = { positivePrompt: 'test', kleinVaeModel: null, kleinQwen3EncoderModel: null, + hrfEnabled: false, + refinerModel: null, } as unknown as ParamsState; // --- Helpers --- const buildGenerateTabArg = (overrides: { - model?: MainModelConfig | null; + model?: MainOrExternalModelConfig | null; kleinVaeModel?: unknown; kleinQwen3EncoderModel?: unknown; + hrfEnabled?: boolean; + refinerModel?: unknown; hasFlux2DiffusersVaeSource?: boolean; hasFlux2DiffusersQwen3Source?: boolean; }) => ({ @@ -88,6 +100,8 @@ const buildGenerateTabArg = (overrides: { ...baseParams, kleinVaeModel: overrides.kleinVaeModel ?? null, kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + hrfEnabled: overrides.hrfEnabled ?? false, + refinerModel: overrides.refinerModel ?? null, } as unknown as ParamsState, refImages: baseRefImages, loras: [], @@ -139,6 +153,12 @@ const hasFlux2VaeReason = (reasons: { content: string }[]) => const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); +const hasHrfExternalReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfExternalModelUnsupported')); + +const hasHrfRefinerReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfRefinerUnsupported')); + // --- Tests --- describe('FLUX.2 Klein readiness checks – generate tab', () => { @@ -221,6 +241,22 @@ describe('FLUX.2 Klein readiness checks – generate tab', () => { }); }); +describe('High Resolution Fix readiness checks - generate tab', () => { + it('errors when HRF is enabled for external models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: externalModel, hrfEnabled: true }) + ); + expect(hasHrfExternalReason(reasons)).toBe(true); + }); + + it('errors when HRF is enabled with SDXL Refiner', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ hrfEnabled: true, refinerModel: { key: 'refiner' } }) + ); + expect(hasHrfRefinerReason(reasons)).toBe(true); + }); +}); + describe('FLUX.2 Klein readiness checks – canvas tab', () => { it('no errors when main model is diffusers', () => { const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 5802a2aed5c..99efd3cecf3 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -358,6 +358,15 @@ export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { }); } + if (params.hrfEnabled) { + if (params.model?.base === 'external' || (model && isExternalApiModelConfig(model))) { + reasons.push({ content: i18n.t('parameters.invoke.hrfExternalModelUnsupported') }); + } + if (params.refinerModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfRefinerUnsupported') }); + } + } + return reasons; }; const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: { diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx new file mode 100644 index 00000000000..cee9236ae7c --- /dev/null +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -0,0 +1,247 @@ +import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import { + Combobox, + CompositeNumberInput, + CompositeSlider, + Flex, + FormControl, + FormControlGroup, + FormLabel, + StandaloneAccordion, + Switch, +} from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { + selectHrfEnabled, + selectHrfFinalDimensions, + selectHrfLatentInterpolationMode, + selectHrfScale, + selectHrfStrength, + selectIsRefinerModelSelected, + selectModelSupportsHrf, + setHrfEnabled, + setHrfLatentInterpolationMode, + setHrfScale, + setHrfStrength, +} from 'features/controlLayers/store/paramsSlice'; +import { zHrfLatentInterpolationMode } from 'features/controlLayers/store/types'; +import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +const SCALE_CONSTRAINTS = { + initial: 2, + sliderMin: 1, + sliderMax: 4, + numberInputMin: 1, + numberInputMax: 8, + coarseStep: 0.05, + fineStep: 0.01, +}; + +const STRENGTH_CONSTRAINTS = { + initial: 0.45, + sliderMin: 0, + sliderMax: 1, + numberInputMin: 0, + numberInputMax: 1, + coarseStep: 0.01, + fineStep: 0.01, +}; + +const selectBadges = createMemoizedSelector( + [selectHrfEnabled, selectHrfScale, selectHrfStrength, selectHrfFinalDimensions], + (enabled, scale, strength, finalDimensions) => { + if (!enabled) { + return EMPTY_ARRAY; + } + + return [`${scale}x`, `${Math.round(strength * 100)}%`, `${finalDimensions.width}x${finalDimensions.height}`]; + } +); + +const ParamHrfEnabled = memo(() => { + const dispatch = useAppDispatch(); + const enabled = useAppSelector(selectHrfEnabled); + const { t } = useTranslation(); + + const onChange = useCallback( + (event: ChangeEvent) => { + dispatch(setHrfEnabled(event.target.checked)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.enableHrf')} + + + + ); +}); + +ParamHrfEnabled.displayName = 'ParamHrfEnabled'; + +const ParamHrfScale = memo(() => { + const dispatch = useAppDispatch(); + const scale = useAppSelector(selectHrfScale); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfScale(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.scale')} + + + + + ); +}); + +ParamHrfScale.displayName = 'ParamHrfScale'; + +const ParamHrfStrength = memo(() => { + const dispatch = useAppDispatch(); + const strength = useAppSelector(selectHrfStrength); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfStrength(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.strength')} + + + + + ); +}); + +ParamHrfStrength.displayName = 'ParamHrfStrength'; + +const ParamHrfLatentInterpolationMode = memo(() => { + const dispatch = useAppDispatch(); + const mode = useAppSelector(selectHrfLatentInterpolationMode); + const { t } = useTranslation(); + + const options = useMemo( + () => [ + { label: t('hrf.bilinear'), value: 'bilinear' }, + { label: t('hrf.bicubic'), value: 'bicubic' }, + { label: t('hrf.nearest'), value: 'nearest' }, + { label: t('hrf.nearestExact'), value: 'nearest-exact' }, + { label: t('hrf.area'), value: 'area' }, + ], + [t] + ); + + const value = useMemo(() => options.find((o) => o.value === mode), [mode, options]); + + const onChange = useCallback( + (v) => { + const result = zHrfLatentInterpolationMode.safeParse(v?.value); + if (!result.success) { + return; + } + dispatch(setHrfLatentInterpolationMode(result.data)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.latentInterpolationMode')} + + + + ); +}); + +ParamHrfLatentInterpolationMode.displayName = 'ParamHrfLatentInterpolationMode'; + +export const HighResFixSettingsAccordion = memo(() => { + const { t } = useTranslation(); + const badges = useAppSelector(selectBadges); + const enabled = useAppSelector(selectHrfEnabled); + const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); + const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); + const { isOpen, onToggle } = useStandaloneAccordionToggle({ + id: 'high-res-fix-settings-generate-tab', + defaultIsOpen: false, + }); + + if (!modelSupportsHrf || isRefinerModelSelected) { + return null; + } + + return ( + + + + {enabled && ( + + + + + + )} + + + ); +}); + +HighResFixSettingsAccordion.displayName = 'HighResFixSettingsAccordion'; diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx index 06a122cc4ef..1b1d4bfa492 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx @@ -7,6 +7,7 @@ import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion'; import { ExternalSettingsAccordion } from 'features/settingsAccordions/components/ExternalSettingsAccordion/ExternalSettingsAccordion'; import { GenerationSettingsAccordion } from 'features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion'; +import { HighResFixSettingsAccordion } from 'features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion'; import { GenerateTabImageSettingsAccordion } from 'features/settingsAccordions/components/ImageSettingsAccordion/GenerateTabImageSettingsAccordion'; import { RefinerSettingsAccordion } from 'features/settingsAccordions/components/RefinerSettingsAccordion/RefinerSettingsAccordion'; import { StylePresetMenu } from 'features/stylePresets/components/StylePresetMenu'; @@ -43,6 +44,7 @@ export const ParametersPanelGenerate = memo(() => { + {isSDXL && } {!isCogview4 && !isExternal && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 182ae210e24..cebd7e9cca4 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7691,6 +7691,18 @@ export type components = { * @default null */ hrf_strength?: number | null; + /** + * Hrf Scale + * @description The high resolution fix latent upscale factor. + * @default null + */ + hrf_scale?: number | null; + /** + * Hrf Latent Interpolation Mode + * @description The latent interpolation mode used in the high resolution fix upscale pass. + * @default null + */ + hrf_latent_interpolation_mode?: string | null; /** * Positive Style Prompt * @description The positive style prompt parameter From e4ba8a65c2953acb1ba6fa89ee34c0cf533b14c5 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Mon, 4 May 2026 15:00:41 +0200 Subject: [PATCH 2/9] feat: add upscale model highres fix --- invokeai/app/invocations/metadata.py | 22 +- invokeai/frontend/web/public/locales/en.json | 11 +- .../controlLayers/store/paramsSlice.test.ts | 22 + .../controlLayers/store/paramsSlice.ts | 59 ++- .../src/features/controlLayers/store/types.ts | 20 +- .../ImageMetadataActions.tsx | 7 +- .../web/src/features/metadata/parsing.tsx | 106 ++++- .../graph/generation/addHighResFix.test.ts | 153 ++++++- .../util/graph/generation/addHighResFix.ts | 236 ++++++++++- .../features/queue/store/readiness.test.ts | 79 ++++ .../web/src/features/queue/store/readiness.ts | 13 + .../HighResFixSettingsAccordion.tsx | 390 +++++++++++++++++- .../frontend/web/src/services/api/schema.ts | 32 +- 13 files changed, 1116 insertions(+), 34 deletions(-) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 545d277e823..4a0a023e355 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -254,12 +254,32 @@ class CoreMetadataInvocation(BaseInvocation): ) hrf_scale: Optional[float] = InputField( default=None, - description="The high resolution fix latent upscale factor.", + description="The high resolution fix upscale factor.", ) hrf_latent_interpolation_mode: Optional[str] = InputField( default=None, description="The latent interpolation mode used in the high resolution fix upscale pass.", ) + hrf_upscale_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The Spandrel upscale model used in the high resolution fix upscale pass.", + ) + hrf_tile_controlnet_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The tile ControlNet model used in the high resolution fix upscale pass.", + ) + hrf_structure: Optional[float] = InputField( + default=None, + description="The high resolution fix tile ControlNet structure value.", + ) + hrf_tile_size: Optional[int] = InputField( + default=None, + description="The high resolution fix tiled processing tile size.", + ) + hrf_tile_overlap: Optional[int] = InputField( + default=None, + description="The high resolution fix tiled processing tile overlap.", + ) # SDXL positive_style_prompt: Optional[str] = InputField( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7d952d9e11d..3074c0b13e2 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -357,6 +357,7 @@ "upscaleMethod": "Upscale Method", "latentInterpolationMode": "Latent Interpolation", "latent": "Latent", + "upscaleModelMethod": "Upscale Model", "nearest": "Nearest", "bilinear": "Bilinear", "bicubic": "Bicubic", @@ -367,7 +368,12 @@ "strength": "High Resolution Fix Strength", "method": "High Resolution Fix Method", "scale": "High Resolution Fix Scale", - "latentInterpolationMode": "High Resolution Fix Latent Interpolation" + "latentInterpolationMode": "High Resolution Fix Latent Interpolation", + "upscaleModel": "High Resolution Fix Upscale Model", + "tileControlNetModel": "High Resolution Fix Tile ControlNet", + "structure": "High Resolution Fix Structure", + "tileSize": "High Resolution Fix Tile Size", + "tileOverlap": "High Resolution Fix Tile Overlap" } }, "prompt": { @@ -1689,6 +1695,9 @@ "incompatibleLoRAs": "Incompatible LoRA(s) added", "hrfExternalModelUnsupported": "High Resolution Fix is not supported for external models", "hrfRefinerUnsupported": "High Resolution Fix is not supported when SDXL Refiner is enabled", + "hrfUpscaleModelBaseUnsupported": "High Resolution Fix with an upscale model is supported for SD1.5 and SDXL models only", + "hrfUpscaleModelMissing": "High Resolution Fix needs an upscale model", + "hrfTileControlNetModelMissing": "High Resolution Fix needs a tile ControlNet model", "canvasIsFiltering": "Canvas is busy (filtering)", "canvasIsTransforming": "Canvas is busy (transforming)", "canvasIsRasterizing": "Canvas is busy (rasterizing)", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index c8719482a21..928b89f8d9f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -12,6 +12,7 @@ import { selectModelSupportsDimensions, selectModelSupportsGuidance, selectModelSupportsHrf, + selectModelSupportsHrfUpscaleModel, selectModelSupportsNegativePrompt, selectModelSupportsRefImages, selectModelSupportsSeed, @@ -148,4 +149,25 @@ describe('paramsSlice HRF selectors', () => { it('rounds final dimensions down to the model grid', () => { expect(selectHrfFinalDimensions.resultFunc(513, 512, 1.5, 'flux')).toEqual({ width: 768, height: 768 }); }); + + it('supports upscale-model HRF only for SD1.5 and SDXL models', () => { + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + }) + ).toBe(true); + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'flux', + hash: 'h', + name: 'FLUX', + base: 'flux', + type: 'main', + }) + ).toBe(false); + }); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 867295569e6..e6b93003ff6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -10,6 +10,7 @@ import { logout } from 'features/auth/store/authSlice'; import type { AspectRatioID, HrfLatentInterpolationMode, + HrfMethod, InfillMethod, ParamsState, RgbaColor, @@ -27,7 +28,7 @@ import { SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS, SUPPORTS_REF_IMAGES_BASE_MODELS, } from 'features/modelManagerV2/models'; -import type { BaseModelType } from 'features/nodes/types/common'; +import type { BaseModelType, ModelIdentifierField } from 'features/nodes/types/common'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -45,13 +46,14 @@ import type { ParameterPrecision, ParameterScheduler, ParameterSDXLRefinerModel, + ParameterSpandrelImageToImageModel, ParameterT5EncoderModel, ParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; import { getExternalPanelControl, hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema'; import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; -import type { AnyModelConfigWithExternal } from 'services/api/types'; +import type { AnyModelConfigWithExternal, ControlNetModelConfig } from 'services/api/types'; import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; @@ -124,6 +126,9 @@ const slice = createSlice({ setHrfEnabled: (state, action: PayloadAction) => { state.hrfEnabled = action.payload && !state.refinerModel; }, + setHrfMethod: (state, action: PayloadAction) => { + state.hrfMethod = action.payload; + }, setHrfScale: (state, action: PayloadAction) => { state.hrfScale = action.payload; }, @@ -133,6 +138,27 @@ const slice = createSlice({ setHrfLatentInterpolationMode: (state, action: PayloadAction) => { state.hrfLatentInterpolationMode = action.payload; }, + setHrfUpscaleModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfUpscaleModel.safeParse(action.payload); + if (result.success) { + state.hrfUpscaleModel = result.data; + } + }, + setHrfTileControlNetModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfTileControlNetModel.safeParse(action.payload); + if (result.success) { + state.hrfTileControlNetModel = result.data; + } + }, + setHrfStructure: (state, action: PayloadAction) => { + state.hrfStructure = action.payload; + }, + setHrfTileSize: (state, action: PayloadAction) => { + state.hrfTileSize = action.payload; + }, + setHrfTileOverlap: (state, action: PayloadAction) => { + state.hrfTileOverlap = action.payload; + }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -640,9 +666,15 @@ export const { setImg2imgStrength, setOptimizedDenoisingEnabled, setHrfEnabled, + setHrfMethod, setHrfScale, setHrfStrength, setHrfLatentInterpolationMode, + setHrfUpscaleModel, + setHrfTileControlNetModel, + setHrfStructure, + setHrfTileSize, + setHrfTileOverlap, setSeamlessXAxis, setSeamlessYAxis, setShouldRandomizeSeed, @@ -733,6 +765,17 @@ export const paramsSliceConfig: SliceConfig = { state.hrfLatentInterpolationMode = 'bicubic'; } + if (state._version === 3) { + // v3 -> v4, add Generate tab upscale-model high resolution fix settings + state._version = 4; + state.hrfMethod = 'latent'; + state.hrfUpscaleModel = null; + state.hrfTileControlNetModel = null; + state.hrfStructure = 0; + state.hrfTileSize = 1024; + state.hrfTileOverlap = 128; + } + return zParamsState.parse(state); }, }, @@ -800,9 +843,15 @@ export const selectInfillColorValue = createParamsSelector((params) => params.in export const selectImg2imgStrength = createParamsSelector((params) => params.img2imgStrength); export const selectOptimizedDenoisingEnabled = createParamsSelector((params) => params.optimizedDenoisingEnabled); export const selectHrfEnabled = createParamsSelector((params) => params.hrfEnabled); +export const selectHrfMethod = createParamsSelector((params) => params.hrfMethod); export const selectHrfScale = createParamsSelector((params) => params.hrfScale); export const selectHrfStrength = createParamsSelector((params) => params.hrfStrength); export const selectHrfLatentInterpolationMode = createParamsSelector((params) => params.hrfLatentInterpolationMode); +export const selectHrfUpscaleModel = createParamsSelector((params) => params.hrfUpscaleModel); +export const selectHrfTileControlNetModel = createParamsSelector((params) => params.hrfTileControlNetModel); +export const selectHrfStructure = createParamsSelector((params) => params.hrfStructure); +export const selectHrfTileSize = createParamsSelector((params) => params.hrfTileSize); +export const selectHrfTileOverlap = createParamsSelector((params) => params.hrfTileOverlap); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? ''); @@ -894,6 +943,12 @@ export const selectModelSupportsHrf = createSelector(selectModel, (model) => { } return true; }); +export const selectModelSupportsHrfUpscaleModel = createSelector(selectModel, (model) => { + if (!model) { + return false; + } + return model.base === 'sd-1' || model.base === 'sdxl'; +}); export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => { if (modelConfig && isExternalApiModelConfig(modelConfig)) { return getExternalPanelControl(modelConfig, 'image', 'seed'); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index e647e8999a8..1901cb036c3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -23,6 +23,7 @@ import { zParameterScheduler, zParameterSDXLRefinerModel, zParameterSeed, + zParameterSpandrelImageToImageModel, zParameterSteps, zParameterStrength, zParameterT5EncoderModel, @@ -751,8 +752,11 @@ export type InfillMethod = z.infer; export const zHrfLatentInterpolationMode = z.enum(['nearest', 'bilinear', 'bicubic', 'area', 'nearest-exact']); export type HrfLatentInterpolationMode = z.infer; +export const zHrfMethod = z.enum(['latent', 'upscale_model']); +export type HrfMethod = z.infer; + export const zParamsState = z.object({ - _version: z.literal(3), + _version: z.literal(4), maskBlur: z.number(), maskBlurMethod: zParameterMaskBlurMethod, canvasCoherenceMode: zParameterCanvasCoherenceMode, @@ -768,9 +772,15 @@ export const zParamsState = z.object({ img2imgStrength: zParameterStrength, optimizedDenoisingEnabled: z.boolean(), hrfEnabled: z.boolean(), + hrfMethod: zHrfMethod, hrfScale: z.number().min(1).max(8), hrfStrength: zParameterStrength, hrfLatentInterpolationMode: zHrfLatentInterpolationMode, + hrfUpscaleModel: zParameterSpandrelImageToImageModel.nullable(), + hrfTileControlNetModel: zModelIdentifierField.nullable(), + hrfStructure: z.number().min(-10).max(10), + hrfTileSize: z.number().int().min(8), + hrfTileOverlap: z.number().int().min(8), iterations: z.number(), scheduler: zParameterScheduler, fluxScheduler: zParameterFluxScheduler, @@ -840,7 +850,7 @@ export const zParamsState = z.object({ }); export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ - _version: 3, + _version: 4, maskBlur: 16, maskBlurMethod: 'box', canvasCoherenceMode: 'Gaussian Blur', @@ -856,9 +866,15 @@ export const getInitialParamsState = (): ParamsState => ({ img2imgStrength: 0.75, optimizedDenoisingEnabled: true, hrfEnabled: false, + hrfMethod: 'latent', hrfScale: 2, hrfStrength: 0.45, hrfLatentInterpolationMode: 'bicubic', + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + hrfStructure: 0, + hrfTileSize: 1024, + hrfTileOverlap: 128, iterations: 1, scheduler: 'dpmpp_3m_k', fluxScheduler: 'euler', diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 3ef175459ca..579af481627 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -50,10 +50,15 @@ export const ImageMetadataActions = memo((props: Props) => { - + + + + + + diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index dd6d580fecf..7da59ac2d9c 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -38,8 +38,14 @@ import { setGuidance, setHrfEnabled, setHrfLatentInterpolationMode, + setHrfMethod, setHrfScale, setHrfStrength, + setHrfStructure, + setHrfTileControlNetModel, + setHrfTileOverlap, + setHrfTileSize, + setHrfUpscaleModel, setImg2imgStrength, setRefinerCFGScale, setRefinerNegativeAestheticScore, @@ -67,6 +73,7 @@ import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; import type { CanvasMetadata, HrfLatentInterpolationMode, + HrfMethod as HrfMethodType, LoRA, RefImageState, } from 'features/controlLayers/store/types'; @@ -74,6 +81,7 @@ import { zCanvasMetadata, zCanvasReferenceImageState_OLD, zHrfLatentInterpolationMode, + zHrfMethod, zRefImageState, } from 'features/controlLayers/store/types'; import type { ModelIdentifierField, ModelType } from 'features/nodes/types/common'; @@ -641,17 +649,20 @@ const HrfEnabled: SingleMetadataHandler = { ValueComponent: ({ value }: SingleMetadataValueProps) => , }; -const HrfMethod: UnrecallableMetadataHandler = { - [UnrecallableMetadataKey]: true, +const HrfMethod: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'HrfMethod', parse: (metadata, _store) => { const raw = getProperty(metadata, 'hrf_method'); - const parsed = z.string().parse(raw); + const parsed = zHrfMethod.parse(raw); return Promise.resolve(parsed); }, + recall: (value, store) => { + store.dispatch(setHrfMethod(value)); + }, i18nKey: 'hrf.metadata.method', LabelComponent: MetadataLabel, - ValueComponent: ({ value }: UnrecallableMetadataValueProps) => , + ValueComponent: ({ value }: SingleMetadataValueProps) => , }; const HrfStrength: SingleMetadataHandler = { @@ -703,6 +714,88 @@ const HrfLatentInterpolationModeMetadata: SingleMetadataHandler ), }; + +const HrfUpscaleModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfUpscaleModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_upscale_model'); + return parseModelIdentifier(raw, store, 'spandrel_image_to_image'); + }, + recall: (value, store) => { + store.dispatch(setHrfUpscaleModel(value)); + }, + i18nKey: 'hrf.metadata.upscaleModel', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfTileControlNetModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlNetModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_tile_controlnet_model'); + return parseModelIdentifier(raw, store, 'controlnet'); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlNetModel(value)); + }, + i18nKey: 'hrf.metadata.tileControlNetModel', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfStructure: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfStructure', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_structure'); + const parsed = z.number().min(-10).max(10).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfStructure(value)); + }, + i18nKey: 'hrf.metadata.structure', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileSize: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileSize', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_size'); + const parsed = z.number().int().min(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileSize(value)); + }, + i18nKey: 'hrf.metadata.tileSize', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileOverlap: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileOverlap', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_overlap'); + const parsed = z.number().int().min(8).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileOverlap(value)); + }, + i18nKey: 'hrf.metadata.tileOverlap', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; //#endregion High Resolution Fix //#region SeamlessX @@ -1621,6 +1714,11 @@ export const ImageMetadataHandlers = { HrfStrength, HrfScale, HrfLatentInterpolationMode: HrfLatentInterpolationModeMetadata, + HrfUpscaleModel, + HrfTileControlNetModel, + HrfStructure, + HrfTileSize, + HrfTileOverlap, SeamlessX, SeamlessY, RefinerModel, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts index 68f56fd77d7..49a182fba11 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -1,19 +1,43 @@ import type { RootState } from 'app/store/store'; +import type { HrfMethod } from 'features/controlLayers/store/types'; import { describe, expect, it } from 'vitest'; import { addHighResFix } from './addHighResFix'; import { Graph } from './Graph'; -const buildState = (overrides?: { base?: string; hrfEnabled?: boolean; refinerModel?: unknown }): RootState => +const buildState = (overrides?: { + base?: string; + hrfEnabled?: boolean; + hrfMethod?: HrfMethod; + refinerModel?: unknown; +}): RootState => ({ ui: { activeTab: 'generate' }, params: { - model: { key: 'model', name: 'model', base: overrides?.base ?? 'sdxl', type: 'main' }, + model: { key: 'model', hash: 'model-hash', name: 'model', base: overrides?.base ?? 'sdxl', type: 'main' }, dimensions: { width: 512, height: 512 }, hrfEnabled: overrides?.hrfEnabled ?? true, + hrfMethod: overrides?.hrfMethod ?? 'latent', hrfScale: 2, hrfStrength: 0.35, hrfLatentInterpolationMode: 'bilinear', + hrfUpscaleModel: { + key: 'upscale', + hash: 'upscale-hash', + name: 'upscale', + base: 'any', + type: 'spandrel_image_to_image', + }, + hrfTileControlNetModel: { + key: 'tile', + hash: 'tile-hash', + name: 'tile', + base: overrides?.base ?? 'sdxl', + type: 'controlnet', + }, + hrfStructure: 0, + hrfTileSize: 1024, + hrfTileOverlap: 128, optimizedDenoisingEnabled: true, refinerModel: overrides?.refinerModel ?? null, }, @@ -23,6 +47,11 @@ const buildClassicGraph = () => { const g = new Graph('test_graph'); const seed = g.addNode({ id: 'seed', type: 'integer' }); const noise = g.addNode({ id: 'noise', type: 'noise', use_cpu: true, width: 512, height: 512 }); + const modelLoader = g.addNode({ + id: 'model_loader', + type: 'sdxl_model_loader', + model: { key: 'model', hash: 'model-hash', name: 'model', base: 'sdxl', type: 'main' }, + }); const denoise = g.addNode({ id: 'denoise', type: 'denoise_latents', @@ -34,9 +63,11 @@ const buildClassicGraph = () => { g.addEdge(seed, 'value', noise, 'seed'); g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(modelLoader, 'unet', denoise, 'unet'); + g.addEdge(modelLoader, 'vae', l2i, 'vae'); g.addEdge(denoise, 'latents', l2i, 'latents'); - return { g, seed, noise, denoise, l2i }; + return { g, seed, noise, modelLoader, denoise, l2i }; }; const addSDXLConditioning = (g: Graph, denoise: ReturnType['denoise']) => { @@ -194,4 +225,120 @@ describe('addHighResFix', () => { destination: { node_id: hrfDenoise.id, field: 'positive_conditioning' }, }); }); + + it('reroutes SDXL txt2img graphs through an upscale model, tiled encode, and tiled second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model' }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const intermediateL2i = nodes.find((node) => node.id.startsWith('hrf_intermediate_l2i')); + const spandrel = nodes.find((node) => node.type === 'spandrel_image_to_image_autoscale'); + const unsharp = nodes.find((node) => node.type === 'unsharp_mask'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + + if (!intermediateL2i || !spandrel || !unsharp || !i2l || !tiledDenoise) { + throw new Error('Expected upscale-model HRF nodes'); + } + + expect(nodes.some((node) => node.type === 'lresize')).toBe(false); + expect(intermediateL2i).toMatchObject({ type: 'l2i', is_intermediate: true }); + expect(spandrel).toMatchObject({ + type: 'spandrel_image_to_image_autoscale', + image_to_image_model: { key: 'upscale' }, + scale: 2, + tile_size: 1024, + fit_to_multiple_of_8: true, + }); + expect(i2l).toMatchObject({ type: 'i2l', tiled: true, tile_size: 1024 }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); + expect(tiledDenoise).toMatchObject({ + type: 'tiled_multi_diffusion_denoise_latents', + tile_height: 1024, + tile_width: 1024, + tile_overlap: 128, + denoising_start: 0.65, + denoising_end: 1, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: 'denoise', field: 'latents' }, + destination: { node_id: intermediateL2i.id, field: 'latents' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: tiledDenoise.id, field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + hrf_enabled: true, + hrf_method: 'upscale_model', + hrf_upscale_model: { key: 'upscale' }, + hrf_tile_controlnet_model: { key: 'tile' }, + hrf_structure: 0, + hrf_tile_size: 1024, + hrf_tile_overlap: 128, + }); + }); + + it('uses a regular second denoise for upscale-model HRF when reference image adapters are connected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + const ipAdapter = g.addNode({ + id: 'ip_adapter', + type: 'ip_adapter', + weight: 1, + method: 'full', + ip_adapter_model: { key: 'ip', hash: 'ip-hash', name: 'ip', base: 'sdxl', type: 'ip_adapter' }, + clip_vision_model: 'ViT-H', + begin_step_percent: 0, + end_step_percent: 1, + image: { image_name: 'test' }, + }); + const ipAdapterCollector = g.addNode({ id: 'ip_adapter_collector', type: 'collect' }); + g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollector, 'item'); + g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model' }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + + if (!classicHrfDenoise) { + throw new Error('Expected classic HRF denoise node'); + } + + expect(nodes.some((node) => node.type === 'tiled_multi_diffusion_denoise_latents')).toBe(false); + expect(graph.edges).toContainEqual({ + source: { node_id: ipAdapterCollector.id, field: 'collection' }, + destination: { node_id: classicHrfDenoise.id, field: 'ip_adapter' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: classicHrfDenoise.id, field: 'latents' }, + destination: { node_id: 'l2i', field: 'latents' }, + }); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts index 92760ad7e1d..120e818a1f7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -8,7 +8,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { DenoiseLatentsNodes, LatentToImageNodes } from 'features/nodes/util/graph/types'; import { getGridSize } from 'features/parameters/util/optimalDimension'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; -import type { AnyInvocationInputField, AnyInvocationOutputField, Invocation } from 'services/api/types'; +import type { AnyInvocation, AnyInvocationInputField, AnyInvocationOutputField, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; type AddHighResFixArg = { @@ -22,6 +22,8 @@ type AddHighResFixArg = { }; const SKIPPED_DENOISE_INPUT_FIELDS = new Set(['latents', 'noise']); +const TILED_DENOISE_INPUT_FIELDS = new Set(['positive_conditioning', 'negative_conditioning', 'unet']); +const SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS = new Set(['latents', 'metadata']); const getHighResFixFinalDimensions = (state: RootState) => { const params = selectParamsSlice(state); @@ -154,14 +156,18 @@ const cloneSDXLConditioningForFinalDimensions = ( const copyDenoiseInputs = ( g: Graph, - from: Invocation, - to: Invocation, - finalDimensions: { width: number; height: number } + from: AnyInvocation, + to: AnyInvocation, + finalDimensions: { width: number; height: number }, + allowedInputFields?: Set ) => { for (const edge of g.getEdgesTo(from)) { if (SKIPPED_DENOISE_INPUT_FIELDS.has(edge.destination.field)) { continue; } + if (allowedInputFields && !allowedInputFields.has(edge.destination.field)) { + continue; + } const finalSizeConditioning = ['positive_conditioning', 'negative_conditioning'].includes(edge.destination.field) ? cloneSDXLConditioningForFinalDimensions(g, edge.source.node_id, finalDimensions) @@ -179,24 +185,78 @@ const copyDenoiseInputs = ( } }; -export const addHighResFix = ({ +const copyInputEdges = (g: Graph, from: AnyInvocation, to: AnyInvocation, skippedInputFields: Set) => { + for (const edge of g.getEdgesTo(from)) { + if (skippedInputFields.has(edge.destination.field)) { + continue; + } + g.addEdgeFromObj({ + source: { ...edge.source }, + destination: { node_id: to.id, field: edge.destination.field }, + }); + } +}; + +const hasUnsupportedTiledDenoiseInputs = (g: Graph, denoise: Invocation<'denoise_latents'>) => { + return g.getEdgesTo(denoise).some((edge) => { + const field = edge.destination.field; + return !SKIPPED_DENOISE_INPUT_FIELDS.has(field) && !TILED_DENOISE_INPUT_FIELDS.has(field); + }); +}; + +const addTileControlNets = ( + g: Graph, + hrfDenoise: AnyInvocation, + imageSource: Invocation<'unsharp_mask'>, + tileControlNetModel: NonNullable['hrfTileControlNetModel']>, + structure: number +) => { + const controlNet1 = g.addNode({ + id: getPrefixedId('hrf_controlnet_1'), + type: 'controlnet', + control_model: tileControlNetModel, + control_mode: 'balanced', + resize_mode: 'just_resize', + control_weight: (structure + 10) * 0.0325 + 0.3, + begin_step_percent: 0, + end_step_percent: (structure + 10) * 0.025 + 0.3, + }); + + const controlNet2 = g.addNode({ + id: getPrefixedId('hrf_controlnet_2'), + type: 'controlnet', + control_model: tileControlNetModel, + control_mode: 'balanced', + resize_mode: 'just_resize', + control_weight: ((structure + 10) * 0.0325 + 0.15) * 0.45, + begin_step_percent: (structure + 10) * 0.025 + 0.3, + end_step_percent: 0.85, + }); + + const controlNetCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('hrf_controlnet_collector'), + }); + + g.addEdge(imageSource, 'image', controlNet1, 'image'); + g.addEdge(imageSource, 'image', controlNet2, 'image'); + g.addEdge(controlNet1, 'control', controlNetCollector, 'item'); + g.addEdge(controlNet2, 'control', controlNetCollector, 'item'); + g.addEdgeFromObj({ + source: { node_id: controlNetCollector.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field: 'control' }, + }); +}; + +const addLatentHighResFix = ({ g, state, - generationMode, denoise, l2i, noise, seed, }: AddHighResFixArg): Invocation => { const params = selectParamsSlice(state); - - if (!shouldApplyHighResFix(state, generationMode)) { - if (shouldWriteDisabledMetadata(state, generationMode)) { - g.upsertMetadata({ hrf_enabled: false }); - } - return l2i; - } - const finalDimensions = getHighResFixFinalDimensions(state); const { denoising_start, denoising_end } = getHighResFixDenoisingStartAndEnd(state); @@ -249,3 +309,151 @@ export const addHighResFix = ({ return l2i; }; + +const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddHighResFixArg): Invocation<'l2i'> => { + const params = selectParamsSlice(state); + const finalDimensions = getHighResFixFinalDimensions(state); + const { denoising_start, denoising_end } = getHighResFixDenoisingStartAndEnd(state); + + assert(params.model?.base === 'sd-1' || params.model?.base === 'sdxl', 'Upscale model HRF supports SD1.5 and SDXL'); + assert(params.hrfUpscaleModel, 'Upscale model HRF requires a Spandrel upscale model'); + assert(params.hrfTileControlNetModel, 'Upscale model HRF requires a tile ControlNet model'); + assert(denoise.type === 'denoise_latents', 'Upscale model HRF requires classic SD denoise latents'); + assert(l2i.type === 'l2i', 'Upscale model HRF requires classic SD latents-to-image'); + assert(noise, 'Upscale model HRF requires a noise node'); + + const intermediateL2i = g.addNode({ + ...l2i, + id: getPrefixedId('hrf_intermediate_l2i'), + is_intermediate: true, + board: undefined, + metadata: undefined, + }); + copyInputEdges(g, l2i, intermediateL2i, SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS); + + const spandrelAutoscale = g.addNode({ + type: 'spandrel_image_to_image_autoscale', + id: getPrefixedId('hrf_spandrel_autoscale'), + image_to_image_model: params.hrfUpscaleModel, + fit_to_multiple_of_8: true, + scale: params.hrfScale, + tile_size: params.hrfTileSize, + }); + + const unsharpMask = g.addNode({ + type: 'unsharp_mask', + id: getPrefixedId('hrf_unsharp_2'), + radius: 2, + strength: 60, + }); + + const i2l = g.addNode({ + type: 'i2l', + id: getPrefixedId('hrf_i2l'), + fp32: l2i.fp32, + tile_size: params.hrfTileSize, + tiled: true, + }); + copyInputEdges(g, l2i, i2l, SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS); + g.updateNode(l2i, { tile_size: params.hrfTileSize, tiled: true }); + + const hrfNoise = g.addNode({ + type: 'noise', + id: getPrefixedId('hrf_noise'), + use_cpu: noise.use_cpu, + }); + + const useClassicDenoise = hasUnsupportedTiledDenoiseInputs(g, denoise); + const hrfDenoise = useClassicDenoise + ? g.addNode({ + ...denoise, + id: getPrefixedId('hrf_denoise_latents'), + denoising_start, + denoising_end, + }) + : g.addNode({ + type: 'tiled_multi_diffusion_denoise_latents', + id: getPrefixedId('hrf_tiled_multidiffusion_denoise_latents'), + tile_height: params.hrfTileSize, + tile_width: params.hrfTileSize, + tile_overlap: params.hrfTileOverlap, + steps: denoise.steps, + cfg_scale: denoise.cfg_scale, + cfg_rescale_multiplier: denoise.cfg_rescale_multiplier, + scheduler: denoise.scheduler, + denoising_start, + denoising_end, + }); + + copyDenoiseInputs( + g, + denoise, + hrfDenoise, + finalDimensions, + useClassicDenoise ? undefined : TILED_DENOISE_INPUT_FIELDS + ); + addTileControlNets(g, hrfDenoise, unsharpMask, params.hrfTileControlNetModel, params.hrfStructure); + + g.deleteEdgesTo(l2i, ['latents']); + g.addEdge(denoise, 'latents', intermediateL2i, 'latents'); + g.addEdge(intermediateL2i, 'image', spandrelAutoscale, 'image'); + g.addEdge(spandrelAutoscale, 'image', unsharpMask, 'image'); + g.addEdge(unsharpMask, 'image', i2l, 'image'); + g.addEdge(seed, 'value', hrfNoise, 'seed'); + g.addEdgeFromObj({ + source: { node_id: unsharpMask.id, field: 'width' }, + destination: { node_id: hrfNoise.id, field: 'width' }, + }); + g.addEdgeFromObj({ + source: { node_id: unsharpMask.id, field: 'height' }, + destination: { node_id: hrfNoise.id, field: 'height' }, + }); + g.addEdgeFromObj({ + source: { node_id: hrfNoise.id, field: 'noise' }, + destination: { node_id: hrfDenoise.id, field: 'noise' }, + }); + g.addEdgeFromObj({ + source: { node_id: i2l.id, field: 'latents' }, + destination: { node_id: hrfDenoise.id, field: 'latents' }, + }); + g.addEdgeFromObj({ + source: { node_id: hrfDenoise.id, field: 'latents' }, + destination: { node_id: l2i.id, field: 'latents' }, + }); + + g.upsertMetadata({ + width: finalDimensions.width, + height: finalDimensions.height, + hrf_enabled: true, + hrf_method: 'upscale_model', + hrf_strength: params.hrfStrength, + hrf_scale: params.hrfScale, + hrf_upscale_model: params.hrfUpscaleModel, + hrf_tile_controlnet_model: params.hrfTileControlNetModel, + hrf_structure: params.hrfStructure, + hrf_tile_size: params.hrfTileSize, + hrf_tile_overlap: params.hrfTileOverlap, + }); + g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width'); + g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height'); + + return l2i; +}; + +export const addHighResFix = (arg: AddHighResFixArg): Invocation => { + const { g, state, generationMode, l2i } = arg; + const params = selectParamsSlice(state); + + if (!shouldApplyHighResFix(state, generationMode)) { + if (shouldWriteDisabledMetadata(state, generationMode)) { + g.upsertMetadata({ hrf_enabled: false }); + } + return l2i; + } + + if (params.hrfMethod === 'upscale_model') { + return addUpscaleModelHighResFix(arg); + } + + return addLatentHighResFix(arg); +}; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts index 10afa43f8df..cd06a052464 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -59,6 +59,31 @@ const externalModel = { format: 'external_api', } as unknown as MainOrExternalModelConfig; +const sdxlModel = { + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + format: 'checkpoint', +} as unknown as MainModelConfig; + +const upscaleModel = { + key: 'upscale', + hash: 'h', + name: 'Upscale', + base: 'any', + type: 'spandrel_image_to_image', +}; + +const tileControlNetModel = { + key: 'tile', + hash: 'h', + name: 'Tile ControlNet', + base: 'sdxl', + type: 'controlnet', +}; + const baseDynamicPrompts: DynamicPromptsState = { _version: 1, maxPrompts: 100, @@ -80,6 +105,9 @@ const baseParams = { kleinVaeModel: null, kleinQwen3EncoderModel: null, hrfEnabled: false, + hrfMethod: 'latent', + hrfUpscaleModel: null, + hrfTileControlNetModel: null, refinerModel: null, } as unknown as ParamsState; @@ -90,6 +118,9 @@ const buildGenerateTabArg = (overrides: { kleinVaeModel?: unknown; kleinQwen3EncoderModel?: unknown; hrfEnabled?: boolean; + hrfMethod?: ParamsState['hrfMethod']; + hrfUpscaleModel?: unknown; + hrfTileControlNetModel?: unknown; refinerModel?: unknown; hasFlux2DiffusersVaeSource?: boolean; hasFlux2DiffusersQwen3Source?: boolean; @@ -101,6 +132,9 @@ const buildGenerateTabArg = (overrides: { kleinVaeModel: overrides.kleinVaeModel ?? null, kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, hrfEnabled: overrides.hrfEnabled ?? false, + hrfMethod: overrides.hrfMethod ?? 'latent', + hrfUpscaleModel: overrides.hrfUpscaleModel ?? null, + hrfTileControlNetModel: overrides.hrfTileControlNetModel ?? null, refinerModel: overrides.refinerModel ?? null, } as unknown as ParamsState, refImages: baseRefImages, @@ -159,6 +193,15 @@ const hasHrfExternalReason = (reasons: { content: string }[]) => const hasHrfRefinerReason = (reasons: { content: string }[]) => reasons.some((r) => r.content.includes('hrfRefinerUnsupported')); +const hasHrfUpscaleModelBaseReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfUpscaleModelBaseUnsupported')); + +const hasHrfUpscaleModelMissingReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfUpscaleModelMissing')); + +const hasHrfTileControlNetMissingReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfTileControlNetModelMissing')); + // --- Tests --- describe('FLUX.2 Klein readiness checks – generate tab', () => { @@ -255,6 +298,42 @@ describe('High Resolution Fix readiness checks - generate tab', () => { ); expect(hasHrfRefinerReason(reasons)).toBe(true); }); + + it('errors when upscale-model HRF is enabled for unsupported model bases', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2DiffusersModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: tileControlNetModel, + }) + ); + expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(true); + }); + + it('errors when upscale-model HRF is missing required models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: sdxlModel, hrfEnabled: true, hrfMethod: 'upscale_model' }) + ); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(true); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(true); + }); + + it('does not error when upscale-model HRF has required SDXL models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: tileControlNetModel, + }) + ); + expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(false); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + }); }); describe('FLUX.2 Klein readiness checks – canvas tab', () => { diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 99efd3cecf3..c81995a7b78 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -365,6 +365,19 @@ export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { if (params.refinerModel) { reasons.push({ content: i18n.t('parameters.invoke.hrfRefinerUnsupported') }); } + if (params.hrfMethod === 'upscale_model') { + if (model && !isExternalApiModelConfig(model) && !['sd-1', 'sdxl'].includes(model.base)) { + reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelBaseUnsupported') }); + } + if (!params.hrfUpscaleModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelMissing') }); + } + if (!params.hrfTileControlNetModel) { + reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); + } else if (model && !isExternalApiModelConfig(model) && params.hrfTileControlNetModel.base !== model.base) { + reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); + } + } } return reasons; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index cee9236ae7c..ea1ec3b1154 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -1,5 +1,8 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library'; import { + Box, + Button, + ButtonGroup, Combobox, CompositeNumberInput, CompositeSlider, @@ -9,29 +12,53 @@ import { FormLabel, StandaloneAccordion, Switch, + Tooltip, } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import { useModelCombobox } from 'common/hooks/useModelCombobox'; import { + selectBase, selectHrfEnabled, selectHrfFinalDimensions, selectHrfLatentInterpolationMode, + selectHrfMethod, selectHrfScale, selectHrfStrength, + selectHrfStructure, + selectHrfTileControlNetModel, + selectHrfTileOverlap, + selectHrfTileSize, + selectHrfUpscaleModel, selectIsRefinerModelSelected, selectModelSupportsHrf, setHrfEnabled, setHrfLatentInterpolationMode, + setHrfMethod, setHrfScale, setHrfStrength, + setHrfStructure, + setHrfTileControlNetModel, + setHrfTileOverlap, + setHrfTileSize, + setHrfUpscaleModel, } from 'features/controlLayers/store/paramsSlice'; -import { zHrfLatentInterpolationMode } from 'features/controlLayers/store/types'; +import { zHrfLatentInterpolationMode, zHrfMethod } from 'features/controlLayers/store/types'; +import { ModelPicker } from 'features/parameters/components/ModelPicker'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import type { ChangeEvent } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; +import { useControlNetModels, useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType'; +import { + type ControlNetModelConfig, + isControlNetModelConfig, + type SpandrelImageToImageModelConfig, +} from 'services/api/types'; const SCALE_CONSTRAINTS = { initial: 2, @@ -53,14 +80,78 @@ const STRENGTH_CONSTRAINTS = { fineStep: 0.01, }; +const STRUCTURE_CONSTRAINTS = { + initial: 0, + sliderMin: -10, + sliderMax: 10, + numberInputMin: -10, + numberInputMax: 10, + coarseStep: 1, + fineStep: 1, +}; + +const TILE_SIZE_CONSTRAINTS = { + initial: 1024, + sliderMin: 512, + sliderMax: 1536, + numberInputMin: 512, + numberInputMax: 1536, + coarseStep: 64, + fineStep: 64, +}; + +const TILE_OVERLAP_CONSTRAINTS = { + initial: 128, + sliderMin: 32, + sliderMax: 256, + numberInputMin: 16, + numberInputMax: 512, + coarseStep: 16, + fineStep: 8, +}; + +const selectHrfTileControlNetModelConfig = createSelector( + selectModelConfigsQuery, + selectHrfTileControlNetModel, + (modelConfigs, modelIdentifierField) => { + if (!modelConfigs.data || !modelIdentifierField) { + return null; + } + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, modelIdentifierField.key); + if (!modelConfig || !isControlNetModelConfig(modelConfig)) { + return null; + } + return modelConfig; + } +); + const selectBadges = createMemoizedSelector( - [selectHrfEnabled, selectHrfScale, selectHrfStrength, selectHrfFinalDimensions], - (enabled, scale, strength, finalDimensions) => { + [ + selectHrfEnabled, + selectHrfMethod, + selectHrfScale, + selectHrfStrength, + selectHrfFinalDimensions, + selectHrfUpscaleModel, + ], + (enabled, method, scale, strength, finalDimensions, upscaleModel) => { if (!enabled) { return EMPTY_ARRAY; } - return [`${scale}x`, `${Math.round(strength * 100)}%`, `${finalDimensions.width}x${finalDimensions.height}`]; + const methodBadge = method === 'upscale_model' ? 'Model' : 'Latent'; + const badges = [ + methodBadge, + `${scale}x`, + `${Math.round(strength * 100)}%`, + `${finalDimensions.width}x${finalDimensions.height}`, + ]; + + if (method === 'upscale_model' && upscaleModel) { + badges.push(upscaleModel.name); + } + + return badges; } ); @@ -88,6 +179,38 @@ const ParamHrfEnabled = memo(() => { ParamHrfEnabled.displayName = 'ParamHrfEnabled'; +const ParamHrfMethod = memo(() => { + const dispatch = useAppDispatch(); + const method = useAppSelector(selectHrfMethod); + const { t } = useTranslation(); + + const onClickLatent = useCallback(() => { + dispatch(setHrfMethod('latent')); + }, [dispatch]); + + const onClickUpscaleModel = useCallback(() => { + dispatch(setHrfMethod('upscale_model')); + }, [dispatch]); + + return ( + + + {t('hrf.upscaleMethod')} + + + + + + + ); +}); + +ParamHrfMethod.displayName = 'ParamHrfMethod'; + const ParamHrfScale = memo(() => { const dispatch = useAppDispatch(); const scale = useAppSelector(selectHrfScale); @@ -213,10 +336,254 @@ const ParamHrfLatentInterpolationMode = memo(() => { ParamHrfLatentInterpolationMode.displayName = 'ParamHrfLatentInterpolationMode'; +const ParamHrfUpscaleModel = memo(() => { + const { t } = useTranslation(); + const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels(); + const model = useAppSelector(selectHrfUpscaleModel); + const dispatch = useAppDispatch(); + + const tooltipLabel = useMemo(() => { + if (!modelConfigs.length || !model) { + return; + } + return modelConfigs.find((m) => m.key === model.key)?.description; + }, [modelConfigs, model]); + + const _onChange = useCallback( + (v: SpandrelImageToImageModelConfig | null) => { + dispatch(setHrfUpscaleModel(v)); + }, + [dispatch] + ); + + const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({ + modelConfigs, + onChange: _onChange, + selectedModel: model, + isLoading, + }); + + return ( + + + {t('upscaling.upscaleModel')} + + + + + + + + + + ); +}); + +ParamHrfUpscaleModel.displayName = 'ParamHrfUpscaleModel'; + +const ParamHrfTileControlNetModel = memo(() => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const tileControlNetModel = useAppSelector(selectHrfTileControlNetModelConfig); + const currentBaseModel = useAppSelector(selectBase); + const [modelConfigs, { isLoading }] = useControlNetModels(); + + const onChange = useCallback( + (controlNetModel: ControlNetModelConfig) => { + dispatch(setHrfTileControlNetModel(controlNetModel)); + }, + [dispatch] + ); + + const filteredModelConfigs = useMemo(() => { + if (!currentBaseModel || !['sd-1', 'sdxl'].includes(currentBaseModel)) { + return []; + } + return modelConfigs.filter((model) => { + const isCompatible = model.base === currentBaseModel; + const isTileOrMultiModel = + model.name.toLowerCase().includes('tile') || model.name.toLowerCase().includes('union'); + return isCompatible && isTileOrMultiModel; + }); + }, [modelConfigs, currentBaseModel]); + + const getIsOptionDisabled = useCallback( + (model: ControlNetModelConfig): boolean => { + return currentBaseModel !== model.base; + }, + [currentBaseModel] + ); + + return ( + + + {t('upscaling.tileControl')} + + + + ); +}); + +ParamHrfTileControlNetModel.displayName = 'ParamHrfTileControlNetModel'; + +const ParamHrfStructure = memo(() => { + const dispatch = useAppDispatch(); + const structure = useAppSelector(selectHrfStructure); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfStructure(v)); + }, + [dispatch] + ); + + return ( + + + {t('upscaling.structure')} + + + + + ); +}); + +ParamHrfStructure.displayName = 'ParamHrfStructure'; + +const ParamHrfTileSize = memo(() => { + const dispatch = useAppDispatch(); + const tileSize = useAppSelector(selectHrfTileSize); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileSize(v)); + }, + [dispatch] + ); + + return ( + + + {t('upscaling.tileSize')} + + + + + ); +}); + +ParamHrfTileSize.displayName = 'ParamHrfTileSize'; + +const ParamHrfTileOverlap = memo(() => { + const dispatch = useAppDispatch(); + const tileOverlap = useAppSelector(selectHrfTileOverlap); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileOverlap(v)); + }, + [dispatch] + ); + + return ( + + + {t('upscaling.tileOverlap')} + + + + + ); +}); + +ParamHrfTileOverlap.displayName = 'ParamHrfTileOverlap'; + export const HighResFixSettingsAccordion = memo(() => { const { t } = useTranslation(); const badges = useAppSelector(selectBadges); const enabled = useAppSelector(selectHrfEnabled); + const method = useAppSelector(selectHrfMethod); const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); const { isOpen, onToggle } = useStandaloneAccordionToggle({ @@ -224,6 +591,8 @@ export const HighResFixSettingsAccordion = memo(() => { defaultIsOpen: false, }); + const parsedMethod = zHrfMethod.parse(method); + if (!modelSupportsHrf || isRefinerModelSelected) { return null; } @@ -234,9 +603,20 @@ export const HighResFixSettingsAccordion = memo(() => { {enabled && ( + - + {parsedMethod === 'latent' ? ( + + ) : ( + <> + + + + + + + )} )} diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index cebd7e9cca4..b6e4716994b 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7693,7 +7693,7 @@ export type components = { hrf_strength?: number | null; /** * Hrf Scale - * @description The high resolution fix latent upscale factor. + * @description The high resolution fix upscale factor. * @default null */ hrf_scale?: number | null; @@ -7703,6 +7703,36 @@ export type components = { * @default null */ hrf_latent_interpolation_mode?: string | null; + /** + * Hrf Upscale Model + * @description The Spandrel upscale model used in the high resolution fix upscale pass. + * @default null + */ + hrf_upscale_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Tile Controlnet Model + * @description The tile ControlNet model used in the high resolution fix upscale pass. + * @default null + */ + hrf_tile_controlnet_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Structure + * @description The high resolution fix tile ControlNet structure value. + * @default null + */ + hrf_structure?: number | null; + /** + * Hrf Tile Size + * @description The high resolution fix tiled processing tile size. + * @default null + */ + hrf_tile_size?: number | null; + /** + * Hrf Tile Overlap + * @description The high resolution fix tiled processing tile overlap. + * @default null + */ + hrf_tile_overlap?: number | null; /** * Positive Style Prompt * @description The positive style prompt parameter From 086af093ffb074cd92515d2546f3760cf5b9b24b Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Mon, 4 May 2026 15:25:48 +0200 Subject: [PATCH 3/9] fix: tune highres fix tile control --- invokeai/app/invocations/metadata.py | 4 + invokeai/frontend/web/public/locales/en.json | 2 + .../controlLayers/store/paramsSlice.ts | 11 +++ .../src/features/controlLayers/store/types.ts | 6 +- .../web/src/features/metadata/parsing.tsx | 18 ++++ .../graph/generation/addHighResFix.test.ts | 21 ++++- .../util/graph/generation/addHighResFix.ts | 91 +++++++++++++------ .../HighResFixSettingsAccordion.tsx | 59 ++++++++++++ .../frontend/web/src/services/api/schema.ts | 6 ++ 9 files changed, 185 insertions(+), 33 deletions(-) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 4a0a023e355..ad0dd98b075 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -272,6 +272,10 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The high resolution fix tile ControlNet structure value.", ) + hrf_tile_control_end: Optional[float] = InputField( + default=None, + description="The high resolution fix tile ControlNet end step percentage.", + ) hrf_tile_size: Optional[int] = InputField( default=None, description="The high resolution fix tiled processing tile size.", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3074c0b13e2..74adf28b9dd 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -356,6 +356,7 @@ "strength": "Denoise Strength", "upscaleMethod": "Upscale Method", "latentInterpolationMode": "Latent Interpolation", + "tileControlEnd": "Tile Control Duration", "latent": "Latent", "upscaleModelMethod": "Upscale Model", "nearest": "Nearest", @@ -372,6 +373,7 @@ "upscaleModel": "High Resolution Fix Upscale Model", "tileControlNetModel": "High Resolution Fix Tile ControlNet", "structure": "High Resolution Fix Structure", + "tileControlEnd": "High Resolution Fix Tile Control Duration", "tileSize": "High Resolution Fix Tile Size", "tileOverlap": "High Resolution Fix Tile Overlap" } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index e6b93003ff6..f8ccfd41231 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -153,6 +153,9 @@ const slice = createSlice({ setHrfStructure: (state, action: PayloadAction) => { state.hrfStructure = action.payload; }, + setHrfTileControlEnd: (state, action: PayloadAction) => { + state.hrfTileControlEnd = action.payload; + }, setHrfTileSize: (state, action: PayloadAction) => { state.hrfTileSize = action.payload; }, @@ -673,6 +676,7 @@ export const { setHrfUpscaleModel, setHrfTileControlNetModel, setHrfStructure, + setHrfTileControlEnd, setHrfTileSize, setHrfTileOverlap, setSeamlessXAxis, @@ -776,6 +780,12 @@ export const paramsSliceConfig: SliceConfig = { state.hrfTileOverlap = 128; } + if (state._version === 4) { + // v4 -> v5, add explicit Generate tab HRF Tile ControlNet timing + state._version = 5; + state.hrfTileControlEnd = 0.2; + } + return zParamsState.parse(state); }, }, @@ -850,6 +860,7 @@ export const selectHrfLatentInterpolationMode = createParamsSelector((params) => export const selectHrfUpscaleModel = createParamsSelector((params) => params.hrfUpscaleModel); export const selectHrfTileControlNetModel = createParamsSelector((params) => params.hrfTileControlNetModel); export const selectHrfStructure = createParamsSelector((params) => params.hrfStructure); +export const selectHrfTileControlEnd = createParamsSelector((params) => params.hrfTileControlEnd); export const selectHrfTileSize = createParamsSelector((params) => params.hrfTileSize); export const selectHrfTileOverlap = createParamsSelector((params) => params.hrfTileOverlap); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 1901cb036c3..2124017067d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -756,7 +756,7 @@ export const zHrfMethod = z.enum(['latent', 'upscale_model']); export type HrfMethod = z.infer; export const zParamsState = z.object({ - _version: z.literal(4), + _version: z.literal(5), maskBlur: z.number(), maskBlurMethod: zParameterMaskBlurMethod, canvasCoherenceMode: zParameterCanvasCoherenceMode, @@ -779,6 +779,7 @@ export const zParamsState = z.object({ hrfUpscaleModel: zParameterSpandrelImageToImageModel.nullable(), hrfTileControlNetModel: zModelIdentifierField.nullable(), hrfStructure: z.number().min(-10).max(10), + hrfTileControlEnd: z.number().min(0).max(1), hrfTileSize: z.number().int().min(8), hrfTileOverlap: z.number().int().min(8), iterations: z.number(), @@ -850,7 +851,7 @@ export const zParamsState = z.object({ }); export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ - _version: 4, + _version: 5, maskBlur: 16, maskBlurMethod: 'box', canvasCoherenceMode: 'Gaussian Blur', @@ -873,6 +874,7 @@ export const getInitialParamsState = (): ParamsState => ({ hrfUpscaleModel: null, hrfTileControlNetModel: null, hrfStructure: 0, + hrfTileControlEnd: 0.2, hrfTileSize: 1024, hrfTileOverlap: 128, iterations: 1, diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index 7da59ac2d9c..db3bda626d6 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -42,6 +42,7 @@ import { setHrfScale, setHrfStrength, setHrfStructure, + setHrfTileControlEnd, setHrfTileControlNetModel, setHrfTileOverlap, setHrfTileSize, @@ -765,6 +766,22 @@ const HrfStructure: SingleMetadataHandler = { ValueComponent: ({ value }: SingleMetadataValueProps) => , }; +const HrfTileControlEnd: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlEnd', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_control_end'); + const parsed = z.number().min(0).max(1).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlEnd(value)); + }, + i18nKey: 'hrf.metadata.tileControlEnd', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + const HrfTileSize: SingleMetadataHandler = { [SingleMetadataKey]: true, type: 'HrfTileSize', @@ -1717,6 +1734,7 @@ export const ImageMetadataHandlers = { HrfUpscaleModel, HrfTileControlNetModel, HrfStructure, + HrfTileControlEnd, HrfTileSize, HrfTileOverlap, SeamlessX, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts index 49a182fba11..8c902bd2a54 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -36,6 +36,7 @@ const buildState = (overrides?: { type: 'controlnet', }, hrfStructure: 0, + hrfTileControlEnd: 0.2, hrfTileSize: 1024, hrfTileOverlap: 128, optimizedDenoisingEnabled: true, @@ -247,8 +248,9 @@ describe('addHighResFix', () => { const unsharp = nodes.find((node) => node.type === 'unsharp_mask'); const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const tileControlNet = nodes.find((node) => node.id.startsWith('hrf_controlnet')); - if (!intermediateL2i || !spandrel || !unsharp || !i2l || !tiledDenoise) { + if (!intermediateL2i || !spandrel || !unsharp || !i2l || !tiledDenoise || !tileControlNet) { throw new Error('Expected upscale-model HRF nodes'); } @@ -271,6 +273,22 @@ describe('addHighResFix', () => { denoising_start: 0.65, denoising_end: 1, }); + expect(tileControlNet).toMatchObject({ + type: 'controlnet', + begin_step_percent: 0, + end_step_percent: 0.2, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: tileControlNet.id, field: 'control' }, + destination: { node_id: tiledDenoise.id, field: 'control' }, + }); + const positiveConditioningEdge = graph.edges.find( + (edge) => edge.destination.node_id === tiledDenoise.id && edge.destination.field === 'positive_conditioning' + ); + if (!positiveConditioningEdge) { + throw new Error('Expected positive conditioning edge'); + } + expect(graph.nodes[positiveConditioningEdge.source.node_id]?.type).not.toBe('collect'); expect(graph.edges).not.toContainEqual({ source: { node_id: 'denoise', field: 'latents' }, destination: { node_id: 'l2i', field: 'latents' }, @@ -289,6 +307,7 @@ describe('addHighResFix', () => { hrf_upscale_model: { key: 'upscale' }, hrf_tile_controlnet_model: { key: 'tile' }, hrf_structure: 0, + hrf_tile_control_end: 0.2, hrf_tile_size: 1024, hrf_tile_overlap: 128, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts index 120e818a1f7..50850d636b4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -111,7 +111,8 @@ const cloneSDXLCompelPromptForFinalDimensions = ( const cloneSDXLConditioningForFinalDimensions = ( g: Graph, sourceNodeId: string, - finalDimensions: { width: number; height: number } + finalDimensions: { width: number; height: number }, + outputType: 'collection' | 'single' ) => { const sourceNode = g.getNode(sourceNodeId); @@ -130,9 +131,28 @@ const cloneSDXLConditioningForFinalDimensions = ( const hasSDXLConditioning = itemEdges.some((edge) => g.getNode(edge.source.node_id).type === 'sdxl_compel_prompt'); if (!hasSDXLConditioning) { + if (outputType === 'single' && itemEdges.length === 1) { + return { nodeId: itemEdges[0]!.source.node_id, field: itemEdges[0]!.source.field }; + } + return null; } + if (outputType === 'single') { + if (itemEdges.length !== 1) { + return null; + } + + const edge = itemEdges[0]!; + const itemNode = g.getNode(edge.source.node_id); + const source = + itemNode.type === 'sdxl_compel_prompt' + ? cloneSDXLCompelPromptForFinalDimensions(g, itemNode, finalDimensions) + : itemNode; + + return { nodeId: source.id, field: edge.source.field }; + } + const collect = g.addNode({ type: 'collect', id: getPrefixedId('hrf_sdxl_conditioning_collect'), @@ -170,7 +190,12 @@ const copyDenoiseInputs = ( } const finalSizeConditioning = ['positive_conditioning', 'negative_conditioning'].includes(edge.destination.field) - ? cloneSDXLConditioningForFinalDimensions(g, edge.source.node_id, finalDimensions) + ? cloneSDXLConditioningForFinalDimensions( + g, + edge.source.node_id, + finalDimensions, + to.type === 'tiled_multi_diffusion_denoise_latents' ? 'single' : 'collection' + ) : null; g.addEdgeFromObj({ @@ -200,50 +225,48 @@ const copyInputEdges = (g: Graph, from: AnyInvocation, to: AnyInvocation, skippe const hasUnsupportedTiledDenoiseInputs = (g: Graph, denoise: Invocation<'denoise_latents'>) => { return g.getEdgesTo(denoise).some((edge) => { const field = edge.destination.field; - return !SKIPPED_DENOISE_INPUT_FIELDS.has(field) && !TILED_DENOISE_INPUT_FIELDS.has(field); + if (SKIPPED_DENOISE_INPUT_FIELDS.has(field)) { + return false; + } + + if (!TILED_DENOISE_INPUT_FIELDS.has(field)) { + return true; + } + + if (['positive_conditioning', 'negative_conditioning'].includes(field)) { + const sourceNode = g.getNode(edge.source.node_id); + if (sourceNode.type === 'collect') { + const itemEdges = g.getEdgesTo(sourceNode).filter((itemEdge) => itemEdge.destination.field === 'item'); + return itemEdges.length !== 1; + } + } + + return false; }); }; -const addTileControlNets = ( +const addTileControlNet = ( g: Graph, hrfDenoise: AnyInvocation, imageSource: Invocation<'unsharp_mask'>, tileControlNetModel: NonNullable['hrfTileControlNetModel']>, - structure: number + structure: number, + tileControlEnd: number ) => { - const controlNet1 = g.addNode({ - id: getPrefixedId('hrf_controlnet_1'), + const controlNet = g.addNode({ + id: getPrefixedId('hrf_controlnet'), type: 'controlnet', control_model: tileControlNetModel, control_mode: 'balanced', resize_mode: 'just_resize', control_weight: (structure + 10) * 0.0325 + 0.3, begin_step_percent: 0, - end_step_percent: (structure + 10) * 0.025 + 0.3, - }); - - const controlNet2 = g.addNode({ - id: getPrefixedId('hrf_controlnet_2'), - type: 'controlnet', - control_model: tileControlNetModel, - control_mode: 'balanced', - resize_mode: 'just_resize', - control_weight: ((structure + 10) * 0.0325 + 0.15) * 0.45, - begin_step_percent: (structure + 10) * 0.025 + 0.3, - end_step_percent: 0.85, - }); - - const controlNetCollector = g.addNode({ - type: 'collect', - id: getPrefixedId('hrf_controlnet_collector'), + end_step_percent: tileControlEnd, }); - g.addEdge(imageSource, 'image', controlNet1, 'image'); - g.addEdge(imageSource, 'image', controlNet2, 'image'); - g.addEdge(controlNet1, 'control', controlNetCollector, 'item'); - g.addEdge(controlNet2, 'control', controlNetCollector, 'item'); + g.addEdge(imageSource, 'image', controlNet, 'image'); g.addEdgeFromObj({ - source: { node_id: controlNetCollector.id, field: 'collection' }, + source: { node_id: controlNet.id, field: 'control' }, destination: { node_id: hrfDenoise.id, field: 'control' }, }); }; @@ -392,7 +415,14 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH finalDimensions, useClassicDenoise ? undefined : TILED_DENOISE_INPUT_FIELDS ); - addTileControlNets(g, hrfDenoise, unsharpMask, params.hrfTileControlNetModel, params.hrfStructure); + addTileControlNet( + g, + hrfDenoise, + unsharpMask, + params.hrfTileControlNetModel, + params.hrfStructure, + params.hrfTileControlEnd + ); g.deleteEdgesTo(l2i, ['latents']); g.addEdge(denoise, 'latents', intermediateL2i, 'latents'); @@ -431,6 +461,7 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH hrf_upscale_model: params.hrfUpscaleModel, hrf_tile_controlnet_model: params.hrfTileControlNetModel, hrf_structure: params.hrfStructure, + hrf_tile_control_end: params.hrfTileControlEnd, hrf_tile_size: params.hrfTileSize, hrf_tile_overlap: params.hrfTileOverlap, }); diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index ea1ec3b1154..211d5c8990b 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -29,6 +29,7 @@ import { selectHrfScale, selectHrfStrength, selectHrfStructure, + selectHrfTileControlEnd, selectHrfTileControlNetModel, selectHrfTileOverlap, selectHrfTileSize, @@ -41,6 +42,7 @@ import { setHrfScale, setHrfStrength, setHrfStructure, + setHrfTileControlEnd, setHrfTileControlNetModel, setHrfTileOverlap, setHrfTileSize, @@ -90,6 +92,16 @@ const STRUCTURE_CONSTRAINTS = { fineStep: 1, }; +const TILE_CONTROL_END_CONSTRAINTS = { + initial: 0.2, + sliderMin: 0, + sliderMax: 1, + numberInputMin: 0, + numberInputMax: 1, + coarseStep: 0.01, + fineStep: 0.01, +}; + const TILE_SIZE_CONSTRAINTS = { initial: 1024, sliderMin: 512, @@ -491,6 +503,52 @@ const ParamHrfStructure = memo(() => { ParamHrfStructure.displayName = 'ParamHrfStructure'; +const ParamHrfTileControlEnd = memo(() => { + const dispatch = useAppDispatch(); + const tileControlEnd = useAppSelector(selectHrfTileControlEnd); + const { t } = useTranslation(); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfTileControlEnd(v)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.tileControlEnd')} + + + + + ); +}); + +ParamHrfTileControlEnd.displayName = 'ParamHrfTileControlEnd'; + const ParamHrfTileSize = memo(() => { const dispatch = useAppDispatch(); const tileSize = useAppSelector(selectHrfTileSize); @@ -613,6 +671,7 @@ export const HighResFixSettingsAccordion = memo(() => { + diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index b6e4716994b..9655f861f15 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7721,6 +7721,12 @@ export type components = { * @default null */ hrf_structure?: number | null; + /** + * Hrf Tile Control End + * @description The high resolution fix tile ControlNet end step percentage. + * @default null + */ + hrf_tile_control_end?: number | null; /** * Hrf Tile Size * @description The high resolution fix tiled processing tile size. From 202b6d25868e12467642711ddaa63e6a3c118963 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Tue, 5 May 2026 14:11:38 +0200 Subject: [PATCH 4/9] Add Generate tab HRF controls --- invokeai/app/invocations/metadata.py | 22 +- invokeai/frontend/web/public/locales/en.json | 25 +- .../controlLayers/store/paramsSlice.test.ts | 86 +++ .../controlLayers/store/paramsSlice.ts | 114 +++- .../src/features/controlLayers/store/types.ts | 19 +- .../src/features/metadata/parsing.test.tsx | 59 +- .../web/src/features/metadata/parsing.tsx | 139 ++++- .../graph/generation/addHighResFix.test.ts | 507 +++++++++++++++- .../util/graph/generation/addHighResFix.ts | 385 +++++++++++- .../nodes/util/graph/generation/addLoRAs.ts | 42 +- .../util/graph/generation/addSDXLLoRAs.ts | 48 +- .../features/queue/store/readiness.test.ts | 53 ++ .../web/src/features/queue/store/readiness.ts | 13 +- .../HighResFixSettingsAccordion.tsx | 550 +++++++++++++++--- .../frontend/web/src/services/api/schema.ts | 32 +- 15 files changed, 1956 insertions(+), 138 deletions(-) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index ad0dd98b075..b51ccc65fc2 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -270,7 +270,11 @@ class CoreMetadataInvocation(BaseInvocation): ) hrf_structure: Optional[float] = InputField( default=None, - description="The high resolution fix tile ControlNet structure value.", + description="Legacy high resolution fix tile ControlNet structure value.", + ) + hrf_tile_control_weight: Optional[float] = InputField( + default=None, + description="The high resolution fix tile ControlNet control weight.", ) hrf_tile_control_end: Optional[float] = InputField( default=None, @@ -284,6 +288,22 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The high resolution fix tiled processing tile overlap.", ) + hrf_steps: Optional[int] = InputField( + default=None, + description="The number of steps used for the high resolution fix refinement pass.", + ) + hrf_model: Optional[ModelIdentifierField] = InputField( + default=None, + description="The optional model override used for the high resolution fix refinement pass.", + ) + hrf_lora_mode: Optional[str] = InputField( + default=None, + description="The LoRA mode used for the high resolution fix refinement pass.", + ) + hrf_loras: Optional[list[LoRAMetadataField]] = InputField( + default=None, + description="The dedicated LoRAs used for the high resolution fix refinement pass.", + ) # SDXL positive_style_prompt: Optional[str] = InputField( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 74adf28b9dd..6dbccdb247a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -350,13 +350,23 @@ "options_withCount_other": "{{count}} options" }, "hrf": { - "hrf": "High Resolution Fix", + "hrf": "Hi-Res Fix", "enableHrf": "Enable High Resolution Fix", "scale": "Scale", "strength": "Denoise Strength", "upscaleMethod": "Upscale Method", "latentInterpolationMode": "Latent Interpolation", + "tileControlWeight": "Tile Control Weight", "tileControlEnd": "Tile Control Duration", + "steps": "Refinement Steps", + "model": "Refinement Model", + "reuseGenerateModel": "Reuse Base Model", + "loraMode": "Refinement LoRAs", + "reuseGenerateLoras": "Reuse", + "noLoras": "None", + "dedicatedLoras": "Dedicated", + "addDedicatedLora": "Add Dedicated LoRA", + "noCompatibleModels": "No Compatible Models", "latent": "Latent", "upscaleModelMethod": "Upscale Model", "nearest": "Nearest", @@ -372,10 +382,15 @@ "latentInterpolationMode": "High Resolution Fix Latent Interpolation", "upscaleModel": "High Resolution Fix Upscale Model", "tileControlNetModel": "High Resolution Fix Tile ControlNet", - "structure": "High Resolution Fix Structure", + "legacyStructure": "Legacy High Resolution Fix Structure", + "tileControlWeight": "High Resolution Fix Tile Control Weight", "tileControlEnd": "High Resolution Fix Tile Control Duration", "tileSize": "High Resolution Fix Tile Size", - "tileOverlap": "High Resolution Fix Tile Overlap" + "tileOverlap": "High Resolution Fix Tile Overlap", + "steps": "High Resolution Fix Refinement Steps", + "model": "High Resolution Fix Refinement Model", + "loraMode": "High Resolution Fix LoRA Mode", + "loras": "High Resolution Fix Dedicated LoRAs" } }, "prompt": { @@ -1698,6 +1713,10 @@ "hrfExternalModelUnsupported": "High Resolution Fix is not supported for external models", "hrfRefinerUnsupported": "High Resolution Fix is not supported when SDXL Refiner is enabled", "hrfUpscaleModelBaseUnsupported": "High Resolution Fix with an upscale model is supported for SD1.5 and SDXL models only", + "hrfModelOverrideMethodUnsupported": "High Resolution Fix model override is only supported with upscale-model HRF", + "hrfModelOverrideExternalUnsupported": "High Resolution Fix model override does not support external models", + "hrfModelOverrideBaseUnsupported": "High Resolution Fix model override is supported for SD1.5 and SDXL models only", + "hrfModelOverrideBaseMismatch": "High Resolution Fix model override must use the same base as the Generate model", "hrfUpscaleModelMissing": "High Resolution Fix needs an upscale model", "hrfTileControlNetModelMissing": "High Resolution Fix needs a tile ControlNet model", "canvasIsFiltering": "Canvas is busy (filtering)", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index 928b89f8d9f..2e057fd0d25 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -8,6 +8,8 @@ import type { import { describe, expect, it } from 'vitest'; import { + modelChanged, + paramsSliceConfig, selectHrfFinalDimensions, selectModelSupportsDimensions, selectModelSupportsGuidance, @@ -17,7 +19,9 @@ import { selectModelSupportsRefImages, selectModelSupportsSeed, selectModelSupportsSteps, + setHrfMethod, } from './paramsSlice'; +import { getInitialParamsState, type ParamsState } from './types'; const buildExternalModelIdentifier = (config: ExternalApiModelConfig) => ({ @@ -171,3 +175,85 @@ describe('paramsSlice HRF selectors', () => { ).toBe(false); }); }); + +describe('paramsSlice HRF reducers', () => { + it('clears upscale-model-only refinement overrides when switching to latent HRF', () => { + const state = getInitialParamsState(); + state.hrfMethod = 'upscale_model'; + state.hrfSteps = 12; + state.hrfModel = { key: 'hrf-model', hash: 'h', name: 'HRF Model', base: 'sdxl', type: 'main' }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'lora', + isEnabled: true, + model: { key: 'lora', hash: 'h', name: 'LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const nextState = paramsSliceConfig.slice.reducer(state, setHrfMethod('latent')); + + expect(nextState.hrfMethod).toBe('latent'); + expect(nextState.hrfSteps).toBeNull(); + expect(nextState.hrfModel).toBeNull(); + expect(nextState.hrfLoraMode).toBe('reuse_generate'); + expect(nextState.hrfLoras).toEqual([]); + }); + + it('filters dedicated HRF LoRAs when the Generate model base changes', () => { + const state = getInitialParamsState(); + const previousModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const nextModel = { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' } as const; + state.model = previousModel; + state.hrfMethod = 'upscale_model'; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + { + id: 'sd1-lora', + isEnabled: true, + model: { key: 'sd1-lora', hash: 'h', name: 'SD1 LoRA', base: 'sd-1', type: 'lora' }, + weight: 0.7, + }, + ] as ParamsState['hrfLoras']; + + const nextState = paramsSliceConfig.slice.reducer(state, modelChanged({ model: nextModel, previousModel })); + + expect(nextState.hrfLoraMode).toBe('dedicated'); + expect(nextState.hrfLoras).toHaveLength(1); + expect(nextState.hrfLoras[0]?.model.key).toBe('sd1-lora'); + }); +}); + +describe('paramsSlice HRF migrations', () => { + it('migrates legacy HRF Structure to explicit Tile Control Weight', () => { + const v5State = { + ...getInitialParamsState(), + _version: 5, + hrfStructure: 0, + } as unknown as Record; + delete v5State.hrfTileControlWeight; + delete v5State.hrfSteps; + delete v5State.hrfModel; + delete v5State.hrfLoraMode; + delete v5State.hrfLoras; + + const migrated = paramsSliceConfig.persistConfig!.migrate(v5State); + + expect(migrated).toMatchObject({ + _version: 6, + hrfTileControlWeight: 0.625, + hrfSteps: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], + }); + expect('hrfStructure' in migrated).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index f8ccfd41231..ab812ca641c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -10,8 +10,10 @@ import { logout } from 'features/auth/store/authSlice'; import type { AspectRatioID, HrfLatentInterpolationMode, + HrfLoraMode, HrfMethod, InfillMethod, + LoRA, ParamsState, RgbaColor, } from 'features/controlLayers/store/types'; @@ -28,7 +30,7 @@ import { SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS, SUPPORTS_REF_IMAGES_BASE_MODELS, } from 'features/modelManagerV2/models'; -import type { BaseModelType, ModelIdentifierField } from 'features/nodes/types/common'; +import { type BaseModelType, type ModelIdentifierField, zModelIdentifierField } from 'features/nodes/types/common'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -53,9 +55,14 @@ import type { import { getExternalPanelControl, hasExternalPanelControl } from 'features/parameters/util/externalPanelSchema'; import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; -import type { AnyModelConfigWithExternal, ControlNetModelConfig } from 'services/api/types'; +import type { AnyModelConfigWithExternal, ControlNetModelConfig, LoRAModelConfig } from 'services/api/types'; import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; +import { v4 as uuidv4 } from 'uuid'; + +const DEFAULT_HRF_LORA_WEIGHT = 0.75; + +const selectHrfLoRA = (state: ParamsState, id: string) => state.hrfLoras.find((lora) => lora.id === id); const slice = createSlice({ name: 'params', @@ -128,6 +135,12 @@ const slice = createSlice({ }, setHrfMethod: (state, action: PayloadAction) => { state.hrfMethod = action.payload; + if (action.payload === 'latent') { + state.hrfSteps = null; + state.hrfModel = null; + state.hrfLoraMode = 'reuse_generate'; + state.hrfLoras = []; + } }, setHrfScale: (state, action: PayloadAction) => { state.hrfScale = action.payload; @@ -150,8 +163,8 @@ const slice = createSlice({ state.hrfTileControlNetModel = result.data; } }, - setHrfStructure: (state, action: PayloadAction) => { - state.hrfStructure = action.payload; + setHrfTileControlWeight: (state, action: PayloadAction) => { + state.hrfTileControlWeight = action.payload; }, setHrfTileControlEnd: (state, action: PayloadAction) => { state.hrfTileControlEnd = action.payload; @@ -162,6 +175,57 @@ const slice = createSlice({ setHrfTileOverlap: (state, action: PayloadAction) => { state.hrfTileOverlap = action.payload; }, + setHrfSteps: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfSteps.safeParse(action.payload); + if (result.success) { + state.hrfSteps = result.data; + } + }, + setHrfModel: (state, action: PayloadAction) => { + const result = zParamsState.shape.hrfModel.safeParse(action.payload); + if (result.success) { + state.hrfModel = result.data; + } + }, + setHrfLoraMode: (state, action: PayloadAction) => { + state.hrfLoraMode = action.payload; + }, + setHrfLoras: (state, action: PayloadAction) => { + state.hrfLoras = action.payload; + }, + hrfLoraAdded: { + reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { + const { model, id } = action.payload; + const parsedModel = zModelIdentifierField.parse(model); + const defaultLoRAConfig: Pick = { + weight: model.default_settings?.weight ?? DEFAULT_HRF_LORA_WEIGHT, + isEnabled: true, + }; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.key !== parsedModel.key); + state.hrfLoras.push({ ...defaultLoRAConfig, model: parsedModel, id }); + }, + prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }), + }, + hrfLoraDeleted: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.id !== id); + }, + hrfLoraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { + const { id, weight } = action.payload; + const lora = selectHrfLoRA(state, id); + if (!lora) { + return; + } + lora.weight = weight; + }, + hrfLoraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => { + const { id, isEnabled } = action.payload; + const lora = selectHrfLoRA(state, id); + if (!lora) { + return; + } + lora.isEnabled = isEnabled; + }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -185,6 +249,7 @@ const slice = createSlice({ if (model?.base === 'external') { state.hrfEnabled = false; + state.hrfModel = null; } // If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things @@ -192,6 +257,15 @@ const slice = createSlice({ return; } + if (state.hrfModel?.base !== model.base) { + state.hrfModel = null; + } + if (state.hrfTileControlNetModel?.base !== model.base) { + state.hrfTileControlNetModel = null; + } + const effectiveHrfBase = state.hrfModel?.base ?? model.base; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.base === effectiveHrfBase); + applyClipSkip(state, model, state.clipSkip); }, vaeSelected: (state, action: PayloadAction) => { @@ -675,10 +749,18 @@ export const { setHrfLatentInterpolationMode, setHrfUpscaleModel, setHrfTileControlNetModel, - setHrfStructure, + setHrfTileControlWeight, setHrfTileControlEnd, setHrfTileSize, setHrfTileOverlap, + setHrfSteps, + setHrfModel, + setHrfLoraMode, + setHrfLoras, + hrfLoraAdded, + hrfLoraDeleted, + hrfLoraWeightChanged, + hrfLoraIsEnabledChanged, setSeamlessXAxis, setSeamlessYAxis, setShouldRandomizeSeed, @@ -786,6 +868,18 @@ export const paramsSliceConfig: SliceConfig = { state.hrfTileControlEnd = 0.2; } + if (state._version === 5) { + // v5 -> v6, replace the Invoke Upscale "Structure" abstraction with explicit Generate HRF controls + state._version = 6; + const legacyHrfStructure = typeof state.hrfStructure === 'number' ? state.hrfStructure : 0; + state.hrfTileControlWeight = (legacyHrfStructure + 10) * 0.0325 + 0.3; + delete state.hrfStructure; + state.hrfSteps = null; + state.hrfModel = null; + state.hrfLoraMode = 'reuse_generate'; + state.hrfLoras = []; + } + return zParamsState.parse(state); }, }, @@ -859,10 +953,18 @@ export const selectHrfStrength = createParamsSelector((params) => params.hrfStre export const selectHrfLatentInterpolationMode = createParamsSelector((params) => params.hrfLatentInterpolationMode); export const selectHrfUpscaleModel = createParamsSelector((params) => params.hrfUpscaleModel); export const selectHrfTileControlNetModel = createParamsSelector((params) => params.hrfTileControlNetModel); -export const selectHrfStructure = createParamsSelector((params) => params.hrfStructure); +export const selectHrfTileControlWeight = createParamsSelector((params) => params.hrfTileControlWeight); export const selectHrfTileControlEnd = createParamsSelector((params) => params.hrfTileControlEnd); export const selectHrfTileSize = createParamsSelector((params) => params.hrfTileSize); export const selectHrfTileOverlap = createParamsSelector((params) => params.hrfTileOverlap); +export const selectHrfSteps = createParamsSelector((params) => params.hrfSteps); +export const selectHrfModel = createParamsSelector((params) => params.hrfModel); +export const selectHrfLoraMode = createParamsSelector((params) => params.hrfLoraMode); +export const selectHrfLoras = createParamsSelector((params) => params.hrfLoras); +export const buildSelectHrfLoRA = (id: string) => + createSelector([selectParamsSlice], (params) => { + return selectHrfLoRA(params, id); + }); export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? ''); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 2124017067d..e7e5c41e7c4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -755,8 +755,11 @@ export type HrfLatentInterpolationMode = z.infer; +export const zHrfLoraMode = z.enum(['reuse_generate', 'none', 'dedicated']); +export type HrfLoraMode = z.infer; + export const zParamsState = z.object({ - _version: z.literal(5), + _version: z.literal(6), maskBlur: z.number(), maskBlurMethod: zParameterMaskBlurMethod, canvasCoherenceMode: zParameterCanvasCoherenceMode, @@ -778,10 +781,14 @@ export const zParamsState = z.object({ hrfLatentInterpolationMode: zHrfLatentInterpolationMode, hrfUpscaleModel: zParameterSpandrelImageToImageModel.nullable(), hrfTileControlNetModel: zModelIdentifierField.nullable(), - hrfStructure: z.number().min(-10).max(10), + hrfTileControlWeight: z.number().min(0).max(2), hrfTileControlEnd: z.number().min(0).max(1), hrfTileSize: z.number().int().min(8), hrfTileOverlap: z.number().int().min(8), + hrfSteps: z.number().int().min(1).nullable(), + hrfModel: zParameterModel.nullable(), + hrfLoraMode: zHrfLoraMode, + hrfLoras: z.array(zLoRA), iterations: z.number(), scheduler: zParameterScheduler, fluxScheduler: zParameterFluxScheduler, @@ -851,7 +858,7 @@ export const zParamsState = z.object({ }); export type ParamsState = z.infer; export const getInitialParamsState = (): ParamsState => ({ - _version: 5, + _version: 6, maskBlur: 16, maskBlurMethod: 'box', canvasCoherenceMode: 'Gaussian Blur', @@ -873,10 +880,14 @@ export const getInitialParamsState = (): ParamsState => ({ hrfLatentInterpolationMode: 'bicubic', hrfUpscaleModel: null, hrfTileControlNetModel: null, - hrfStructure: 0, + hrfTileControlWeight: 0.625, hrfTileControlEnd: 0.2, hrfTileSize: 1024, hrfTileOverlap: 128, + hrfSteps: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], iterations: 1, scheduler: 'dpmpp_3m_k', fluxScheduler: 'euler', diff --git a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx index bb295303273..dbe8c3af80a 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx @@ -1,6 +1,6 @@ import type { AppStore } from 'app/store/store'; import type * as paramsSliceModule from 'features/controlLayers/store/paramsSlice'; -import { ImageMetadataHandlers } from 'features/metadata/parsing'; +import { ImageMetadataHandlers, MetadataUtils } from 'features/metadata/parsing'; import type * as modelsApiModule from 'services/api/endpoints/models'; import { beforeEach, describe, expect, it, vi } from 'vitest'; @@ -22,7 +22,7 @@ vi.mock('features/controlLayers/store/paramsSlice', async (importOriginal) => { return { ...mod, selectBase: () => currentBase }; }); -const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({ +const fakeModel = (type: 'main' | 'vae' | 'qwen3_encoder' | 'lora', base: string) => ({ key: `${type}-key`, hash: 'hash', name: `Some ${type}`, @@ -31,6 +31,7 @@ const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({ }); let nextResolved: ReturnType = fakeModel('vae', 'flux2'); +let resolvedModels: Record> = {}; vi.mock('services/api/endpoints/models', async (importOriginal) => { const mod = await importOriginal(); @@ -48,15 +49,22 @@ vi.mock('services/api/endpoints/models', async (importOriginal) => { const makeStore = (): AppStore => ({ - dispatch: vi.fn(() => ({ - unwrap: () => Promise.resolve(nextResolved), - })), + dispatch: vi.fn((action) => { + if (action?.type === 'generation/modelSelected') { + currentBase = action.payload.base; + return action; + } + return { + unwrap: () => Promise.resolve(resolvedModels[action?.key] ?? nextResolved), + }; + }), getState: () => ({}), }) as unknown as AppStore; beforeEach(() => { currentBase = 'flux2'; nextResolved = fakeModel('vae', 'flux2'); + resolvedModels = {}; }); describe('ImageMetadataHandlers — Klein recall gating', () => { @@ -171,4 +179,45 @@ describe('ImageMetadataHandlers — Klein recall gating', () => { expect(parsed).toBe(3.5); }); }); + + describe('HRF LoRAs', () => { + it('recalls dedicated HRF LoRAs after the recalled main model changes base', async () => { + currentBase = 'sd-1'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [mainModel.key]: mainModel, + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const recalled = await MetadataUtils.recallByHandlers({ + metadata: { + model: mainModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + handlers: [ImageMetadataHandlers.HrfLoRAs, ImageMetadataHandlers.MainModel], + store, + silent: true, + }); + + expect(Object.keys(ImageMetadataHandlers).indexOf('HrfLoRAs')).toBeGreaterThan( + Object.keys(ImageMetadataHandlers).indexOf('MainModel') + ); + expect(recalled.has(ImageMetadataHandlers.MainModel)).toBe(true); + expect(recalled.has(ImageMetadataHandlers.HrfLoRAs)).toBe(true); + expect(store.dispatch).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'generation/modelSelected', + payload: expect.objectContaining({ base: 'sdxl' }), + }) + ); + expect(store.dispatch).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'params/setHrfLoras', + payload: [expect.objectContaining({ model: expect.objectContaining({ key: hrfLora.key }), weight: 0.6 })], + }) + ); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index db3bda626d6..96d4636d46b 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -38,12 +38,16 @@ import { setGuidance, setHrfEnabled, setHrfLatentInterpolationMode, + setHrfLoraMode, + setHrfLoras, setHrfMethod, + setHrfModel, setHrfScale, + setHrfSteps, setHrfStrength, - setHrfStructure, setHrfTileControlEnd, setHrfTileControlNetModel, + setHrfTileControlWeight, setHrfTileOverlap, setHrfTileSize, setHrfUpscaleModel, @@ -74,6 +78,7 @@ import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; import type { CanvasMetadata, HrfLatentInterpolationMode, + HrfLoraMode, HrfMethod as HrfMethodType, LoRA, RefImageState, @@ -82,6 +87,7 @@ import { zCanvasMetadata, zCanvasReferenceImageState_OLD, zHrfLatentInterpolationMode, + zHrfLoraMode, zHrfMethod, zRefImageState, } from 'features/controlLayers/store/types'; @@ -756,12 +762,28 @@ const HrfStructure: SingleMetadataHandler = { parse: (metadata, _store) => { const raw = getProperty(metadata, 'hrf_structure'); const parsed = z.number().min(-10).max(10).parse(raw); + return Promise.resolve((parsed + 10) * 0.0325 + 0.3); + }, + recall: (value, store) => { + store.dispatch(setHrfTileControlWeight(value)); + }, + i18nKey: 'hrf.metadata.legacyStructure', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfTileControlWeight: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfTileControlWeight', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_tile_control_weight'); + const parsed = z.number().min(0).max(2).parse(raw); return Promise.resolve(parsed); }, recall: (value, store) => { - store.dispatch(setHrfStructure(value)); + store.dispatch(setHrfTileControlWeight(value)); }, - i18nKey: 'hrf.metadata.structure', + i18nKey: 'hrf.metadata.tileControlWeight', LabelComponent: MetadataLabel, ValueComponent: ({ value }: SingleMetadataValueProps) => , }; @@ -813,6 +835,55 @@ const HrfTileOverlap: SingleMetadataHandler = { LabelComponent: MetadataLabel, ValueComponent: ({ value }: SingleMetadataValueProps) => , }; + +const HrfSteps: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfSteps', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_steps'); + const parsed = zParameterSteps.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfSteps(value)); + }, + i18nKey: 'hrf.metadata.steps', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; + +const HrfModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfModel', + parse: (metadata, store) => { + const raw = getProperty(metadata, 'hrf_model'); + return parseModelIdentifier(raw, store, 'main'); + }, + recall: (value, store) => { + store.dispatch(setHrfModel(value as ParameterModel)); + }, + i18nKey: 'hrf.metadata.model', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; + +const HrfLoraModeMetadata: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'HrfLoraMode', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'hrf_lora_mode'); + const parsed = zHrfLoraMode.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setHrfLoraMode(value)); + }, + i18nKey: 'hrf.metadata.loraMode', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => , +}; //#endregion High Resolution Fix //#region SeamlessX @@ -1452,6 +1523,57 @@ const LoRAs: CollectionMetadataHandler = { }; //#endregion LoRAs +//#region HRF LoRAs +const HrfLoRAs: CollectionMetadataHandler = { + [CollectionMetadataKey]: true, + type: 'HrfLoRAs', + parse: async (metadata, store) => { + const rawArray = getProperty(metadata, 'hrf_loras'); + + if (!rawArray) { + return []; + } + + assert(isArray(rawArray)); + + const loras: LoRA[] = []; + + for (const rawItem of rawArray) { + try { + const rawIdentifier = getProperty(rawItem, 'model'); + const identifier = await parseModelIdentifier(rawIdentifier, store, 'lora'); + assert(identifier.type === 'lora'); + assert(isCompatibleWithMainModel(identifier, store)); + + const weight = getProperty(rawItem, 'weight'); + + loras.push({ + id: getPrefixedId('hrf_lora'), + model: identifier, + weight: zLoRAWeight.parse(weight), + isEnabled: true, + }); + } catch { + continue; + } + } + + return loras; + }, + recallOne: (value, store) => { + store.dispatch(setHrfLoras([value])); + }, + recall: (values, store) => { + store.dispatch(setHrfLoras(values)); + }, + i18nKey: 'hrf.metadata.loras', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: CollectionMetadataValueProps) => ( + + ), +}; +//#endregion HRF LoRAs + //#region CanvasLayers const CanvasLayers: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1734,9 +1856,13 @@ export const ImageMetadataHandlers = { HrfUpscaleModel, HrfTileControlNetModel, HrfStructure, + HrfTileControlWeight, HrfTileControlEnd, HrfTileSize, HrfTileOverlap, + HrfSteps, + HrfModel, + HrfLoraMode: HrfLoraModeMetadata, SeamlessX, SeamlessY, RefinerModel, @@ -1766,6 +1892,7 @@ export const ImageMetadataHandlers = { QwenImageShift, ZImageShift, LoRAs, + HrfLoRAs, CanvasLayers, RefImages, ImageSize, @@ -1850,10 +1977,8 @@ const recallByHandlers = async (arg: { (handler) => !skip.some((skippedHandler) => skippedHandler.type === handler.type) ); - // It's possible for some metadata item's recall to clobber the recall of another. For example, the model recall - // may change the width and height. If we are also recalling the width and height directly, we need to ensure that the - // model is recalled first, so it doesn't accidentally override the width and height. This is the only known case - // where the order of recall matters. + // It's possible for some metadata item's recall to clobber another or to affect compatibility checks. For example, + // model recall may change dimensions, and LoRA-like handlers validate against the current model base. const sortedHandlers = filteredHandlers.sort((a, b) => { if (a === ImageMetadataHandlers.MainModel) { return -1; // MainModel should be recalled first diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts index 8c902bd2a54..7cb5f48131f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -9,10 +9,17 @@ const buildState = (overrides?: { base?: string; hrfEnabled?: boolean; hrfMethod?: HrfMethod; + hrfSteps?: number | null; + hrfModel?: unknown; + hrfLoraMode?: 'reuse_generate' | 'none' | 'dedicated'; + hrfLoras?: unknown[]; + hrfTileSize?: number; + loras?: unknown[]; refinerModel?: unknown; }): RootState => ({ ui: { activeTab: 'generate' }, + loras: { loras: overrides?.loras ?? [] }, params: { model: { key: 'model', hash: 'model-hash', name: 'model', base: overrides?.base ?? 'sdxl', type: 'main' }, dimensions: { width: 512, height: 512 }, @@ -35,10 +42,14 @@ const buildState = (overrides?: { base: overrides?.base ?? 'sdxl', type: 'controlnet', }, - hrfStructure: 0, + hrfTileControlWeight: 0.625, hrfTileControlEnd: 0.2, - hrfTileSize: 1024, + hrfTileSize: overrides?.hrfTileSize ?? 1024, hrfTileOverlap: 128, + hrfSteps: overrides?.hrfSteps ?? null, + hrfModel: overrides?.hrfModel ?? null, + hrfLoraMode: overrides?.hrfLoraMode ?? 'reuse_generate', + hrfLoras: overrides?.hrfLoras ?? [], optimizedDenoisingEnabled: true, refinerModel: overrides?.refinerModel ?? null, }, @@ -81,11 +92,86 @@ const addSDXLConditioning = (g: Graph, denoise: ReturnType) => { + const mask = g.addNode({ + id: 'region_mask', + type: 'alpha_mask_to_tensor', + image: { image_name: 'region-mask' }, + }); + const regionPrompt = g.addNode({ id: 'region_prompt', type: 'string', value: 'regional prompt' }); + const regionPosCond = g.addNode({ + id: 'region_pos_cond', + type: 'sdxl_compel_prompt', + style: 'regional positive style', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + const regionNegCond = g.addNode({ + id: 'region_neg_cond', + type: 'sdxl_compel_prompt', + prompt: 'regional negative', + style: 'regional negative style', + original_width: 512, + original_height: 512, + target_width: 512, + target_height: 512, + }); + + g.addEdge(regionPrompt, 'value', regionPosCond, 'prompt'); + g.addEdgeFromObj({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: regionPosCond.id, field: 'mask' }, + }); + g.addEdgeFromObj({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: regionNegCond.id, field: 'mask' }, + }); + g.addEdge(regionPosCond, 'conditioning', conditioning.posCollect, 'item'); + g.addEdge(regionNegCond, 'conditioning', conditioning.negCollect, 'item'); + + return { mask, regionPrompt, regionPosCond, regionNegCond }; +}; + +const addSeamlessToClassicGraph = ({ + g, + modelLoader, + denoise, +}: Pick, 'g' | 'modelLoader' | 'denoise'>) => { + const seamless = g.addNode({ + id: 'seamless', + type: 'seamless', + seamless_x: true, + seamless_y: false, + }); + + g.deleteEdgesTo(denoise, ['unet']); + g.addEdge(modelLoader, 'unet', seamless, 'unet'); + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + g.addEdge(seamless, 'unet', denoise, 'unet'); + + return seamless; }; const buildTransformerGraph = () => { @@ -124,6 +210,7 @@ describe('addHighResFix', () => { expect(resize).toMatchObject({ type: 'lresize', width: 1024, height: 1024, mode: 'bilinear' }); expect(hrfDenoise).toMatchObject({ type: 'denoise_latents', denoising_start: 0.65, denoising_end: 1 }); expect(hrfNoise).toMatchObject({ type: 'noise', width: 1024, height: 1024, use_cpu: true }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); expect(graph.edges).not.toContainEqual({ source: { node_id: 'denoise', field: 'latents' }, destination: { node_id: 'l2i', field: 'latents' }, @@ -139,6 +226,46 @@ describe('addHighResFix', () => { }); }); + it('ignores stale upscale-model-only settings when latent HRF is selected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'latent', + hrfSteps: 12, + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'dedicated', + hrfTileSize: 1536, + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const nodes = Object.values(g.getGraph().nodes); + const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + + expect(nodes.some((node) => node.type === 'lresize')).toBe(true); + expect(nodes.some((node) => node.type === 'spandrel_image_to_image_autoscale')).toBe(false); + expect(nodes.some((node) => node.id.startsWith('hrf_sdxl_model_loader'))).toBe(false); + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(hrfDenoise).toMatchObject({ steps: 30 }); + expect(l2i).toMatchObject({ type: 'l2i', tiled: true, tile_size: 1024 }); + expect(g.getMetadataNode()).toMatchObject({ hrf_method: 'latent' }); + expect(g.getMetadataNode().hrf_steps).toBeUndefined(); + }); + it('preserves the original graph and writes disabled metadata when HRF is off', () => { const { g, seed, noise, denoise, l2i } = buildClassicGraph(); @@ -161,6 +288,29 @@ describe('addHighResFix', () => { expect(g.getMetadataNode()).toMatchObject({ hrf_enabled: false }); }); + it('applies custom HRF steps only to the second pass', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ hrfMethod: 'upscale_model', hrfSteps: 12 }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const hrfDenoise = Object.values(g.getGraph().nodes).find( + (node) => node.type === 'tiled_multi_diffusion_denoise_latents' + ); + + expect(denoise.steps).toBe(30); + expect(hrfDenoise).toMatchObject({ steps: 12 }); + expect(g.getMetadataNode()).toMatchObject({ hrf_steps: 12 }); + }); + it('reroutes transformer txt2img graphs through latent resize and a final-size second denoise pass', () => { const { g, seed, denoise, l2i } = buildTransformerGraph(); @@ -199,7 +349,7 @@ describe('addHighResFix', () => { const nodes = Object.values(graph.nodes); const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); - const hrfPosCollect = nodes.find((node) => node.id.startsWith('hrf_sdxl_conditioning_collect')); + const hrfPosCollect = nodes.find((node) => node.id.startsWith('hrf_pos_collect_conditioning_collect')); if (!hrfDenoise || !hrfPosCond || !hrfPosCollect) { throw new Error('Expected HRF denoise and cloned SDXL conditioning nodes'); @@ -306,10 +456,357 @@ describe('addHighResFix', () => { hrf_method: 'upscale_model', hrf_upscale_model: { key: 'upscale' }, hrf_tile_controlnet_model: { key: 'tile' }, - hrf_structure: 0, + hrf_tile_control_weight: 0.625, hrf_tile_control_end: 0.2, hrf_tile_size: 1024, hrf_tile_overlap: 128, + hrf_lora_mode: 'reuse_generate', + }); + }); + + it('uses a dedicated SDXL HRF model and custom steps for the second pass only', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfSteps: 14, + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfModelLoader = nodes.find((node) => node.id.startsWith('hrf_sdxl_model_loader')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + + if (!hrfModelLoader || !hrfPosCond || !tiledDenoise || !i2l) { + throw new Error('Expected dedicated HRF model graph nodes'); + } + + expect(hrfModelLoader).toMatchObject({ type: 'sdxl_model_loader', model: { key: 'hrf-model' } }); + expect(tiledDenoise).toMatchObject({ steps: 14 }); + expect(denoise.steps).toBe(30); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'clip' }, + destination: { node_id: hrfPosCond.id, field: 'clip' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: i2l.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: 'l2i', field: 'vae' }, + }); + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(g.getMetadataNode()).toMatchObject({ + hrf_steps: 14, + hrf_model: { key: 'hrf-model' }, + hrf_lora_mode: 'none', + }); + }); + + it('preserves regional SDXL conditioning when a dedicated HRF model uses classic fallback', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const conditioning = addSDXLConditioning(g, denoise); + const { mask, regionPrompt } = addSDXLRegionalConditioning(g, conditioning); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfRegionPosCond = nodes.find((node) => node.id.startsWith('hrf_region_pos_cond')); + const hrfPositiveCollect = nodes.find((node) => node.id.startsWith('hrf_positive_conditioning_collect')); + + if (!classicHrfDenoise || !hrfRegionPosCond || !hrfPositiveCollect) { + throw new Error('Expected classic HRF denoise and cloned regional conditioning nodes'); + } + + expect(nodes.some((node) => node.type === 'tiled_multi_diffusion_denoise_latents')).toBe(false); + expect(hrfRegionPosCond).toMatchObject({ + type: 'sdxl_compel_prompt', + original_width: 1024, + original_height: 1024, + target_width: 1024, + target_height: 1024, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: regionPrompt.id, field: 'value' }, + destination: { node_id: hrfRegionPosCond.id, field: 'prompt' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: mask.id, field: 'mask' }, + destination: { node_id: hrfRegionPosCond.id, field: 'mask' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfRegionPosCond.id, field: 'conditioning' }, + destination: { node_id: hrfPositiveCollect.id, field: 'item' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfPositiveCollect.id, field: 'collection' }, + destination: { node_id: classicHrfDenoise.id, field: 'positive_conditioning' }, + }); + }); + + it('applies dedicated HRF LoRA CLIP outputs to all regional conditioning clones', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + const conditioning = addSDXLConditioning(g, denoise); + addSDXLRegionalConditioning(g, conditioning); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const classicHrfDenoise = nodes.find((node) => node.id.startsWith('hrf_denoise_latents')); + const hrfPosCond = nodes.find((node) => node.id.startsWith('hrf_pos_cond')); + const hrfRegionPosCond = nodes.find((node) => node.id.startsWith('hrf_region_pos_cond')); + const hrfRegionNegCond = nodes.find((node) => node.id.startsWith('hrf_region_neg_cond')); + + if (!loraLoader || !classicHrfDenoise || !hrfPosCond || !hrfRegionPosCond || !hrfRegionNegCond) { + throw new Error('Expected dedicated HRF LoRA and cloned regional conditioning nodes'); + } + + for (const cond of [hrfPosCond, hrfRegionPosCond, hrfRegionNegCond]) { + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'clip' }, + destination: { node_id: cond.id, field: 'clip' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'clip2' }, + destination: { node_id: cond.id, field: 'clip2' }, + }); + } + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: classicHrfDenoise.id, field: 'unet' }, + }); + }); + + it('applies only dedicated HRF LoRAs when dedicated LoRA mode is selected', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + + if (!loraLoader || !tiledDenoise) { + throw new Error('Expected dedicated HRF LoRA graph nodes'); + } + + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(g.getMetadataNode()).toMatchObject({ + hrf_lora_mode: 'dedicated', + hrf_loras: [{ model: { key: 'hrf-lora' }, weight: 0.6 }], + }); + }); + + it('ignores incompatible stale dedicated HRF LoRAs in the graph and metadata', () => { + const { g, seed, noise, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'stale-sd1-lora-id', + isEnabled: true, + model: { key: 'stale-sd1-lora', hash: 'stale-hash', name: 'Stale SD1 LoRA', base: 'sd-1', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const nodes = Object.values(g.getGraph().nodes); + + expect(nodes.some((node) => node.type === 'sdxl_lora_collection_loader')).toBe(false); + expect(nodes.some((node) => node.type === 'lora_selector')).toBe(false); + expect(g.getMetadataNode()).toMatchObject({ hrf_lora_mode: 'dedicated' }); + expect(g.getMetadataNode().hrf_loras).toEqual([]); + }); + + it('preserves seamless routing when a dedicated HRF model is selected', () => { + const { g, seed, noise, modelLoader, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + addSeamlessToClassicGraph({ g, modelLoader, denoise }); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfModel: { key: 'hrf-model', hash: 'hrf-hash', name: 'HRF Model', base: 'sdxl', type: 'main' }, + hrfLoraMode: 'none', + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfModelLoader = nodes.find((node) => node.id.startsWith('hrf_sdxl_model_loader')); + const hrfSeamless = nodes.find((node) => node.id.startsWith('hrf_seamless')); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + const i2l = nodes.find((node) => node.id.startsWith('hrf_i2l')); + + if (!hrfModelLoader || !hrfSeamless || !tiledDenoise || !i2l) { + throw new Error('Expected dedicated HRF seamless graph nodes'); + } + + expect(hrfSeamless).toMatchObject({ type: 'seamless', seamless_x: true, seamless_y: false }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'unet' }, + destination: { node_id: hrfSeamless.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfModelLoader.id, field: 'vae' }, + destination: { node_id: hrfSeamless.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'vae' }, + destination: { node_id: i2l.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'vae' }, + destination: { node_id: 'l2i', field: 'vae' }, + }); + }); + + it('preserves seamless routing before dedicated HRF LoRAs', () => { + const { g, seed, noise, modelLoader, denoise, l2i } = buildClassicGraph(); + addSDXLConditioning(g, denoise); + addSeamlessToClassicGraph({ g, modelLoader, denoise }); + + addHighResFix({ + g, + state: buildState({ + hrfMethod: 'upscale_model', + hrfLoraMode: 'dedicated', + hrfLoras: [ + { + id: 'hrf-lora-id', + isEnabled: true, + model: { key: 'hrf-lora', hash: 'hrf-lora-hash', name: 'HRF LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ], + }), + generationMode: 'txt2img', + denoise, + l2i, + noise, + seed, + }); + + const graph = g.getGraph(); + const nodes = Object.values(graph.nodes); + const hrfSeamless = nodes.find((node) => node.id.startsWith('hrf_seamless')); + const loraLoader = nodes.find((node) => node.type === 'sdxl_lora_collection_loader'); + const tiledDenoise = nodes.find((node) => node.type === 'tiled_multi_diffusion_denoise_latents'); + + if (!hrfSeamless || !loraLoader || !tiledDenoise) { + throw new Error('Expected HRF seamless and dedicated LoRA graph nodes'); + } + + expect(graph.edges).toContainEqual({ + source: { node_id: 'model_loader', field: 'vae' }, + destination: { node_id: hrfSeamless.id, field: 'vae' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: loraLoader.id, field: 'unet' }, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: loraLoader.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: hrfSeamless.id, field: 'unet' }, + destination: { node_id: tiledDenoise.id, field: 'unet' }, }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts index 50850d636b4..f941757d0f7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -2,8 +2,11 @@ import type { RootState } from 'app/store/store'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; -import type { GenerationMode } from 'features/controlLayers/store/types'; +import type { GenerationMode, LoRA } from 'features/controlLayers/store/types'; import type { BaseModelType } from 'features/nodes/types/common'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs'; +import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { DenoiseLatentsNodes, LatentToImageNodes } from 'features/nodes/util/graph/types'; import { getGridSize } from 'features/parameters/util/optimalDimension'; @@ -21,9 +24,22 @@ type AddHighResFixArg = { seed: Invocation<'integer'>; }; +type HrfModelLoader = Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>; +type HrfVaeSource = HrfModelLoader | Invocation<'seamless'>; +type FreshHrfModelInputs = { + modelLoader: HrfModelLoader; + vaeSource: HrfVaeSource; +}; +type ConditioningSource = { + node_id: string; + field: AnyInvocationOutputField; +}; + const SKIPPED_DENOISE_INPUT_FIELDS = new Set(['latents', 'noise']); +const FRESH_HRF_MODEL_INPUT_FIELDS = new Set(['unet', 'positive_conditioning', 'negative_conditioning']); const TILED_DENOISE_INPUT_FIELDS = new Set(['positive_conditioning', 'negative_conditioning', 'unet']); const SKIPPED_LATENT_TO_IMAGE_INPUT_FIELDS = new Set(['latents', 'metadata']); +const LATENT_HRF_L2I_TILE_SIZE = 1024; const getHighResFixFinalDimensions = (state: RootState) => { const params = selectParamsSlice(state); @@ -155,7 +171,7 @@ const cloneSDXLConditioningForFinalDimensions = ( const collect = g.addNode({ type: 'collect', - id: getPrefixedId('hrf_sdxl_conditioning_collect'), + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}_conditioning_collect`), }); for (const edge of itemEdges) { @@ -179,12 +195,16 @@ const copyDenoiseInputs = ( from: AnyInvocation, to: AnyInvocation, finalDimensions: { width: number; height: number }, - allowedInputFields?: Set + allowedInputFields?: Set, + skippedInputFields?: Set ) => { for (const edge of g.getEdgesTo(from)) { if (SKIPPED_DENOISE_INPUT_FIELDS.has(edge.destination.field)) { continue; } + if (skippedInputFields?.has(edge.destination.field)) { + continue; + } if (allowedInputFields && !allowedInputFields.has(edge.destination.field)) { continue; } @@ -250,7 +270,7 @@ const addTileControlNet = ( hrfDenoise: AnyInvocation, imageSource: Invocation<'unsharp_mask'>, tileControlNetModel: NonNullable['hrfTileControlNetModel']>, - structure: number, + tileControlWeight: number, tileControlEnd: number ) => { const controlNet = g.addNode({ @@ -259,7 +279,7 @@ const addTileControlNet = ( control_model: tileControlNetModel, control_mode: 'balanced', resize_mode: 'just_resize', - control_weight: (structure + 10) * 0.0325 + 0.3, + control_weight: tileControlWeight, begin_step_percent: 0, end_step_percent: tileControlEnd, }); @@ -271,6 +291,329 @@ const addTileControlNet = ( }); }; +const getConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning' +): ConditioningSource[] => { + const edge = g.getEdgesTo(denoise).find((edge) => edge.destination.field === field); + assert(edge, `Missing ${field} edge for HRF second pass`); + + const sourceNode = g.getNode(edge.source.node_id); + if (sourceNode.type !== 'collect') { + return [{ node_id: edge.source.node_id, field: edge.source.field as AnyInvocationOutputField }]; + } + + const itemEdges = g.getEdgesTo(sourceNode).filter((edge) => edge.destination.field === 'item'); + assert(itemEdges.length > 0, `HRF second pass expects at least one ${field} conditioning source`); + + return itemEdges.map((edge) => ({ + node_id: edge.source.node_id, + field: edge.source.field as AnyInvocationOutputField, + })); +}; + +const connectConditioningToDenoise = ( + g: Graph, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning', + conditioning: Array | Invocation<'sdxl_compel_prompt'>> +) => { + assert(conditioning.length > 0, `HRF second pass expects at least one ${field} conditioning node`); + + if (conditioning.length === 1) { + g.addEdgeFromObj({ + source: { node_id: conditioning[0]!.id, field: 'conditioning' }, + destination: { node_id: hrfDenoise.id, field }, + }); + return; + } + + assert(hrfDenoise.type === 'denoise_latents', 'Tiled HRF second pass supports only single conditioning inputs'); + + const collect = g.addNode({ + type: 'collect', + id: getPrefixedId(`hrf_${field}_collect`), + }); + + for (const cond of conditioning) { + g.addEdge(cond, 'conditioning', collect, 'item'); + } + + g.addEdgeFromObj({ + source: { node_id: collect.id, field: 'collection' }, + destination: { node_id: hrfDenoise.id, field }, + }); +}; + +const cloneSD1ConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning' +) => { + return getConditioningSources(g, denoise, field).map((source) => { + const sourceNode = g.getNode(source.node_id); + assert(sourceNode.type === 'compel', 'SD1 HRF refinement requires SD1 conditioning'); + + const clone = g.addNode({ + ...sourceNode, + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}`), + }); + copyInputEdges(g, sourceNode, clone, new Set(['clip'])); + return clone; + }); +}; + +const cloneSDXLConditioningSources = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + field: 'positive_conditioning' | 'negative_conditioning', + finalDimensions: { width: number; height: number } +) => { + return getConditioningSources(g, denoise, field).map((source) => { + const sourceNode = g.getNode(source.node_id); + assert(sourceNode.type === 'sdxl_compel_prompt', 'SDXL HRF refinement requires SDXL conditioning'); + + const clone = g.addNode({ + ...sourceNode, + id: getPrefixedId(`hrf_${sourceNode.id.split(':')[0]}`), + original_width: finalDimensions.width, + original_height: finalDimensions.height, + target_width: finalDimensions.width, + target_height: finalDimensions.height, + }); + copyInputEdges(g, sourceNode, clone, new Set(['clip', 'clip2'])); + return clone; + }); +}; + +const findOriginalSeamlessNode = (g: Graph, denoise: Invocation<'denoise_latents'>) => { + let current: AnyInvocation = denoise; + const visited = new Set(); + + while (!visited.has(current.id)) { + visited.add(current.id); + const unetEdge = g.getEdgesTo(current).find((edge) => edge.destination.field === 'unet'); + if (!unetEdge) { + return null; + } + + const sourceNode = g.getNode(unetEdge.source.node_id); + if (sourceNode.type === 'seamless') { + return sourceNode; + } + + current = sourceNode; + } + + return null; +}; + +const addFreshHrfSeamless = ( + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + modelLoader: HrfModelLoader, + useHrfModelVae: boolean +) => { + const originalSeamless = findOriginalSeamlessNode(g, denoise); + if (!originalSeamless) { + return null; + } + + const seamless = g.addNode({ + ...originalSeamless, + id: getPrefixedId('hrf_seamless'), + }); + + g.addEdge(modelLoader, 'unet', seamless, 'unet'); + + if (useHrfModelVae) { + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + } else { + const originalVaeEdge = g.getEdgesTo(originalSeamless).find((edge) => edge.destination.field === 'vae'); + if (originalVaeEdge) { + g.addEdgeFromObj({ + source: { ...originalVaeEdge.source }, + destination: { node_id: seamless.id, field: 'vae' }, + }); + } else { + g.addEdge(modelLoader, 'vae', seamless, 'vae'); + } + } + + g.addEdge(seamless, 'unet', hrfDenoise, 'unet'); + return seamless; +}; + +const getEnabledHrfLoRAs = (state: RootState): LoRA[] | null => { + const params = selectParamsSlice(state); + + if (params.hrfLoraMode === 'none') { + return null; + } + + if (params.hrfLoraMode === 'dedicated') { + return getEnabledDedicatedHrfLoRAs(state); + } + + return state.loras.loras.filter((lora) => lora.isEnabled); +}; + +const getEnabledDedicatedHrfLoRAs = (state: RootState): LoRA[] => { + const params = selectParamsSlice(state); + const effectiveHrfBase = params.hrfModel?.base ?? params.model?.base; + if (!effectiveHrfBase) { + return []; + } + return params.hrfLoras.filter((lora) => lora.isEnabled && lora.model.base === effectiveHrfBase); +}; + +const getHrfLoRAMetadata = (state: RootState) => { + const params = selectParamsSlice(state); + if (params.hrfLoraMode !== 'dedicated') { + return undefined; + } + + return getEnabledDedicatedHrfLoRAs(state).map((lora) => ({ + model: zModelIdentifierField.parse(lora.model), + weight: lora.weight, + })); +}; + +const addFreshSD1HrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'> +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + const hrfModel = params.hrfModel ?? params.model; + assert(hrfModel?.base === 'sd-1', 'SD1 HRF refinement model must be an SD1 model'); + + const modelLoader = g.addNode({ + type: 'main_model_loader', + id: getPrefixedId('hrf_sd1_model_loader'), + model: hrfModel, + }); + const clipSkip = g.addNode({ + type: 'clip_skip', + id: getPrefixedId('hrf_clip_skip'), + skipped_layers: params.clipSkip, + }); + + const positiveConditioning = cloneSD1ConditioningSources(g, denoise, 'positive_conditioning'); + const negativeConditioning = cloneSD1ConditioningSources(g, denoise, 'negative_conditioning'); + const hrfSeamless = addFreshHrfSeamless(g, denoise, hrfDenoise, modelLoader, params.hrfModel !== null); + + g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); + for (const cond of positiveConditioning) { + g.addEdge(clipSkip, 'clip', cond, 'clip'); + } + for (const cond of negativeConditioning) { + g.addEdge(clipSkip, 'clip', cond, 'clip'); + } + if (!hrfSeamless) { + g.addEdgeFromObj({ + source: { node_id: modelLoader.id, field: 'unet' }, + destination: { node_id: hrfDenoise.id, field: 'unet' }, + }); + } + connectConditioningToDenoise(g, hrfDenoise, 'positive_conditioning', positiveConditioning); + connectConditioningToDenoise(g, hrfDenoise, 'negative_conditioning', negativeConditioning); + + const enabledLoRAs = getEnabledHrfLoRAs(state); + if (enabledLoRAs?.length) { + addLoRAs( + state, + g, + hrfDenoise, + modelLoader, + hrfSeamless, + clipSkip, + positiveConditioning[0]!, + negativeConditioning[0]!, + { + loras: enabledLoRAs, + idPrefix: 'hrf', + metadataKey: params.hrfLoraMode === 'dedicated' ? 'hrf_loras' : 'loras', + extraPositiveConditioning: positiveConditioning.slice(1), + extraNegativeConditioning: negativeConditioning.slice(1), + } + ); + } + + return { modelLoader, vaeSource: hrfSeamless ?? modelLoader }; +}; + +const addFreshSDXLHrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + finalDimensions: { width: number; height: number } +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + const hrfModel = params.hrfModel ?? params.model; + assert(hrfModel?.base === 'sdxl', 'SDXL HRF refinement model must be an SDXL model'); + + const modelLoader = g.addNode({ + type: 'sdxl_model_loader', + id: getPrefixedId('hrf_sdxl_model_loader'), + model: hrfModel, + }); + + const positiveConditioning = cloneSDXLConditioningSources(g, denoise, 'positive_conditioning', finalDimensions); + const negativeConditioning = cloneSDXLConditioningSources(g, denoise, 'negative_conditioning', finalDimensions); + const hrfSeamless = addFreshHrfSeamless(g, denoise, hrfDenoise, modelLoader, params.hrfModel !== null); + + for (const cond of positiveConditioning) { + g.addEdge(modelLoader, 'clip', cond, 'clip'); + g.addEdge(modelLoader, 'clip2', cond, 'clip2'); + } + for (const cond of negativeConditioning) { + g.addEdge(modelLoader, 'clip', cond, 'clip'); + g.addEdge(modelLoader, 'clip2', cond, 'clip2'); + } + if (!hrfSeamless) { + g.addEdgeFromObj({ + source: { node_id: modelLoader.id, field: 'unet' }, + destination: { node_id: hrfDenoise.id, field: 'unet' }, + }); + } + connectConditioningToDenoise(g, hrfDenoise, 'positive_conditioning', positiveConditioning); + connectConditioningToDenoise(g, hrfDenoise, 'negative_conditioning', negativeConditioning); + + const enabledLoRAs = getEnabledHrfLoRAs(state); + if (enabledLoRAs?.length) { + addSDXLLoRAs(state, g, hrfDenoise, modelLoader, hrfSeamless, positiveConditioning[0]!, negativeConditioning[0]!, { + loras: enabledLoRAs, + idPrefix: 'hrf', + metadataKey: params.hrfLoraMode === 'dedicated' ? 'hrf_loras' : 'loras', + extraPositiveConditioning: positiveConditioning.slice(1), + extraNegativeConditioning: negativeConditioning.slice(1), + }); + } + + return { modelLoader, vaeSource: hrfSeamless ?? modelLoader }; +}; + +const addFreshHrfModelInputs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + hrfDenoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, + finalDimensions: { width: number; height: number } +): FreshHrfModelInputs => { + const params = selectParamsSlice(state); + + if (params.model?.base === 'sdxl') { + return addFreshSDXLHrfModelInputs(state, g, denoise, hrfDenoise, finalDimensions); + } + + return addFreshSD1HrfModelInputs(state, g, denoise, hrfDenoise); +}; + const addLatentHighResFix = ({ g, state, @@ -319,6 +662,9 @@ const addLatentHighResFix = ({ g.addEdge(denoise, 'latents', resizeLatents, 'latents'); g.addEdge(resizeLatents, 'latents', hrfDenoise, 'latents'); g.addEdge(hrfDenoise, 'latents', l2i, 'latents'); + if (l2i.type === 'l2i') { + g.updateNode(l2i, { tile_size: LATENT_HRF_L2I_TILE_SIZE, tiled: true }); + } g.upsertMetadata({ width: finalDimensions.width, @@ -387,10 +733,13 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH }); const useClassicDenoise = hasUnsupportedTiledDenoiseInputs(g, denoise); + const hrfSteps = params.hrfSteps ?? denoise.steps; + const needsFreshHrfModelInputs = params.hrfModel !== null || params.hrfLoraMode !== 'reuse_generate'; const hrfDenoise = useClassicDenoise ? g.addNode({ ...denoise, id: getPrefixedId('hrf_denoise_latents'), + steps: hrfSteps, denoising_start, denoising_end, }) @@ -400,7 +749,7 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH tile_height: params.hrfTileSize, tile_width: params.hrfTileSize, tile_overlap: params.hrfTileOverlap, - steps: denoise.steps, + steps: hrfSteps, cfg_scale: denoise.cfg_scale, cfg_rescale_multiplier: denoise.cfg_rescale_multiplier, scheduler: denoise.scheduler, @@ -413,14 +762,28 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH denoise, hrfDenoise, finalDimensions, - useClassicDenoise ? undefined : TILED_DENOISE_INPUT_FIELDS + useClassicDenoise ? undefined : TILED_DENOISE_INPUT_FIELDS, + needsFreshHrfModelInputs ? FRESH_HRF_MODEL_INPUT_FIELDS : undefined ); + + if (needsFreshHrfModelInputs) { + if (params.hrfModel) { + g.deleteEdgesTo(i2l, ['vae']); + g.deleteEdgesTo(l2i, ['vae']); + } + + const { vaeSource: hrfVaeSource } = addFreshHrfModelInputs(state, g, denoise, hrfDenoise, finalDimensions); + if (params.hrfModel) { + g.addEdge(hrfVaeSource, 'vae', i2l, 'vae'); + g.addEdge(hrfVaeSource, 'vae', l2i, 'vae'); + } + } addTileControlNet( g, hrfDenoise, unsharpMask, params.hrfTileControlNetModel, - params.hrfStructure, + params.hrfTileControlWeight, params.hrfTileControlEnd ); @@ -460,10 +823,14 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH hrf_scale: params.hrfScale, hrf_upscale_model: params.hrfUpscaleModel, hrf_tile_controlnet_model: params.hrfTileControlNetModel, - hrf_structure: params.hrfStructure, + hrf_tile_control_weight: params.hrfTileControlWeight, hrf_tile_control_end: params.hrfTileControlEnd, hrf_tile_size: params.hrfTileSize, hrf_tile_overlap: params.hrfTileOverlap, + hrf_steps: params.hrfSteps ?? undefined, + hrf_model: params.hrfModel ?? undefined, + hrf_lora_mode: params.hrfLoraMode, + hrf_loras: getHrfLoRAMetadata(state), }); g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width'); g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index 79a8521efba..15b29e120f4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -1,9 +1,18 @@ import type { RootState } from 'app/store/store'; import { getPrefixedId } from 'features/controlLayers/konva/util'; +import type { LoRA } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; +type AddLoRAsOptions = { + loras?: LoRA[]; + metadataKey?: 'loras' | 'hrf_loras'; + idPrefix?: string; + extraPositiveConditioning?: Invocation<'compel'>[]; + extraNegativeConditioning?: Invocation<'compel'>[]; +}; + export const addLoRAs = ( state: RootState, g: Graph, @@ -12,12 +21,15 @@ export const addLoRAs = ( seamless: Invocation<'seamless'> | null, clipSkip: Invocation<'clip_skip'>, posCond: Invocation<'compel'>, - negCond: Invocation<'compel'> + negCond: Invocation<'compel'>, + options?: AddLoRAsOptions ): void => { - const enabledLoRAs = state.loras.loras.filter( + const enabledLoRAs = (options?.loras ?? state.loras.loras).filter( (l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2') ); const loraCount = enabledLoRAs.length; + const positiveConditioning = [posCond, ...(options?.extraPositiveConditioning ?? [])]; + const negativeConditioning = [negCond, ...(options?.extraNegativeConditioning ?? [])]; if (loraCount === 0) { return; @@ -29,11 +41,11 @@ export const addLoRAs = ( // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ type: 'collect', - id: getPrefixedId('lora_collector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collector` : 'lora_collector'), }); const loraCollectionLoader = g.addNode({ type: 'lora_collection_loader', - id: getPrefixedId('lora_collection_loader'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collection_loader` : 'lora_collection_loader'), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -42,11 +54,17 @@ export const addLoRAs = ( g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip'); // Reroute UNet & CLIP connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['unet']); - g.deleteEdgesTo(posCond, ['clip']); - g.deleteEdgesTo(negCond, ['clip']); g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); - g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); + + for (const cond of positiveConditioning) { + g.deleteEdgesTo(cond, ['clip']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + } + + for (const cond of negativeConditioning) { + g.deleteEdgesTo(cond, ['clip']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + } for (const lora of enabledLoRAs) { const { weight } = lora; @@ -54,7 +72,7 @@ export const addLoRAs = ( const loraSelector = g.addNode({ type: 'lora_selector', - id: getPrefixedId('lora_selector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_selector` : 'lora_selector'), lora: parsedModel, weight, }); @@ -67,5 +85,9 @@ export const addLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - g.upsertMetadata({ loras: loraMetadata }); + if (options?.metadataKey === 'hrf_loras') { + g.upsertMetadata({ hrf_loras: loraMetadata }); + } else { + g.upsertMetadata({ loras: loraMetadata }); + } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index a38c9757cea..1fd7bc448c0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -1,9 +1,18 @@ import type { RootState } from 'app/store/store'; import { getPrefixedId } from 'features/controlLayers/konva/util'; +import type { LoRA } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; +type AddSDXLLoRAsOptions = { + loras?: LoRA[]; + metadataKey?: 'loras' | 'hrf_loras'; + idPrefix?: string; + extraPositiveConditioning?: Invocation<'sdxl_compel_prompt'>[]; + extraNegativeConditioning?: Invocation<'sdxl_compel_prompt'>[]; +}; + export const addSDXLLoRAs = ( state: RootState, g: Graph, @@ -11,10 +20,13 @@ export const addSDXLLoRAs = ( modelLoader: Invocation<'sdxl_model_loader'>, seamless: Invocation<'seamless'> | null, posCond: Invocation<'sdxl_compel_prompt'>, - negCond: Invocation<'sdxl_compel_prompt'> + negCond: Invocation<'sdxl_compel_prompt'>, + options?: AddSDXLLoRAsOptions ): void => { - const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); + const enabledLoRAs = (options?.loras ?? state.loras.loras).filter((l) => l.isEnabled && l.model.base === 'sdxl'); const loraCount = enabledLoRAs.length; + const positiveConditioning = [posCond, ...(options?.extraPositiveConditioning ?? [])]; + const negativeConditioning = [negCond, ...(options?.extraNegativeConditioning ?? [])]; if (loraCount === 0) { return; @@ -25,12 +37,14 @@ export const addSDXLLoRAs = ( // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ - id: getPrefixedId('lora_collector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_collector` : 'lora_collector'), type: 'collect', }); const loraCollectionLoader = g.addNode({ type: 'sdxl_lora_collection_loader', - id: getPrefixedId('sdxl_lora_collection_loader'), + id: getPrefixedId( + options?.idPrefix ? `${options.idPrefix}_sdxl_lora_collection_loader` : 'sdxl_lora_collection_loader' + ), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -40,13 +54,19 @@ export const addSDXLLoRAs = ( g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2'); // Reroute UNet & CLIP connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['unet']); - g.deleteEdgesTo(posCond, ['clip', 'clip2']); - g.deleteEdgesTo(negCond, ['clip', 'clip2']); g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); - g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); - g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2'); - g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2'); + + for (const cond of positiveConditioning) { + g.deleteEdgesTo(cond, ['clip', 'clip2']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip2', cond, 'clip2'); + } + + for (const cond of negativeConditioning) { + g.deleteEdgesTo(cond, ['clip', 'clip2']); + g.addEdge(loraCollectionLoader, 'clip', cond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip2', cond, 'clip2'); + } for (const lora of enabledLoRAs) { const { weight } = lora; @@ -54,7 +74,7 @@ export const addSDXLLoRAs = ( const loraSelector = g.addNode({ type: 'lora_selector', - id: getPrefixedId('lora_selector'), + id: getPrefixedId(options?.idPrefix ? `${options.idPrefix}_lora_selector` : 'lora_selector'), lora: parsedModel, weight, }); @@ -67,5 +87,9 @@ export const addSDXLLoRAs = ( g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } - g.upsertMetadata({ loras: loraMetadata }); + if (options?.metadataKey === 'hrf_loras') { + g.upsertMetadata({ hrf_loras: loraMetadata }); + } else { + g.upsertMetadata({ loras: loraMetadata }); + } }; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts index cd06a052464..575f3a9612a 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -108,6 +108,9 @@ const baseParams = { hrfMethod: 'latent', hrfUpscaleModel: null, hrfTileControlNetModel: null, + hrfModel: null, + hrfLoraMode: 'reuse_generate', + hrfLoras: [], refinerModel: null, } as unknown as ParamsState; @@ -121,6 +124,7 @@ const buildGenerateTabArg = (overrides: { hrfMethod?: ParamsState['hrfMethod']; hrfUpscaleModel?: unknown; hrfTileControlNetModel?: unknown; + hrfModel?: unknown; refinerModel?: unknown; hasFlux2DiffusersVaeSource?: boolean; hasFlux2DiffusersQwen3Source?: boolean; @@ -135,6 +139,7 @@ const buildGenerateTabArg = (overrides: { hrfMethod: overrides.hrfMethod ?? 'latent', hrfUpscaleModel: overrides.hrfUpscaleModel ?? null, hrfTileControlNetModel: overrides.hrfTileControlNetModel ?? null, + hrfModel: overrides.hrfModel ?? null, refinerModel: overrides.refinerModel ?? null, } as unknown as ParamsState, refImages: baseRefImages, @@ -202,6 +207,9 @@ const hasHrfUpscaleModelMissingReason = (reasons: { content: string }[]) => const hasHrfTileControlNetMissingReason = (reasons: { content: string }[]) => reasons.some((r) => r.content.includes('hrfTileControlNetModelMissing')); +const hasHrfModelOverrideBaseMismatchReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('hrfModelOverrideBaseMismatch')); + // --- Tests --- describe('FLUX.2 Klein readiness checks – generate tab', () => { @@ -334,6 +342,51 @@ describe('High Resolution Fix readiness checks - generate tab', () => { expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); }); + + it('does not apply stale upscale-model-only readiness checks to latent HRF', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'latent', + hrfModel: { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' }, + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + }) + ); + + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(false); + }); + + it('errors when dedicated HRF model base differs from the Generate model base', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: tileControlNetModel, + hrfModel: { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' }, + }) + ); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(true); + }); + + it('validates Tile ControlNet against the dedicated HRF model base', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: sdxlModel, + hrfEnabled: true, + hrfMethod: 'upscale_model', + hrfUpscaleModel: upscaleModel, + hrfTileControlNetModel: { ...tileControlNetModel, base: 'sd-1' }, + hrfModel: { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }, + }) + ); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(true); + }); }); describe('FLUX.2 Klein readiness checks – canvas tab', () => { diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index c81995a7b78..e5c1452bae2 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -366,15 +366,26 @@ export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { reasons.push({ content: i18n.t('parameters.invoke.hrfRefinerUnsupported') }); } if (params.hrfMethod === 'upscale_model') { + const hrfBase = + params.hrfModel?.base ?? (model && !isExternalApiModelConfig(model) ? model.base : params.model?.base); if (model && !isExternalApiModelConfig(model) && !['sd-1', 'sdxl'].includes(model.base)) { reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelBaseUnsupported') }); } + if (params.hrfModel) { + if (params.hrfModel.base === 'external') { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideExternalUnsupported') }); + } else if (!['sd-1', 'sdxl'].includes(params.hrfModel.base)) { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseUnsupported') }); + } else if (model && !isExternalApiModelConfig(model) && params.hrfModel.base !== model.base) { + reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseMismatch') }); + } + } if (!params.hrfUpscaleModel) { reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelMissing') }); } if (!params.hrfTileControlNetModel) { reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); - } else if (model && !isExternalApiModelConfig(model) && params.hrfTileControlNetModel.base !== model.base) { + } else if (hrfBase && params.hrfTileControlNetModel.base !== hrfBase) { reasons.push({ content: i18n.t('parameters.invoke.hrfTileControlNetModelMissing') }); } } diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index 211d5c8990b..b9bb71536f7 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -1,17 +1,23 @@ -import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import type { ComboboxOnChange, FormLabelProps } from '@invoke-ai/ui-library'; import { Box, Button, ButtonGroup, + Card, + CardBody, + CardHeader, Combobox, CompositeNumberInput, CompositeSlider, + Expander, Flex, FormControl, FormControlGroup, FormLabel, + IconButton, StandaloneAccordion, Switch, + Text, Tooltip, } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; @@ -20,45 +26,76 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useModelCombobox } from 'common/hooks/useModelCombobox'; +import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice'; import { + buildSelectHrfLoRA, + hrfLoraAdded, + hrfLoraDeleted, + hrfLoraIsEnabledChanged, + hrfLoraWeightChanged, selectBase, selectHrfEnabled, selectHrfFinalDimensions, selectHrfLatentInterpolationMode, + selectHrfLoraMode, + selectHrfLoras, selectHrfMethod, + selectHrfModel, selectHrfScale, + selectHrfSteps, selectHrfStrength, - selectHrfStructure, selectHrfTileControlEnd, selectHrfTileControlNetModel, + selectHrfTileControlWeight, selectHrfTileOverlap, selectHrfTileSize, selectHrfUpscaleModel, selectIsRefinerModelSelected, selectModelSupportsHrf, + selectSteps, setHrfEnabled, setHrfLatentInterpolationMode, + setHrfLoraMode, setHrfMethod, + setHrfModel, setHrfScale, + setHrfSteps, setHrfStrength, - setHrfStructure, setHrfTileControlEnd, setHrfTileControlNetModel, + setHrfTileControlWeight, setHrfTileOverlap, setHrfTileSize, setHrfUpscaleModel, } from 'features/controlLayers/store/paramsSlice'; +import type { LoRA } from 'features/controlLayers/store/types'; import { zHrfLatentInterpolationMode, zHrfMethod } from 'features/controlLayers/store/types'; +import { CONSTRAINTS as STEPS_CONSTRAINTS } from 'features/parameters/components/Core/ParamSteps'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; +import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import type { ChangeEvent } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; -import { useControlNetModels, useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { + modelConfigsAdapterSelectors, + selectModelConfigsQuery, + useGetModelConfigQuery, +} from 'services/api/endpoints/models'; +import { + useControlNetModels, + useLoRAModels, + useMainModels, + useSpandrelImageToImageModels, +} from 'services/api/hooks/modelsByType'; import { type ControlNetModelConfig, isControlNetModelConfig, + isExternalApiModelConfig, + isMainOrExternalModelConfig, + type LoRAModelConfig, + type MainOrExternalModelConfig, type SpandrelImageToImageModelConfig, } from 'services/api/types'; @@ -82,14 +119,14 @@ const STRENGTH_CONSTRAINTS = { fineStep: 0.01, }; -const STRUCTURE_CONSTRAINTS = { - initial: 0, - sliderMin: -10, - sliderMax: 10, - numberInputMin: -10, - numberInputMax: 10, - coarseStep: 1, - fineStep: 1, +const TILE_CONTROL_WEIGHT_CONSTRAINTS = { + initial: 0.625, + sliderMin: 0, + sliderMax: 1.5, + numberInputMin: 0, + numberInputMax: 2, + coarseStep: 0.025, + fineStep: 0.005, }; const TILE_CONTROL_END_CONSTRAINTS = { @@ -122,6 +159,24 @@ const TILE_OVERLAP_CONSTRAINTS = { fineStep: 8, }; +const formLabelProps: FormLabelProps = { + m: 0, + w: '10.5rem', + minW: '10.5rem', + maxW: '10.5rem', + flexShrink: 0, + whiteSpace: 'normal', + lineHeight: 1.2, + overflowWrap: 'break-word', +}; + +const formControlProps = { + alignItems: 'center', + gap: 3, + minW: 0, + w: 'full', +}; + const selectHrfTileControlNetModelConfig = createSelector( selectModelConfigsQuery, selectHrfTileControlNetModel, @@ -137,33 +192,34 @@ const selectHrfTileControlNetModelConfig = createSelector( } ); +const selectHrfModelConfig = createSelector(selectModelConfigsQuery, selectHrfModel, (modelConfigs, model) => { + if (!modelConfigs.data || !model) { + return null; + } + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key); + if (!modelConfig || !isMainOrExternalModelConfig(modelConfig) || isExternalApiModelConfig(modelConfig)) { + return null; + } + return modelConfig; +}); + +const selectHrfLoRAIds = createMemoizedSelector(selectHrfLoras, (loras) => loras.map(({ id }) => id)); +const selectHrfLoRAModelKeys = createMemoizedSelector(selectHrfLoras, (loras) => loras.map(({ model }) => model.key)); + const selectBadges = createMemoizedSelector( - [ - selectHrfEnabled, - selectHrfMethod, - selectHrfScale, - selectHrfStrength, - selectHrfFinalDimensions, - selectHrfUpscaleModel, - ], - (enabled, method, scale, strength, finalDimensions, upscaleModel) => { + [selectHrfEnabled, selectHrfMethod, selectHrfScale, selectHrfStrength, selectHrfFinalDimensions], + (enabled, method, scale, strength, finalDimensions) => { if (!enabled) { return EMPTY_ARRAY; } const methodBadge = method === 'upscale_model' ? 'Model' : 'Latent'; - const badges = [ + return [ methodBadge, `${scale}x`, `${Math.round(strength * 100)}%`, `${finalDimensions.width}x${finalDimensions.height}`, ]; - - if (method === 'upscale_model' && upscaleModel) { - badges.push(upscaleModel.name); - } - - return badges; } ); @@ -180,7 +236,7 @@ const ParamHrfEnabled = memo(() => { ); return ( - + {t('hrf.enableHrf')} @@ -209,11 +265,16 @@ const ParamHrfMethod = memo(() => { {t('hrf.upscaleMethod')} - - - @@ -376,13 +437,13 @@ const ParamHrfUpscaleModel = memo(() => { }); return ( - + {t('upscaling.upscaleModel')} - + { const dispatch = useAppDispatch(); const { t } = useTranslation(); const tileControlNetModel = useAppSelector(selectHrfTileControlNetModelConfig); - const currentBaseModel = useAppSelector(selectBase); + const generateBaseModel = useAppSelector(selectBase); + const hrfModel = useAppSelector(selectHrfModel); + const currentBaseModel = hrfModel?.base ?? generateBaseModel; const [modelConfigs, { isLoading }] = useControlNetModels(); const onChange = useCallback( @@ -461,47 +524,51 @@ const ParamHrfTileControlNetModel = memo(() => { ParamHrfTileControlNetModel.displayName = 'ParamHrfTileControlNetModel'; -const ParamHrfStructure = memo(() => { +const ParamHrfTileControlWeight = memo(() => { const dispatch = useAppDispatch(); - const structure = useAppSelector(selectHrfStructure); + const tileControlWeight = useAppSelector(selectHrfTileControlWeight); const { t } = useTranslation(); const onChange = useCallback( (v: number) => { - dispatch(setHrfStructure(v)); + dispatch(setHrfTileControlWeight(v)); }, [dispatch] ); return ( - - {t('upscaling.structure')} + + {t('hrf.tileControlWeight')} ); }); -ParamHrfStructure.displayName = 'ParamHrfStructure'; +ParamHrfTileControlWeight.displayName = 'ParamHrfTileControlWeight'; const ParamHrfTileControlEnd = memo(() => { const dispatch = useAppDispatch(); @@ -637,6 +704,314 @@ const ParamHrfTileOverlap = memo(() => { ParamHrfTileOverlap.displayName = 'ParamHrfTileOverlap'; +const ParamHrfSteps = memo(() => { + const dispatch = useAppDispatch(); + const hrfSteps = useAppSelector(selectHrfSteps); + const generateSteps = useAppSelector(selectSteps); + const { t } = useTranslation(); + + const onToggle = useCallback( + (event: ChangeEvent) => { + dispatch(setHrfSteps(event.target.checked ? generateSteps : null)); + }, + [dispatch, generateSteps] + ); + + const onChange = useCallback( + (v: number) => { + dispatch(setHrfSteps(v)); + }, + [dispatch] + ); + + const isCustom = hrfSteps !== null; + + return ( + + + {t('hrf.steps')} + + + + + + + ); +}); + +ParamHrfSteps.displayName = 'ParamHrfSteps'; + +const ParamHrfModel = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const selectedModelConfig = useAppSelector(selectHrfModelConfig); + const currentBaseModel = useAppSelector(selectBase); + + const filter = useCallback( + (model: MainOrExternalModelConfig) => { + return ( + !isExternalApiModelConfig(model) && model.base === currentBaseModel && ['sd-1', 'sdxl'].includes(model.base) + ); + }, + [currentBaseModel] + ); + const [modelConfigs, { isLoading }] = useMainModels(filter); + + const onChange = useCallback( + (model: MainOrExternalModelConfig | null) => { + dispatch(setHrfModel(model && !isExternalApiModelConfig(model) ? model : null)); + }, + [dispatch] + ); + + return ( + + + {t('hrf.model')} + + + + ); +}); + +ParamHrfModel.displayName = 'ParamHrfModel'; + +const ParamHrfLoraMode = memo(() => { + const dispatch = useAppDispatch(); + const mode = useAppSelector(selectHrfLoraMode); + const { t } = useTranslation(); + + const onClickReuseGenerate = useCallback(() => { + dispatch(setHrfLoraMode('reuse_generate')); + }, [dispatch]); + + const onClickNone = useCallback(() => { + dispatch(setHrfLoraMode('none')); + }, [dispatch]); + + const onClickDedicated = useCallback(() => { + dispatch(setHrfLoraMode('dedicated')); + }, [dispatch]); + + return ( + + + {t('hrf.loraMode')} + + + + + + + + ); +}); + +ParamHrfLoraMode.displayName = 'ParamHrfLoraMode'; + +const ParamHrfLoraSelect = memo(() => { + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useLoRAModels(); + const { t } = useTranslation(); + const addedLoRAModelKeys = useAppSelector(selectHrfLoRAModelKeys); + const currentBaseModel = useAppSelector(selectBase); + const hrfModel = useAppSelector(selectHrfModel); + const hrfBase = hrfModel?.base ?? currentBaseModel; + + const compatibleLoRAs = useMemo(() => { + if (!hrfBase) { + return EMPTY_ARRAY; + } + return modelConfigs.filter((model) => model.base === hrfBase); + }, [hrfBase, modelConfigs]); + + const getIsDisabled = useCallback( + (model: LoRAModelConfig): boolean => { + return addedLoRAModelKeys.includes(model.key); + }, + [addedLoRAModelKeys] + ); + + const onChange = useCallback( + (model: LoRAModelConfig | null) => { + if (!model) { + return; + } + dispatch(hrfLoraAdded({ model })); + }, + [dispatch] + ); + + const placeholder = useMemo(() => { + if (isLoading) { + return t('common.loading'); + } + if (compatibleLoRAs.length === 0) { + return hrfBase ? t('models.noCompatibleLoRAs') : t('models.selectModel'); + } + return t('hrf.addDedicatedLora'); + }, [compatibleLoRAs.length, hrfBase, isLoading, t]); + + return ( + + + {t('hrf.dedicatedLoras')} + + + + ); +}); + +ParamHrfLoraSelect.displayName = 'ParamHrfLoraSelect'; + +const HrfLoRAList = memo(() => { + const ids = useAppSelector(selectHrfLoRAIds); + + if (!ids.length) { + return null; + } + + return ( + + {ids.map((id) => ( + + ))} + + ); +}); + +HrfLoRAList.displayName = 'HrfLoRAList'; + +const HrfLoRACard = memo((props: { id: string }) => { + const selectLoRA = useMemo(() => buildSelectHrfLoRA(props.id), [props.id]); + const lora = useAppSelector(selectLoRA); + + if (!lora) { + return null; + } + return ; +}); + +HrfLoRACard.displayName = 'HrfLoRACard'; + +const HrfLoRAContent = memo(({ lora }: { lora: LoRA }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const { data: loraConfig } = useGetModelConfigQuery(lora.model.key); + + const onChange = useCallback( + (v: number) => { + dispatch(hrfLoraWeightChanged({ id: lora.id, weight: v })); + }, + [dispatch, lora.id] + ); + + const onToggle = useCallback(() => { + dispatch(hrfLoraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.id, lora.isEnabled]); + + const onRemove = useCallback(() => { + dispatch(hrfLoraDeleted({ id: lora.id })); + }, [dispatch, lora.id]); + + return ( + + + + + {loraConfig?.name ?? lora.model.key.substring(0, 8)} + + + + } + /> + + + + + + + + + + + ); +}); + +HrfLoRAContent.displayName = 'HrfLoRAContent'; + export const HighResFixSettingsAccordion = memo(() => { const { t } = useTranslation(); const badges = useAppSelector(selectBadges); @@ -644,10 +1019,15 @@ export const HighResFixSettingsAccordion = memo(() => { const method = useAppSelector(selectHrfMethod); const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); + const hrfLoraMode = useAppSelector(selectHrfLoraMode); const { isOpen, onToggle } = useStandaloneAccordionToggle({ id: 'high-res-fix-settings-generate-tab', defaultIsOpen: false, }); + const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ + id: 'high-res-fix-settings-generate-tab-advanced', + defaultIsOpen: false, + }); const parsedMethod = zHrfMethod.parse(method); @@ -657,28 +1037,50 @@ export const HighResFixSettingsAccordion = memo(() => { return ( - - - {enabled && ( - - - - - {parsedMethod === 'latent' ? ( - - ) : ( + + + + {enabled && ( + <> + + + + {parsedMethod === 'upscale_model' && ( + <> + + + + )} + + )} + + + {enabled && ( + + + + {parsedMethod === 'latent' && } + {parsedMethod === 'upscale_model' && ( + <> + + + + + + + + + )} + + {parsedMethod === 'upscale_model' && hrfLoraMode === 'dedicated' && ( <> - - - - - - + + )} - - )} - + + + )} ); }); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 9655f861f15..df169605331 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7717,10 +7717,16 @@ export type components = { hrf_tile_controlnet_model?: components["schemas"]["ModelIdentifierField"] | null; /** * Hrf Structure - * @description The high resolution fix tile ControlNet structure value. + * @description Legacy high resolution fix tile ControlNet structure value. * @default null */ hrf_structure?: number | null; + /** + * Hrf Tile Control Weight + * @description The high resolution fix tile ControlNet control weight. + * @default null + */ + hrf_tile_control_weight?: number | null; /** * Hrf Tile Control End * @description The high resolution fix tile ControlNet end step percentage. @@ -7739,6 +7745,30 @@ export type components = { * @default null */ hrf_tile_overlap?: number | null; + /** + * Hrf Steps + * @description The number of steps used for the high resolution fix refinement pass. + * @default null + */ + hrf_steps?: number | null; + /** + * Hrf Model + * @description The optional model override used for the high resolution fix refinement pass. + * @default null + */ + hrf_model?: components["schemas"]["ModelIdentifierField"] | null; + /** + * Hrf Lora Mode + * @description The LoRA mode used for the high resolution fix refinement pass. + * @default null + */ + hrf_lora_mode?: string | null; + /** + * Hrf Loras + * @description The dedicated LoRAs used for the high resolution fix refinement pass. + * @default null + */ + hrf_loras?: components["schemas"]["LoRAMetadataField"][] | null; /** * Positive Style Prompt * @description The positive style prompt parameter From b5120b4044d17cb01758de028db3258d0e7dd76b Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Tue, 5 May 2026 18:02:40 +0200 Subject: [PATCH 5/9] Polish Generate Upscale metadata recall --- invokeai/frontend/web/public/locales/en.json | 5 +- .../ImageMetadataActions.test.tsx | 21 ++++++ .../ImageMetadataActions.tsx | 6 ++ .../src/features/metadata/parsing.test.tsx | 44 +++++++++++++ .../web/src/features/metadata/parsing.tsx | 35 +++++++++- .../graph/generation/addHighResFix.test.ts | 15 ++++- .../util/graph/generation/addHighResFix.ts | 10 ++- .../HighResFixSettingsAccordion.tsx | 66 +++++++++---------- .../ParametersPanelGenerate.tsx | 2 +- 9 files changed, 155 insertions(+), 49 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6dbccdb247a..f55faef6c6a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -350,8 +350,9 @@ "options_withCount_other": "{{count}} options" }, "hrf": { - "hrf": "Hi-Res Fix", - "enableHrf": "Enable High Resolution Fix", + "hrf": "Upscale", + "enableHrf": "Enable", + "enableUpscale": "Enable Upscale", "scale": "Scale", "strength": "Denoise Strength", "upscaleMethod": "Upscale Method", diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx index c7433153825..362fc05d511 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx @@ -21,4 +21,25 @@ describe('ImageMetadataActions', () => { expect(handlers).toContain(ImageMetadataHandlers.QwenImageQuantization); expect(handlers).toContain(ImageMetadataHandlers.QwenImageShift); }); + + it('includes HRF refinement metadata handlers in the recall parameters UI', () => { + const element = (ImageMetadataActions as unknown as { type: (props: { metadata: unknown }) => unknown }).type({ + metadata: { model: { key: 'test' } }, + }) as { + props: { + children: Array<{ props?: { handler?: unknown } }>; + }; + }; + + const handlers = element.props.children + .map((child) => child.props?.handler) + .filter((handler): handler is unknown => handler !== undefined); + + expect(handlers).toContain(ImageMetadataHandlers.HrfTileControlWeight); + expect(handlers).toContain(ImageMetadataHandlers.HrfTileControlEnd); + expect(handlers).toContain(ImageMetadataHandlers.HrfSteps); + expect(handlers).toContain(ImageMetadataHandlers.HrfModel); + expect(handlers).toContain(ImageMetadataHandlers.HrfLoraMode); + expect(handlers).toContain(ImageMetadataHandlers.HrfLoRAs); + }); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 579af481627..d3bc3ca0da9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -57,8 +57,14 @@ export const ImageMetadataActions = memo((props: Props) => { + + + + + + diff --git a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx index dbe8c3af80a..1a775fb95a2 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx @@ -181,6 +181,50 @@ describe('ImageMetadataHandlers — Klein recall gating', () => { }); describe('HRF LoRAs', () => { + it('parses dedicated HRF LoRAs against the image metadata model base', async () => { + currentBase = 'sd-1'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.HrfLoRAs.parse( + { + model: mainModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + store + ); + + expect(parsed).toEqual([ + expect.objectContaining({ model: expect.objectContaining({ key: hrfLora.key }), weight: 0.6 }), + ]); + }); + + it('filters HRF LoRAs that do not match the metadata HRF model base', async () => { + currentBase = 'sdxl'; + const mainModel = fakeModel('main', 'sdxl'); + const hrfModel = fakeModel('main', 'sd-1'); + const hrfLora = fakeModel('lora', 'sdxl'); + resolvedModels = { + [hrfLora.key]: hrfLora, + }; + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.HrfLoRAs.parse( + { + model: mainModel, + hrf_model: hrfModel, + hrf_loras: [{ model: hrfLora, weight: 0.6 }], + }, + store + ); + + expect(parsed).toEqual([]); + }); + it('recalls dedicated HRF LoRAs after the recalled main model changes base', async () => { currentBase = 'sd-1'; const mainModel = fakeModel('main', 'sdxl'); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index 96d4636d46b..ca63d653c39 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -1537,13 +1537,14 @@ const HrfLoRAs: CollectionMetadataHandler = { assert(isArray(rawArray)); const loras: LoRA[] = []; + const effectiveHrfBase = getHrfLoRACompatibilityBase(metadata, store); for (const rawItem of rawArray) { try { const rawIdentifier = getProperty(rawItem, 'model'); const identifier = await parseModelIdentifier(rawIdentifier, store, 'lora'); assert(identifier.type === 'lora'); - assert(isCompatibleWithMainModel(identifier, store)); + assert(isCompatibleWithBase(identifier, effectiveHrfBase)); const weight = getProperty(rawItem, 'weight'); @@ -2226,13 +2227,43 @@ const parseModelIdentifier = async (raw: unknown, store: AppStore, type: ModelTy }; const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppStore) => { - const base = selectBase(store.getState()); + return isCompatibleWithBase(candidate, selectBase(store.getState())); +}; + +const isCompatibleWithBase = (candidate: ModelIdentifierField, base: string | null | undefined) => { if (!base) { return true; } return candidate.base === base; }; +const getMetadataModelBase = (metadata: unknown, path: string): string | null => { + const raw = getProperty(metadata, path); + if (!raw) { + return null; + } + + const identifierResult = zModelIdentifierField.safeParse(raw); + if (identifierResult.success) { + return identifierResult.data.base; + } + + const oldIdentifierResult = zModelIdentifier.safeParse(raw); + if (oldIdentifierResult.success) { + return oldIdentifierResult.data.base_model; + } + + return null; +}; + +const getHrfLoRACompatibilityBase = (metadata: unknown, store: AppStore): string | null | undefined => { + return ( + getMetadataModelBase(metadata, 'hrf_model') ?? + getMetadataModelBase(metadata, 'model') ?? + selectBase(store.getState()) + ); +}; + const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise => { try { await store.dispatch(imagesApi.endpoints.getImageDTO.initiate(name, { subscribe: false })).unwrap(); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts index 7cb5f48131f..511d25fe41c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -216,8 +216,8 @@ describe('addHighResFix', () => { destination: { node_id: 'l2i', field: 'latents' }, }); expect(g.getMetadataNode()).toMatchObject({ - width: 1024, - height: 1024, + width: 512, + height: 512, hrf_enabled: true, hrf_method: 'latent', hrf_strength: 0.35, @@ -451,7 +451,18 @@ describe('addHighResFix', () => { source: { node_id: tiledDenoise.id, field: 'latents' }, destination: { node_id: 'l2i', field: 'latents' }, }); + const metadataNode = g.getMetadataNode(); + expect(graph.edges).not.toContainEqual({ + source: { node_id: spandrel.id, field: 'width' }, + destination: { node_id: metadataNode.id, field: 'width' }, + }); + expect(graph.edges).not.toContainEqual({ + source: { node_id: spandrel.id, field: 'height' }, + destination: { node_id: metadataNode.id, field: 'height' }, + }); expect(g.getMetadataNode()).toMatchObject({ + width: 512, + height: 512, hrf_enabled: true, hrf_method: 'upscale_model', hrf_upscale_model: { key: 'upscale' }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts index f941757d0f7..1b9d8acb0db 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -667,8 +667,8 @@ const addLatentHighResFix = ({ } g.upsertMetadata({ - width: finalDimensions.width, - height: finalDimensions.height, + width: params.dimensions.width, + height: params.dimensions.height, hrf_enabled: true, hrf_method: 'latent', hrf_strength: params.hrfStrength, @@ -815,8 +815,8 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH }); g.upsertMetadata({ - width: finalDimensions.width, - height: finalDimensions.height, + width: params.dimensions.width, + height: params.dimensions.height, hrf_enabled: true, hrf_method: 'upscale_model', hrf_strength: params.hrfStrength, @@ -832,8 +832,6 @@ const addUpscaleModelHighResFix = ({ g, state, denoise, l2i, noise, seed }: AddH hrf_lora_mode: params.hrfLoraMode, hrf_loras: getHrfLoRAMetadata(state), }); - g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width'); - g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height'); return l2i; }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index b9bb71536f7..1aa01d4c170 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -240,7 +240,7 @@ const ParamHrfEnabled = memo(() => { {t('hrf.enableHrf')} - + ); }); @@ -785,6 +785,7 @@ const ParamHrfModel = memo(() => { modelConfigs={modelConfigs} selectedModelConfig={selectedModelConfig ?? undefined} onChange={onChange} + grouped allowEmpty placeholder={t('hrf.reuseGenerateModel')} noOptionsText={currentBaseModel ? t('hrf.noCompatibleModels') : t('models.selectModel')} @@ -1015,7 +1016,6 @@ HrfLoRAContent.displayName = 'HrfLoRAContent'; export const HighResFixSettingsAccordion = memo(() => { const { t } = useTranslation(); const badges = useAppSelector(selectBadges); - const enabled = useAppSelector(selectHrfEnabled); const method = useAppSelector(selectHrfMethod); const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); @@ -1037,50 +1037,44 @@ export const HighResFixSettingsAccordion = memo(() => { return ( - + - {enabled && ( + + + + {parsedMethod === 'upscale_model' && ( <> - - - - {parsedMethod === 'upscale_model' && ( - <> - - - - )} + + )} - {enabled && ( - - - - {parsedMethod === 'latent' && } - {parsedMethod === 'upscale_model' && ( - <> - - - - - - - - - )} - - {parsedMethod === 'upscale_model' && hrfLoraMode === 'dedicated' && ( + + + + {parsedMethod === 'latent' && } + {parsedMethod === 'upscale_model' && ( <> - - + + + + + + + )} - - - )} + + {parsedMethod === 'upscale_model' && hrfLoraMode === 'dedicated' && ( + <> + + + + )} + + ); }); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx index 1b1d4bfa492..8a14b55413d 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelGenerate.tsx @@ -44,8 +44,8 @@ export const ParametersPanelGenerate = memo(() => { - + {isSDXL && } {!isCogview4 && !isExternal && } {isExternal && } From 04e79a2a3d34552267591bc47bebffeaf2838933 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Wed, 6 May 2026 13:38:48 +0200 Subject: [PATCH 6/9] fix(ui): refine Generate Upscale gating --- .../controlLayers/store/paramsSlice.test.ts | 141 +++++++++++++++- .../controlLayers/store/paramsSlice.ts | 31 ++-- .../graph/generation/addHighResFix.test.ts | 18 +-- .../util/graph/generation/addHighResFix.ts | 4 +- .../graph/generation/buildAnimaGraph.test.ts | 1 + .../graph/generation/buildFLUXGraph.test.ts | 1 + .../generation/buildQwenImageGraph.test.ts | 1 + .../features/queue/store/readiness.test.ts | 18 ++- .../web/src/features/queue/store/readiness.ts | 12 +- .../HighResFixSettingsAccordion.tsx | 150 +++++++++++------- 10 files changed, 263 insertions(+), 114 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index 2e057fd0d25..c8172322014 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -138,14 +138,41 @@ describe('paramsSlice selectors for external models', () => { expect(selectModelSupportsDimensions.resultFunc(model, config)).toBe(true); }); - it('returns false for HRF support on external models', () => { - const config = createExternalConfig({ - modes: ['txt2img'], - supports_reference_images: false, + it('supports HRF only for SD1.5 and SDXL models', () => { + expect( + selectModelSupportsHrf.resultFunc({ + key: 'sd1', + hash: 'h', + name: 'SD1', + base: 'sd-1', + type: 'main', + }) + ).toBe(true); + expect( + selectModelSupportsHrf.resultFunc({ + key: 'sdxl', + hash: 'h', + name: 'SDXL', + base: 'sdxl', + type: 'main', + }) + ).toBe(true); + + const unsupportedBases = ['sd-2', 'sd-3', 'flux', 'flux2', 'anima', 'cogview4', 'qwen-image', 'z-image'] as const; + unsupportedBases.forEach((base) => { + expect( + selectModelSupportsHrf.resultFunc({ + key: base, + hash: 'h', + name: base, + base, + type: 'main', + }) + ).toBe(false); }); - const model = buildExternalModelIdentifier(config); - expect(selectModelSupportsHrf.resultFunc(model)).toBe(false); + const config = createExternalConfig({ modes: ['txt2img'], supports_reference_images: false }); + expect(selectModelSupportsHrf.resultFunc(buildExternalModelIdentifier(config))).toBe(false); }); }); @@ -155,6 +182,15 @@ describe('paramsSlice HRF selectors', () => { }); it('supports upscale-model HRF only for SD1.5 and SDXL models', () => { + expect( + selectModelSupportsHrfUpscaleModel.resultFunc({ + key: 'sd1', + hash: 'h', + name: 'SD1', + base: 'sd-1', + type: 'main', + }) + ).toBe(true); expect( selectModelSupportsHrfUpscaleModel.resultFunc({ key: 'sdxl', @@ -229,6 +265,99 @@ describe('paramsSlice HRF reducers', () => { expect(nextState.hrfLoras).toHaveLength(1); expect(nextState.hrfLoras[0]?.model.key).toBe('sd1-lora'); }); + + it('preserves HRF settings when switching through an unsupported base and back', () => { + const state = getInitialParamsState(); + const sdxlModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const externalModel = buildExternalModelIdentifier( + createExternalConfig({ modes: ['txt2img'], supports_reference_images: false }) + ); + state.model = sdxlModel; + state.hrfEnabled = true; + state.hrfMethod = 'upscale_model'; + state.hrfModel = { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }; + state.hrfTileControlNetModel = { + key: 'tile-sdxl', + hash: 'h', + name: 'Tile SDXL', + base: 'sdxl', + type: 'controlnet', + }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const unsupportedState = paramsSliceConfig.slice.reducer( + state, + modelChanged({ model: externalModel, previousModel: sdxlModel }) + ); + const returnedState = paramsSliceConfig.slice.reducer( + unsupportedState, + modelChanged({ model: sdxlModel, previousModel: externalModel }) + ); + + expect(unsupportedState.hrfEnabled).toBe(true); + expect(unsupportedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(unsupportedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(unsupportedState.hrfLoraMode).toBe('dedicated'); + expect(unsupportedState.hrfLoras).toHaveLength(1); + expect(returnedState.hrfEnabled).toBe(true); + expect(returnedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(returnedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(returnedState.hrfLoraMode).toBe('dedicated'); + expect(returnedState.hrfLoras[0]?.model.key).toBe('sdxl-lora'); + }); + + it('cleans preserved HRF settings when returning to a different supported base', () => { + const state = getInitialParamsState(); + const sdxlModel = { key: 'sdxl', hash: 'h', name: 'SDXL', base: 'sdxl', type: 'main' } as const; + const sd1Model = { key: 'sd1', hash: 'h', name: 'SD1', base: 'sd-1', type: 'main' } as const; + const fluxModel = { key: 'flux', hash: 'h', name: 'FLUX', base: 'flux', type: 'main' } as const; + state.model = sdxlModel; + state.hrfEnabled = true; + state.hrfMethod = 'upscale_model'; + state.hrfModel = { key: 'hrf-sdxl', hash: 'h', name: 'HRF SDXL', base: 'sdxl', type: 'main' }; + state.hrfTileControlNetModel = { + key: 'tile-sdxl', + hash: 'h', + name: 'Tile SDXL', + base: 'sdxl', + type: 'controlnet', + }; + state.hrfLoraMode = 'dedicated'; + state.hrfLoras = [ + { + id: 'sdxl-lora', + isEnabled: true, + model: { key: 'sdxl-lora', hash: 'h', name: 'SDXL LoRA', base: 'sdxl', type: 'lora' }, + weight: 0.6, + }, + ] as ParamsState['hrfLoras']; + + const unsupportedState = paramsSliceConfig.slice.reducer( + state, + modelChanged({ model: fluxModel, previousModel: sdxlModel }) + ); + const sd1State = paramsSliceConfig.slice.reducer( + unsupportedState, + modelChanged({ model: sd1Model, previousModel: fluxModel }) + ); + + expect(unsupportedState.hrfModel?.key).toBe('hrf-sdxl'); + expect(unsupportedState.hrfTileControlNetModel?.key).toBe('tile-sdxl'); + expect(unsupportedState.hrfLoras).toHaveLength(1); + expect(sd1State.hrfEnabled).toBe(true); + expect(sd1State.hrfModel).toBeNull(); + expect(sd1State.hrfTileControlNetModel).toBeNull(); + expect(sd1State.hrfLoraMode).toBe('dedicated'); + expect(sd1State.hrfLoras).toEqual([]); + }); }); describe('paramsSlice HRF migrations', () => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index ab812ca641c..a7823f4d829 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -247,24 +247,21 @@ const slice = createSlice({ const model = result.data; state.model = model; - if (model?.base === 'external') { - state.hrfEnabled = false; - state.hrfModel = null; - } - // If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things if (model === null || previousModel?.base === model.base) { return; } - if (state.hrfModel?.base !== model.base) { - state.hrfModel = null; - } - if (state.hrfTileControlNetModel?.base !== model.base) { - state.hrfTileControlNetModel = null; + if (isHrfSupportedBase(model.base)) { + if (state.hrfModel?.base !== model.base) { + state.hrfModel = null; + } + if (state.hrfTileControlNetModel?.base !== model.base) { + state.hrfTileControlNetModel = null; + } + const effectiveHrfBase = state.hrfModel?.base ?? model.base; + state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.base === effectiveHrfBase); } - const effectiveHrfBase = state.hrfModel?.base ?? model.base; - state.hrfLoras = state.hrfLoras.filter((lora) => lora.model.base === effectiveHrfBase); applyClipSkip(state, model, state.clipSkip); }, @@ -1047,20 +1044,20 @@ export const selectModelSupportsDimensions = createSelector(selectModel, selectM } return true; }); +export const isHrfSupportedBase = (base: BaseModelType | null | undefined): boolean => + base === 'sd-1' || base === 'sdxl'; + export const selectModelSupportsHrf = createSelector(selectModel, (model) => { if (!model) { return false; } - if (model.base === 'external') { - return false; - } - return true; + return isHrfSupportedBase(model.base); }); export const selectModelSupportsHrfUpscaleModel = createSelector(selectModel, (model) => { if (!model) { return false; } - return model.base === 'sd-1' || model.base === 'sdxl'; + return isHrfSupportedBase(model.base); }); export const selectSeedControl = createSelector(selectModelConfig, (modelConfig) => { if (modelConfig && isExternalApiModelConfig(modelConfig)) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts index 511d25fe41c..24acc9d9f1e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.test.ts @@ -311,32 +311,20 @@ describe('addHighResFix', () => { expect(g.getMetadataNode()).toMatchObject({ hrf_steps: 12 }); }); - it('reroutes transformer txt2img graphs through latent resize and a final-size second denoise pass', () => { + it('preserves unsupported model-family graphs and writes disabled metadata', () => { const { g, seed, denoise, l2i } = buildTransformerGraph(); addHighResFix({ g, state: buildState({ base: 'sd-3' }), generationMode: 'txt2img', denoise, l2i, seed }); const graph = g.getGraph(); const nodes = Object.values(graph.nodes); - const resize = nodes.find((node) => node.type === 'lresize'); - const hrfDenoise = nodes.find((node) => node.id.startsWith('hrf_sd3_denoise')); - - if (!hrfDenoise) { - throw new Error('Expected HRF SD3 denoise node'); - } - - expect(resize).toMatchObject({ type: 'lresize', width: 1024, height: 1024, mode: 'bilinear' }); - expect(hrfDenoise).toMatchObject({ type: 'sd3_denoise', width: 1024, height: 1024, denoising_end: 1 }); - expect((hrfDenoise as { denoising_start: number }).denoising_start).toBeCloseTo(1 - 0.35 ** 0.2); + expect(nodes.some((node) => node.type === 'lresize')).toBe(false); expect(nodes.some((node) => node.id.startsWith('hrf_noise'))).toBe(false); expect(graph.edges).toContainEqual({ - source: { node_id: 'seed', field: 'value' }, - destination: { node_id: hrfDenoise.id, field: 'seed' }, - }); - expect(graph.edges).not.toContainEqual({ source: { node_id: 'sd3_denoise', field: 'latents' }, destination: { node_id: 'sd3_l2i', field: 'latents' }, }); + expect(g.getMetadataNode()).toMatchObject({ hrf_enabled: false }); }); it('clones SDXL conditioning with final HRF dimensions for the second pass', () => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts index 1b9d8acb0db..f7dbb860a47 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addHighResFix.ts @@ -1,7 +1,7 @@ import type { RootState } from 'app/store/store'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { isHrfSupportedBase, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import type { GenerationMode, LoRA } from 'features/controlLayers/store/types'; import type { BaseModelType } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common'; @@ -91,7 +91,7 @@ const shouldApplyHighResFix = (state: RootState, generationMode: GenerationMode) generationMode === 'txt2img' && params.hrfEnabled && model !== null && - model.base !== 'external' && + isHrfSupportedBase(model.base) && !params.refinerModel ); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts index 0afe04ed13b..bc05a75802f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildAnimaGraph.test.ts @@ -34,6 +34,7 @@ const defaultParams: { let params = { ...defaultParams }; vi.mock('features/controlLayers/store/paramsSlice', () => ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => model), selectParamsSlice: vi.fn(() => params), selectAnimaVaeModel: vi.fn(() => animaVaeModel), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts index 5b9f3d0a468..f2f1c65d082 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -112,6 +112,7 @@ const mockParams = { }; vi.mock('features/controlLayers/store/paramsSlice', () => ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => currentModel), selectParamsSlice: vi.fn(() => mockParams), selectKleinVaeModel: vi.fn(() => currentKleinVae), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.test.ts index 3a5c2cde344..c49b518542a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildQwenImageGraph.test.ts @@ -58,6 +58,7 @@ const refImagesSlice = { }; vi.mock('features/controlLayers/store/paramsSlice', () => ({ + isHrfSupportedBase: vi.fn((base) => base === 'sd-1' || base === 'sdxl'), selectMainModelConfig: vi.fn(() => model), selectParamsSlice: vi.fn(() => params), })); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts index 575f3a9612a..41ba098ac17 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -293,31 +293,35 @@ describe('FLUX.2 Klein readiness checks – generate tab', () => { }); describe('High Resolution Fix readiness checks - generate tab', () => { - it('errors when HRF is enabled for external models', () => { + it('ignores stale HRF state for external models', () => { const reasons = getReasonsWhyCannotEnqueueGenerateTab( buildGenerateTabArg({ model: externalModel, hrfEnabled: true }) ); - expect(hasHrfExternalReason(reasons)).toBe(true); + expect(hasHrfExternalReason(reasons)).toBe(false); }); it('errors when HRF is enabled with SDXL Refiner', () => { const reasons = getReasonsWhyCannotEnqueueGenerateTab( - buildGenerateTabArg({ hrfEnabled: true, refinerModel: { key: 'refiner' } }) + buildGenerateTabArg({ model: sdxlModel, hrfEnabled: true, refinerModel: { key: 'refiner' } }) ); expect(hasHrfRefinerReason(reasons)).toBe(true); }); - it('errors when upscale-model HRF is enabled for unsupported model bases', () => { + it('ignores stale HRF state for unsupported model bases', () => { const reasons = getReasonsWhyCannotEnqueueGenerateTab( buildGenerateTabArg({ model: flux2DiffusersModel, hrfEnabled: true, hrfMethod: 'upscale_model', - hrfUpscaleModel: upscaleModel, - hrfTileControlNetModel: tileControlNetModel, + hrfUpscaleModel: null, + hrfTileControlNetModel: null, + hrfModel: { key: 'anima', hash: 'h', name: 'Anima', base: 'anima', type: 'main' }, }) ); - expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(true); + expect(hasHrfUpscaleModelBaseReason(reasons)).toBe(false); + expect(hasHrfUpscaleModelMissingReason(reasons)).toBe(false); + expect(hasHrfTileControlNetMissingReason(reasons)).toBe(false); + expect(hasHrfModelOverrideBaseMismatchReason(reasons)).toBe(false); }); it('errors when upscale-model HRF is missing required models', () => { diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index e5c1452bae2..fbe214bb98f 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -8,7 +8,7 @@ import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; import { debounce, groupBy, upperFirst } from 'es-toolkit/compat'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice'; -import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { isHrfSupportedBase, selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import type { CanvasState, LoRA, ParamsState, RefImagesState } from 'features/controlLayers/store/types'; @@ -358,23 +358,17 @@ export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { }); } - if (params.hrfEnabled) { - if (params.model?.base === 'external' || (model && isExternalApiModelConfig(model))) { - reasons.push({ content: i18n.t('parameters.invoke.hrfExternalModelUnsupported') }); - } + if (params.hrfEnabled && model && !isExternalApiModelConfig(model) && isHrfSupportedBase(model.base)) { if (params.refinerModel) { reasons.push({ content: i18n.t('parameters.invoke.hrfRefinerUnsupported') }); } if (params.hrfMethod === 'upscale_model') { const hrfBase = params.hrfModel?.base ?? (model && !isExternalApiModelConfig(model) ? model.base : params.model?.base); - if (model && !isExternalApiModelConfig(model) && !['sd-1', 'sdxl'].includes(model.base)) { - reasons.push({ content: i18n.t('parameters.invoke.hrfUpscaleModelBaseUnsupported') }); - } if (params.hrfModel) { if (params.hrfModel.base === 'external') { reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideExternalUnsupported') }); - } else if (!['sd-1', 'sdxl'].includes(params.hrfModel.base)) { + } else if (!isHrfSupportedBase(params.hrfModel.base)) { reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseUnsupported') }); } else if (model && !isExternalApiModelConfig(model) && params.hrfModel.base !== model.base) { reasons.push({ content: i18n.t('parameters.invoke.hrfModelOverrideBaseMismatch') }); diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index 1aa01d4c170..d7730c53018 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -177,6 +177,10 @@ const formControlProps = { w: 'full', }; +type DisabledProps = { + isDisabled?: boolean; +}; + const selectHrfTileControlNetModelConfig = createSelector( selectModelConfigsQuery, selectHrfTileControlNetModel, @@ -247,7 +251,7 @@ const ParamHrfEnabled = memo(() => { ParamHrfEnabled.displayName = 'ParamHrfEnabled'; -const ParamHrfMethod = memo(() => { +const ParamHrfMethod = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const method = useAppSelector(selectHrfMethod); const { t } = useTranslation(); @@ -266,7 +270,13 @@ const ParamHrfMethod = memo(() => { {t('hrf.upscaleMethod')} - @@ -284,7 +295,7 @@ const ParamHrfMethod = memo(() => { ParamHrfMethod.displayName = 'ParamHrfMethod'; -const ParamHrfScale = memo(() => { +const ParamHrfScale = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const scale = useAppSelector(selectHrfScale); const { t } = useTranslation(); @@ -310,6 +321,7 @@ const ParamHrfScale = memo(() => { fineStep={SCALE_CONSTRAINTS.fineStep} onChange={onChange} marks={[SCALE_CONSTRAINTS.sliderMin, SCALE_CONSTRAINTS.initial, SCALE_CONSTRAINTS.sliderMax]} + isDisabled={isDisabled} /> { step={SCALE_CONSTRAINTS.coarseStep} fineStep={SCALE_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -326,7 +339,7 @@ const ParamHrfScale = memo(() => { ParamHrfScale.displayName = 'ParamHrfScale'; -const ParamHrfStrength = memo(() => { +const ParamHrfStrength = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const strength = useAppSelector(selectHrfStrength); const { t } = useTranslation(); @@ -352,6 +365,7 @@ const ParamHrfStrength = memo(() => { fineStep={STRENGTH_CONSTRAINTS.fineStep} onChange={onChange} marks={[STRENGTH_CONSTRAINTS.sliderMin, STRENGTH_CONSTRAINTS.initial, STRENGTH_CONSTRAINTS.sliderMax]} + isDisabled={isDisabled} /> { step={STRENGTH_CONSTRAINTS.coarseStep} fineStep={STRENGTH_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -368,7 +383,7 @@ const ParamHrfStrength = memo(() => { ParamHrfStrength.displayName = 'ParamHrfStrength'; -const ParamHrfLatentInterpolationMode = memo(() => { +const ParamHrfLatentInterpolationMode = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const mode = useAppSelector(selectHrfLatentInterpolationMode); const { t } = useTranslation(); @@ -402,14 +417,14 @@ const ParamHrfLatentInterpolationMode = memo(() => { {t('hrf.latentInterpolationMode')} - + ); }); ParamHrfLatentInterpolationMode.displayName = 'ParamHrfLatentInterpolationMode'; -const ParamHrfUpscaleModel = memo(() => { +const ParamHrfUpscaleModel = memo(({ isDisabled = false }: DisabledProps) => { const { t } = useTranslation(); const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels(); const model = useAppSelector(selectHrfUpscaleModel); @@ -437,7 +452,7 @@ const ParamHrfUpscaleModel = memo(() => { }); return ( - + {t('upscaling.upscaleModel')} @@ -450,7 +465,7 @@ const ParamHrfUpscaleModel = memo(() => { options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} - isDisabled={options.length === 0} + isDisabled={isDisabled || options.length === 0} /> @@ -461,7 +476,7 @@ const ParamHrfUpscaleModel = memo(() => { ParamHrfUpscaleModel.displayName = 'ParamHrfUpscaleModel'; -const ParamHrfTileControlNetModel = memo(() => { +const ParamHrfTileControlNetModel = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const tileControlNetModel = useAppSelector(selectHrfTileControlNetModelConfig); @@ -495,15 +510,11 @@ const ParamHrfTileControlNetModel = memo(() => { }, [currentBaseModel] ); + const isMissingModel = !filteredModelConfigs.length; + const isInvalid = !isDisabled && isMissingModel; return ( - + {t('upscaling.tileControl')} @@ -515,8 +526,8 @@ const ParamHrfTileControlNetModel = memo(() => { getIsOptionDisabled={getIsOptionDisabled} placeholder={t('common.placeholderSelectAModel')} noOptionsText={t('upscaling.missingTileControlNetModel')} - isDisabled={isLoading || !filteredModelConfigs.length} - isInvalid={!filteredModelConfigs.length} + isDisabled={isDisabled || isLoading || isMissingModel} + isInvalid={isInvalid} /> ); @@ -524,7 +535,7 @@ const ParamHrfTileControlNetModel = memo(() => { ParamHrfTileControlNetModel.displayName = 'ParamHrfTileControlNetModel'; -const ParamHrfTileControlWeight = memo(() => { +const ParamHrfTileControlWeight = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const tileControlWeight = useAppSelector(selectHrfTileControlWeight); const { t } = useTranslation(); @@ -554,6 +565,7 @@ const ParamHrfTileControlWeight = memo(() => { TILE_CONTROL_WEIGHT_CONSTRAINTS.initial, TILE_CONTROL_WEIGHT_CONSTRAINTS.sliderMax, ]} + isDisabled={isDisabled} /> { step={TILE_CONTROL_WEIGHT_CONSTRAINTS.coarseStep} fineStep={TILE_CONTROL_WEIGHT_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -570,7 +583,7 @@ const ParamHrfTileControlWeight = memo(() => { ParamHrfTileControlWeight.displayName = 'ParamHrfTileControlWeight'; -const ParamHrfTileControlEnd = memo(() => { +const ParamHrfTileControlEnd = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const tileControlEnd = useAppSelector(selectHrfTileControlEnd); const { t } = useTranslation(); @@ -600,6 +613,7 @@ const ParamHrfTileControlEnd = memo(() => { TILE_CONTROL_END_CONSTRAINTS.initial, TILE_CONTROL_END_CONSTRAINTS.sliderMax, ]} + isDisabled={isDisabled} /> { step={TILE_CONTROL_END_CONSTRAINTS.coarseStep} fineStep={TILE_CONTROL_END_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -616,7 +631,7 @@ const ParamHrfTileControlEnd = memo(() => { ParamHrfTileControlEnd.displayName = 'ParamHrfTileControlEnd'; -const ParamHrfTileSize = memo(() => { +const ParamHrfTileSize = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const tileSize = useAppSelector(selectHrfTileSize); const { t } = useTranslation(); @@ -642,6 +657,7 @@ const ParamHrfTileSize = memo(() => { fineStep={TILE_SIZE_CONSTRAINTS.fineStep} onChange={onChange} marks={[TILE_SIZE_CONSTRAINTS.sliderMin, TILE_SIZE_CONSTRAINTS.initial, TILE_SIZE_CONSTRAINTS.sliderMax]} + isDisabled={isDisabled} /> { step={TILE_SIZE_CONSTRAINTS.coarseStep} fineStep={TILE_SIZE_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -658,7 +675,7 @@ const ParamHrfTileSize = memo(() => { ParamHrfTileSize.displayName = 'ParamHrfTileSize'; -const ParamHrfTileOverlap = memo(() => { +const ParamHrfTileOverlap = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const tileOverlap = useAppSelector(selectHrfTileOverlap); const { t } = useTranslation(); @@ -688,6 +705,7 @@ const ParamHrfTileOverlap = memo(() => { TILE_OVERLAP_CONSTRAINTS.initial, TILE_OVERLAP_CONSTRAINTS.sliderMax, ]} + isDisabled={isDisabled} /> { step={TILE_OVERLAP_CONSTRAINTS.coarseStep} fineStep={TILE_OVERLAP_CONSTRAINTS.fineStep} onChange={onChange} + isDisabled={isDisabled} /> ); @@ -704,7 +723,7 @@ const ParamHrfTileOverlap = memo(() => { ParamHrfTileOverlap.displayName = 'ParamHrfTileOverlap'; -const ParamHrfSteps = memo(() => { +const ParamHrfSteps = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const hrfSteps = useAppSelector(selectHrfSteps); const generateSteps = useAppSelector(selectSteps); @@ -732,7 +751,7 @@ const ParamHrfSteps = memo(() => { {t('hrf.steps')} - + { step={STEPS_CONSTRAINTS.coarseStep} fineStep={STEPS_CONSTRAINTS.fineStep} onChange={onChange} - isDisabled={!isCustom} + isDisabled={isDisabled || !isCustom} w={24} flexShrink={0} /> @@ -752,7 +771,7 @@ const ParamHrfSteps = memo(() => { ParamHrfSteps.displayName = 'ParamHrfSteps'; -const ParamHrfModel = memo(() => { +const ParamHrfModel = memo(({ isDisabled = false }: DisabledProps) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const selectedModelConfig = useAppSelector(selectHrfModelConfig); @@ -776,7 +795,7 @@ const ParamHrfModel = memo(() => { ); return ( - + {t('hrf.model')} @@ -789,7 +808,7 @@ const ParamHrfModel = memo(() => { allowEmpty placeholder={t('hrf.reuseGenerateModel')} noOptionsText={currentBaseModel ? t('hrf.noCompatibleModels') : t('models.selectModel')} - isDisabled={isLoading || !modelConfigs.length} + isDisabled={isDisabled || isLoading || !modelConfigs.length} /> ); @@ -797,7 +816,7 @@ const ParamHrfModel = memo(() => { ParamHrfModel.displayName = 'ParamHrfModel'; -const ParamHrfLoraMode = memo(() => { +const ParamHrfLoraMode = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const mode = useAppSelector(selectHrfLoraMode); const { t } = useTranslation(); @@ -825,10 +844,17 @@ const ParamHrfLoraMode = memo(() => { minW={0} colorScheme={mode === 'reuse_generate' ? 'invokeBlue' : undefined} onClick={onClickReuseGenerate} + isDisabled={isDisabled} > {t('hrf.reuseGenerateLoras')} - @@ -846,7 +873,7 @@ const ParamHrfLoraMode = memo(() => { ParamHrfLoraMode.displayName = 'ParamHrfLoraMode'; -const ParamHrfLoraSelect = memo(() => { +const ParamHrfLoraSelect = memo(({ isDisabled = false }: DisabledProps) => { const dispatch = useAppDispatch(); const [modelConfigs, { isLoading }] = useLoRAModels(); const { t } = useTranslation(); @@ -890,7 +917,7 @@ const ParamHrfLoraSelect = memo(() => { }, [compatibleLoRAs.length, hrfBase, isLoading, t]); return ( - + {t('hrf.dedicatedLoras')} @@ -904,6 +931,7 @@ const ParamHrfLoraSelect = memo(() => { placeholder={placeholder} getIsOptionDisabled={getIsDisabled} noOptionsText={hrfBase ? t('models.noCompatibleLoRAs') : t('models.selectModel')} + isDisabled={isDisabled} /> ); @@ -911,7 +939,7 @@ const ParamHrfLoraSelect = memo(() => { ParamHrfLoraSelect.displayName = 'ParamHrfLoraSelect'; -const HrfLoRAList = memo(() => { +const HrfLoRAList = memo(({ isDisabled = false }: DisabledProps) => { const ids = useAppSelector(selectHrfLoRAIds); if (!ids.length) { @@ -921,7 +949,7 @@ const HrfLoRAList = memo(() => { return ( {ids.map((id) => ( - + ))} ); @@ -929,22 +957,23 @@ const HrfLoRAList = memo(() => { HrfLoRAList.displayName = 'HrfLoRAList'; -const HrfLoRACard = memo((props: { id: string }) => { +const HrfLoRACard = memo((props: { id: string } & DisabledProps) => { const selectLoRA = useMemo(() => buildSelectHrfLoRA(props.id), [props.id]); const lora = useAppSelector(selectLoRA); if (!lora) { return null; } - return ; + return ; }); HrfLoRACard.displayName = 'HrfLoRACard'; -const HrfLoRAContent = memo(({ lora }: { lora: LoRA }) => { +const HrfLoRAContent = memo(({ lora, isDisabled = false }: { lora: LoRA } & DisabledProps) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const { data: loraConfig } = useGetModelConfigQuery(lora.model.key); + const isWeightDisabled = isDisabled || !lora.isEnabled; const onChange = useCallback( (v: number) => { @@ -965,17 +994,18 @@ const HrfLoRAContent = memo(({ lora }: { lora: LoRA }) => { - + {loraConfig?.name ?? lora.model.key.substring(0, 8)} - + } + isDisabled={isDisabled} /> @@ -991,7 +1021,7 @@ const HrfLoRAContent = memo(({ lora }: { lora: LoRA }) => { fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep} marks={[-1, 0, 1, 2]} defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial} - isDisabled={!lora.isEnabled} + isDisabled={isWeightDisabled} /> { w={20} flexShrink={0} defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial} - isDisabled={!lora.isEnabled} + isDisabled={isWeightDisabled} /> @@ -1016,6 +1046,7 @@ HrfLoRAContent.displayName = 'HrfLoRAContent'; export const HighResFixSettingsAccordion = memo(() => { const { t } = useTranslation(); const badges = useAppSelector(selectBadges); + const hrfEnabled = useAppSelector(selectHrfEnabled); const method = useAppSelector(selectHrfMethod); const modelSupportsHrf = useAppSelector(selectModelSupportsHrf); const isRefinerModelSelected = useAppSelector(selectIsRefinerModelSelected); @@ -1030,6 +1061,7 @@ export const HighResFixSettingsAccordion = memo(() => { }); const parsedMethod = zHrfMethod.parse(method); + const isDisabled = !hrfEnabled; if (!modelSupportsHrf || isRefinerModelSelected) { return null; @@ -1040,37 +1072,39 @@ export const HighResFixSettingsAccordion = memo(() => { - - - + + + + + {parsedMethod === 'upscale_model' && ( <> - - + + )} - - {parsedMethod === 'latent' && } + + {parsedMethod === 'latent' && } {parsedMethod === 'upscale_model' && ( <> - - - - - - - + + + + + + + )} {parsedMethod === 'upscale_model' && hrfLoraMode === 'dedicated' && ( <> - - + + )} From b72ff08619347ce69c696b12f2de77113a096a20 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Fri, 8 May 2026 09:12:11 +0200 Subject: [PATCH 7/9] docs(ui): clarify Generate Upscale tooltips --- invokeai/frontend/web/public/locales/en.json | 112 ++++++++++++++++++ .../InformationalPopover/constants.ts | 16 +++ .../HighResFixSettingsAccordion.tsx | 32 ++--- 3 files changed, 144 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index f55faef6c6a..5b61c585cdd 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2237,6 +2237,118 @@ "Generate high quality images at a larger resolution than optimal for the model. Generally used to prevent duplication in the generated image." ] }, + "hrfEnabled": { + "heading": "Enable Upscale", + "paragraphs": [ + "Adds an optional second pass that upscales the base image and refines details at the larger size.", + "When disabled, the Upscale settings are saved but are not applied to new generations." + ] + }, + "hrfMethod": { + "heading": "Upscale Method", + "paragraphs": [ + "Latent resizes the image in latent space before refinement. It is faster and does not require an upscale model.", + "Upscale Model enlarges the image with a selected upscale model before refinement, which can preserve more texture and detail." + ] + }, + "hrfScale": { + "heading": "Scale", + "paragraphs": [ + "Controls final output size as a multiplier of the base Width and Height.", + "Higher scale makes a larger image and uses more time and VRAM. Lower scale stays closer to the base size." + ] + }, + "hrfDenoisingStrength": { + "heading": "Denoise Strength", + "paragraphs": [ + "Controls how much the refinement pass can change the upscaled image.", + "Lower strength preserves more of the upscaled image. Higher strength allows more new detail and more prompt influence." + ] + }, + "hrfLatentInterpolation": { + "heading": "Latent Interpolation", + "paragraphs": [ + "Controls how Latent Upscale resizes the latent image before refinement.", + "Smoother modes such as Bicubic and Bilinear can soften transitions. Nearest modes preserve harder edges but can look blockier." + ] + }, + "hrfUpscaleModel": { + "heading": "Upscale Model", + "paragraphs": [ + "Selects the image-to-image upscale model used to enlarge the base image before refinement.", + "Different upscale models may work better for different image types, such as photos, illustration, or line art." + ] + }, + "hrfTileControl": { + "heading": "Tile Control", + "paragraphs": [ + "Selects a compatible Tile or Union ControlNet model used during Upscale Model refinement.", + "Only same-base Tile-style ControlNet models are shown. Tile guidance helps preserve structure and reduce drift after the image is enlarged." + ] + }, + "hrfTileControlWeight": { + "heading": "Tile Control Weight", + "paragraphs": [ + "Controls how strongly Tile Control guides the refinement pass.", + "Higher weight preserves structure more strongly. Lower weight gives the model and prompt more freedom to change details." + ] + }, + "hrfTileControlEnd": { + "heading": "Tile Control Duration", + "paragraphs": [ + "Controls how long Tile Control remains active during the refinement pass.", + "Higher values keep tile guidance active for more of the pass. Lower values use tile guidance early, then let refinement finish more freely." + ] + }, + "hrfTileSize": { + "heading": "Tile Size", + "paragraphs": [ + "Controls the tile size used during Upscale Model refinement.", + "Larger tiles can improve consistency but use more VRAM. Smaller tiles use less VRAM and may be safer on memory-limited systems." + ] + }, + "hrfTileOverlap": { + "heading": "Tile Overlap", + "paragraphs": [ + "Controls how much adjacent tiles overlap during Upscale Model refinement.", + "Higher overlap can reduce visible seams but uses more VRAM and time. Lower overlap is faster but may show more tile boundaries." + ] + }, + "hrfModel": { + "heading": "Refinement Model", + "paragraphs": [ + "Selects an optional model used only for the Upscale refinement pass.", + "Reuse Base Model keeps the Generate model for refinement. A dedicated refinement model must use the same base family." + ] + }, + "hrfLoraMode": { + "heading": "Refinement LoRAs", + "paragraphs": [ + "Controls which LoRAs are applied during the Upscale refinement pass.", + "Reuse applies the Generate LoRAs again. None disables LoRAs for refinement. Dedicated lets you select LoRAs used only for refinement." + ] + }, + "hrfLoraSelect": { + "heading": "Dedicated Refinement LoRAs", + "paragraphs": [ + "Selects LoRAs that apply only during the Upscale refinement pass.", + "Dedicated refinement LoRAs can add detail or style after upscaling without affecting the base image generation." + ] + }, + "hrfLoraWeight": { + "heading": "Refinement LoRA Weight", + "paragraphs": [ + "Controls how strongly this dedicated LoRA affects the Upscale refinement pass.", + "Higher weight increases the LoRA's impact on refined details. Lower weight makes the effect more subtle." + ] + }, + "hrfSteps": { + "heading": "Refinement Steps", + "paragraphs": [ + "Controls the number of denoising steps used for the Upscale refinement pass.", + "When custom refinement steps are off, Upscale reuses the Generate step count. Higher values can add detail but take longer." + ] + }, "paramIterations": { "heading": "Iterations", "paragraphs": [ diff --git a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts index 95fa75cfa32..93b581aaa00 100644 --- a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts +++ b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts @@ -6,7 +6,23 @@ export type Feature = | 'fluxDypePreset' | 'fluxDypeScale' | 'fluxDypeExponent' + | 'hrfDenoisingStrength' + | 'hrfEnabled' | 'hrf' + | 'hrfLatentInterpolation' + | 'hrfLoraMode' + | 'hrfLoraSelect' + | 'hrfLoraWeight' + | 'hrfMethod' + | 'hrfModel' + | 'hrfScale' + | 'hrfSteps' + | 'hrfTileControl' + | 'hrfTileControlEnd' + | 'hrfTileControlWeight' + | 'hrfTileOverlap' + | 'hrfTileSize' + | 'hrfUpscaleModel' | 'paramNegativeConditioning' | 'paramPositiveConditioning' | 'paramScheduler' diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx index d7730c53018..1a2a9fb3973 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/HighResFixSettingsAccordion/HighResFixSettingsAccordion.tsx @@ -241,7 +241,7 @@ const ParamHrfEnabled = memo(() => { return ( - + {t('hrf.enableHrf')} @@ -266,7 +266,7 @@ const ParamHrfMethod = memo(({ isDisabled = false }: DisabledProps) => { return ( - + {t('hrf.upscaleMethod')} @@ -309,7 +309,7 @@ const ParamHrfScale = memo(({ isDisabled = false }: DisabledProps) => { return ( - + {t('hrf.scale')} { return ( - + {t('hrf.strength')} - + {t('hrf.latentInterpolationMode')} @@ -453,7 +453,7 @@ const ParamHrfUpscaleModel = memo(({ isDisabled = false }: DisabledProps) => { return ( - + {t('upscaling.upscaleModel')} @@ -515,7 +515,7 @@ const ParamHrfTileControlNetModel = memo(({ isDisabled = false }: DisabledProps) return ( - + {t('upscaling.tileControl')} - + {t('hrf.tileControlWeight')} { return ( - + {t('hrf.tileControlEnd')} { return ( - + {t('upscaling.tileSize')} { return ( - + {t('upscaling.tileOverlap')} { return ( - + {t('hrf.steps')} @@ -796,7 +796,7 @@ const ParamHrfModel = memo(({ isDisabled = false }: DisabledProps) => { return ( - + {t('hrf.model')} { return ( - + {t('hrf.loraMode')} @@ -918,7 +918,7 @@ const ParamHrfLoraSelect = memo(({ isDisabled = false }: DisabledProps) => { return ( - + {t('hrf.dedicatedLoras')} - + Date: Fri, 8 May 2026 09:18:39 +0200 Subject: [PATCH 8/9] docs(ui): refine Upscale LoRA tooltip wording --- invokeai/frontend/web/public/locales/en.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 5b61c585cdd..baa51c36f17 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2325,7 +2325,7 @@ "heading": "Refinement LoRAs", "paragraphs": [ "Controls which LoRAs are applied during the Upscale refinement pass.", - "Reuse applies the Generate LoRAs again. None disables LoRAs for refinement. Dedicated lets you select LoRAs used only for refinement." + "Reuse applies the base generation LoRAs again. None disables LoRAs for refinement. Dedicated lets you select LoRAs used only for refinement." ] }, "hrfLoraSelect": { From 189696133b7d6a3426e2f91545f6c96dafe83024 Mon Sep 17 00:00:00 2001 From: Astra orion <13394741+AsuraAce@users.noreply.github.com> Date: Fri, 8 May 2026 10:08:03 +0200 Subject: [PATCH 9/9] docs(ui): refine Upscale refinement copy --- invokeai/frontend/web/public/locales/en.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 911bc8f11fd..da515108716 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -365,8 +365,8 @@ "loraMode": "Refinement LoRAs", "reuseGenerateLoras": "Reuse", "noLoras": "None", - "dedicatedLoras": "Dedicated", - "addDedicatedLora": "Add Dedicated LoRA", + "dedicatedLoras": "LoRAs", + "addDedicatedLora": "Add LoRA", "noCompatibleModels": "No Compatible Models", "latent": "Latent", "upscaleModelMethod": "Upscale Model", @@ -2325,7 +2325,7 @@ "heading": "Refinement Model", "paragraphs": [ "Selects an optional model used only for the Upscale refinement pass.", - "Reuse Base Model keeps the Generate model for refinement. A dedicated refinement model must use the same base family." + "Reuse Base Model keeps the base generation model for refinement. Choosing another model lets you use a compatible refinement model for the Upscale pass." ] }, "hrfLoraMode": {