Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion ai_diffusion/backend/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,13 @@ def normalize(x):
@overload
def add_cached(self, class_type: str, output_count: Literal[1], **inputs) -> Output: ...

@overload
def add_cached(self, class_type: str, output_count: Literal[2], **inputs) -> Output2: ...

@overload
def add_cached(self, class_type: str, output_count: Literal[3], **inputs) -> Output3: ...

def add_cached(self, class_type: str, output_count: Literal[1, 3], **inputs):
def add_cached(self, class_type: str, output_count: Literal[1, 2, 3], **inputs):
key = class_type + str(inputs)
result = self._cache.get(key, None)
if result is None:
Expand Down Expand Up @@ -785,6 +788,30 @@ def apply_controlnet_inpainting(
)
)

def apply_anima_lllite(
Comment thread
Sen-sou marked this conversation as resolved.
self,
model: Output,
lllite_name: str,
image: Output,
strength=1.0,
range: tuple[float, float] = (0.0, 1.0),
mask: Output | None = None,
):
model, control_net = self.add_cached(
"ETN_control_load", 2, model=model, weights=lllite_name
)
inputs = {
"model": model,
"control_net": control_net,
"image": image,
"strength": strength,
"start_percent": range[0],
"end_percent": range[1],
}
if mask is not None:
inputs["mask"] = mask
return self.add("ETN_control_apply", 1, **inputs)

def set_controlnet_type(self, controlnet: Output, mode: ControlMode):
match mode:
case ControlMode.pose:
Expand Down
17 changes: 16 additions & 1 deletion ai_diffusion/backend/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def resolve(self, checkpoint: str):

@property
def has_controlnet_inpaint(self):
return self in (Arch.sd15, Arch.flux, Arch.zimage, Arch.qwen)
return self in (Arch.sd15, Arch.flux, Arch.zimage, Arch.qwen, Arch.anima)

@property
def supports_regions(self):
Expand Down Expand Up @@ -373,6 +373,14 @@ def text(self):

