Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 98fce82

Browse files
committed
fix 37425fb
Things to understand: - subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias; - subscripted generic classes are `typing._GenericAlias`; - neither can be used with `isinstance()`; - get_origin is the cleanest way to check for this.
1 parent f9305aa commit 98fce82

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

src/refiners/fluxion/adapters/lora.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Generic, TypeVar, cast
2+
from typing import Any, Generic, Iterator, TypeVar, cast
33

44
from torch import Tensor, device as Device, dtype as DType
55
from torch.nn import Parameter as TorchParameter
@@ -385,20 +385,25 @@ def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
385385
with self.setup_adapter(target):
386386
super().__init__(target, *loras)
387387

388+
@property
389+
def lora_layers(self) -> Iterator[Lora[Any]]:
390+
"""The LoRA layers."""
391+
return cast(Iterator[Lora[Any]], self.layers(Lora))
392+
388393
@property
389394
def names(self) -> list[str]:
390395
"""The names of the LoRA layers."""
391-
return [lora.name for lora in self.layers(Lora[Any])]
396+
return [lora.name for lora in self.lora_layers]
392397

393398
@property
394399
def loras(self) -> dict[str, Lora[Any]]:
395400
"""The LoRA layers indexed by name."""
396-
return {lora.name: lora for lora in self.layers(Lora[Any])}
401+
return {lora.name: lora for lora in self.lora_layers}
397402

398403
@property
399404
def scales(self) -> dict[str, float]:
400405
"""The scales of the LoRA layers indexed by names."""
401-
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
406+
return {lora.name: lora.scale for lora in self.lora_layers}
402407

403408
@scales.setter
404409
def scale(self, values: dict[str, float]) -> None:

src/refiners/fluxion/layers/chain.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import traceback
55
from collections import defaultdict
6-
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
6+
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
77

88
import torch
99
from torch import Tensor, cat, device as Device, dtype as DType
@@ -349,6 +349,10 @@ def walk(
349349
Yields:
350350
Each module that matches the predicate.
351351
"""
352+
353+
if get_origin(predicate) is not None:
354+
raise ValueError(f"subscripted generics cannot be used as predicates")
355+
352356
if isinstance(predicate, type):
353357
# if the predicate is a Module type
354358
# build a predicate function that matches the type

src/refiners/foundationals/latent_diffusion/lora.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Iterator, cast
22
from warnings import warn
33

44
from torch import Tensor
@@ -193,7 +193,9 @@ def update_scales(self, scales: dict[str, float], /) -> None:
193193
@property
194194
def loras(self) -> list[Lora[Any]]:
195195
"""List of all the LoRA layers managed by the SDLoraManager."""
196-
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
196+
unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora))
197+
text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora))
198+
return [*unet_layers, *text_encoder_layers]
197199

198200
@property
199201
def names(self) -> list[str]:

0 commit comments

Comments
 (0)