Skip to content

Commit 3d4c1fd

Browse files
committed
Fix type issues
1 parent 885ed76 commit 3d4c1fd

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

captum/insights/attr_vis/server.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from time import sleep
77
from typing import cast, Dict, Optional
88

9+
from captum.insights import AttributionVisualizer
10+
911
from captum.log import log_usage
1012
from flask import Flask, jsonify, render_template, request
1113
from flask.wrappers import Response
@@ -44,7 +46,7 @@ def attribute() -> Response:
4446
r = cast(Dict, request.get_json(force=True))
4547
return jsonify(
4648
namedtuple_to_dict(
47-
visualizer._calculate_attribution_from_cache( # type: ignore
49+
cast(AttributionVisualizer, visualizer)._calculate_attribution_from_cache(
4850
r["inputIndex"], r["modelIndex"], r["labelIndex"]
4951
)
5052
)
@@ -54,15 +56,15 @@ def attribute() -> Response:
5456
@app.route("/fetch", methods=["POST"])
5557
def fetch() -> Response:
5658
# force=True needed, see comment for "/attribute" route above
57-
visualizer._update_config(request.get_json(force=True)) # type: ignore
58-
visualizer_output = visualizer.visualize() # type: ignore
59+
cast(AttributionVisualizer, visualizer)._update_config(request.get_json(force=True))
60+
visualizer_output = cast(AttributionVisualizer, visualizer).visualize()
5961
clean_output = namedtuple_to_dict(visualizer_output)
6062
return jsonify(clean_output)
6163

6264

6365
@app.route("/init")
6466
def init() -> Response:
65-
return jsonify(visualizer.get_insights_config()) # type: ignore
67+
return jsonify(cast(AttributionVisualizer, visualizer).get_insights_config())
6668

6769

6870
@app.route("/")

scripts/install_via_pip.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ while getopts 'ndfv:' flag; do
1212
d) DEPLOY=true ;;
1313
f) FRAMEWORKS=true ;;
1414
v) CHOSEN_TORCH_VERSION=${OPTARG};;
15-
*) echo "usage: $0 [-n] [-d] [-f] [-v version] [-m install_mode]" >&2
15+
*) echo "usage: $0 [-n] [-d] [-f] [-v version]" >&2
1616
exit 1 ;;
1717
esac
1818
done

tests/helpers/basic_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class BasicModel(nn.Module):
4343
def __init__(self) -> None:
4444
super().__init__()
4545

46-
def forward(self, input: int):
47-
input = 1 - F.relu(torch.tensor(1 - input))
46+
def forward(self, input: Tensor):
47+
input = 1 - F.relu(1 - input)
4848
return input
4949

5050

0 commit comments

Comments
 (0)