Skip to content

Commit 223192d

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Support multiple perturbations per eval when masking across tensors (#1530)
Summary: This was supported in the old path (when constructing ablated inputs over each input tensor individually) to improve compute efficiency by optionally passing in multiple perturbed inputs to the model fwd function. Differential Revision: D71435704
1 parent b899732 commit 223192d

File tree

3 files changed

+246
-72
lines changed

3 files changed

+246
-72
lines changed

captum/attr/_core/feature_ablation.py

+153-62
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
)
2929
from captum._utils.exceptions import FeatureAblationFutureError
3030
from captum._utils.progress import progress, SimpleProgress
31-
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
31+
from captum._utils.typing import (
32+
BaselineTupleType,
33+
BaselineType,
34+
TargetType,
35+
TensorOrTupleOfTensorsGeneric,
36+
)
3237
from captum.attr._utils.attribution import PerturbationAttribution
3338
from captum.attr._utils.common import _format_input_baseline
3439
from captum.log import log_usage
@@ -353,10 +358,12 @@ def attribute(
353358
formatted_feature_mask,
354359
attr_progress,
355360
flattened_initial_eval,
361+
initial_eval,
356362
n_outputs,
357363
total_attrib,
358364
weights,
359365
attrib_type,
366+
perturbations_per_eval,
360367
**kwargs,
361368
)
362369
else:
@@ -470,10 +477,12 @@ def _attribute_with_cross_tensor_feature_masks(
470477
formatted_feature_mask: Tuple[Tensor, ...],
471478
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
472479
flattened_initial_eval: Tensor,
480+
initial_eval: Tensor,
473481
n_outputs: int,
474482
total_attrib: List[Tensor],
475483
weights: List[Tensor],
476484
attrib_type: dtype,
485+
perturbations_per_eval: int,
477486
**kwargs: Any,
478487
) -> Tuple[List[Tensor], List[Tensor]]:
479488
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
@@ -482,17 +491,78 @@ def _attribute_with_cross_tensor_feature_masks(
482491
if feature_idx.item() not in feature_idx_to_tensor_idx:
483492
feature_idx_to_tensor_idx[feature_idx.item()] = []
484493
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
494+
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
495+
496+
additional_args_repeated: object
497+
if perturbations_per_eval > 1:
498+
# Repeat features and additional args for batch size.
499+
all_features_repeated = tuple(
500+
torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0)
501+
for j in range(len(formatted_inputs))
502+
)
503+
additional_args_repeated = (
504+
_expand_additional_forward_args(
505+
formatted_additional_forward_args, perturbations_per_eval
506+
)
507+
if formatted_additional_forward_args is not None
508+
else None
509+
)
510+
target_repeated = _expand_target(target, perturbations_per_eval)
511+
else:
512+
all_features_repeated = formatted_inputs
513+
additional_args_repeated = formatted_additional_forward_args
514+
target_repeated = target
515+
num_examples = formatted_inputs[0].shape[0]
516+
517+
current_additional_args: object
518+
if isinstance(baselines, tuple):
519+
reshaped = False
520+
reshaped_baselines: list[Union[Tensor, int, float]] = []
521+
for baseline in baselines:
522+
if isinstance(baseline, Tensor):
523+
reshaped = True
524+
reshaped_baselines.append(
525+
baseline.reshape((1,) + tuple(baseline.shape))
526+
)
527+
else:
528+
reshaped_baselines.append(baseline)
529+
baselines = tuple(reshaped_baselines) if reshaped else baselines
530+
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
531+
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
532+
current_num_ablated_features = min(
533+
perturbations_per_eval, len(current_feature_idxs)
534+
)
535+
536+
# Store appropriate inputs and additional args based on batch size.
537+
if current_num_ablated_features != perturbations_per_eval:
538+
current_additional_args = (
539+
_expand_additional_forward_args(
540+
formatted_additional_forward_args, current_num_ablated_features
541+
)
542+
if formatted_additional_forward_args is not None
543+
else None
544+
)
545+
current_target = _expand_target(target, current_num_ablated_features)
546+
expanded_inputs = tuple(
547+
feature_repeated[0 : current_num_ablated_features * num_examples]
548+
for feature_repeated in all_features_repeated
549+
)
550+
else:
551+
current_additional_args = additional_args_repeated
552+
current_target = target_repeated
553+
expanded_inputs = all_features_repeated
554+
555+
current_inputs, current_masks = (
556+
self._construct_ablated_input_across_tensors(
557+
expanded_inputs,
558+
formatted_feature_mask,
559+
baselines,
560+
current_feature_idxs,
561+
feature_idx_to_tensor_idx,
562+
current_num_ablated_features,
563+
)
564+
)
485565

