Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def find(self, id: ResourceId):
if result := self.resources.get(id.string):
return result
# Fallback to epsilon model if v-prediction model not found
if id.arch is Arch.illu_v:
if id.arch in (Arch.illu_v, Arch.illu_rf):
if result := self.resources.get(id._replace(arch=Arch.illu).string):
return result
# Search for architecture-agnostic model
Expand Down
6 changes: 4 additions & 2 deletions ai_diffusion/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _update_is_supported(self):
if client := root.connection.client_if_connected:
models = client.models.for_arch(self._model.arch)

if self.mode.is_ip_adapter and models.arch in [Arch.illu, Arch.illu_v]:
if self.mode.is_ip_adapter and models.arch in [Arch.illu, Arch.illu_v, Arch.illu_rf]:
resid = resource_id(ResourceKind.clip_vision, Arch.illu, "ip_adapter")
has_clip_vision = client.models.resources.get(resid, None) is not None
if not has_clip_vision:
Expand Down Expand Up @@ -177,7 +177,9 @@ def _update_is_supported(self):
model = models.find_control(self.mode)
self.has_range = model == models.control.find(self.mode, True)
if model is None:
search_arch = Arch.illu if models.arch is Arch.illu_v else models.arch
search_arch = (
Arch.illu if models.arch in (Arch.illu_v, Arch.illu_rf) else models.arch
)
search_path = (
resources.search_path(ResourceKind.controlnet, search_arch, self.mode)
or resources.search_path(ResourceKind.model_patch, search_arch, self.mode)
Expand Down
18 changes: 13 additions & 5 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class Arch(Enum):
flux2_9b = "Flux 2 Klein 9B"
illu = "Illustrious"
illu_v = "Illustrious v-prediction"
illu_rf = "Illustrious Rectified Flow"
chroma = "Chroma"
qwen = "Qwen"
qwen_e = "Qwen Edit"
Expand All @@ -103,6 +104,8 @@ class Arch(Enum):
@staticmethod
def from_string(string: str, model_type: str = "eps", filename: str | None = None):
filename = filename.lower() if filename else ""
if filename == "chenkinnoobxlv02_v02.safetensors":
return Arch.illu_rf
if string == "sd15":
return Arch.sd15
if string == "sdxl" and model_type == "v-prediction":
Expand All @@ -123,6 +126,8 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
return Arch.illu
if string == "illu_v":
return Arch.illu_v
if string == "illu_rf":
return Arch.illu_rf
if string == "chroma":
return Arch.chroma
if string == "qwen-image" and "edit" in filename:
Expand Down Expand Up @@ -175,19 +180,19 @@ def has_controlnet_inpaint(self):

@property
def supports_regions(self):
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v]
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v, Arch.illu_rf]

@property
def supports_lcm(self):
return self in [Arch.sd15, Arch.sdxl]

@property
def supports_clip_skip(self):
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v]
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v, Arch.illu_rf]

@property
def supports_attention_guidance(self):
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v]
return self in [Arch.sd15, Arch.sdxl, Arch.illu, Arch.illu_v, Arch.illu_rf]

@property
def supports_cfg(self):
Expand All @@ -204,7 +209,7 @@ def supports_edit(self): # includes text-to-image models that can also edit
@property
def is_sdxl_like(self):
# illustrious technically uses sdxl architecture, but has a separate ecosystem
return self in [Arch.sdxl, Arch.illu, Arch.illu_v]
return self in [Arch.sdxl, Arch.illu, Arch.illu_v, Arch.illu_rf]

