Skip to content

Lora silent loading/unloading. Lora loading exception when no weights #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

config.cache_size_limit = 10000000000
ind_config.shape_padding = True
config.suppress_errors = True
import platform

from loguru import logger
Expand Down Expand Up @@ -153,6 +154,7 @@ def load_lora(
lora_path: Union[str, OrderedDict[str, torch.Tensor]],
scale: float,
name: Optional[str] = None,
silent=False
):
"""
Loads a LoRA checkpoint into the Flux flow transformer.
Expand All @@ -165,16 +167,16 @@ def load_lora(
scale (float): Scaling factor for the LoRA weights.
name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier.
"""
self.model.load_lora(path=lora_path, scale=scale, name=name)
self.model.load_lora(path=lora_path, scale=scale, name=name, silent=silent)

def unload_lora(self, path_or_identifier: str):
def unload_lora(self, path_or_identifier: str, silent=False):
"""
Unloads the LoRA checkpoint from the Flux flow transformer.

Args:
path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded.
"""
self.model.unload_lora(path_or_identifier=path_or_identifier)
self.model.unload_lora(path_or_identifier=path_or_identifier, silent=silent)

@torch.inference_mode()
def compile(self):
Expand Down
30 changes: 13 additions & 17 deletions lora_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def convert_diffusers_to_flux_transformer_checkpoint(
dtype = sample_component_A.dtype
device = sample_component_A.device
else:
logger.info(
f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}"
)
logger.info(f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}")
temp_dict[f"{component}"] = [None, None]

if device is not None:
Expand Down Expand Up @@ -344,30 +342,26 @@ def convert_diffusers_to_flux_transformer_checkpoint(
shape_qkv_a = None
shape_qkv_b = None
# Q, K, V, mlp
q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight", None)
q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight", None)
if q_A is not None and q_B is not None:
has_q = True
shape_qkv_a = q_A.shape
shape_qkv_b = q_B.shape
k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight", None)
k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight", None)
if k_A is not None and k_B is not None:
has_k = True
shape_qkv_a = k_A.shape
shape_qkv_b = k_B.shape
v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight", None)
v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight", None)
if v_A is not None and v_B is not None:
has_v = True
shape_qkv_a = v_A.shape
shape_qkv_b = v_B.shape
mlp_A = diffusers_state_dict.pop(
f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
)
mlp_B = diffusers_state_dict.pop(
f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
)
mlp_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_A.weight", None)
mlp_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_B.weight", None)
if mlp_A is not None and mlp_B is not None:
has_mlp = True
shape_qkv_a = mlp_A.shape
Expand Down Expand Up @@ -637,6 +631,7 @@ def apply_lora_to_model(
lora_path: str | StateDict,
lora_scale: float = 1.0,
return_lora_resolved: bool = False,
silent=False
) -> Flux:
has_guidance = model.params.guidance_embed
logger.info(f"Loading LoRA weights for {lora_path}")
Expand Down Expand Up @@ -675,7 +670,7 @@ def apply_lora_to_model(
]
)
)
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab), disable=silent):
module = get_module_for_key(key, model)
weight, is_f8, dtype = extract_weight_from_linear(module)
lora_sd = get_lora_for_key(key, lora_weights)
Expand All @@ -697,6 +692,7 @@ def remove_lora_from_module(
model: Flux,
lora_path: str | StateDict,
lora_scale: float = 1.0,
silent=False
):
has_guidance = model.params.guidance_embed
logger.info(f"Loading LoRA weights for {lora_path}")
Expand Down Expand Up @@ -737,7 +733,7 @@ def remove_lora_from_module(
)
)

for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab), disable=silent):
module = get_module_for_key(key, model)
weight, is_f8, dtype = extract_weight_from_linear(module)
lora_sd = get_lora_for_key(key, lora_weights)
Expand Down
12 changes: 6 additions & 6 deletions modules/flux_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def has_lora(self, identifier: str):
if lora.path == identifier or lora.name == identifier:
return True

def load_lora(self, path: str, scale: float, name: str = None):
def load_lora(self, path: str, scale: float, name: str = None, silent=False):
from lora_loading import (
LoraWeights,
apply_lora_to_model,
Expand All @@ -642,23 +642,23 @@ def load_lora(self, path: str, scale: float, name: str = None):
f"Lora {lora.name} already loaded with same scale - ignoring!"
)
else:
remove_lora_from_module(self, lora, lora.scale)
apply_lora_to_model(self, lora, scale)
remove_lora_from_module(self, lora, lora.scale, silent=silent)
apply_lora_to_model(self, lora, scale, silent=silent)
for idx, lora_ in enumerate(self.loras):
if lora_.path == lora.path:
self.loras[idx].scale = scale
break
else:
_, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True)
_, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True, silent=silent)
self.loras.append(LoraWeights(lora, path, name, scale))

def unload_lora(self, path_or_identifier: str):
def unload_lora(self, path_or_identifier: str, silent=False):
from lora_loading import remove_lora_from_module

removed = False
for idx, lora_ in enumerate(list(self.loras)):
if lora_.path == path_or_identifier or lora_.name == path_or_identifier:
remove_lora_from_module(self, lora_.weights, lora_.scale)
remove_lora_from_module(self, lora_.weights, lora_.scale, silent=silent)
self.loras.pop(idx)
removed = True
break
Expand Down