Skip to content

Commit 657e79c

Browse files
mtsmfmcodex
authored andcommitted
Preserve style architecture in custom workflows
Co-authored-by: Codex <codex@openai.com>
1 parent 0528508 commit 657e79c

2 files changed

Lines changed: 50 additions & 3 deletions

File tree

ai_diffusion/model/custom_workflow.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222

2323
from .. import eventloop
2424
from ..backend.api import CustomStyleInput, InpaintContext, WorkflowInput
25-
from ..backend.client import ClientModels, ClientOutput, JobInfoOutput, OutputBatchMode, TextOutput
25+
from ..backend.client import (
26+
ClientModels,
27+
ClientOutput,
28+
JobInfoOutput,
29+
OutputBatchMode,
30+
TextOutput,
31+
resolve_arch,
32+
)
2633
from ..backend.comfy_workflow import ComfyNode, ComfyWorkflow
2734
from ..backend.workflow import sampling_from_style
2835
from ..image import Bounds, Image, Mask
@@ -591,8 +598,10 @@ def collect_parameters(
591598
use_live_sampling = True
592599
else: # auto
593600
use_live_sampling = is_live
601+
style_models = style.get_models(models.checkpoints)
602+
style_models.version = resolve_arch(style, models)
594603
params[md.name] = CustomStyleInput(
595-
style.get_models(models.checkpoints),
604+
style_models,
596605
sampling_from_style(style, 1.0, use_live_sampling),
597606
style.style_prompt,
598607
style.negative_prompt,

tests/test_custom_workflow.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ai_diffusion.backend.comfy_workflow import ComfyNode, ComfyObjectInfo, ComfyWorkflow, Output
2626
from ai_diffusion.backend.resources import Arch
2727
from ai_diffusion.image import Bounds, Extent, Image, ImageCollection, Mask
28+
from ai_diffusion.layer import LayerManager
2829
from ai_diffusion.model.connection import Connection, ConnectionState
2930
from ai_diffusion.model.custom_workflow import (
3031
CustomParam,
@@ -36,7 +37,7 @@
3637
workflow_parameters,
3738
)
3839
from ai_diffusion.model.jobs import Job, JobKind, JobParams, JobQueue
39-
from ai_diffusion.style import Style
40+
from ai_diffusion.style import Style, Styles
4041
from ai_diffusion.util import PluginError
4142

4243
from .config import test_dir
@@ -368,6 +369,43 @@ def test_parameters():
368369
]
369370

370371

372+
def test_collect_parameters_preserves_style_architecture():
373+
graph = {
374+
"1": {
375+
"class_type": "ETN_KritaStyle",
376+
"inputs": {"name": "style", "sampler_preset": "auto"},
377+
}
378+
}
379+
connection = create_mock_connection({"connection1": graph})
380+
workflows = WorkflowCollection(connection)
381+
workspace = CustomWorkspace(workflows, dummy_generate, JobQueue())
382+
383+
styles = Styles.list()
384+
style = styles.create("anima-test.json")
385+
try:
386+
style.architecture = Arch.anima
387+
style.checkpoints = ["checkpoint.safetensors"]
388+
workspace.params["style"] = style.filename
389+
390+
models = ClientModels()
391+
models.checkpoints = {
392+
"checkpoint.safetensors": CheckpointInfo("checkpoint.safetensors", Arch.anima)
393+
}
394+
395+
params = workspace.collect_parameters(
396+
layers=LayerManager(None),
397+
bounds=Bounds(0, 0, 1, 1),
398+
models=models,
399+
is_live=False,
400+
is_animation=False,
401+
)
402+
403+
assert isinstance(params["style"], CustomStyleInput)
404+
assert params["style"].models.version is Arch.anima
405+
finally:
406+
styles.delete(style)
407+
408+
371409
def test_parameter_order():
372410
params = [
373411
CustomParam(ParamKind.number_int, "Ant", 4, 0, 10),

0 commit comments

Comments
 (0)