Skip to content

Commit d4b259c

Browse files
authored
fix: align sage_attn default hyperparameter return type with refactor (#540)
1 parent bd5f322 commit d4b259c

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/pruna/algorithms/sage_attn.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from diffusers import DiffusionPipeline
22+
from typing_extensions import cast
2223

2324
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
2425
from pruna.algorithms.base.tags import AlgorithmTag as tags
@@ -91,10 +92,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
9192
target_modules = smash_config["target_modules"]
9293

9394
if target_modules is None:
94-
target_modules = self.get_model_dependent_hyperparameter_defaults(
95-
model,
96-
smash_config
97-
)
95+
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)["target_modules"]
96+
target_modules = cast(TARGET_MODULES_TYPE, target_modules)
9897

9998
def apply_sage_attn(
10099
root_name: str | None,
@@ -154,7 +153,7 @@ def get_model_dependent_hyperparameter_defaults(
154153
self,
155154
model: Any,
156155
smash_config: SmashConfigPrefixWrapper,
157-
) -> TARGET_MODULES_TYPE:
156+
) -> dict[str, Any]:
158157
"""
159158
Provide default `target_modules` targeting all transformer modules.
160159
@@ -178,5 +177,5 @@ def get_model_dependent_hyperparameter_defaults(
178177
# SageAttn might also be applicable to other modules but could significantly decrease model quality.
179178
include = ["transformer*"]
180179
exclude = []
181-
182-
return {"include": include, "exclude": exclude}
180+
target_modules = {"include": include, "exclude": exclude}
181+
return {"target_modules": target_modules}

0 commit comments

Comments
 (0)