Skip to content

Commit 0c6a39e

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Avoid unnecessary tensor construction when creating input masks for permutation/ablation (#1527)
Summary: Pull Request resolved: #1527 Study: https://docs.google.com/spreadsheets/d/1GyNJJBrNkazGOyJQLv00QV4phX2R3488oNgVPT17qzU/edit?gid=0#gid=0 Saw a regression in the new logic introduced in D69531512 with one of the models for both permutation and ablation methods, potentially due to large sparse features. vivekmig suggested we can avoid creating all these zero tensors Differential Revision: D71057703
1 parent 4ca5c2c commit 0c6a39e

File tree

2 files changed

+50
-21
lines changed

2 files changed

+50
-21
lines changed

captum/attr/_core/feature_ablation.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44

55
import math
6-
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
Generator,
12+
List,
13+
Optional,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
)
718

819
import torch
920
from captum._utils.common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476
attrib_type: dtype,
466477
**kwargs: Any,
467478
) -> Tuple[List[Tensor], List[Tensor]]:
479+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
480+
for i, mask in enumerate(formatted_feature_mask):
481+
for feature_idx in torch.unique(mask):
482+
if feature_idx.item() not in feature_idx_to_tensor_idx.keys():
483+
feature_idx_to_tensor_idx[feature_idx.item()] = []
484+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
485+
468486
for (
469487
current_inputs,
470488
current_mask,
471489
) in self._ablation_generator(
472490
formatted_inputs,
473491
baselines,
474492
formatted_feature_mask,
493+
feature_idx_to_tensor_idx,
475494
**kwargs,
476495
):
477496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,11 +530,12 @@ def _ablation_generator(
511530
inputs: Tuple[Tensor, ...],
512531
baselines: BaselineType,
513532
input_mask: Tuple[Tensor, ...],
533+
feature_idx_to_tensor_idx: Dict[int, List[int]],
514534
**kwargs: Any,
515535
) -> Generator[
516536
Tuple[
517537
Tuple[Tensor, ...],
518-
Tuple[Tensor, ...],
538+
Tuple[Optional[Tensor], ...],
519539
],
520540
None,
521541
None,
@@ -531,7 +551,11 @@ def _ablation_generator(
531551
for feature_idx in unique_feature_ids:
532552
ablated_inputs, current_masks = (
533553
self._construct_ablated_input_across_tensors(
534-
inputs, input_mask, baselines, feature_idx
554+
inputs,
555+
input_mask,
556+
baselines,
557+
feature_idx,
558+
feature_idx_to_tensor_idx[feature_idx],
535559
)
536560
)
537561
yield ablated_inputs, current_masks
@@ -542,18 +566,17 @@ def _construct_ablated_input_across_tensors(
542566
input_mask: Tuple[Tensor, ...],
543567
baselines: BaselineType,
544568
feature_idx: int,
545-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
569+
tensor_idxs: List[int],
570+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
546571

547572
ablated_inputs = []
548-
current_masks = []
573+
current_masks: List[Optional[Tensor]] = []
549574
for i, input_tensor in enumerate(inputs):
550-
mask = input_mask[i]
551-
tensor_mask = mask == feature_idx
552-
if not tensor_mask.any():
575+
if i not in tensor_idxs:
553576
ablated_inputs.append(input_tensor)
554-
current_masks.append(torch.zeros_like(tensor_mask))
577+
current_masks.append(None)
555578
continue
556-
tensor_mask = tensor_mask.to(input_tensor.device).long()
579+
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
557580
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
558581
if isinstance(baseline, torch.Tensor):
559582
baseline = baseline.reshape(
@@ -1173,7 +1196,7 @@ def _process_ablated_out(
11731196
def _process_ablated_out_full(
11741197
self,
11751198
modified_eval: Tensor,
1176-
current_mask: Tuple[Tensor, ...],
1199+
current_mask: Tuple[Optional[Tensor], ...],
11771200
flattened_initial_eval: Tensor,
11781201
inputs: TensorOrTupleOfTensorsGeneric,
11791202
n_outputs: int,
@@ -1195,9 +1218,10 @@ def _process_ablated_out_full(
11951218

11961219
if self.use_weights:
11971220
for weight, mask in zip(weights, current_mask):
1198-
weight += mask.float().sum(dim=0)
1221+
if mask is not None:
1222+
weight += mask.float().sum(dim=0)
11991223
for i, mask in enumerate(current_mask):
1200-
if inputs[i].numel() == 0:
1224+
if mask is None or inputs[i].numel() == 0:
12011225
continue
12021226
eval_diff = eval_diff.reshape(
12031227
eval_diff_shape + (inputs[i].dim() - 1) * (1,)

captum/attr/_core/feature_permutation.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -26,15 +26,15 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
2626

2727

2828
def _permute_features_across_tensors(
29-
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...]
29+
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
3030
) -> Tuple[Tensor, ...]:
3131
"""
3232
Permutes features across multiple input tensors using the corresponding
3333
feature masks.
3434
"""
3535
permuted_outputs = []
3636
for input_tensor, feature_mask in zip(inputs, feature_masks):
37-
if not feature_mask.any():
37+
if feature_mask is None or not feature_mask.any():
3838
permuted_outputs.append(input_tensor)
3939
continue
4040
n = input_tensor.size(0)
@@ -103,7 +103,7 @@ def __init__(
103103
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
104104
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
105105
perm_func_cross_tensor: Callable[
106-
[Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...]
106+
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
107107
] = _permute_features_across_tensors,
108108
) -> None:
109109
r"""
@@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors(
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394394
feature_idx: int,
395-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
396-
feature_masks = tuple(
397-
(mask == feature_idx).to(inputs[0].device) for mask in input_mask
398-
)
395+
tensor_idxs: List[int],
396+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397+
current_masks: List[Optional[Tensor]] = []
398+
for i, mask in enumerate(input_mask):
399+
if i in tensor_idxs:
400+
current_masks.append((mask == feature_idx).to(inputs[0].device))
401+
else:
402+
current_masks.append(None)
403+
feature_masks = tuple(current_masks)
399404
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
400405
return permuted_outputs, feature_masks

0 commit comments

Comments
 (0)