@property
def is_flux_like(self):
Expand All @@ -223,7 +228,7 @@ def text_encoders(self):
match self:
case Arch.sd15:
return ["clip_l"]
case Arch.sdxl | Arch.illu | Arch.illu_v:
case Arch.sdxl | Arch.illu | Arch.illu_v | Arch.illu_rf:
return ["clip_l", "clip_g"]
case Arch.sd3:
return ["clip_l", "clip_g"]
Expand Down Expand Up @@ -257,6 +262,7 @@ def list():
Arch.flux2_9b,
Arch.illu,
Arch.illu_v,
Arch.illu_rf,
Arch.chroma,
Arch.qwen,
Arch.qwen_e,
Expand Down Expand Up @@ -721,6 +727,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.illu, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.illu_v, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.illu_rf, ControlMode.universal): ["union-sdxl", "xinsirunion"],
resource_id(ResourceKind.controlnet, Arch.flux, ControlMode.universal): ["flux.1-dev-controlnet-union-pro-2.0", "flux.1-dev-controlnet-union-pro", "flux.1-dev-controlnet-union", "flux1devcontrolnetunion"],
resource_id(ResourceKind.controlnet, Arch.qwen, ControlMode.universal): ["qwen-image-instantx-controlnet-union"],
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.scribble): ["control_v11p_sd15_scribble", "control_lora_rank128_v11p_sd15_scribble"],
Expand Down Expand Up @@ -798,6 +805,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.illu, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.illu_v, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.illu_rf, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.sd3, "default"): ["sd3"],
resource_id(ResourceKind.vae, Arch.flux, "default"): ["flux-", "flux_", "flux/", "flux1", "ae.s"],
resource_id(ResourceKind.vae, Arch.flux_k, "default"): ["flux-", "flux_", "flux/", "flux1", "ae.s"],
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/ui/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(self, server: Server, parent=None):
),
"illu": PackageGroupWidget(
_("Illustrious/NoobAI XL models"),
[m for m in optional_models if m.arch in [Arch.illu, Arch.illu_v]],
[m for m in optional_models if m.arch in [Arch.illu, Arch.illu_v, Arch.illu_rf]],
is_checkable=True,
is_expanded=False,
parent=self,
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def model_name(id: ResourceId, with_file=False):
text += "</ul>"
text += _("Detected base models:") + "\n<ul>"
for arch, missing in res.missing.items():
if arch in [Arch.all, Arch.illu_v]:
if arch in [Arch.all, Arch.illu_v, Arch.illu_rf]:
continue
text += f"<li><b>{arch.value}</b>: "
if len(missing) == 0:
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/ui/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def _show_builtin_info(self, style: Style):
def _enable_checkpoint_advanced(self):
arch = resolve_arch(self.current_style, root.connection.client_if_connected)
if arch.is_sdxl_like:
valid_archs = (Arch.auto, Arch.sdxl, Arch.illu, Arch.illu_v)
valid_archs = (Arch.auto, Arch.sdxl, Arch.illu, Arch.illu_v, Arch.illu_rf)
elif arch.is_flux_like:
valid_archs = (Arch.auto, Arch.flux, Arch.flux_k)
elif arch.is_qwen_like:
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/ui/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def checkpoint_icon(arch: Arch, format: FileFormat | None = None, client: Client
return icon("sd-version-flux-2")
elif arch is Arch.illu:
return icon("sd-version-illu")
elif arch is Arch.illu_v:
elif arch in (Arch.illu_v, Arch.illu_rf):
return icon("sd-version-illu-v")
elif arch is Arch.chroma:
return icon("sd-version-chroma")
Expand Down
4 changes: 2 additions & 2 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
match arch:
case Arch.sd15:
clip = w.load_clip(te["clip_l"], "stable_diffusion")
case Arch.sdxl | Arch.illu | Arch.illu_v:
case Arch.sdxl | Arch.illu | Arch.illu_v | Arch.illu_rf:
clip = w.load_dual_clip(te["clip_g"], te["clip_l"], type="sdxl")
case Arch.sd3:
if te.find("t5"):
Expand Down Expand Up @@ -188,7 +188,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
else:
model, clip = w.load_lora(model, clip, lora.name, lora.strength, lora.strength)

if arch is Arch.sd3:
if arch is Arch.sd3 or arch is Arch.illu_rf:
model = w.model_sampling_sd3(model)

if checkpoint.v_prediction_zsnr:
Expand Down