Skip to content

Commit 6903fbe

Browse files
committed
Fixes
1 parent e20a566 commit 6903fbe

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

captum/attr/_core/deep_lift.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,9 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
582582
def attribute(
583583
self,
584584
inputs: TensorOrTupleOfTensorsGeneric,
585-
baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]],
585+
baselines: Union[
586+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
587+
],
586588
target: TargetType = None,
587589
additional_forward_args: Any = None,
588590
return_convergence_delta: Literal[False] = False,
@@ -593,7 +595,9 @@ def attribute(
593595
def attribute(
594596
self,
595597
inputs: TensorOrTupleOfTensorsGeneric,
596-
baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]],
598+
baselines: Union[
599+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
600+
],
597601
target: TargetType = None,
598602
additional_forward_args: Any = None,
599603
*,

tests/attr/test_deeplift_classification.py

+5
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def softmax_classification(
155155
target: TargetType,
156156
) -> None:
157157
# TODO add test cases for multiple different layers
158+
if isinstance(attr_method, DeepLiftShap):
159+
assert isinstance(
160+
baselines, Tensor
161+
), "Non-tensor baseline not supported for DeepLiftShap"
162+
158163
model.zero_grad()
159164
attributions, delta = attr_method.attribute(
160165
input, baselines=baselines, target=target, return_convergence_delta=True

tests/helpers/basic.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def copy_args(*args, **kwargs):
1919
return copy_args
2020

2121

22-
def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"):
22+
def assertTensorAlmostEqual(
23+
test, actual, expected, delta: float = 0.0001, mode: str = "sum"
24+
):
2325
assert isinstance(actual, torch.Tensor), (
2426
"Actual parameter given for " "comparison must be a tensor."
2527
)
@@ -57,7 +59,9 @@ def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"):
5759
raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.")
5860

5961

60-
def assertTensorTuplesAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"):
62+
def assertTensorTuplesAlmostEqual(
63+
test, actual, expected, delta: float = 0.0001, mode: str = "sum"
64+
):
6165
if isinstance(expected, tuple):
6266
assert len(actual) == len(
6367
expected

tests/robust/test_attack_comparator.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
import collections
3-
from typing import List
3+
from typing import Dict, List, Tuple, Union
44

55
import torch
66
from captum.robust import AttackComparator, FGSM
@@ -202,7 +202,9 @@ def test_attack_comparator_with_additional_args(self) -> None:
202202
attack_comp.reset()
203203
self.assertEqual(len(attack_comp.summary()), 0)
204204

205-
def _compare_results(self, obtained, expected) -> None:
205+
def _compare_results(
206+
self, obtained: Union[Dict, Tuple, Tensor], expected: Union[Dict, Tuple, Tensor]
207+
) -> None:
206208
if isinstance(expected, dict):
207209
self.assertIsInstance(obtained, dict)
208210
for key in expected:

0 commit comments

Comments
 (0)