Skip to content

Commit 8502988

Browse files
authored
raise exception when not all grads have been supplied for WF pruning (#296)
1 parent 7676b27 commit 8502988

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/sparseml/pytorch/optim/mask_pruning_scorer.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ def score_parameters(self) -> List[Tensor]:
332332
given by the OBS method. For the approximated Hessian inverse matrix
333333
H^-1, scores will be W^2 / (2 * diag(H^-1))
334334
"""
335+
336+
if self._grad_buffer is None or torch.any(
337+
torch.all(self._grad_buffer == 0.0, dim=1)
338+
):
339+
# raise Exception if grad buffer is not full
340+
raise RuntimeError(
341+
"MFAC pruning step called, but not enough gradient samples have been "
342+
f"collected. Expected {self._mfac_options.num_grads} samples"
343+
)
344+
335345
if self._is_ddp:
336346
# move all grads to one device
337347
if self._is_main_proc:
@@ -450,10 +460,6 @@ def get_name() -> str:
450460

451461
def _score_parameters(self) -> List[Tensor]:
452462
# score params using MFAC and the gathered grad buffers
453-
if torch.any(torch.all(self._grads == 0.0, dim=1)):
454-
# if not all grads are captured, return magnitudes as scores
455-
return [torch.abs(param.data) for param in self._params]
456-
457463
# gather non-pruned weights
458464
non_pruned_weights = torch.empty(self._grads.size(1)).to(self._grads.device)
459465
weights_idx = 0

0 commit comments

Comments
 (0)