Skip to content

Commit b47bf90

Browse files
authored
Fix bug in NoiseAugmentation constructor (#183)
1 parent 32cb74e commit b47bf90

File tree

6 files changed

+37
-16
lines changed

6 files changed

+37
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "0.10.0"
4+
version = "0.10.1-DEV"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ heatmap(input, analyzer, 5) # for heatmap
6969

7070
| **Analyzer** | **Heatmap for class "castle"** |**Heatmap for class "street sign"** |
7171
|:--------------------------------------------- |:------------------------------:|:----------------------------------:|
72-
| `InputTimesGradient` | ![][castle-ixg] | ![][streetsign-ixg] |
7372
| `Gradient` | ![][castle-grad] | ![][streetsign-grad] |
7473
| `SmoothGrad` | ![][castle-smoothgrad] | ![][streetsign-smoothgrad] |
7574
| `IntegratedGradients` | ![][castle-intgrad] | ![][streetsign-intgrad] |
75+
| `InputTimesGradient` | ![][castle-ixg] | ![][streetsign-ixg] |
7676

7777
> [!TIP]
7878
> The heatmaps shown above were created using a VGG-16 vision model

src/gradient.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,17 @@ function call_analyzer(
7676
end
7777

7878
"""
79-
SmoothGrad(analyzer, [n=50, std=1.0f0, rng=GLOBAL_RNG])
80-
SmoothGrad(analyzer, [n=50, distribution=Normal(0.0f0, 1.0f0), rng=GLOBAL_RNG])
79+
SmoothGrad(analyzer)
80+
SmoothGrad(analyzer, [n, std, rng]])
81+
SmoothGrad(analyzer, [n, distribution, rng])
8182
8283
Analyze model by calculating a smoothed sensitivity map.
8384
This is done by averaging sensitivity maps of a `Gradient` analyzer over random samples
84-
in a neighborhood of the input, typically by adding Gaussian noise with mean 0.
85+
in a neighborhood of the input.
86+
Defaults to 50 samples from the normal distribution with zero mean and `std=1.0f0`.
87+
88+
For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
89+
e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
8590
8691
# References
8792
- $REF_SMILKOV_SMOOTHGRAD

src/input_augmentation.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ end
99
(s::AugmentationSelector)(out) = s.indices
1010

1111
"""
12-
NoiseAugmentation(analyzer, n)
13-
NoiseAugmentation(analyzer, n, std::Real)
14-
NoiseAugmentation(analyzer, n, distribution::Sampleable)
12+
NoiseAugmentation(analyzer, n, [std::Real, rng])
13+
NoiseAugmentation(analyzer, n, [distribution::Sampleable, rng])
1514
1615
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
1716
This input augmentation is then averaged to return an `Explanation`.
17+
Defaults to the normal distribution with zero mean and `std=1.0f0`.
1818
19-
Defaults to the normal distribution `Normal(0, std^2)` with `std=1.0f0`.
2019
For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
2120
e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
2221
@@ -32,17 +31,14 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
3231
rng::R
3332

3433
function NoiseAugmentation(
35-
analyzer::A, n::Int, distribution::D, rng::R
34+
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG
3635
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
37-
n < 2 &&
38-
throw(ArgumentError("Number of noise samples `n` needs to be larger than one."))
36+
n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero."))
3937
return new{A,D,R}(analyzer, n, distribution, rng)
4038
end
4139
end
42-
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
43-
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
44-
end
45-
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
40+
function NoiseAugmentation(analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
41+
distribution = Normal(zero(T), std^2)
4642
return NoiseAugmentation(analyzer, n, distribution, rng)
4743
end
4844

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ using JET
2121
end
2222
end
2323

24+
@testset "Constructors" begin
25+
@info "Testing constructors..."
26+
include("test_constructors.jl")
27+
end
2428
@testset "CNN" begin
2529
@info "Testing analyzers on CNN..."
2630
include("test_cnn.jl")

test/test_constructors.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using ExplainableAI
2+
using Test
3+
using Distributions: Normal
4+
using StableRNGs: StableRNG
5+
6+
@testset "Optional arguments" begin
7+
distribution = Normal(0.0f0, 1.0f0)
8+
rng = StableRNG(123)
9+
10+
@test_nowarn SmoothGrad(identity, 50, distribution)
11+
@test_nowarn SmoothGrad(identity, 50, distribution, rng)
12+
13+
gradient_analyzer = Gradient(identity)
14+
@test_nowarn NoiseAugmentation(gradient_analyzer, 50, distribution)
15+
@test_nowarn NoiseAugmentation(gradient_analyzer, 50, distribution, rng)
16+
end

0 commit comments

Comments
 (0)