Skip to content

Commit cf719ae

Browse files
authored
Tested GPU support for Gradient, InputTimesGradient, SmoothGrad, IntegratedGradients (#184)
* Add ProgressMeter to NoiseAugmentation * Add GPU tests * GPU friendly implementation of SmoothGrad * Test `IntegratedGradients` as well * Ignore JET failures on pre-release
1 parent ca359ed commit cf719ae

File tree

9 files changed

+78
-16
lines changed

9 files changed

+78
-16
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
version:
2323
- 'lts'
2424
- '1'
25-
- 'pre'
25+
# - 'pre'
2626
steps:
2727
- uses: actions/checkout@v4
2828
- uses: julia-actions/setup-julia@v2

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# ExplainableAI.jl
22

3+
## Version `v0.10.2`
4+
- ![Feature][badge-feature] Tested GPU support for `Gradient`, `InputTimesGradient`, `SmoothGrad`, `IntegratedGradients` ([#184])
5+
- ![Feature][badge-feature] `NoiseAugmentation`s show a progress meter by default. Turn off via `show_progress=false` ([#184])
6+
37
## Version `v0.10.1`
48
- ![Bugfix][badge-bugfix] Fix bug in `NoiseAugmentation` constructor ([#183])
59

@@ -227,6 +231,7 @@ Performance improvements:
227231
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
228232
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/
229233

234+
[#184]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/184
230235
[#183]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/183
231236
[#180]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/180
232237
[#179]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/179

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "0.10.1"
4+
version = "0.10.2-DEV"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -16,6 +17,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1617
ADTypes = "1"
1718
DifferentiationInterface = "0.6"
1819
Distributions = "0.25"
20+
ProgressMeter = "1.10.4"
1921
Random = "<0.0.1, 1"
2022
Reexport = "1"
2123
Statistics = "<0.0.1, 1"

src/ExplainableAI.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import XAIBase: call_analyzer
66

77
using Base.Iterators
88
using Distributions: Distribution, Sampleable, Normal
9-
using Random: AbstractRNG, GLOBAL_RNG
9+
using Random: AbstractRNG, GLOBAL_RNG, rand!
10+
using ProgressMeter: Progress, next!
1011

1112
# Automatic differentiation
1213
using ADTypes: AbstractADType, AutoZygote

src/input_augmentation.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,28 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
2222
## Keyword arguments
2323
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
2424
Defaults to `GLOBAL_RNG`.
25+
- `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`.
2526
"""
2627
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
2728
AbstractXAIMethod
2829
analyzer::A
2930
n::Int
3031
distribution::D
3132
rng::R
33+
show_progress::Bool
3234

3335
function NoiseAugmentation(
34-
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG
36+
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG, show_progress=true
3537
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
3638
n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero."))
37-
return new{A,D,R}(analyzer, n, distribution, rng)
39+
return new{A,D,R}(analyzer, n, distribution, rng, show_progress)
3840
end
3941
end
40-
function NoiseAugmentation(analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
42+
function NoiseAugmentation(
43+
analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG, show_progress=true
44+
) where {T<:Real}
4145
distribution = Normal(zero(T), std^2)
42-
return NoiseAugmentation(analyzer, n, distribution, rng)
46+
return NoiseAugmentation(analyzer, n, distribution, rng, show_progress)
4347
end
4448

4549
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
@@ -48,17 +52,21 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
4852
output_indices = ns(output)
4953
output_selector = AugmentationSelector(output_indices)
5054

55+
p = Progress(aug.n; desc="Sampling NoiseAugmentation...", enabled=aug.show_progress)
56+
5157
# First augmentation
52-
input_aug = similar(input)
53-
input_aug = sample_noise!(input_aug, input, aug)
54-
expl_aug = aug.analyzer(input_aug, output_selector)
58+
noisy_input = similar(input)
59+
noisy_input = sample_noise!(noisy_input, input, aug)
60+
expl_aug = aug.analyzer(noisy_input, output_selector)
5561
sum_val = expl_aug.val
62+
next!(p)
5663

5764
# Further augmentations
5865
for _ in 2:(aug.n)
59-
input_aug = sample_noise!(input_aug, input, aug)
60-
expl_aug = aug.analyzer(input_aug, output_selector)
61-
sum_val += expl_aug.val
66+
noisy_input = sample_noise!(noisy_input, input, aug)
67+
expl_aug = aug.analyzer(noisy_input, output_selector)
68+
sum_val .+= expl_aug.val
69+
next!(p)
6270
end
6371

6472
# Average explanation
@@ -72,7 +80,9 @@ end
7280
function sample_noise!(
7381
out::A, input::A, aug::NoiseAugmentation
7482
) where {T,A<:AbstractArray{T}}
75-
out .= input .+ rand(aug.rng, aug.distribution, size(input))
83+
out = rand!(aug.rng, aug.distribution, out)
84+
out .+= input
85+
return out
7686
end
7787

7888
"""
@@ -114,9 +124,9 @@ function call_analyzer(
114124
# Further augmentations
115125
input_delta = (input - input_ref) / (aug.n - 1)
116126
for _ in 1:(aug.n)
117-
input_aug += input_delta
127+
input_aug .+= input_delta
118128
expl_aug = aug.analyzer(input_aug, output_selector)
119-
sum_val += expl_aug.val
129+
sum_val .+= expl_aug.val
120130
end
121131

122132
# Average gradients and compute explanation

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
66
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
78
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
89
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1012
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
0 Bytes
Binary file not shown.

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ using JET
3333
@info "Testing analyzers on batches..."
3434
include("test_batches.jl")
3535
end
36+
@testset "GPU tests" begin
37+
include("test_gpu.jl")
38+
end
3639
@testset "Benchmark correctness" begin
3740
@info "Testing whether benchmarks are up-to-date..."
3841
include("test_benchmarks.jl")

test/test_gpu.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using ExplainableAI
2+
using Test
3+
4+
using Flux
5+
using Metal, JLArrays
6+
7+
if Metal.functional()
8+
@info "Using Metal as GPU device"
9+
device = mtl # use Apple Metal locally
10+
else
11+
@info "Using JLArrays as GPU device"
12+
device = jl # use JLArrays to fake GPU array
13+
end
14+
15+
model = Chain(Dense(10 => 32, relu), Dense(32 => 5))
16+
input = rand(Float32, 10, 8)
17+
@test_nowarn model(input)
18+
19+
model_gpu = device(model)
20+
input_gpu = device(input)
21+
@test_nowarn model_gpu(input_gpu)
22+
23+
analyzer_types = (Gradient, SmoothGrad, InputTimesGradient, IntegratedGradients)
24+
25+
@testset "Run analyzer (CPU)" begin
26+
@testset "$A" for A in analyzer_types
27+
analyzer = A(model)
28+
expl = analyze(input, analyzer)
29+
@test expl isa Explanation
30+
end
31+
end
32+
33+
@testset "Run analyzer (GPU)" begin
34+
@testset "$A" for A in analyzer_types
35+
analyzer_gpu = A(model_gpu)
36+
expl = analyze(input_gpu, analyzer_gpu)
37+
@test expl isa Explanation
38+
end
39+
end

0 commit comments

Comments
 (0)