diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 10a238561..95ef7b5e3 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -24,6 +24,9 @@ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) +ModuleOrModuleListOrModuleTuple = TypeVar( + "ModuleOrModuleListOrModuleTuple", Module, List[Module], Tuple[Module] +) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]] BaselineType = Union[None, Tensor, int, float, BaselineTupleType] diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index ba6a73d70..d8dd68616 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -16,6 +16,7 @@ ) from captum._utils.typing import ( ModuleOrModuleList, + ModuleOrModuleListOrModuleTuple, TargetType, TensorOrTupleOfTensorsGeneric, ) @@ -253,7 +254,7 @@ def attribute( if return_convergence_delta: delta: Union[Tensor, List[Tensor]] - if isinstance(self.layer, list): + if isinstance(self.layer, list) or isinstance(self.layer, tuple): delta = [] for relevance_layer in relevances: delta.append( @@ -294,7 +295,7 @@ def _get_single_output_relevance(self, layer, output): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def _get_output_relevance(self, output): - if isinstance(self.layer, list): + if isinstance(self.layer, list) or isinstance(self.layer, tuple): relevances = [] for layer in self.layer: relevances.append(self._get_single_output_relevance(layer, output))