Skip to content

Commit 8dd3f4a

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Support multiple perturbations per eval when masking across tensors (#1530)
Summary: This was supported in the old path (when constructing ablated inputs over each input tensor individually) to improve compute efficiency by optionally passing in multiple perturbed inputs to the model fwd function. Differential Revision: D71435704
1 parent b899732 commit 8dd3f4a

File tree

3 files changed

+240
-71
lines changed

3 files changed

+240
-71
lines changed

captum/attr/_core/feature_ablation.py

+147-61
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,12 @@ def attribute(
353353
formatted_feature_mask,
354354
attr_progress,
355355
flattened_initial_eval,
356+
initial_eval,
356357
n_outputs,
357358
total_attrib,
358359
weights,
359360
attrib_type,
361+
perturbations_per_eval,
360362
**kwargs,
361363
)
362364
else:
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470472
formatted_feature_mask: Tuple[Tensor, ...],
471473
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
472474
flattened_initial_eval: Tensor,
475+
initial_eval: Tensor,
473476
n_outputs: int,
474477
total_attrib: List[Tensor],
475478
weights: List[Tensor],
476479
attrib_type: dtype,
480+
perturbations_per_eval: int,
477481
**kwargs: Any,
478482
) -> Tuple[List[Tensor], List[Tensor]]:
479483
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
@@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks(
482486
if feature_idx.item() not in feature_idx_to_tensor_idx:
483487
feature_idx_to_tensor_idx[feature_idx.item()] = []
484488
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
489+
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
490+
491+
additional_args_repeated: object
492+
if perturbations_per_eval > 1:
493+
# Repeat features and additional args for batch size.
494+
all_features_repeated = tuple(
495+
torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0)
496+
for j in range(len(formatted_inputs))
497+
)
498+
additional_args_repeated = (
499+
_expand_additional_forward_args(
500+
formatted_additional_forward_args, perturbations_per_eval
501+
)
502+
if formatted_additional_forward_args is not None
503+
else None
504+
)
505+
target_repeated = _expand_target(target, perturbations_per_eval)
506+
else:
507+
all_features_repeated = formatted_inputs
508+
additional_args_repeated = formatted_additional_forward_args
509+
target_repeated = target
510+
num_examples = formatted_inputs[0].shape[0]
511+
512+
current_additional_args: object
513+
if isinstance(baselines, tuple):
514+
reshaped = False
515+
reshaped_baselines: list[BaselineType] = []
516+
for baseline in baselines:
517+
if isinstance(baseline, Tensor):
518+
reshaped = True
519+
reshaped_baselines.append(
520+
baseline.reshape((1,) + tuple(baseline.shape))
521+
)
522+
else:
523+
reshaped_baselines.append(baseline)
524+
baselines = tuple(reshaped_baselines) if reshaped else baselines
525+
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
526+
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
527+
current_num_ablated_features = min(
528+
perturbations_per_eval, len(current_feature_idxs)
529+
)
530+
531+
# Store appropriate inputs and additional args based on batch size.
532+
if current_num_ablated_features != perturbations_per_eval:
533+
current_additional_args = (
534+
_expand_additional_forward_args(
535+
formatted_additional_forward_args, current_num_ablated_features
536+
)
537+
if formatted_additional_forward_args is not None
538+
else None
539+
)
540+
current_target = _expand_target(target, current_num_ablated_features)
541+
expanded_inputs = tuple(
542+
feature_repeated[0 : current_num_ablated_features * num_examples]
543+
for feature_repeated in all_features_repeated
544+
)
545+
else:
546+
current_additional_args = additional_args_repeated
547+
current_target = target_repeated
548+
expanded_inputs = all_features_repeated
549+
550+
current_inputs, current_masks = (
551+
self._construct_ablated_input_across_tensors(
552+
expanded_inputs,
553+
formatted_feature_mask,
554+
baselines,
555+
current_feature_idxs,
556+
feature_idx_to_tensor_idx,
557+
current_num_ablated_features,
558+
)
559+
)
485560

