@@ -418,7 +418,7 @@ def forward(self, input1, input2, input3=None):
418
418
return self .linear2 (self .relu (self .linear1 (embeddings ))).sum (1 )
419
419
420
420
421
- class GradientUnsupportedLayerOutput (nn .Module ):
421
+ class PassThroughLayerOutput (nn .Module ):
422
422
"""
423
423
This layer is used to test the case where the model returns a layer that
424
424
is not supported by the gradient computation.
@@ -428,10 +428,8 @@ def __init__(self) -> None:
428
428
super ().__init__ ()
429
429
430
430
@no_type_check
431
- def forward (
432
- self , unsupported_layer_output : PassThroughOutputType
433
- ) -> PassThroughOutputType :
434
- return unsupported_layer_output
431
+ def forward (self , output : PassThroughOutputType ) -> PassThroughOutputType :
432
+ return output
435
433
436
434
437
435
class BasicModel_GradientLayerAttribution (nn .Module ):
@@ -456,7 +454,7 @@ def __init__(
456
454
457
455
self .relu = nn .ReLU (inplace = inplace )
458
456
self .relu_alt = nn .ReLU (inplace = False )
459
- self .unsupportedLayer = GradientUnsupportedLayerOutput ()
457
+ self .unsupported_layer = PassThroughLayerOutput ()
460
458
461
459
self .linear2 = nn .Linear (4 , 2 )
462
460
self .linear2 .weight = nn .Parameter (torch .ones (2 , 4 ))
@@ -466,6 +464,8 @@ def __init__(
466
464
self .linear3 .weight = nn .Parameter (torch .ones (2 , 4 ))
467
465
self .linear3 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
468
466
467
+ self .int_layer = PassThroughLayerOutput () # sample layer with an int ouput
468
+
469
469
@no_type_check
470
470
def forward (
471
471
self , x : Tensor , add_input : Optional [Tensor ] = None
@@ -476,7 +476,7 @@ def forward(
476
476
lin1_out_alt = self .linear1_alt (lin0_out )
477
477
478
478
if self .unsupported_layer_output is not None :
479
- self .unsupportedLayer (self .unsupported_layer_output )
479
+ self .unsupported_layer (self .unsupported_layer_output )
480
480
# unsupportedLayer is unused in the forward func.
481
481
self .relu_alt (
482
482
lin1_out_alt
@@ -485,9 +485,10 @@ def forward(
485
485
relu_out = self .relu (lin1_out )
486
486
lin2_out = self .linear2 (relu_out )
487
487
488
- lin3_out = self .linear3 (lin1_out_alt ).to (torch .int64 )
488
+ lin3_out = self .linear3 (lin1_out_alt )
489
+ int_output = self .int_layer (lin3_out .to (torch .int64 ))
489
490
490
- output_tensors = torch .cat ((lin2_out , lin3_out ), dim = 1 )
491
+ output_tensors = torch .cat ((lin2_out , int_output ), dim = 1 )
491
492
492
493
# we return a dictionary of tensors as an output to test the case
493
494
# where an output accessor is required
0 commit comments