Skip to content

Commit d20df21

Browse files
authored
Update XAIBase to v3.0.0 and document heatmap_overlay (#162)
This is a breaking change since heatmap now requires an explicit import of either VisionHeatmaps.jl or TextHeatmaps.jl: * Update XAIBase to v3.0.0 * Import VisionHeatmaps in docs * Document `heatmap_overlay`
1 parent 974a934 commit d20df21

15 files changed

+60
-55
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "0.7.0"
4+
version = "0.8.0-DEV"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -16,6 +16,6 @@ Distributions = "0.25"
1616
Random = "1"
1717
Reexport = "1"
1818
Statistics = "1"
19-
XAIBase = "1.2"
19+
XAIBase = "3"
2020
Zygote = "0.6"
2121
julia = "1.6"

docs/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
1010
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
1111
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
1212
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
13+
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"
14+
15+
[compat]
16+
VisionHeatmaps = "1.4"

docs/make.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ makedocs(;
2525
modules=[XAIBase, ExplainableAI],
2626
authors="Adrian Hill",
2727
sitename="ExplainableAI.jl",
28-
format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", assets=String[]),
28+
format=Documenter.HTML(;
29+
prettyurls=get(ENV, "CI", "false") == "true",
30+
size_threshold=300_000,
31+
assets=String[],
32+
),
2933
#! format: off
3034
pages=[
3135
"Home" => "index.md",

docs/src/api.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@ All methods in ExplainableAI.jl work by calling `analyze` on an input and an ana
33
```@docs
44
analyze
55
Explanation
6-
heatmap
76
```
87

8+
For heatmapping functionality, take a look at either
9+
[VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) or
10+
[TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).
11+
Both provide `heatmap` methods for visualizing explanations,
12+
either for images or text, respectively.
13+
914
# Analyzers
1015
```@docs
1116
Gradient
@@ -27,3 +32,4 @@ InterpolationAugmentation
2732
# Index
2833
```@index
2934
```
35+

docs/src/literate/augmentations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# We build on the basics shown in the [*Getting started*](@ref docs-getting-started) section
88
# and start out by loading the same pre-trained LeNet5 model and MNIST input data:
99
using ExplainableAI
10+
using VisionHeatmaps
1011
using Flux
1112

1213
using BSON # hide

docs/src/literate/example.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ expl.val
6868

6969
# ## Heatmapping basics
7070
# Since the array `expl.val` is not very informative at first sight,
71-
# we can visualize `Explanation`s by computing a [`heatmap`](@ref):
71+
# we can visualize `Explanation`s by computing a `heatmap` using either
72+
# [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) or
73+
# [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).
74+
using VisionHeatmaps
75+
7276
heatmap(expl)
7377

7478
# If we are only interested in the heatmap, we can combine analysis and heatmapping
@@ -92,7 +96,7 @@ heatmap(expl)
9296

9397
#md # !!! note
9498
#md #
95-
#md # The output neuron can also be specified when calling [`heatmap`](@ref):
99+
#md # The output neuron can also be specified when calling `heatmap`:
96100
#md # ```julia
97101
#md # heatmap(input, analyzer, 5)
98102
#md # ```

docs/src/literate/heatmapping.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# # [Heatmapping](@id docs-heatmapping)
22
# Since numerical explanations are not very informative at first sight,
3-
# we can visualize them by computing a [`heatmap`](@ref).
3+
# we can visualize them by computing a `heatmap`, using either
4+
# [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) or
5+
# [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).
6+
#
47
# This page showcases different options and preset for heatmapping,
58
# building on the basics shown in the [*Getting started*](@ref docs-getting-started) section.
69
#
710
# We start out by loading the same pre-trained LeNet5 model and MNIST input data:
811
using ExplainableAI
12+
using VisionHeatmaps
913
using Flux
1014

1115
using BSON # hide
@@ -19,10 +23,10 @@ index = 10
1923
x, y = MNIST(Float32, :test)[10]
2024
input = reshape(x, 28, 28, 1, :)
2125

22-
convert2image(MNIST, x)
26+
img = convert2image(MNIST, x)
2327

2428
# ## Automatic heatmap presets
25-
# The function [`heatmap`](@ref) automatically applies common presets for each method.
29+
# The function `heatmap` automatically applies common presets for each method.
2630
#
2731
# Since [`InputTimesGradient`](@ref) computes attributions,
2832
# heatmaps are shown in a blue-white-red color scheme.
@@ -35,7 +39,7 @@ heatmap(input, analyzer)
3539

3640
# ## Custom heatmap settings
3741
# ### Color schemes
38-
# We can partially or fully override presets by passing keyword arguments to [`heatmap`](@ref).
42+
# We can partially or fully override presets by passing keyword arguments to `heatmap`.
3943
# For example, we can use a custom color scheme from ColorSchemes.jl using the keyword argument `colorscheme`:
4044
using ColorSchemes
4145

@@ -89,20 +93,32 @@ heatmap(expl; rangescale=:centered, colorscheme=:inferno)
8993
#-
9094
heatmap(expl; rangescale=:extrema, colorscheme=:inferno)
9195

92-
# For the full list of `heatmap` keyword arguments, refer to the [`heatmap`](@ref) documentation.
96+
# For the full list of `heatmap` keyword arguments, refer to the `heatmap` documentation.
97+
98+
# ## [Heatmap overlays](@id overlay)
99+
# Heatmaps can be overlaid onto the input image using the `heatmap_overlay` function
100+
# from VisionHeatmaps.jl.
101+
# This can be useful for visualizing the relevance of specific regions of the input:
102+
heatmap_overlay(expl, img)
103+
104+
# The alpha value of the heatmap can be adjusted using the `alpha` keyword argument:
105+
heatmap_overlay(expl, img; alpha=0.3)
106+
107+
# All previously discussed keyword arguments for `heatmap` can also be used with `heatmap_overlay`:
108+
heatmap_overlay(expl, img; alpha=0.7, colorscheme=:inferno, rangescale=:extrema)
93109

94110
# ## [Heatmapping batches](@id docs-heatmapping-batches)
95111
# Heatmapping also works with input batches.
96-
# Let's demonstrate this by using a batch of 100 images from the MNIST dataset:
97-
xs, ys = MNIST(Float32, :test)[1:100]
112+
# Let's demonstrate this by using a batch of 25 images from the MNIST dataset:
113+
xs, ys = MNIST(Float32, :test)[1:25]
98114
batch = reshape(xs, 28, 28, 1, :); # reshape to WHCN format
99115

100-
# The [`heatmap`](@ref) function automatically recognizes
116+
# The `heatmap` function automatically recognizes
101117
# that the explanation is batched and returns a `Vector` of images:
102118
heatmaps = heatmap(batch, analyzer)
103119

104120
# Image.jl's `mosaic` function can used to display them in a grid:
105-
mosaic(heatmaps; nrow=10)
121+
mosaic(heatmaps; nrow=5)
106122

107123
# When heatmapping batches, the mapping to the color scheme is applied per sample.
108124
# For example, `rangescale=:extrema` will normalize each heatmap
@@ -113,12 +129,12 @@ mosaic(heatmaps; nrow=10)
113129
# `heatmap` can be called with the keyword-argument `process_batch=true`:
114130
expl = analyze(batch, analyzer)
115131
heatmaps = heatmap(expl; process_batch=true)
116-
mosaic(heatmaps; nrow=10)
132+
mosaic(heatmaps; nrow=5)
117133

118134
# This can be useful when comparing heatmaps for fixed output neurons:
119135
expl = analyze(batch, analyzer, 7) # explain digit "6"
120136
heatmaps = heatmap(expl; process_batch=true)
121-
mosaic(heatmaps; nrow=10)
137+
mosaic(heatmaps; nrow=5)
122138

123139
#md # !!! note "Output type consistency"
124140
#md #

src/bibliography.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Gradient methods:
22
const REF_SMILKOV_SMOOTHGRAD = "Smilkov et al., *SmoothGrad: removing noise by adding noise*"
33
const REF_SUNDARARAJAN_AXIOMATIC = "Sundararajan et al., *Axiomatic Attribution for Deep Networks*"
4-
const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*"
4+
const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*"

src/gradcam.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct GradCAM{F,A} <: AbstractXAIMethod
1919
feature_layers::F
2020
adaptation_layers::A
2121
end
22-
function (analyzer::GradCAM)(input, ns::AbstractNeuronSelector)
22+
function (analyzer::GradCAM)(input, ns::AbstractOutputSelector)
2323
A = analyzer.feature_layers(input) # feature map
2424
feature_map_size = size(A, 1) * size(A, 2)
2525

src/gradient.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function gradient_wrt_input(model, input, ns::AbstractNeuronSelector)
1+
function gradient_wrt_input(model, input, ns::AbstractOutputSelector)
22
output, back = Zygote.pullback(model, input)
33
output_indices = ns(output)
44

@@ -19,7 +19,7 @@ struct Gradient{M} <: AbstractXAIMethod
1919
Gradient(model) = new{typeof(model)}(model)
2020
end
2121

22-
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
22+
function (analyzer::Gradient)(input, ns::AbstractOutputSelector)
2323
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
2424
return Explanation(grad, output, output_indices, :Gradient, :sensitivity, nothing)
2525
end
@@ -35,7 +35,7 @@ struct InputTimesGradient{M} <: AbstractXAIMethod
3535
InputTimesGradient(model) = new{typeof(model)}(model)
3636
end
3737

38-
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
38+
function (analyzer::InputTimesGradient)(input, ns::AbstractOutputSelector)
3939
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
4040
attr = input .* grad
4141
return Explanation(

src/input_augmentation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
Neuron selector that passes through an augmented neuron selection.
55
"""
6-
struct AugmentationSelector{I} <: AbstractNeuronSelector
6+
struct AugmentationSelector{I} <: AbstractOutputSelector
77
indices::I
88
end
99
(s::AugmentationSelector)(out) = s.indices
@@ -103,7 +103,7 @@ function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
103103
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
104104
end
105105

106-
function (aug::NoiseAugmentation)(input, ns::AbstractNeuronSelector)
106+
function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector)
107107
# Regular forward pass of model
108108
output = aug.analyzer.model(input)
109109
output_indices = ns(output)
@@ -142,7 +142,7 @@ struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
142142
end
143143

144144
function (aug::InterpolationAugmentation)(
145-
input, ns::AbstractNeuronSelector; input_ref=zero(input)
145+
input, ns::AbstractOutputSelector; input_ref=zero(input)
146146
)
147147
size(input) != size(input_ref) &&
148148
throw(ArgumentError("Input reference size doesn't match input size."))

test/Project.toml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,3 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1010
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1111
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1212

13-
[compat]
14-
Aqua = "0.8"
15-
Distributions = "0.25"
16-
Flux = "0.13, 0.14"
17-
JLD2 = "0.4"
18-
ReferenceTests = "0.10"
19-
Suppressor = "0.2"
20-
XAIBase = "1.2"

0 commit comments

Comments
 (0)