Skip to content

Commit dbd5372

Browse files
authored
Implicit AD: Corrected NestedAD error checking, added error tests and custom error displays (#518)
* Fixed NestedAD check * Implement AD error checks * Add custom showerror functions for NestedADError and MultipleTagError * fixed error with AD error checking tests
1 parent 8c95503 commit dbd5372

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

src/base/errors.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,14 @@ struct NestedADError <: Exception
9292
msg::String
9393
end
9494
NestedADError() = NestedADError("")
95+
function Base.showerror(io::IO, e::NestedADError)
96+
print(io, "NestedADError: ", e.msg)
97+
end
9598

9699
struct MultipleTagError <: Exception
97100
msg::String
98101
end
99-
MultipleTagError() = MultipleTagError("")
102+
MultipleTagError() = MultipleTagError("")
103+
function Base.showerror(io::IO, e::MultipleTagError)
104+
print(io, "MultipleTagError: ", e.msg)
105+
end

src/methods/differentials.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ function nested_ad_check(a::A) where A
280280
AT = eltype(a)
281281
if AT <: ForwardDiff.Dual
282282
V = ForwardDiff.valtype(AT)
283-
V isa ForwardDiff.Dual && throw(NestedADError("Found nested Duals of type $AT. This is currently not supported in implicit differentiation."))
283+
V <: ForwardDiff.Dual && throw(NestedADError("Found nested Duals of type $AT. This is currently not supported in implicit differentiation."))
284284
end
285285
return nothing
286286
end
@@ -319,4 +319,4 @@ end
319319
function implicit_ad_check(a::Tuple)
320320
nested_ad_check(a)
321321
multiple_tag_ad_check(a)
322-
end
322+
end

test/test_differentials.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,29 @@ end
8282
@test all(ddp[2] .≈ dp)
8383
@test all(ddp[1] .≈ d2p)
8484
end
85+
86+
@testset "AD Error Checks" begin
87+
import ForwardDiff
88+
import Clapeyron: MultipleTagError, NestedADError, __gradients_for_root_finders
89+
f_test(x,tups) = tups[1]*tups[2]*x # just a generic function to check error handling
90+
tag1 = ForwardDiff.Tag{:tag1,Float64}
91+
tag2 = ForwardDiff.Tag{:tag2,Float64}
92+
Tdual1 = ForwardDiff.Dual{tag1,Float64,1}
93+
Tdual2 = ForwardDiff.Dual{tag2,Float64,1}
94+
parts = ForwardDiff.Partials{1,Float64}((1.0,))
95+
theta1,theta2 = 0.5,2.0;
96+
x = 0.0
97+
x_dual = ForwardDiff.Dual{tag1,Float64,1}(x,parts)
98+
tups_primal = (theta1,theta2)
99+
tups_Dual2 = (Tdual1(theta1,parts),Tdual2(theta2,parts))
100+
tups_nestedDual = (ForwardDiff.Dual{tag2,Tdual1,1}(Tdual1(theta1,parts)),theta2)
101+
# Test multiple tag error
102+
@test_throws MultipleTagError __gradients_for_root_finders(x,tups_Dual2,tups_primal,f_test)
103+
# Test nested dual error
104+
@test_throws NestedADError __gradients_for_root_finders(x,tups_nestedDual,tups_primal,f_test)
105+
# Test dual as x error
106+
@test_throws ErrorException __gradients_for_root_finders(x_dual,(theta1,theta2),tups_primal,f_test)
107+
end
85108
end
86109

87110

0 commit comments

Comments
 (0)