Skip to content

Commit 885ed76

Browse files
committed
Fixes
1 parent 26c5309 commit 885ed76

22 files changed

+52
-52
lines changed

captum/_utils/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from functools import reduce
55
from inspect import signature
6-
from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
77

88
import numpy as np
99
import torch
@@ -683,7 +683,7 @@ def _extract_device(
683683

684684

685685
def _reduce_list(
686-
val_list: List[TupleOrTensorOrBoolGeneric],
686+
val_list: Sequence[TupleOrTensorOrBoolGeneric],
687687
red_func: Callable[[List], Any] = torch.cat,
688688
) -> TupleOrTensorOrBoolGeneric:
689689
"""

captum/attr/_core/deep_lift.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
582582
def attribute(
583583
self,
584584
inputs: TensorOrTupleOfTensorsGeneric,
585-
baselines: Union[
586-
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
587-
],
585+
baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]],
588586
target: TargetType = None,
589587
additional_forward_args: Any = None,
590588
return_convergence_delta: Literal[False] = False,
@@ -595,9 +593,7 @@ def attribute(
595593
def attribute(
596594
self,
597595
inputs: TensorOrTupleOfTensorsGeneric,
598-
baselines: Union[
599-
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
600-
],
596+
baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]],
601597
target: TargetType = None,
602598
additional_forward_args: Any = None,
603599
*,

captum/influence/_core/tracincp_fast_rand_proj.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ class TracInCPFastRandProj(TracInCPFast):
867867
def __init__(
868868
self,
869869
model: Module,
870-
final_fc_layer: Union[Module, str],
870+
final_fc_layer: Module,
871871
train_dataset: Union[Dataset, DataLoader],
872872
checkpoints: Union[str, List[str], Iterator],
873873
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -884,11 +884,9 @@ def __init__(
884884
885885
model (torch.nn.Module): An instance of pytorch model. This model should
886886
define all of its layers as attributes of the model.
887-
final_fc_layer (torch.nn.Module or str): The last fully connected layer in
887+
final_fc_layer (torch.nn.Module): The last fully connected layer in
888888
the network for which gradients will be approximated via fast random
889-
projection method. Can be either the layer module itself, or the
890-
fully qualified name of the layer if it is a defined attribute of
891-
the passed `model`.
889+
projection method.
892890
train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader):
893891
In the `influence` method, we compute the influence score of
894892
training examples on examples in a test batch.

captum/insights/attr_vis/server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
import threading
66
from time import sleep
7-
from typing import Optional
7+
from typing import cast, Dict, Optional
88

99
from captum.log import log_usage
1010
from flask import Flask, jsonify, render_template, request
@@ -41,7 +41,7 @@ def namedtuple_to_dict(obj):
4141
def attribute() -> Response:
4242
# force=True needed for Colab notebooks, which doesn't use the correct
4343
# Content-Type header when forwarding requests through the Colab proxy
44-
r = request.get_json(force=True)
44+
r = cast(Dict, request.get_json(force=True))
4545
return jsonify(
4646
namedtuple_to_dict(
4747
visualizer._calculate_attribution_from_cache( # type: ignore

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
# E203: black and flake8 disagree on whitespace before ':'
33
# W503: black and flake8 disagree on how to place operators
4-
ignore = E203, W503
4+
ignore = E203, W503, E704
55
max-line-length = 88
66
exclude =
77
build, dist, tutorials, website

tests/attr/helpers/conductance_reference.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Optional, Tuple
2+
from typing import cast, Optional, Tuple, Union
33

44
import numpy as np
55
import torch
@@ -11,6 +11,7 @@
1111
from captum.attr._utils.attribution import LayerAttribution
1212
from captum.attr._utils.common import _reshape_and_sum
1313
from torch import Tensor
14+
from torch.utils.hooks import RemovableHandle
1415

1516
"""
1617
Note: This implementation of conductance follows the procedure described in the original
@@ -55,7 +56,7 @@ def forward_hook(module, inp, out):
5556
# The hidden layer tensor is assumed to have dimension (num_hidden, ...)
5657
# where the product of the dimensions >= 1 correspond to the total
5758
# number of hidden neurons in the layer.
58-
layer_size = tuple(saved_tensor.size())[1:]
59+
layer_size = tuple(cast(Tensor, saved_tensor).size())[1:]
5960
layer_units = int(np.prod(layer_size))
6061

6162
# Remove unnecessary forward hook.
@@ -101,12 +102,12 @@ def forward_hook_register_back(module, inp, out):
101102
input_grads = torch.autograd.grad(torch.unbind(output), expanded_input)
102103

103104
# Remove backwards hook
104-
back_hook.remove()
105+
cast(RemovableHandle, back_hook).remove()
105106

106107
# Remove duplicates in gradient with respect to hidden layer,
107108
# choose one for each layer_units indices.
108109
output_mid_grads = torch.index_select(
109-
saved_grads,
110+
cast(Tensor, saved_grads),
110111
0,
111112
torch.tensor(range(0, input_grads[0].shape[0], layer_units)),
112113
)
@@ -115,7 +116,7 @@ def forward_hook_register_back(module, inp, out):
115116
def attribute(
116117
self,
117118
inputs,
118-
baselines: Optional[int] = None,
119+
baselines: Union[None, int, Tensor] = None,
119120
target=None,
120121
n_steps: int = 500,
121122
method: str = "riemann_trapezoid",

tests/attr/layer/test_layer_lrp.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
# mypy: ignore-errors
23

34
from typing import Any, Tuple
45

tests/attr/models/test_pytext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self) -> None:
4343

4444

4545
class TestWordEmbeddings(unittest.TestCase):
46-
def setUp(self) -> Optional[NoReturn]:
46+
def setUp(self) -> None:
4747
if not HAS_PYTEXT:
4848
return self.skipTest("Skip the test since PyText is not installed")
4949

tests/attr/test_class_summarizer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
from typing import List
3+
24
import torch
35
from captum.attr import ClassSummarizer, CommonStats
46
from tests.helpers.basic import BaseTest
@@ -45,7 +47,7 @@ def test_classes(self) -> None:
4547
((3, 2, 10, 3), (1,)),
4648
# ((20,),),
4749
]
48-
list_of_classes = [
50+
list_of_classes: List[List] = [
4951
list(range(100)),
5052
["%d" % i for i in range(100)],
5153
list(range(300, 400)),

tests/attr/test_guided_grad_cam.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import unittest
4-
from typing import Any
4+
from typing import Any, List, Tuple, Union
55

66
import torch
77
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
@@ -107,7 +107,7 @@ def _guided_grad_cam_test_assert(
107107
model: Module,
108108
target_layer: Module,
109109
test_input: TensorOrTupleOfTensorsGeneric,
110-
expected: Tensor,
110+
expected: Union[Tensor, List, Tuple],
111111
additional_input: Any = None,
112112
interpolate_mode: str = "nearest",
113113
attribute_to_layer_input: bool = False,

tests/attr/test_input_layer_wrapper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
BasicModel_MultiLayer_TrueMultiInput,
2828
MixedKwargsAndArgsModule,
2929
)
30+
from torch.nn import Module
3031

3132
layer_methods_to_test_with_equiv = [
3233
# layer_method, equiv_method, whether or not to use multiple layers
@@ -115,7 +116,7 @@ def layer_method_with_input_layer_patches(
115116
assertTensorTuplesAlmostEqual(self, a1, real_attributions)
116117

117118
def forward_eval_layer_with_inputs_helper(
118-
self, model: ModelInputWrapper, inputs_to_test
119+
self, model: Module, inputs_to_test
119120
) -> None:
120121
# hard coding for simplicity
121122
# 0 if using args, 1 if using kwargs

tests/attr/test_lime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _lime_test_assert(
494494
model: Callable,
495495
test_input: TensorOrTupleOfTensorsGeneric,
496496
expected_attr,
497-
expected_coefs_only: Optional[Tensor] = None,
497+
expected_coefs_only: Union[None, List, Tensor] = None,
498498
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
499499
additional_input: Any = None,
500500
perturbations_per_eval: Tuple[int, ...] = (1,),

tests/attr/test_stat.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import random
3+
from typing import Callable, List
34

45
import torch
56
from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var
@@ -140,7 +141,7 @@ def test_stats_random_data(self) -> None:
140141
"sum",
141142
"mse",
142143
]
143-
gt_fns = [
144+
gt_fns: List[Callable] = [
144145
torch.mean,
145146
lambda x: torch.var(x, unbiased=False),
146147
lambda x: torch.var(x, unbiased=True),

tests/helpers/basic.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
import random
44
import unittest
5-
from typing import Callable
5+
from typing import Callable, List, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -20,9 +20,7 @@ def copy_args(*args, **kwargs):
2020
return copy_args
2121

2222

23-
def assertTensorAlmostEqual(
24-
test, actual: Tensor, expected: Tensor, delta: float = 0.0001, mode: str = "sum"
25-
) -> None:
23+
def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"):
2624
assert isinstance(actual, torch.Tensor), (
2725
"Actual parameter given for " "comparison must be a tensor."
2826
)
@@ -60,9 +58,7 @@ def assertTensorAlmostEqual(
6058
raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.")
6159

6260

63-
def assertTensorTuplesAlmostEqual(
64-
test, actual, expected, delta: float = 0.0001, mode: str = "sum"
65-
) -> None:
61+
def assertTensorTuplesAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"):
6662
if isinstance(expected, tuple):
6763
assert len(actual) == len(
6864
expected

tests/helpers/basic_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self) -> None:
4444
super().__init__()
4545

4646
def forward(self, input: int):
47-
input = 1 - F.relu(1 - input)
47+
input = 1 - F.relu(torch.tensor(1 - input))
4848
return input
4949

5050

tests/influence/_core/test_tracin_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_tracincp_fast_rand_proj_inputs(self) -> None:
7777
):
7878
TracInCPFast(
7979
net,
80-
"invalid_layer",
80+
"invalid_layer", # type: ignore
8181
train_dataset,
8282
tmpdir,
8383
loss_fn=nn.MSELoss(),

tests/robust/test_FGSM.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _FGSM_assert(
188188
inputs: TensorOrTupleOfTensorsGeneric,
189189
target: Any,
190190
epsilon: float,
191-
answer: Union[TensorLikeList, Tuple[TensorLikeList, ...]],
191+
answer: Union[List, Tuple[List, ...]],
192192
targeted: bool = False,
193193
additional_inputs: Any = None,
194194
lower_bound: float = float("-inf"),

tests/robust/test_attack_comparator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_attack_comparator_with_additional_args(self) -> None:
202202
attack_comp.reset()
203203
self.assertEqual(len(attack_comp.summary()), 0)
204204

205-
def _compare_results(self, obtained: Tensor, expected) -> None:
205+
def _compare_results(self, obtained, expected) -> None:
206206
if isinstance(expected, dict):
207207
self.assertIsInstance(obtained, dict)
208208
for key in expected:

tests/utils/models/linear_models/_test_linear_classifier.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import random
3-
from typing import Optional
3+
from typing import cast, Optional
44

55
import captum._utils.models.linear_model.model as pytorch_model_module
66
import numpy as np
@@ -98,9 +98,9 @@ def compare_to_sk_learn(
9898
o_pytorch["l1_reg"] = alpha * pytorch_h.norm(p=1, dim=-1)
9999
o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1)
100100

101-
rel_diff = (sum(o_sklearn.values()) - sum(o_pytorch.values())) / abs(
102-
sum(o_sklearn.values())
103-
)
101+
rel_diff = cast(
102+
np.ndarray, (sum(o_sklearn.values()) - sum(o_pytorch.values()))
103+
) / abs(sum(o_sklearn.values()))
104104
return (
105105
{
106106
"objective_rel_diff": rel_diff.tolist(),

tests/utils/test_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class HelpersTest(BaseTest):
88
def test_assert_tensor_almost_equal(self) -> None:
99
with self.assertRaises(AssertionError) as cm:
10-
assertTensorAlmostEqual(self, [[1.0]], [[1.0]])
10+
assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) # type: ignore
1111
self.assertEqual(
1212
cm.exception.args,
1313
("Actual parameter given for comparison must be a tensor.",),

tests/utils/test_linear_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from torch import Tensor
1313

1414

15-
def _evaluate(test_data, classifier) -> Dict[str, float]:
15+
def _evaluate(test_data, classifier) -> Dict[str, Tensor]:
1616
classifier.eval()
1717

18-
l1_loss = 0.0
19-
l2_loss = 0.0
18+
l1_loss = torch.tensor(0.0)
19+
l2_loss = torch.tensor(0.0)
2020
n = 0
2121
l2_losses = []
2222
with torch.no_grad():
@@ -67,7 +67,7 @@ def train_and_compare(
6767
model_type,
6868
xs,
6969
ys,
70-
expected_loss: Tensor,
70+
expected_loss: Union[int, float, Tensor],
7171
expected_reg: Union[float, Tensor] = 0.0,
7272
expected_hyperplane: Optional[Tensor] = None,
7373
norm_hyperplane: bool = True,

tests/utils/test_sample_gradient.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import unittest
4-
from typing import Callable, Tuple
4+
from typing import Callable, List, Tuple
55

66
import torch
77
from captum._utils.gradient import apply_gradient_requirements
@@ -110,7 +110,11 @@ def test_sample_grads_layer_modules(self) -> None:
110110

111111
# possible candidates for `layer_modules`, which are the modules whose
112112
# parameters we want to compute sample grads for
113-
layer_moduless = [[model.conv1], [model.fc1], [model.conv1, model.fc1]]
113+
layer_moduless: List[List[Module]] = [
114+
[model.conv1],
115+
[model.fc1],
116+
[model.conv1, model.fc1],
117+
]
114118
# hard coded all modules we want to check
115119
all_modules = [model.conv1, model.fc1]
116120

@@ -135,10 +139,10 @@ def test_sample_grads_layer_modules(self) -> None:
135139
# So, check that we did calculate sample grads for the desired
136140
# layers via the above checking approach.
137141
for parameter in module.parameters():
138-
assert not isinstance(parameter.sample_grad, int)
142+
assert not isinstance(parameter.sample_grad, int) # type: ignore
139143
else:
140144
# For the layers we do not want sample grads for, their
141145
# `sample_grad` should still be 0, since they should not have been
142146
# over-written.
143147
for parameter in module.parameters():
144-
assert parameter.sample_grad == 0
148+
assert parameter.sample_grad == 0 # type: ignore

0 commit comments

Comments
 (0)