Skip to content

Commit b899732

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding test for output that is a tensor of integers. Updating passthrough layer. (#1526)
Summary: Pull Request resolved: #1526 We are adding tests for different types of unsupported and non-differentiable layer output. Here, we add a test for layer output that is a tensor of integers. We split by the cases for unsupported layers from the case when the layer output is used by some tasks and not others. When layer output is not supported (layer output is a List of Tensors or a Tensor of integers), we don't get attributions and return None for those layers. In the case when a layer output is not used by a task, we should output a tensor of zeros for that task. Reviewed By: craymichael Differential Revision: D70919347 fbshipit-source-id: 191d9d69c78bcf00fa3cbbbd5707154e0f221410
1 parent 215c907 commit b899732

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

captum/testing/helpers/basic_models.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def forward(self, input1, input2, input3=None):
418418
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
419419

420420

421-
class GradientUnsupportedLayerOutput(nn.Module):
421+
class PassThroughLayerOutput(nn.Module):
422422
"""
423423
This layer is used to test the case where the model returns a layer that
424424
is not supported by the gradient computation.
@@ -428,10 +428,8 @@ def __init__(self) -> None:
428428
super().__init__()
429429

430430
@no_type_check
431-
def forward(
432-
self, unsupported_layer_output: PassThroughOutputType
433-
) -> PassThroughOutputType:
434-
return unsupported_layer_output
431+
def forward(self, output: PassThroughOutputType) -> PassThroughOutputType:
432+
return output
435433

436434

437435
class BasicModel_GradientLayerAttribution(nn.Module):
@@ -456,7 +454,7 @@ def __init__(
456454

457455
self.relu = nn.ReLU(inplace=inplace)
458456
self.relu_alt = nn.ReLU(inplace=False)
459-
self.unsupportedLayer = GradientUnsupportedLayerOutput()
457+
self.unsupported_layer = PassThroughLayerOutput()
460458

461459
self.linear2 = nn.Linear(4, 2)
462460
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
@@ -466,6 +464,8 @@ def __init__(
466464
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
467465
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468466

467+
self.int_layer = PassThroughLayerOutput() # sample layer with an int ouput
468+
469469
@no_type_check
470470
def forward(
471471
self, x: Tensor, add_input: Optional[Tensor] = None
@@ -476,7 +476,7 @@ def forward(
476476
lin1_out_alt = self.linear1_alt(lin0_out)
477477

478478
if self.unsupported_layer_output is not None:
479-
self.unsupportedLayer(self.unsupported_layer_output)
479+
self.unsupported_layer(self.unsupported_layer_output)
480480
# unsupportedLayer is unused in the forward func.
481481
self.relu_alt(
482482
lin1_out_alt
@@ -485,9 +485,10 @@ def forward(
485485
relu_out = self.relu(lin1_out)
486486
lin2_out = self.linear2(relu_out)
487487

488-
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
488+
lin3_out = self.linear3(lin1_out_alt)
489+
int_output = self.int_layer(lin3_out.to(torch.int64))
489490

490-
output_tensors = torch.cat((lin2_out, lin3_out), dim=1)
491+
output_tensors = torch.cat((lin2_out, int_output), dim=1)
491492

492493
# we return a dictionary of tensors as an output to test the case
493494
# where an output accessor is required

0 commit comments

Comments
 (0)