Skip to content

Commit d39a8be

Browse files
styusuffacebook-github-bot
authored andcommitted
Control for when output from model is a scalar or a 1D tensor (pytorch#1521)
Summary: This is to make sure that we control for when the output is not a 2D tensor We also include an output accessor that parses a dictionary model output to get final output. Differential Revision: D69876980
1 parent 4ca5c2c commit d39a8be

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

captum/testing/helpers/basic_models.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import no_type_check, Optional, Tuple, Union
5+
from typing import Dict, no_type_check, Optional, Tuple, Union
66

77
import torch
88
import torch.nn as nn
@@ -467,7 +467,9 @@ def __init__(
467467
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468468

469469
@no_type_check
470-
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
470+
def forward(
471+
self, x: Tensor, add_input: Optional[Tensor] = None
472+
) -> Dict[str, Tensor]:
471473
input = x if add_input is None else x + add_input
472474
lin0_out = self.linear0(input)
473475
lin1_out = self.linear1(lin0_out)
@@ -485,7 +487,14 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
485487

486488
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487489

488-
return torch.cat((lin2_out, lin3_out), dim=1)
490+
output_tensors = torch.cat((lin2_out, lin3_out), dim=1)
491+
492+
# we return a dictionary of tensors as an output to test the case
493+
# where an output accessor is required
494+
return {
495+
"task {}".format(i + 1): output_tensors[:, i]
496+
for i in range(output_tensors.shape[1])
497+
}
489498

490499

491500
class MultiRelu(nn.Module):

0 commit comments

Comments
 (0)