From a91c1a35e847a2eda7e8670b79e79c7dc4f7dff3 Mon Sep 17 00:00:00 2001 From: Quoding Date: Thu, 15 Aug 2024 15:32:33 -0400 Subject: [PATCH] Add support for tuple of layers (instead of strictly list) in LayerLRP --- captum/_utils/typing.py | 3 +++ captum/attr/_core/layer/layer_lrp.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index d9ac6304c8..b11390ceb2 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -18,6 +18,9 @@ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) +ModuleOrModuleListOrModuleTuple = TypeVar( + "ModuleOrModuleListOrModuleTuple", Module, List[Module], Tuple[Module] +) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]] diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index 7bd2721328..7cc8846328 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -17,6 +17,7 @@ from captum._utils.typing import ( Literal, ModuleOrModuleList, + ModuleOrModuleListOrModuleTuple, TargetType, TensorOrTupleOfTensorsGeneric, ) @@ -264,7 +265,7 @@ def attribute( if return_convergence_delta: delta: Union[Tensor, List[Tensor]] - if isinstance(self.layer, list): + if isinstance(self.layer, list) or isinstance(self.layer, tuple): delta = [] for relevance_layer in relevances: delta.append( @@ -305,7 +306,7 @@ def _get_single_output_relevance(self, layer, output): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def _get_output_relevance(self, output): - if isinstance(self.layer, list): + if isinstance(self.layer, list) or isinstance(self.layer, tuple): relevances = [] for layer in self.layer: relevances.append(self._get_single_output_relevance(layer, output))