def can_substitute_universal(self, arch: Arch):
"""True if this control mode is covered by univeral control-net."""
if arch is Arch.anima:
return self in [
ControlMode.scribble,
Comment thread
Sen-sou marked this conversation as resolved.
ControlMode.line_art,
ControlMode.depth,
ControlMode.pose,
ControlMode.blur,
]
if arch.is_sdxl_like or arch is Arch.qwen:
return self in [
ControlMode.scribble,
Expand Down Expand Up @@ -727,17 +735,21 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.inpaint): ["flux.1-dev-controlnet-inpaint"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.inpaint): ["noobaiinpainting"],
resource_id(ResourceKind.controlnet, Arch.qwen, ControlMode.inpaint): ["qwen-image-instantx-controlnet-inpainting"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.inpaint): ["anima*lllite*inpaint", "lllite*anima*inpaint", "inpaint*anima*lllite", "inpaint*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.illu_v, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.universal): ["anima*lllite*any", "lllite*anima*any", "any*anima*lllite", "any*lllite*anima", "any*test*like", "anima*lllite*union", "lllite*anima*union", "union*anima*lllite", "union*lllite*anima", "anima*lllite*universal", "lllite*anima*universal", "universal*anima*lllite", "universal*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.universal): ["flux.1-dev-controlnet-union-pro-2.0", "flux.1-dev-controlnet-union-pro", "flux.1-dev-controlnet-union", "flux1devcontrolnetunion"],
resource_id(ResourceKind.controlnet, Arch.qwen, ControlMode.universal): ["qwen-image-instantx-controlnet-union"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.scribble): ["control_v11p_sd15_scribble", "control_lora_rank128_v11p_sd15_scribble"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.scribble): ["xinsirscribble", "scribble-sdxl", "mistoline_fp16", "mistoline_rank", "control-lora-sketch-rank", "sai_xl_sketch_"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.scribble): ["anima*lllite*scribble", "lllite*anima*scribble", "scribble*anima*lllite", "scribble*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.scribble): ["noob-sdxl-controlnet-scribble_pidinet", "noobaixlcontrolnet_epsscribble", "noob-sdxl-controlnet-scribble"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.line_art): ["control_v11p_sd15_lineart", "control_lora_rank128_v11p_sd15_lineart"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.line_art): ["xinsirscribble", "mistoline_fp16", "mistoline_rank", "scribble-sdxl", "control-lora-sketch-rank", "sai_xl_sketch_"],
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.line_art): ["mistoline_flux"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.line_art): ["anima*lllite*lineart", "anima*lllite*line_art", "lllite*anima*lineart", "lllite*anima*line_art", "lineart*anima*lllite", "line_art*anima*lllite", "lineart*lllite*anima", "line_art*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.line_art): ["noob-sdxl-controlnet-lineart_anime", "noobaixlcontrolnet_epslineart", "noob-sdxl-controlnet-lineart"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.soft_edge): ["control_v11p_sd15_softedge", "control_lora_rank128_v11p_sd15_softedge"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.soft_edge): ["mistoline_fp16", "mistoline_rank", "xinsirscribble", "scribble-sdxl"],
Expand All @@ -750,17 +762,20 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.depth): ["control_sd15_depth_anything", "control_v11f1p_sd15_depth", "control_lora_rank128_v11f1p_sd15_depth"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.depth): ["xinsirdepth", "depth-sdxl", "control-lora-depth-rank", "sai_xl_depth_"],
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.depth): ["flux-depth"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.depth): ["anima*lllite*depth", "lllite*anima*depth", "depth*anima*lllite", "depth*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.depth): ["noob-sdxl-controlnet-depth", "noobaixlcontrolnet_epsdepth"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.normal): ["control_v11p_sd15_normalbae", "control_lora_rank128_v11p_sd15_normalbae"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.normal): ["noob-sdxl-controlnet-normal", "noobaixlcontrolnet_epsnormal"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.pose): ["control_v11p_sd15_openpose", "control_lora_rank128_v11p_sd15_openpose"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.pose): ["xinsiropenpose", "openpose-sdxl", "control-lora-openposexl2-rank", "thibaud_xl_openpose"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.pose): ["anima*lllite*pose", "anima*lllite*openpose", "lllite*anima*pose", "lllite*anima*openpose", "pose*anima*lllite", "openpose*anima*lllite", "pose*lllite*anima", "openpose*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.pose): ["noob-sdxl-controlnet-openpose", "noobaixlcontrolnet_openpose"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.segmentation): ["control_v11p_sd15_seg", "control_lora_rank128_v11p_sd15_seg"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.segmentation): ["sdxl_segmentation_ade20k_controlnet"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.blur): ["control_v11f1e_sd15_tile", "control_lora_rank128_v11f1e_sd15_tile"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.blur): ["xinsirtile", "tile-sdxl", "ttplanetsdxlcontrolnet", "ttplanet_sdxl_controlnet_tile_realistic", "ttplanet_controlnet_tile_realistic"],
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.blur): ["flux.1-dev-controlnet-upscale"],
resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.blur): ["anima*lllite*tile", "anima*lllite*blur", "lllite*anima*tile", "lllite*anima*blur", "tile*anima*lllite", "blur*anima*lllite", "tile*lllite*anima", "blur*lllite*anima"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.blur): ["noob-sdxl-controlnet-tile", "noobaixlcontrolnet_epstile"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.stencil): ["control_v1p_sd15_qrcode_monster"],
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.stencil): ["sdxl_qrcode_monster"],
Expand Down
11 changes: 10 additions & 1 deletion ai_diffusion/backend/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,15 @@ def apply_control(
if control.mode.is_lines: # ControlNet expects white lines on black background
image = w.invert_image(image)

if models.arch is Arch.anima:
if cn_model := models.find(control.mode, allow_universal=True):
mask = control.mask.load(w) if control.mask is not None else None
model = w.apply_anima_lllite(
model, cn_model, image, control.strength, control.range, mask
)
continue
raise RuntimeError(f"Anima ControlNet-LLLite model not found for mode {control.mode}")

if cn_model := models.find(control.mode):
controlnet = w.load_controlnet(cn_model)
elif cn_model := models.find(ControlMode.universal):
Expand Down Expand Up @@ -972,7 +981,7 @@ def detect_inpaint(
)
elif arch.is_sdxl_like:
result.use_inpaint_model = strength > 0.8
elif arch in (Arch.flux, Arch.zimage):
elif arch in (Arch.flux, Arch.zimage, Arch.anima):
result.use_inpaint_model = strength == 1.0
elif arch.is_edit:
result.mode = InpaintMode.custom
Expand Down
37 changes: 36 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ai_diffusion.backend.cloud_client import CloudClient
from ai_diffusion.backend.comfy_client import ComfyClient
from ai_diffusion.backend.comfy_workflow import ComfyWorkflow
from ai_diffusion.backend.resources import ControlMode
from ai_diffusion.backend.resources import ControlMode, ResourceKind, resource_id
from ai_diffusion.backend.workflow import detect_inpaint
from ai_diffusion.files import File, FileCollection, FileLibrary, FileSource
from ai_diffusion.image import Bounds, Extent, Image, ImageCollection, Mask
Expand Down Expand Up @@ -223,6 +223,41 @@ def test_inpaint_params():
f = detect_inpaint(InpaintMode.fill, bounds, Arch.sd15, prompt, 1.0)
assert f.fill is FillMode.none

g = detect_inpaint(InpaintMode.fill, bounds, Arch.anima, no_cond, 1.0)
assert g.fill is FillMode.blur and g.use_inpaint_model


def test_anima_lllite_control_workflow():
w = ComfyWorkflow()
models = ClientModels()
models.resources[resource_id(ResourceKind.controlnet, Arch.anima, ControlMode.universal)] = (
"anima-lllite-anytest.safetensors"
)
cond = workflow.ConditioningOutput(workflow.Output(2, 0), workflow.Output(3, 0))
control = workflow.Control(
ControlMode.blur,
workflow.ImageOutput(Image.create(Extent(16, 16))),
strength=0.5,
range=(0.1, 0.8),
)

model, result = workflow.apply_control(
w,
workflow.Output(1, 0),
cond,
[control],
Extent(64, 64),
workflow.Output(4, 0),
models.for_arch(Arch.anima),
)

assert result == cond
assert w.root[str(model.node)]["class_type"] == "ETN_control_apply"
assert w.root[str(model.node)]["inputs"]["control_net"] == [str(model.node - 1), 1]
assert w.root[str(model.node - 1)]["inputs"]["weights"] == "anima-lllite-anytest.safetensors"
assert any(n["class_type"] == "ETN_control_load" for n in w.root.values())
assert not any(n["class_type"] == "ControlNetLoader" for n in w.root.values())


def test_prepare_lora():
models = ClientModels()
Expand Down
Loading