@@ -32,8 +32,6 @@ class SparsityModifierBase(Modifier):
3232 owl_lmbda : float | None = None
3333
3434 # data pipeline arguments
35- sequential_update : bool | None = False # deprecated
36- sequential_targets : str | list [str ] | None = None
3735 targets : str | list [str ] = ["Linear" ]
3836 ignore : list [str ] = Field (default_factory = list )
3937
@@ -44,17 +42,6 @@ class SparsityModifierBase(Modifier):
4442 _target_layers : dict [str , torch .nn .Module ] = PrivateAttr (default_factory = dict )
4543 _module_sparsities : dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
4644
47- @field_validator ("sequential_update" , mode = "before" )
48- def validate_sequential_update (cls , value : bool ) -> bool :
49- if not value :
50- warnings .warn (
51- "`sequential_update=False` is no longer supported, setting "
52- "sequential_update=True" ,
53- DeprecationWarning ,
54- )
55-
56- return True
57-
5845 @field_validator ("sparsity_profile" , mode = "before" )
5946 def validate_sparsity_profile (cls , value : str | None ) -> bool :
6047 if value is None :
@@ -111,12 +98,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
11198 dataloader : torch .utils .data .DataLoader = state .data .calib
11299
113100 # infer module and sequential targets
114- # Note: only pass sequential_targets from kwargs, not the full kwargs dict
115- # which may contain 'model' and cause duplicate argument errors
116- self .sequential_targets = self ._infer_sequential_targets (
117- model , sequential_targets = kwargs .get ("sequential_targets" )
118- )
119- layers = dict (match_named_modules (model , self .sequential_targets ))
101+ sequential_targets = model ._get_no_split_modules ("auto" )
102+ layers = dict (match_named_modules (model , sequential_targets ))
120103 self ._target_layers = dict (
121104 match_named_modules (model , self .targets )
122105 ) # layers containing targets
@@ -194,33 +177,6 @@ def on_end(self, state: State, event: Event, **kwargs):
194177 self .ended_ = True
195178 self .remove_hooks ()
196179
197- def _infer_sequential_targets (
198- self , model : torch .nn .Module , ** kwargs
199- ) -> str | list [str ]:
200- targets_from_kwargs = kwargs .get ("sequential_targets" )
201-
202- # Validate that sequential_targets is not provided from both sources
203- if self .sequential_targets is not None and targets_from_kwargs is not None :
204- raise ValueError (
205- "sequential_targets was provided both in the modifier config and in "
206- "oneshot() dataset_args. Please provide sequential_targets in only "
207- "one location to avoid conflicts."
208- )
209-
210- match self .sequential_targets :
211- case None :
212- # Check if sequential_targets was passed via kwargs (from dataset_args)
213- if targets_from_kwargs is not None :
214- if isinstance (targets_from_kwargs , str ):
215- return [targets_from_kwargs ]
216- return targets_from_kwargs
217- # Fall back to auto-inference
218- return get_no_split_params (model )
219- case str ():
220- return [self .sequential_targets ]
221- case _:
222- return self .sequential_targets
223-
224180 def _infer_owl_layer_sparsity (
225181 self ,
226182 model : torch .nn .Module ,
0 commit comments