486-
for (
487-
current_inputs,
488-
current_mask,
489-
) in self._ablation_generator(
490-
formatted_inputs,
491-
baselines,
492-
formatted_feature_mask,
493-
feature_idx_to_tensor_idx,
494-
**kwargs,
495-
):
496566
# modified_eval has (n_feature_perturbed * n_outputs) elements
497567
# shape:
498568
# agg mode: (*initial_eval.shape)
@@ -501,8 +571,8 @@ def _attribute_with_cross_tensor_feature_masks(
501571
modified_eval = _run_forward(
502572
self.forward_func,
503573
current_inputs,
504-
target,
505-
formatted_additional_forward_args,
574+
current_target,
575+
current_additional_args,
506576
)
507577

508578
if attr_progress is not None:
@@ -515,75 +585,65 @@ def _attribute_with_cross_tensor_feature_masks(
515585

516586
total_attrib, weights = self._process_ablated_out_full(
517587
modified_eval,
518-
current_mask,
588+
current_masks,
519589
flattened_initial_eval,
520-
formatted_inputs,
590+
initial_eval,
591+
current_inputs,
521592
n_outputs,
593+
num_examples,
522594
total_attrib,
523595
weights,
524596
attrib_type,
597+
perturbations_per_eval,
525598
)
526599
return total_attrib, weights
527600

528-
def _ablation_generator(
529-
self,
530-
inputs: Tuple[Tensor, ...],
531-
baselines: BaselineType,
532-
input_mask: Tuple[Tensor, ...],
533-
feature_idx_to_tensor_idx: Dict[int, List[int]],
534-
**kwargs: Any,
535-
) -> Generator[
536-
Tuple[
537-
Tuple[Tensor, ...],
538-
Tuple[Optional[Tensor], ...],
539-
],
540-
None,
541-
None,
542-
]:
543-
if isinstance(baselines, torch.Tensor):
544-
baselines = baselines.reshape((1,) + tuple(baselines.shape))
545-
546-
# Process one feature per time, rather than processing every input tensor
547-
for feature_idx in feature_idx_to_tensor_idx.keys():
548-
ablated_inputs, current_masks = (
549-
self._construct_ablated_input_across_tensors(
550-
inputs,
551-
input_mask,
552-
baselines,
553-
feature_idx,
554-
feature_idx_to_tensor_idx[feature_idx],
555-
)
556-
)
557-
yield ablated_inputs, current_masks
558-
559601
def _construct_ablated_input_across_tensors(
560602
self,
561603
inputs: Tuple[Tensor, ...],
562604
input_mask: Tuple[Tensor, ...],
563605
baselines: BaselineType,
564-
feature_idx: int,
565-
tensor_idxs: List[int],
606+
feature_idxs: List[int],
607+
feature_idx_to_tensor_idx: Dict[int, List[int]],
608+
current_num_ablated_features: int,
566609
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
567-
568610
ablated_inputs = []
569611
current_masks: List[Optional[Tensor]] = []
612+
tensor_idxs = {
613+
tensor_idx
614+
for sublist in (
615+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
616+
)
617+
for tensor_idx in sublist
618+
}
619+
570620
for i, input_tensor in enumerate(inputs):
571621
if i not in tensor_idxs:
572622
ablated_inputs.append(input_tensor)
573623
current_masks.append(None)
574624
continue
575-
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
625+
tensor_mask = []
626+
ablated_input = input_tensor.clone()
576627
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
577-
if isinstance(baseline, torch.Tensor):
578-
baseline = baseline.reshape(
579-
(1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape)
628+
for j, feature_idx in enumerate(feature_idxs):
629+
original_input_size = (
630+
input_tensor.shape[0] // current_num_ablated_features
580631
)
581-
assert baseline is not None, "baseline must be provided"
582-
ablated_input = (
583-
input_tensor * (1 - tensor_mask).to(input_tensor.dtype)
584-
) + (baseline * tensor_mask.to(input_tensor.dtype))
632+
start_idx = j * original_input_size
633+
end_idx = (j + 1) * original_input_size
634+
635+
mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
636+
if mask.ndim == 0:
637+
mask = mask.reshape((1,) * input_tensor.dim())
638+
tensor_mask.append(mask)
639+
640+
assert baseline is not None, "baseline must be provided"
641+
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
642+
1 - mask
643+
) + (baseline * mask.to(input_tensor.dtype))
644+
current_masks.append(torch.stack(tensor_mask, dim=0))
585645
ablated_inputs.append(ablated_input)
586-
current_masks.append(tensor_mask)
646+
587647
return tuple(ablated_inputs), tuple(current_masks)
588648

589649
def _initial_eval_to_processed_initial_eval_fut(
@@ -784,7 +844,7 @@ def _attribute_progress_setup(
784844
formatted_inputs, feature_mask, **kwargs
785845
)
786846
total_forwards = (
787-
int(sum(feature_counts))
847+
math.ceil(int(sum(feature_counts)) / perturbations_per_eval)
788848
if enable_cross_tensor_attribution
789849
else sum(
790850
math.ceil(count / perturbations_per_eval) for count in feature_counts
@@ -1194,13 +1254,46 @@ def _process_ablated_out_full(
11941254
modified_eval: Tensor,
11951255
current_mask: Tuple[Optional[Tensor], ...],
11961256
flattened_initial_eval: Tensor,
1257+
initial_eval: Tensor,
11971258
inputs: TensorOrTupleOfTensorsGeneric,
11981259
n_outputs: int,
1260+
num_examples: int,
11991261
total_attrib: List[Tensor],
12001262
weights: List[Tensor],
12011263
attrib_type: dtype,
1264+
perturbations_per_eval: int,
12021265
) -> Tuple[List[Tensor], List[Tensor]]:
12031266
modified_eval = self._parse_forward_out(modified_eval)
1267+
# if perturbations_per_eval > 1, the output shape must grow with
1268+
# input and not be aggregated
1269+
current_batch_size = inputs[0].shape[0]
1270+
1271+
# number of perturbation, which is not the same as
1272+
# perturbations_per_eval when not enough features to perturb
1273+
n_perturb = current_batch_size / num_examples
1274+
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
1275+
1276+
current_output_shape = modified_eval.shape
1277+
1278+
# use initial_eval as the forward of perturbations_per_eval = 1
1279+
initial_output_shape = initial_eval.shape
1280+
1281+
assert (
1282+
# check if the output is not a scalar
1283+
current_output_shape
1284+
and initial_output_shape
1285+
# check if the output grow in same ratio, i.e., not agg
1286+
and current_output_shape[0] == n_perturb * initial_output_shape[0]
1287+
), (
1288+
"When perturbations_per_eval > 1, forward_func's output "
1289+
"should be a tensor whose 1st dim grow with the input "
1290+
f"batch size: when input batch size is {num_examples}, "
1291+
f"the output shape is {initial_output_shape}; "
1292+
f"when input batch size is {current_batch_size}, "
1293+
f"the output shape is {current_output_shape}"
1294+
)
1295+
1296+
self._is_output_shape_valid = True
12041297

12051298
# reshape the leading dim for n_feature_perturbed
12061299
# flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1302,6 @@ def _process_ablated_out_full(
12091302
eval_diff = flattened_initial_eval - modified_eval
12101303
eval_diff_shape = eval_diff.shape
12111304

1212-
# append the shape of one input example
1213-
# to make it broadcastable to mask
1214-
12151305
if self.use_weights:
12161306
for weight, mask in zip(weights, current_mask):
12171307
if mask is not None:
@@ -1224,6 +1314,7 @@ def _process_ablated_out_full(
12241314
)
12251315
eval_diff = eval_diff.to(total_attrib[i].device)
12261316
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)
1317+
12271318
return total_attrib, weights
12281319

12291320
def _fut_tuple_to_accumulate_fut_list(

captum/attr/_core/feature_permutation.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors(
391391
inputs: Tuple[Tensor, ...],
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394-
feature_idx: int,
395-
tensor_idxs: List[int],
394+
feature_idxs: List[int],
395+
feature_idx_to_tensor_idx: Dict[int, List[int]],
396+
current_num_ablated_features: int,
396397
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397398
current_masks: List[Optional[Tensor]] = []
398-
for i, mask in enumerate(input_mask):
399-
if i in tensor_idxs:
400-
current_masks.append((mask == feature_idx).to(inputs[0].device))
401-
else:
399+
tensor_idxs = {
400+
tensor_idx
401+
for sublist in (
402+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
403+
)
404+
for tensor_idx in sublist
405+
}
406+
permuted_inputs = []
407+
for i, input_tensor in enumerate(inputs):
408+
if i not in tensor_idxs:
402409
current_masks.append(None)
403-
feature_masks = tuple(current_masks)
404-
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
405-
return permuted_outputs, feature_masks
410+
permuted_inputs.append(input_tensor)
411+
continue
412+
tensor_mask = []
413+
permuted_input = input_tensor.clone()
414+
for j, feature_idx in enumerate(feature_idxs):
415+
original_input_size = (
416+
input_tensor.shape[0] // current_num_ablated_features
417+
)
418+
start_idx = j * original_input_size
419+
end_idx = (j + 1) * original_input_size
420+
421+
mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool()
422+
if mask.ndim == 0:
423+
mask = mask.reshape((1,) * input_tensor.dim())
424+
tensor_mask.append(mask)
425+
permuted_input[start_idx:end_idx] = self.perm_func(
426+
input_tensor[start_idx:end_idx], mask
427+
)
428+
current_masks.append(torch.stack(tensor_mask, dim=0))
429+
permuted_inputs.append(permuted_input)
430+
431+
return tuple(permuted_inputs), tuple(current_masks)

0 commit comments

Comments
 (0)