1919
2020import torch
2121from diffusers import DiffusionPipeline
22+ from typing_extensions import cast
2223
2324from pruna .algorithms .base .pruna_base import PrunaAlgorithmBase
2425from pruna .algorithms .base .tags import AlgorithmTag as tags
@@ -91,10 +92,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
9192 target_modules = smash_config ["target_modules" ]
9293
9394 if target_modules is None :
94- target_modules = self .get_model_dependent_hyperparameter_defaults (
95- model ,
96- smash_config
97- )
95+ target_modules = self .get_model_dependent_hyperparameter_defaults (model , smash_config )["target_modules" ]
96+ target_modules = cast (TARGET_MODULES_TYPE , target_modules )
9897
9998 def apply_sage_attn (
10099 root_name : str | None ,
@@ -154,7 +153,7 @@ def get_model_dependent_hyperparameter_defaults(
154153 self ,
155154 model : Any ,
156155 smash_config : SmashConfigPrefixWrapper ,
157- ) -> TARGET_MODULES_TYPE :
156+ ) -> dict [ str , Any ] :
158157 """
159158 Provide default `target_modules` targeting all transformer modules.
160159
@@ -178,5 +177,5 @@ def get_model_dependent_hyperparameter_defaults(
178177 # SageAttn might also be applicable to other modules but could significantly decrease model quality.
179178 include = ["transformer*" ]
180179 exclude = []
181-
182- return {"include " : include , "exclude" : exclude }
180+ target_modules = { "include" : include , "exclude" : exclude }
181+ return {"target_modules " : target_modules }
0 commit comments