Skip to content

Commit b917b2a

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Avoid unnecessary tensor construction when creating input masks for permutation/ablation (#1527)
Summary: Pull Request resolved: #1527 Study: https://docs.google.com/spreadsheets/d/1GyNJJBrNkazGOyJQLv00QV4phX2R3488oNgVPT17qzU/edit?gid=0#gid=0 Saw a regression in the new logic introduced in D69531512 with one of the models for both permutation and ablation methods, potentially due to large sparse features. vivekmig suggested we can avoid creating all these zero tensors Reviewed By: craymichael Differential Revision: D71057703 fbshipit-source-id: 3c4acc00b82de3fff7322c4f7cf99ad87fed1d02
1 parent 4ca5c2c commit b917b2a

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

captum/attr/_core/feature_ablation.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44

55
import math
6-
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
Generator,
12+
List,
13+
Optional,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
)
718

819
import torch
920
from captum._utils.common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476
attrib_type: dtype,
466477
**kwargs: Any,
467478
) -> Tuple[List[Tensor], List[Tensor]]:
479+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
480+
for i, mask in enumerate(formatted_feature_mask):
481+
for feature_idx in torch.unique(mask):
482+
if feature_idx.item() not in feature_idx_to_tensor_idx:
483+
feature_idx_to_tensor_idx[feature_idx.item()] = []
484+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
485+
468486
for (
469487
current_inputs,
470488
current_mask,
471489
) in self._ablation_generator(
472490
formatted_inputs,
473491
baselines,
474492
formatted_feature_mask,
493+
feature_idx_to_tensor_idx,
475494
**kwargs,
476495
):
477496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,27 +530,28 @@ def _ablation_generator(
511530
inputs: Tuple[Tensor, ...],
512531
baselines: BaselineType,
513532
input_mask: Tuple[Tensor, ...],
533+
feature_idx_to_tensor_idx: Dict[int, List[int]],
514534
**kwargs: Any,
515535
) -> Generator[
516536
Tuple[
517537
Tuple[Tensor, ...],
518-
Tuple[Tensor, ...],
538+
Tuple[Optional[Tensor], ...],
519539
],
520540
None,
521541
None,
522542
]:
523-
unique_feature_ids = torch.unique(
524-
torch.cat([mask.flatten() for mask in input_mask])
525-
).tolist()
526-
527543
if isinstance(baselines, torch.Tensor):
528544
baselines = baselines.reshape((1,) + tuple(baselines.shape))
529545

530546
# Process one feature per time, rather than processing every input tensor
531-
for feature_idx in unique_feature_ids:
547+
for feature_idx in feature_idx_to_tensor_idx.keys():
532548
ablated_inputs, current_masks = (
533549
self._construct_ablated_input_across_tensors(
534-
inputs, input_mask, baselines, feature_idx
550+
inputs,
551+
input_mask,
552+
baselines,
553+
feature_idx,
554+
feature_idx_to_tensor_idx[feature_idx],
535555
)
536556
)
537557
yield ablated_inputs, current_masks
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542562
input_mask: Tuple[Tensor, ...],
543563
baselines: BaselineType,
544564
feature_idx: int,
545-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
565+
tensor_idxs: List[int],
566+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
546567

547568
ablated_inputs = []
548-
current_masks = []
569+
current_masks: List[Optional[Tensor]] = []
549570
for i, input_tensor in enumerate(inputs):
550-
mask = input_mask[i]
551-
tensor_mask = mask == feature_idx
552-
if not tensor_mask.any():
571+
if i not in tensor_idxs:
553572
ablated_inputs.append(input_tensor)
554-
current_masks.append(torch.zeros_like(tensor_mask))
573+
current_masks.append(None)
555574
continue
556-
tensor_mask = tensor_mask.to(input_tensor.device).long()
575+
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
557576
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
558577
if isinstance(baseline, torch.Tensor):
559578
baseline = baseline.reshape(
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
11731192
def _process_ablated_out_full(
11741193
self,
11751194
modified_eval: Tensor,
1176-
current_mask: Tuple[Tensor, ...],
1195+
current_mask: Tuple[Optional[Tensor], ...],
11771196
flattened_initial_eval: Tensor,
11781197
inputs: TensorOrTupleOfTensorsGeneric,
11791198
n_outputs: int,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
11951214

11961215
if self.use_weights:
11971216
for weight, mask in zip(weights, current_mask):
1198-
weight += mask.float().sum(dim=0)
1217+
if mask is not None:
1218+
weight += mask.float().sum(dim=0)
11991219
for i, mask in enumerate(current_mask):
1200-
if inputs[i].numel() == 0:
1220+
if mask is None or inputs[i].numel() == 0:
12011221
continue
12021222
eval_diff = eval_diff.reshape(
12031223
eval_diff_shape + (inputs[i].dim() - 1) * (1,)

captum/attr/_core/feature_permutation.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -26,15 +26,15 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
2626

2727

2828
def _permute_features_across_tensors(
29-
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...]
29+
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
3030
) -> Tuple[Tensor, ...]:
3131
"""
3232
Permutes features across multiple input tensors using the corresponding
3333
feature masks.
3434
"""
3535
permuted_outputs = []
3636
for input_tensor, feature_mask in zip(inputs, feature_masks):
37-
if not feature_mask.any():
37+
if feature_mask is None or not feature_mask.any():
3838
permuted_outputs.append(input_tensor)
3939
continue
4040
n = input_tensor.size(0)
@@ -103,7 +103,7 @@ def __init__(
103103
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
104104
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
105105
perm_func_cross_tensor: Callable[
106-
[Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...]
106+
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
107107
] = _permute_features_across_tensors,
108108
) -> None:
109109
r"""
@@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors(
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394394
feature_idx: int,
395-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
396-
feature_masks = tuple(
397-
(mask == feature_idx).to(inputs[0].device) for mask in input_mask
398-
)
395+
tensor_idxs: List[int],
396+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397+
current_masks: List[Optional[Tensor]] = []
398+
for i, mask in enumerate(input_mask):
399+
if i in tensor_idxs:
400+
current_masks.append((mask == feature_idx).to(inputs[0].device))
401+
else:
402+
current_masks.append(None)
403+
feature_masks = tuple(current_masks)
399404
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
400405
return permuted_outputs, feature_masks

0 commit comments

Comments
 (0)