Skip to content

Commit 9e15c75

Browse files
styusuffacebook-github-bot
authored andcommitted
Control for when output from model is a scalar or a 1D tensor
Summary: This is to make sure that we control for when the output is not a 2D tensor Differential Revision: D69876980
1 parent 0146332 commit 9e15c75

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

captum/testing/helpers/basic_models.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,14 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
485485

486486
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487487

488-
return torch.cat((lin2_out, lin3_out), dim=1)
488+
output_tensors = torch.cat((lin2_out, lin3_out), dim=1)
489+
490+
# we return a dictionary of tensors as an output to test the case
491+
# where an output accessor is required
492+
return {
493+
"task {}".format(i + 1): output_tensors[:, i]
494+
for i in range(output_tensors.shape[1])
495+
}
489496

490497

491498
class MultiRelu(nn.Module):

0 commit comments

Comments
 (0)