File tree 1 file changed +10
-4
lines changed
src/sparseml/pytorch/optim
1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -332,6 +332,16 @@ def score_parameters(self) -> List[Tensor]:
332
332
given by the OBS method. For the approximated Hessian inverse matrix
333
333
H^-1, scores will be W^2 / (2 * diag(H^-1))
334
334
"""
335
+
336
+ if self ._grad_buffer is None or torch .any (
337
+ torch .all (self ._grad_buffer == 0.0 , dim = 1 )
338
+ ):
339
+ # raise Exception if grad buffer is not full
340
+ raise RuntimeError (
341
+ "MFAC pruning step called, but not enough gradient samples have been "
342
+ f"collected. Expected { self ._mfac_options .num_grads } samples"
343
+ )
344
+
335
345
if self ._is_ddp :
336
346
# move all grads to one device
337
347
if self ._is_main_proc :
@@ -450,10 +460,6 @@ def get_name() -> str:
450
460
451
461
def _score_parameters (self ) -> List [Tensor ]:
452
462
# score params using MFAC and the gathered grad buffers
453
- if torch .any (torch .all (self ._grads == 0.0 , dim = 1 )):
454
- # if not all grads are captured, return magnitudes as scores
455
- return [torch .abs (param .data ) for param in self ._params ]
456
-
457
463
# gather non-pruned weights
458
464
non_pruned_weights = torch .empty (self ._grads .size (1 )).to (self ._grads .device )
459
465
weights_idx = 0
You can’t perform that action at this time.
0 commit comments