486-
for (
487-
current_inputs,
488-
current_mask,
489-
) in self._ablation_generator(
490-
formatted_inputs,
491-
baselines,
492-
formatted_feature_mask,
493-
feature_idx_to_tensor_idx,
494-
**kwargs,
495-
):
496561
# modified_eval has (n_feature_perturbed * n_outputs) elements
497562
# shape:
498563
# agg mode: (*initial_eval.shape)
@@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks(
501566
modified_eval = _run_forward(
502567
self.forward_func,
503568
current_inputs,
504-
target,
505-
formatted_additional_forward_args,
569+
current_target,
570+
current_additional_args,
506571
)
507572

508573
if attr_progress is not None:
@@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks(
515580

516581
total_attrib, weights = self._process_ablated_out_full(
517582
modified_eval,
518-
current_mask,
583+
current_masks,
519584
flattened_initial_eval,
520-
formatted_inputs,
585+
initial_eval,
586+
current_inputs,
521587
n_outputs,
588+
num_examples,
522589
total_attrib,
523590
weights,
524591
attrib_type,
592+
perturbations_per_eval,
525593
)
526594
return total_attrib, weights
527595

528-
def _ablation_generator(
529-
self,
530-
inputs: Tuple[Tensor, ...],
531-
baselines: BaselineType,
532-
input_mask: Tuple[Tensor, ...],
533-
feature_idx_to_tensor_idx: Dict[int, List[int]],
534-
**kwargs: Any,
535-
) -> Generator[
536-
Tuple[
537-
Tuple[Tensor, ...],
538-
Tuple[Optional[Tensor], ...],
539-
],
540-
None,
541-
None,
542-
]:
543-
if isinstance(baselines, torch.Tensor):
544-
baselines = baselines.reshape((1,) + tuple(baselines.shape))
545-
546-
# Process one feature per time, rather than processing every input tensor
547-
for feature_idx in feature_idx_to_tensor_idx.keys():
548-
ablated_inputs, current_masks = (
549-
self._construct_ablated_input_across_tensors(
550-
inputs,
551-
input_mask,
552-
baselines,
553-
feature_idx,
554-
feature_idx_to_tensor_idx[feature_idx],
555-
)
556-
)
557-
yield ablated_inputs, current_masks
558-
559596
def _construct_ablated_input_across_tensors(
560597
self,
561598
inputs: Tuple[Tensor, ...],
562599
input_mask: Tuple[Tensor, ...],
563600
baselines: BaselineType,
564-
feature_idx: int,
565-
tensor_idxs: List[int],
601+
feature_idxs: List[int],
602+
feature_idx_to_tensor_idx: Dict[int, List[int]],
603+
current_num_ablated_features: int,
566604
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
567-
568605
ablated_inputs = []
569606
current_masks: List[Optional[Tensor]] = []
607+
tensor_idxs = {
608+
tensor_idx
609+
for sublist in (
610+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
611+
)
612+
for tensor_idx in sublist
613+
}
614+
570615
for i, input_tensor in enumerate(inputs):
571616
if i not in tensor_idxs:
572617
ablated_inputs.append(input_tensor)
573618
current_masks.append(None)
574619
continue
575-
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
620+
tensor_mask = []
621+
ablated_input = input_tensor.clone()
576622
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
577-
if isinstance(baseline, torch.Tensor):
578-
baseline = baseline.reshape(
579-
(1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape)
623+
for j, feature_idx in enumerate(feature_idxs):
624+
original_input_size = (
625+
input_tensor.shape[0] // current_num_ablated_features
580626
)
581-
assert baseline is not None, "baseline must be provided"
582-
ablated_input = (
583-
input_tensor * (1 - tensor_mask).to(input_tensor.dtype)
584-
) + (baseline * tensor_mask.to(input_tensor.dtype))
627+
start_idx = j * original_input_size
628+
end_idx = (j + 1) * original_input_size
629+
630+
mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
631+
if mask.ndim == 0:
632+
mask = mask.reshape((1,) * input_tensor.dim())
633+
tensor_mask.append(mask)
634+
635+
assert baseline is not None, "baseline must be provided"
636+
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
637+
1 - mask
638+
) + (baseline * mask.to(input_tensor.dtype))
639+
current_masks.append(torch.stack(tensor_mask, dim=0))
585640
ablated_inputs.append(ablated_input)
586-
current_masks.append(tensor_mask)
641+
587642
return tuple(ablated_inputs), tuple(current_masks)
588643

589644
def _initial_eval_to_processed_initial_eval_fut(
@@ -784,7 +839,7 @@ def _attribute_progress_setup(
784839
formatted_inputs, feature_mask, **kwargs
785840
)
786841
total_forwards = (
787-
int(sum(feature_counts))
842+
math.ceil(int(sum(feature_counts)) / perturbations_per_eval)
788843
if enable_cross_tensor_attribution
789844
else sum(
790845
math.ceil(count / perturbations_per_eval) for count in feature_counts
@@ -1194,13 +1249,46 @@ def _process_ablated_out_full(
11941249
modified_eval: Tensor,
11951250
current_mask: Tuple[Optional[Tensor], ...],
11961251
flattened_initial_eval: Tensor,
1252+
initial_eval: Tensor,
11971253
inputs: TensorOrTupleOfTensorsGeneric,
11981254
n_outputs: int,
1255+
num_examples: int,
11991256
total_attrib: List[Tensor],
12001257
weights: List[Tensor],
12011258
attrib_type: dtype,
1259+
perturbations_per_eval: int,
12021260
) -> Tuple[List[Tensor], List[Tensor]]:
12031261
modified_eval = self._parse_forward_out(modified_eval)
1262+
# if perturbations_per_eval > 1, the output shape must grow with
1263+
# input and not be aggregated
1264+
current_batch_size = inputs[0].shape[0]
1265+
1266+
# number of perturbation, which is not the same as
1267+
# perturbations_per_eval when not enough features to perturb
1268+
n_perturb = current_batch_size / num_examples
1269+
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
1270+
1271+
current_output_shape = modified_eval.shape
1272+
1273+
# use initial_eval as the forward of perturbations_per_eval = 1
1274+
initial_output_shape = initial_eval.shape
1275+
1276+
assert (
1277+
# check if the output is not a scalar
1278+
current_output_shape
1279+
and initial_output_shape
1280+
# check if the output grow in same ratio, i.e., not agg
1281+
and current_output_shape[0] == n_perturb * initial_output_shape[0]
1282+
), (
1283+
"When perturbations_per_eval > 1, forward_func's output "
1284+
"should be a tensor whose 1st dim grow with the input "
1285+
f"batch size: when input batch size is {num_examples}, "
1286+
f"the output shape is {initial_output_shape}; "
1287+
f"when input batch size is {current_batch_size}, "
1288+
f"the output shape is {current_output_shape}"
1289+
)
1290+
1291+
self._is_output_shape_valid = True
12041292

12051293
# reshape the leading dim for n_feature_perturbed
12061294
# flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1297,6 @@ def _process_ablated_out_full(
12091297
eval_diff = flattened_initial_eval - modified_eval
12101298
eval_diff_shape = eval_diff.shape
12111299

1212-
# append the shape of one input example
1213-
# to make it broadcastable to mask
1214-
12151300
if self.use_weights:
12161301
for weight, mask in zip(weights, current_mask):
12171302
if mask is not None:
@@ -1224,6 +1309,7 @@ def _process_ablated_out_full(
12241309
)
12251310
eval_diff = eval_diff.to(total_attrib[i].device)
12261311
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)
1312+
12271313
return total_attrib, weights
12281314

12291315
def _fut_tuple_to_accumulate_fut_list(

captum/attr/_core/feature_permutation.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors(
391391
inputs: Tuple[Tensor, ...],
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394-
feature_idx: int,
395-
tensor_idxs: List[int],
394+
feature_idxs: List[int],
395+
feature_idx_to_tensor_idx: Dict[int, List[int]],
396+
current_num_ablated_features: int,
396397
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397398
current_masks: List[Optional[Tensor]] = []
398-
for i, mask in enumerate(input_mask):
399-
if i in tensor_idxs:
400-
current_masks.append((mask == feature_idx).to(inputs[0].device))
401-
else:
399+
tensor_idxs = {
400+
tensor_idx
401+
for sublist in (
402+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
403+
)
404+
for tensor_idx in sublist
405+
}
406+
permuted_inputs = []
407+
for i, input_tensor in enumerate(inputs):
408+
if i not in tensor_idxs:
402409
current_masks.append(None)
403-
feature_masks = tuple(current_masks)
404-
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
405-
return permuted_outputs, feature_masks
410+
permuted_inputs.append(input_tensor)
411+
continue
412+
tensor_mask = []
413+
permuted_input = input_tensor.clone()
414+
for j, feature_idx in enumerate(feature_idxs):
415+
original_input_size = (
416+
input_tensor.shape[0] // current_num_ablated_features
417+
)
418+
start_idx = j * original_input_size
419+
end_idx = (j + 1) * original_input_size
420+
421+
mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool()
422+
if mask.ndim == 0:
423+
mask = mask.reshape((1,) * input_tensor.dim())
424+
tensor_mask.append(mask)
425+
permuted_input[start_idx:end_idx] = self.perm_func(
426+
input_tensor[start_idx:end_idx], mask
427+
)
428+
current_masks.append(torch.stack(tensor_mask, dim=0))
429+
permuted_inputs.append(permuted_input)
430+
431+
return tuple(permuted_inputs), tuple(current_masks)

0 commit comments

Comments
 (0)