Skip to content

Commit cbadb62

Browse files
vchuravygdalle
andauthored
Add common error supertype (#2352) (#2361)
(cherry picked from commit 1ad6620) Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent db2ac1b commit cbadb62

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

src/errors.jl

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
const VERBOSE_ERRORS = Ref(false)
22

3-
abstract type CompilationException <: Base.Exception end
3+
"""
4+
EnzymeError
45
5-
struct EnzymeRuntimeException <: Base.Exception
6+
Common supertype for Enzyme-specific errors.
7+
8+
This type is made public so that downstream packages can add custom [error hints](https://docs.julialang.org/en/v1/base/base/#Base.Experimental.register_error_hint) for the most common exceptions thrown by Enzyme.
9+
"""
10+
abstract type EnzymeError <: Base.Exception end
11+
12+
abstract type CompilationException <: EnzymeError end
13+
14+
struct EnzymeRuntimeException <: EnzymeError
615
msg::Cstring
716
end
817

918
function Base.showerror(io::IO, ece::EnzymeRuntimeException)
19+
if isdefined(Base.Experimental, :show_error_hints)
20+
Base.Experimental.show_error_hints(io, ece)
21+
end
1022
print(io, "Enzyme execution failed.\n")
1123
msg = Base.unsafe_string(ece.msg)
1224
print(io, msg, '\n')
@@ -19,6 +31,9 @@ struct NoDerivativeException <: CompilationException
1931
end
2032

2133
function Base.showerror(io::IO, ece::NoDerivativeException)
34+
if isdefined(Base.Experimental, :show_error_hints)
35+
Base.Experimental.show_error_hints(io, ece)
36+
end
2237
print(io, "Enzyme compilation failed.\n")
2338
if ece.ir !== nothing
2439
if VERBOSE_ERRORS[]
@@ -51,6 +66,9 @@ struct IllegalTypeAnalysisException <: CompilationException
5166
end
5267

5368
function Base.showerror(io::IO, ece::IllegalTypeAnalysisException)
69+
if isdefined(Base.Experimental, :show_error_hints)
70+
Base.Experimental.show_error_hints(io, ece)
71+
end
5472
print(io, "Enzyme compilation failed due to illegal type analysis.\n")
5573
print(io, " This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].\n")
5674
print(io, " Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.\n")
@@ -78,6 +96,9 @@ struct IllegalFirstPointerException <: CompilationException
7896
end
7997

8098
function Base.showerror(io::IO, ece::IllegalFirstPointerException)
99+
if isdefined(Base.Experimental, :show_error_hints)
100+
Base.Experimental.show_error_hints(io, ece)
101+
end
81102
print(io, "Enzyme compilation failed due to an internal error (first pointer exception).\n")
82103
print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl\n")
83104
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
@@ -101,6 +122,9 @@ struct EnzymeInternalError <: CompilationException
101122
end
102123

103124
function Base.showerror(io::IO, ece::EnzymeInternalError)
125+
if isdefined(Base.Experimental, :show_error_hints)
126+
Base.Experimental.show_error_hints(io, ece)
127+
end
104128
print(io, "Enzyme compilation failed due to an internal error.\n")
105129
print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl\n")
106130
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
@@ -123,20 +147,26 @@ function Base.showerror(io::IO, ece::EnzymeInternalError)
123147
end
124148
end
125149

126-
struct EnzymeMutabilityException <: Base.Exception
150+
struct EnzymeMutabilityException <: EnzymeError
127151
msg::Cstring
128152
end
129153

130154
function Base.showerror(io::IO, ece::EnzymeMutabilityException)
155+
if isdefined(Base.Experimental, :show_error_hints)
156+
Base.Experimental.show_error_hints(io, ece)
157+
end
131158
msg = Base.unsafe_string(ece.msg)
132159
print(io, msg, '\n')
133160
end
134161

135-
struct EnzymeRuntimeActivityError <: Base.Exception
162+
struct EnzymeRuntimeActivityError <: EnzymeError
136163
msg::Cstring
137164
end
138165

139166
function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError)
167+
if isdefined(Base.Experimental, :show_error_hints)
168+
Base.Experimental.show_error_hints(io, ece)
169+
end
140170
println(io, "Constant memory is stored (or returned) to a differentiable variable.")
141171
println(
142172
io,
@@ -163,31 +193,40 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError)
163193
print(io, msg, '\n')
164194
end
165195

166-
struct EnzymeNoTypeError <: Base.Exception
196+
struct EnzymeNoTypeError <: EnzymeError
167197
msg::Cstring
168198
end
169199

170200
function Base.showerror(io::IO, ece::EnzymeNoTypeError)
201+
if isdefined(Base.Experimental, :show_error_hints)
202+
Base.Experimental.show_error_hints(io, ece)
203+
end
171204
print(io, "Enzyme cannot deduce type\n")
172205
msg = Base.unsafe_string(ece.msg)
173206
print(io, msg, '\n')
174207
end
175208

176-
struct EnzymeNoShadowError <: Base.Exception
209+
struct EnzymeNoShadowError <: EnzymeError
177210
msg::Cstring
178211
end
179212

180213
function Base.showerror(io::IO, ece::EnzymeNoShadowError)
214+
if isdefined(Base.Experimental, :show_error_hints)
215+
Base.Experimental.show_error_hints(io, ece)
216+
end
181217
print(io, "Enzyme could not find shadow for value\n")
182218
msg = Base.unsafe_string(ece.msg)
183219
print(io, msg, '\n')
184220
end
185221

186-
struct EnzymeNoDerivativeError <: Base.Exception
222+
struct EnzymeNoDerivativeError <: EnzymeError
187223
msg::Cstring
188224
end
189225

190226
function Base.showerror(io::IO, ece::EnzymeNoDerivativeError)
227+
if isdefined(Base.Experimental, :show_error_hints)
228+
Base.Experimental.show_error_hints(io, ece)
229+
end
191230
msg = Base.unsafe_string(ece.msg)
192231
print(io, msg, '\n')
193232
end
@@ -779,4 +818,3 @@ end
779818
end
780819
throw(AssertionError("Unknown errtype"))
781820
end
782-

0 commit comments

Comments
 (0)