Skip to content

Commit 7f84157

Browse files
styusuffacebook-github-bot
authored andcommitted
Adding test for output that is a tensor of integers. Updating passthrough layer.
Summary: We are adding tests for different types of unsupported and non-differentiable layer output. Here, we add a test for layer output that is a tensor of integers. Differential Revision: D70919347
1 parent 886f52f commit 7f84157

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

captum/testing/helpers/basic_models.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def forward(self, input1, input2, input3=None):
418418
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
419419

420420

421-
class GradientUnsupportedLayerOutput(nn.Module):
421+
class PassThroughLayerOutput(nn.Module):
422422
"""
423423
This layer is used to test the case where the model returns a layer that
424424
is not supported by the gradient computation.
@@ -428,10 +428,8 @@ def __init__(self) -> None:
428428
super().__init__()
429429

430430
@no_type_check
431-
def forward(
432-
self, unsupported_layer_output: PassThroughOutputType
433-
) -> PassThroughOutputType:
434-
return unsupported_layer_output
431+
def forward(self, output: PassThroughOutputType) -> PassThroughOutputType:
432+
return output
435433

436434

437435
class BasicModel_GradientLayerAttribution(nn.Module):
@@ -456,7 +454,7 @@ def __init__(
456454

457455
self.relu = nn.ReLU(inplace=inplace)
458456
self.relu_alt = nn.ReLU(inplace=False)
459-
self.unsupportedLayer = GradientUnsupportedLayerOutput()
457+
self.unsupported_layer = PassThroughLayerOutput()
460458

461459
self.linear2 = nn.Linear(4, 2)
462460
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
@@ -466,6 +464,8 @@ def __init__(
466464
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
467465
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468466

467+
self.int_layer = PassThroughLayerOutput() # sample layer with an int ouput
468+
469469
@no_type_check
470470
def forward(
471471
self, x: Tensor, add_input: Optional[Tensor] = None
@@ -476,7 +476,7 @@ def forward(
476476
lin1_out_alt = self.linear1_alt(lin0_out)
477477

478478
if self.unsupported_layer_output is not None:
479-
self.unsupportedLayer(self.unsupported_layer_output)
479+
self.unsupported_layer(self.unsupported_layer_output)
480480
# unsupportedLayer is unused in the forward func.
481481
self.relu_alt(
482482
lin1_out_alt
@@ -485,9 +485,10 @@ def forward(
485485
relu_out = self.relu(lin1_out)
486486
lin2_out = self.linear2(relu_out)
487487

488-
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
488+
lin3_out = self.linear3(lin1_out_alt)
489+
int_output = self.int_layer(lin3_out.to(torch.int64))
489490

490-
output_tensors = torch.cat((lin2_out, lin3_out), dim=1)
491+
output_tensors = torch.cat((lin2_out, int_output), dim=1)
491492

492493
# we return a dictionary of tensors as an output to test the case
493494
# where an output accessor is required

0 commit comments

Comments
 (0)