Skip to content

Commit 215c907

Browse files
styusuffacebook-github-bot
authored andcommitted
Control for when output from model is a scalar or a 1D tensor (#1521)
Summary: Pull Request resolved: #1521 This is to make sure that we control for when the output is not a 2D tensor * If the shape of the model output is a 0D, it would fail since LayerGradientXActivation always [assumes](https://www.internalfb.com/code/fbsource/[ffa152e31f81]/fbcode/pytorch/captum/captum/_utils/gradient.py?lines=681) that output (output[0] would raise an index error) for a task is 1D. * I propose we raise an assertion error if output is 0D and ask the user to edit output or output accessor to ensure output > 0D. * If the model output shape is a 1D, it could either be of size (batch_size) when there’s one target or (n_targets) when there’s only one observation with multiple targets or some kind of aggregated batch loss across multiple targets * When it’s size (batch_size), we can assume there’s just one target and get attributions without passing in a target. * When it’s size (n_targets), there will be an issue when we call LayerGradientXActivation since we will need to pass in the target parameter to get attribution for each target. * We cannot pass in a target when the output is a 1D tensor. LayerGradientXActivation [checks that the output dimension is 2D](https://www.internalfb.com/code/fbsource/[ffa152e31f81]/fbcode/pytorch/captum/captum/_utils/common.py?lines=700-701) * The output needs to be 2D with the shape (1 x n_targets). That needs to be done on the output_accessor or forward function to make sure LayerGradientXActivation can account for it. * We could check whether output.shape[0] = inputs.shape[0]. If this is the case, we know that the 1D tensor is for one target. If not, then it’s for multiple targets. We could throw an error in the latter case to inform the user that output needs to be 2D if attributing over multiple targets. I worry that this is assuming too much and the assumption would break if there are multiple targets for 1D case but batch_size = n_targets. In this case, we would automatically assume that there's only one target when maybe there isn't. * I propose that we keep the assumption that 1D tensor is for one target. In the case that the 1D tensor is for multiple targets, it would fail LayerGradientXActivation anyway unless it’s converted to 2D. We also include an output accessor that parses a dictionary model output to get 1D tensor for testing. Reviewed By: vivekmig Differential Revision: D69876980 fbshipit-source-id: 4c64410c1d0a25f2819f2da31f303f5fe710d3e1
1 parent b917b2a commit 215c907

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

captum/testing/helpers/basic_models.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import no_type_check, Optional, Tuple, Union
5+
from typing import Dict, no_type_check, Optional, Tuple, Union
66

77
import torch
88
import torch.nn as nn
@@ -467,7 +467,9 @@ def __init__(
467467
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468468

469469
@no_type_check
470-
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
470+
def forward(
471+
self, x: Tensor, add_input: Optional[Tensor] = None
472+
) -> Dict[str, Tensor]:
471473
input = x if add_input is None else x + add_input
472474
lin0_out = self.linear0(input)
473475
lin1_out = self.linear1(lin0_out)
@@ -485,7 +487,14 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
485487

486488
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487489

488-
return torch.cat((lin2_out, lin3_out), dim=1)
490+
output_tensors = torch.cat((lin2_out, lin3_out), dim=1)
491+
492+
# we return a dictionary of tensors as an output to test the case
493+
# where an output accessor is required
494+
return {
495+
"task {}".format(i + 1): output_tensors[:, i]
496+
for i in range(output_tensors.shape[1])
497+
}
489498

490499

491500
class MultiRelu(nn.Module):

0 commit comments

Comments
 (0)