Skip to content

Commit fe340ef

Browse files
fukc-gihtubfukc
andauthored
Add initial Krea 2 model support (#2541)
* Add initial Krea 2 model support * ruff format --------- Co-authored-by: fukc <fukc@github.com>
1 parent 657e79c commit fe340ef

7 files changed

Lines changed: 48 additions & 1 deletion

File tree

ai_diffusion/backend/comfy_client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,17 @@ def find_model(model_list: Sequence[str], id: ResourceId):
748748

749749
def _find_text_encoder_models(model_list: Sequence[str]):
750750
kind = ResourceKind.text_encoder
751-
tes = ["clip_l", "clip_g", "t5", "qwen", "qwen_3_06b", "qwen_3_4b", "qwen_3_8b", "ministral"]
751+
tes = [
752+
"clip_l",
753+
"clip_g",
754+
"t5",
755+
"qwen",
756+
"qwen_3_06b",
757+
"qwen_3_4b",
758+
"qwen_3_8b",
759+
"qwen3vl_4b",
760+
"ministral",
761+
]
752762
return {
753763
resource_id(kind, Arch.all, te): _find_model(model_list, kind, Arch.all, te) for te in tes
754764
}

ai_diffusion/backend/resources.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class Arch(Enum):
9898
anima = "Anima"
9999
zimage = "Z-Image"
100100
ernie = "ERNIE Image"
101+
krea2 = "Krea 2"
101102

102103
auto = "Automatic"
103104
all = "All"
@@ -142,6 +143,8 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
142143
return Arch.zimage
143144
if string in {"ernie-image", "ernie_image"}:
144145
return Arch.ernie
146+
if string == "krea2":
147+
return Arch.krea2
145148
return None
146149

147150
@staticmethod
@@ -249,6 +252,8 @@ def text_encoders(self):
249252
return ["qwen_3_4b"]
250253
case Arch.ernie:
251254
return ["ministral"]
255+
case Arch.krea2:
256+
return ["qwen3vl_4b"]
252257
raise ValueError(f"Unsupported architecture: {self}")
253258

254259
@staticmethod
@@ -271,6 +276,7 @@ def list():
271276
Arch.anima,
272277
Arch.zimage,
273278
Arch.ernie,
279+
Arch.krea2,
274280
]
275281

276282

@@ -818,6 +824,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
818824
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_8b"): ["qwen_3_8b", "qwen3-8b", "qwen3_8b"],
819825
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_06b"): ["qwen_3_06b", "qwen3-06b", "qwen3_06b"],
820826
resource_id(ResourceKind.text_encoder, Arch.all, "ministral"): ["ministral-3-3b", "ministral"],
827+
resource_id(ResourceKind.text_encoder, Arch.all, "qwen3vl_4b"): ["qwen3vl_4b"],
821828
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
822829
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
823830
resource_id(ResourceKind.vae, Arch.illu, "default"): ["sdxl_vae"],
@@ -835,6 +842,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
835842
resource_id(ResourceKind.vae, Arch.anima, "default"): ["qwen_image"],
836843
resource_id(ResourceKind.vae, Arch.zimage, "default"): ["z-image", "flux-", "flux_", "flux/", "flux1", "ae.s"],
837844
resource_id(ResourceKind.vae, Arch.ernie, "default"): ["flux2"],
845+
resource_id(ResourceKind.vae, Arch.krea2, "default"): ["qwen_image"],
838846
}
839847
# fmt: on
840848

@@ -870,6 +878,8 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
870878
ResourceId(ResourceKind.vae, Arch.flux2_9b, "default"),
871879
ResourceId(ResourceKind.text_encoder, Arch.ernie, "ministral"),
872880
ResourceId(ResourceKind.vae, Arch.ernie, "default"),
881+
ResourceId(ResourceKind.text_encoder, Arch.krea2, "qwen3vl_4b"),
882+
ResourceId(ResourceKind.vae, Arch.krea2, "default"),
873883
}
874884

875885
recommended_resource_ids = [

ai_diffusion/backend/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
175175
clip = w.load_clip(te["qwen_3_4b"], type="lumina2")
176176
case Arch.ernie:
177177
clip = w.load_clip(te["ministral"], type="flux2")
178+
case Arch.krea2:
179+
clip = w.load_clip(te["qwen3vl_4b"], type="krea2")
178180
case _:
179181
raise RuntimeError(f"No text encoder for model architecture {arch.name}")
180182

Lines changed: 11 additions & 0 deletions
Loading
Lines changed: 11 additions & 0 deletions
Loading

ai_diffusion/ui/theme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def checkpoint_icon(arch: Arch, format: FileFormat | None = None, client: Client
8686
return icon("sd-version-anima")
8787
elif arch is Arch.ernie:
8888
return icon("sd-version-ernie")
89+
elif arch is Arch.krea2:
90+
return icon("sd-version-krea2")
8991
else:
9092
log.warning(f"Unresolved SD version {arch}, cannot fetch icon")
9193
return icon("warning")

tests/test_resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_resource_ids_exist():
5555
Arch.flux2_9b,
5656
Arch.anima,
5757
Arch.ernie,
58+
Arch.krea2,
5859
):
5960
continue # no model downloads yet
6061
model = res.find_resource(resource_id)

0 commit comments

Comments
 (0)