7
7
import torch
8
8
import torch .nn as nn
9
9
import torch .nn .functional as F
10
+ from captum ._utils .typing import PassThroughOutputType
10
11
from torch import Tensor
11
12
from torch .futures import Future
12
13
@@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
417
418
return self .linear2 (self .relu (self .linear1 (embeddings ))).sum (1 )
418
419
419
420
421
+ class GradientUnsupportedLayerOutput (nn .Module ):
422
+ """
423
+ This layer is used to test the case where the model returns a layer that
424
+ is not supported by the gradient computation.
425
+ """
426
+
427
+ def __init__ (self ) -> None :
428
+ super ().__init__ ()
429
+
430
+ @no_type_check
431
+ def forward (
432
+ self , unsupported_layer_output : PassThroughOutputType
433
+ ) -> PassThroughOutputType :
434
+ return unsupported_layer_output
435
+
436
+
437
+ class BasicModel_GradientLayerAttribution (nn .Module ):
438
+ def __init__ (
439
+ self ,
440
+ inplace : bool = False ,
441
+ unsupported_layer_output : PassThroughOutputType = None ,
442
+ ) -> None :
443
+ super ().__init__ ()
444
+ # Linear 0 is simply identity transform
445
+ self .unsupported_layer_output = unsupported_layer_output
446
+ self .linear0 = nn .Linear (3 , 3 )
447
+ self .linear0 .weight = nn .Parameter (torch .eye (3 ))
448
+ self .linear0 .bias = nn .Parameter (torch .zeros (3 ))
449
+ self .linear1 = nn .Linear (3 , 4 )
450
+ self .linear1 .weight = nn .Parameter (torch .ones (4 , 3 ))
451
+ self .linear1 .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
452
+
453
+ self .linear1_alt = nn .Linear (3 , 4 )
454
+ self .linear1_alt .weight = nn .Parameter (torch .ones (4 , 3 ))
455
+ self .linear1_alt .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
456
+
457
+ self .relu = nn .ReLU (inplace = inplace )
458
+ self .relu_alt = nn .ReLU (inplace = False )
459
+ self .unsupportedLayer = GradientUnsupportedLayerOutput ()
460
+
461
+ self .linear2 = nn .Linear (4 , 2 )
462
+ self .linear2 .weight = nn .Parameter (torch .ones (2 , 4 ))
463
+ self .linear2 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
464
+
465
+ self .linear3 = nn .Linear (4 , 2 )
466
+ self .linear3 .weight = nn .Parameter (torch .ones (2 , 4 ))
467
+ self .linear3 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
468
+
469
+ @no_type_check
470
+ def forward (self , x : Tensor , add_input : Optional [Tensor ] = None ) -> Tensor :
471
+ input = x if add_input is None else x + add_input
472
+ lin0_out = self .linear0 (input )
473
+ lin1_out = self .linear1 (lin0_out )
474
+ lin1_out_alt = self .linear1_alt (lin0_out )
475
+
476
+ if self .unsupported_layer_output is not None :
477
+ self .unsupportedLayer (self .unsupported_layer_output )
478
+ # unsupportedLayer is unused in the forward func.
479
+ self .relu_alt (
480
+ lin1_out_alt
481
+ ) # relu_alt's output is supported but it's unused in the forward func.
482
+
483
+ relu_out = self .relu (lin1_out )
484
+ lin2_out = self .linear2 (relu_out )
485
+
486
+ lin3_out = self .linear3 (lin1_out_alt ).to (torch .int64 )
487
+
488
+ return torch .cat ((lin2_out , lin3_out ), dim = 1 )
489
+
490
+
420
491
class MultiRelu (nn .Module ):
421
492
def __init__ (self , inplace : bool = False ) -> None :
422
493
super ().__init__ ()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429
500
430
501
431
502
class BasicModel_MultiLayer (nn .Module ):
432
- def __init__ (self , inplace : bool = False , multi_input_module : bool = False ) -> None :
503
+ def __init__ (
504
+ self ,
505
+ inplace : bool = False ,
506
+ multi_input_module : bool = False ,
507
+ ) -> None :
433
508
super ().__init__ ()
434
509
# Linear 0 is simply identity transform
435
510
self .multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461
536
input = x if add_input is None else x + add_input
462
537
lin0_out = self .linear0 (input )
463
538
lin1_out = self .linear1 (lin0_out )
539
+
464
540
if self .multi_input_module :
465
541
relu_out1 , relu_out2 = self .multi_relu (lin1_out , self .linear1_alt (input ))
466
542
relu_out = relu_out1 + relu_out2
0 commit comments