6
6
from time import sleep
7
7
from typing import cast , Dict , Optional
8
8
9
+ from captum .insights import AttributionVisualizer
10
+
9
11
from captum .log import log_usage
10
12
from flask import Flask , jsonify , render_template , request
11
13
from flask .wrappers import Response
@@ -44,7 +46,7 @@ def attribute() -> Response:
44
46
r = cast (Dict , request .get_json (force = True ))
45
47
return jsonify (
46
48
namedtuple_to_dict (
47
- visualizer ._calculate_attribution_from_cache ( # type: ignore
49
+ cast ( AttributionVisualizer , visualizer ) ._calculate_attribution_from_cache (
48
50
r ["inputIndex" ], r ["modelIndex" ], r ["labelIndex" ]
49
51
)
50
52
)
@@ -54,15 +56,15 @@ def attribute() -> Response:
54
56
@app .route ("/fetch" , methods = ["POST" ])
55
57
def fetch () -> Response :
56
58
# force=True needed, see comment for "/attribute" route above
57
- visualizer ._update_config (request .get_json (force = True )) # type: ignore
58
- visualizer_output = visualizer .visualize () # type: ignore
59
+ cast ( AttributionVisualizer , visualizer ) ._update_config (request .get_json (force = True ))
60
+ visualizer_output = cast ( AttributionVisualizer , visualizer ) .visualize ()
59
61
clean_output = namedtuple_to_dict (visualizer_output )
60
62
return jsonify (clean_output )
61
63
62
64
63
65
@app .route ("/init" )
64
66
def init () -> Response :
65
- return jsonify (visualizer .get_insights_config ()) # type: ignore
67
+ return jsonify (cast ( AttributionVisualizer , visualizer ) .get_insights_config ())
66
68
67
69
68
70
@app .route ("/" )
0 commit comments