Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl
index af8d60e4..219eb3e4 100644
--- a/ext/EnzymeChainRulesCoreExt.jl
+++ b/ext/EnzymeChainRulesCoreExt.jl
@@ -204,23 +204,27 @@ function Enzyme._import_rrule(fn, tys...)
shadow, byref = if !EnzymeRules.needs_shadow(config)
nothing, Val(false)
elseif !Enzyme.Compiler.guaranteed_nonactive(Core.Typeof(res))
- (if EnzymeRules.width(config) == 1
- Ref(Enzyme.make_zero(res))
- else
- ntuple(Val(EnzymeRules.width(config))) do j
- Base.@_inline_meta
+ (
+ if EnzymeRules.width(config) == 1
Ref(Enzyme.make_zero(res))
- end
- end, Val(true))
- else
- (if EnzymeRules.width(config) == 1
+ else
+ ntuple(Val(EnzymeRules.width(config))) do j
+ Base.@_inline_meta
+ Ref(Enzyme.make_zero(res))
+ end
+ end, Val(true),
+ )
+ else
+ (
+ if EnzymeRules.width(config) == 1
Enzyme.make_zero(res)
else
ntuple(Val(EnzymeRules.width(config))) do j
Base.@_inline_meta
Enzyme.make_zero(res)
end
- end, Val(false))
+ end, Val(false),
+ )
end
cache = (shadow, pullback, byref)
diff --git a/lib/EnzymeCore/src/easyrules.jl b/lib/EnzymeCore/src/easyrules.jl
index a8ddeada..e6941b54 100644
--- a/lib/EnzymeCore/src/easyrules.jl
+++ b/lib/EnzymeCore/src/easyrules.jl
@@ -479,7 +479,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
PT = EnzymeRules.primal_type(config, ($(esc(:RTA))).parameters[1])
ST = EnzymeRules.shadow_type(config, ($(esc(:RTA))).parameters[1])
- AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT,$ST,typeof(cache)})
+ AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT, $ST, typeof(cache)})
genres = if needs_primal(config)
if needs_shadow(config)
diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl
index 2d5259e0..7284abab 100644
--- a/lib/EnzymeCore/src/rules.jl
+++ b/lib/EnzymeCore/src/rules.jl
@@ -133,8 +133,8 @@ Compute the expected primal return type given a reverse mode config and return a
"""
@inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
@inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
-@inline primal_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
-@inline primal_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
+@inline primal_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_primal(config) ? RT : Nothing
+@inline primal_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_primal(config) ? RT : Nothing
"""
shadow_type(::FwdConfig, ::Type{<:Annotation{RT}})
@@ -146,8 +146,8 @@ Compute the expected shadow return type given a reverse mode config and return a
"""
@inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
@inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
-@inline shadow_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
-@inline shadow_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
+@inline shadow_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
+@inline shadow_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
"""
@@ -173,7 +173,7 @@ Otherwise, just return the shadows.
needsPrimal = EnzymeRules.needs_primal(C)
needsShadow = EnzymeRules.needs_shadow(C)
width = EnzymeRules.width(C)
- if !needsShadow
+ return if !needsShadow
if needsPrimal
return RealRt
else
@@ -184,14 +184,14 @@ Otherwise, just return the shadows.
if !needsPrimal
ST = RealRt
if width != 1
- ST = NTuple{Int(width),ST}
+ ST = NTuple{Int(width), ST}
end
return ST
else
ST = if width == 1
Duplicated{RealRt}
else
- BatchDuplicated{RealRt,Int(width)}
+ BatchDuplicated{RealRt, Int(width)}
end
return ST
end
@@ -218,17 +218,17 @@ struct AugmentedReturn{PrimalType,ShadowType,TapeType}
tape::TapeType
end
-@inline function AugmentedReturn{PrimalType,ShadowType}(primal, shadow, cache) where {PrimalType, ShadowType}
- AT = AugmentedReturn{PrimalType,ShadowType, typeof(cache)}
+@inline function AugmentedReturn{PrimalType, ShadowType}(primal, shadow, cache) where {PrimalType, ShadowType}
+ AT = AugmentedReturn{PrimalType, ShadowType, typeof(cache)}
return AT(primal, shadow, cache)
end
@inline primal_type(::Type{<:AugmentedReturn{PrimalType}}) where {PrimalType} = PrimalType
@inline primal_type(::AugmentedReturn{PrimalType}) where {PrimalType} = PrimalType
-@inline shadow_type(::Type{<:AugmentedReturn{<:Any,ShadowType}}) where {ShadowType} = ShadowType
-@inline shadow_type(::AugmentedReturn{<:Any,ShadowType}) where {ShadowType} = ShadowType
-@inline tape_type(::Type{<:AugmentedReturn{<:Any,<:Any,TapeType}}) where {TapeType} = TapeType
-@inline tape_type(::AugmentedReturn{<:Any,<:Any,TapeType}) where {TapeType} = TapeType
+@inline shadow_type(::Type{<:AugmentedReturn{<:Any, ShadowType}}) where {ShadowType} = ShadowType
+@inline shadow_type(::AugmentedReturn{<:Any, ShadowType}) where {ShadowType} = ShadowType
+@inline tape_type(::Type{<:AugmentedReturn{<:Any, <:Any, TapeType}}) where {TapeType} = TapeType
+@inline tape_type(::AugmentedReturn{<:Any, <:Any, TapeType}) where {TapeType} = TapeType
struct AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}
primal::PrimalType
shadow::ShadowType
@@ -461,10 +461,12 @@ This function is currently considered internal/experimental and may not respect
"""
function inactive_kwarg end
-function is_inactive_kwarg_from_sig(@nospecialize(TT);
- world::UInt=Base.get_world_counter(),
- method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
- caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
+function is_inactive_kwarg_from_sig(
+ @nospecialize(TT);
+ world::UInt = Base.get_world_counter(),
+ method_table::Union{Nothing, Core.Compiler.MethodTableView} = nothing,
+ caller::Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
+ )
return isapplicable(inactive_kwarg, TT; world, method_table, caller)
end
@@ -478,10 +480,12 @@ This function is currently considered internal/experimental and may not respect
"""
function inactive_arg end
-function is_inactive_arg_from_sig(@nospecialize(TT);
- world::UInt=Base.get_world_counter(),
- method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
- caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
+function is_inactive_arg_from_sig(
+ @nospecialize(TT);
+ world::UInt = Base.get_world_counter(),
+ method_table::Union{Nothing, Core.Compiler.MethodTableView} = nothing,
+ caller::Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
+ )
return isapplicable(inactive_arg, TT; world, method_table, caller)
end
diff --git a/src/compiler.jl b/src/compiler.jl
index cec43e9e..c6340bf9 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -410,7 +410,7 @@ const JuliaEnzymeNameMap = Dict{String,Any}(
"enz_no_derivative_exc" => EnzymeNoDerivativeError{Nothing, Nothing},
"enz_no_derivative_mi_exc" => EnzymeNoDerivativeError{Core.MethodInstance, UInt},
"enz_non_const_kwarg_exc" => NonConstantKeywordArgException,
- "enz_callconv_mismatch_exc"=> CallingConventionMismatchError,
+ "enz_callconv_mismatch_exc" => CallingConventionMismatchError,
"enz_illegal_ta_exc" => IllegalTypeAnalysisException,
"enz_illegal_first_pointer_exc" => IllegalFirstPointerException,
"enz_internal_exc" => EnzymeInternalError,
@@ -6318,12 +6318,12 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no
rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}}
add_edge!(edges, rev_sig)
end
-
+
for gen_sig in (
- Tuple{typeof(EnzymeRules.inactive), Vararg{Any}},
+ Tuple{typeof(EnzymeRules.inactive), Vararg{Any}},
Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
- Tuple{typeof(EnzymeRules.inactive_arg), Vararg{Any}},
- Tuple{typeof(EnzymeRules.inactive_kwarg), Vararg{Any}},
+ Tuple{typeof(EnzymeRules.inactive_arg), Vararg{Any}},
+ Tuple{typeof(EnzymeRules.inactive_kwarg), Vararg{Any}},
Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
Tuple{typeof(EnzymeRules.inactive_type), Type},
)
diff --git a/src/errors.jl b/src/errors.jl
index cd6aba94..aec96f5a 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -11,7 +11,7 @@ abstract type EnzymeError <: Base.Exception end
abstract type CompilationException <: EnzymeError end
-function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
+function pretty_print_mi(mi, io = stdout; digit_align_width = 1)
spec = mi.specTypes.parameters
ft = spec[1]
arg_types_param = spec[2:end]
@@ -26,7 +26,7 @@ function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
end
Base.show_signature_function(io, ft)
- Base.show_tuple_as_call(io, :function, Tuple{arg_types_param...}; hasfirst=false, kwargs = isempty(kwargs) ? nothing : kwargs)
+ Base.show_tuple_as_call(io, :function, Tuple{arg_types_param...}; hasfirst = false, kwargs = isempty(kwargs) ? nothing : kwargs)
m = mi.def
@@ -43,12 +43,12 @@ function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
end
# module & file, re-using function from errorshow.jl
- Base.print_module_path_file(io, Base.parentmodule(m), string(file), line; modulecolor, digit_align_width)
+ return Base.print_module_path_file(io, Base.parentmodule(m), string(file), line; modulecolor, digit_align_width)
end
using InteractiveUtils
-function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool=false, kwargs...)
+function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool = false, kwargs...)
CT = @static if VERSION >= v"1.11.0-DEV.1552"
EnzymeCacheToken(
typeof(DefaultCompilerTarget()),
@@ -71,10 +71,10 @@ function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.AP
interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
sig = mi.specTypes # XXX: can we just use the method instance?
- if interactive
+ return if interactive
# call Cthulhu without introducing a dependency on Cthulhu
mod = get(Base.loaded_modules, Cthulhu, nothing)
- mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
+ mod === nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
descend_code_typed = getfield(mod, :descend_code_typed)
descend_code_typed(sig; interp, kwargs...)
else
@@ -120,7 +120,7 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeExceptionMI)
)
println(io)
msg = Base.unsafe_string(ece.msg)
- print(io, msg, '\n')
+ return print(io, msg, '\n')
end
abstract type CustomRuleError <: Base.Exception end
@@ -148,7 +148,7 @@ function Base.showerror(io::IO, ece::NonConstantKeywordArgException)
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
struct CallingConventionMismatchError <: CustomRuleError
@@ -174,7 +174,7 @@ function Base.showerror(io::IO, ece::CallingConventionMismatchError)
)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
InteractiveUtils.code_typed(ece::CallingConventionMismatchError; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
@@ -236,7 +236,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C)
if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
hint = "Shadow was not requested, you should only return the primal"
- elseif fwd_RT <: (NTuple{N, <:RealRt} where N)
+ elseif fwd_RT <: (NTuple{N, <:RealRt} where {N})
hint = "You appear to be returning a tuple of shadows, but only the primal was requested"
elseif fwd_RT <: RealRt
hint = "Expected the abstract type $RealRt for primal, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -249,7 +249,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
hint = "Primal was not requested, you should only return the shadow"
elseif width == 1
- if fwd_RT <: (NTuple{N, <:RealRt} where N)
+ if fwd_RT <: (NTuple{N, <:RealRt} where {N})
hint = "You look to be returning a tuple of shadows, when the batch size is 1"
elseif fwd_RT <: RealRt
hint = "Expected the abstract type $RealRt for shadow, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -273,11 +273,11 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
hint = "Neither primal nor shadow were requested, you should return nothing, not both the primal and shadow"
- elseif fwd_RT <: (NTuple{N, <:RealRt} where N)
+ elseif fwd_RT <: (NTuple{N, <:RealRt} where {N})
hint = "You appear to be returning a tuple of shadows, but neither primal nor shadow were requested"
elseif fwd_RT <: RealRt && width == 1
hint = "You appear to be returning a primal or shadow, but neither were requested"
- elseif fwd_RT <: RealRt
+ elseif fwd_RT <: RealRt
hint = "You appear to be returning a primal, but it was not requested"
else
hint = "You should return nothing"
@@ -309,7 +309,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
@@ -345,7 +345,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
elseif EnzymeRules.primal_type(fwd_RT) == Nothing
hint = "Missing primal return"
elseif EnzymeRules.shadow_type(fwd_RT) == Nothing
- hint = "Missing shadow return"
+ hint = "Missing shadow return"
elseif EnzymeRules.primal_type(fwd_RT) != RealRt
if EnzymeRules.primal_type(fwd_RT) <: RealRt
hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -356,7 +356,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
if width == 1
if EnzymeRules.shadow_type(fwd_RT) <: RealRt
hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
- elseif shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+ elseif shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
hint = "Batch size was 1, expected a single shadow, not a tuple of shadows."
else
hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -364,7 +364,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
else
if EnzymeRules.shadow_type(fwd_RT) <: RealRt
hint = "Batch size was $width, expected a tuple of shadows, not a single shadow."
- elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+ elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
else
hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -375,11 +375,11 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
"primal and shadow configuration"
elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C)
if !(fwd_RT <: EnzymeRules.AugmentedReturn)
- hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
+ hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
elseif EnzymeRules.primal_type(fwd_RT) == Nothing
hint = "Missing primal return"
elseif EnzymeRules.shadow_type(fwd_RT) != Nothing
- hint = "Shadow return was not requested"
+ hint = "Shadow return was not requested"
elseif EnzymeRules.primal_type(fwd_RT) != RealRt
if EnzymeRules.primal_type(fwd_RT) <: RealRt
hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -392,14 +392,14 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
elseif !EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C)
if !(fwd_RT <: EnzymeRules.AugmentedReturn)
- hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
+ hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
elseif EnzymeRules.primal_type(fwd_RT) != Nothing
hint = "Primal was not requested"
elseif EnzymeRules.shadow_type(fwd_RT) != RealRt
if width == 1
if EnzymeRules.shadow_type(fwd_RT) <: RealRt
hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
- elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+ elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
hint = "Batch size was 1, expected a single shadow, not a tuple of shadows."
else
hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -407,7 +407,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
else
if EnzymeRules.shadow_type(fwd_RT) <: RealRt
hint = "Batch size was $width, expected a tuple of shadows, not a single shadow."
- elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+ elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
else
hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -418,7 +418,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
"shadow-only configuration"
else
if !(fwd_RT <: EnzymeRules.AugmentedReturn)
- hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
+ hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
elseif EnzymeRules.primal_type(fwd_RT) != Nothing
hint = "Primal was not requested"
elseif EnzymeRules.shadow_type(fwd_RT) != Nothing
@@ -454,7 +454,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
@@ -469,7 +469,7 @@ InteractiveUtils.code_typed(ece::ReverseRuleReturnError; kwargs...) = code_typed
function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT}) where {C, ArgAct, rev_RT}
width = EnzymeRules.width(C)
Tys = (
- A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width),eltype(A)}) : Nothing for A in ArgAct.parameters
+ A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in ArgAct.parameters
)
ExpRT = Tuple{Tys...}
@assert ExpRT != rev_RT
@@ -500,7 +500,7 @@ function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT})
if width == 1
- if rev_RT.parameters[i] <: (NTuple{N, ExpRT.parameters[i]} where N)
+ if rev_RT.parameters[i] <: (NTuple{N, ExpRT.parameters[i]} where {N})
hint = "Tuple return mismatch at index $i, returned a tuple of results when expected just one of type $(ExpRT.parameters[i])."
break
end
@@ -550,7 +550,7 @@ function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT})
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
struct MixedReturnException{RT} <: CustomRuleError
@@ -561,7 +561,7 @@ end
InteractiveUtils.code_typed(ece::MixedReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
-function Base.showerror(io::IO, ece::MixedReturnException{RT}) where RT
+function Base.showerror(io::IO, ece::MixedReturnException{RT}) where {RT}
if isdefined(Base.Experimental, :show_error_hints)
Base.Experimental.show_error_hints(io, ece)
end
@@ -577,7 +577,7 @@ function Base.showerror(io::IO, ece::MixedReturnException{RT}) where RT
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
@@ -589,7 +589,7 @@ end
InteractiveUtils.code_typed(ece::UnionSretReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
-function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where RT
+function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where {RT}
if isdefined(Base.Experimental, :show_error_hints)
Base.Experimental.show_error_hints(io, ece)
end
@@ -606,11 +606,10 @@ function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where RT
println(io)
pretty_print_mi(ece.mi, io)
println(io)
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ return Base.println(io, Base.unsafe_string(ece.backtrace))
end
-
struct NoDerivativeException <: CompilationException
msg::String
ir::Union{Nothing,String}
@@ -694,7 +693,7 @@ function InteractiveUtils.code_typed(ece::IllegalTypeAnalysisException; kwargs..
end
world = ece.world::UInt
mode = Enzyme.API.DEM_ReverseModeCombined
- code_typed_helper(ece.mi, ece.world; kwargs...)
+ return code_typed_helper(ece.mi, ece.world; kwargs...)
end
struct IllegalFirstPointerException <: CompilationException
@@ -764,7 +763,7 @@ function Base.showerror(io::IO, ece::EnzymeMutabilityException)
Base.Experimental.show_error_hints(io, ece)
end
msg = Base.unsafe_string(ece.msg)
- print(io, "EnzymeMutabilityException: ", msg, '\n')
+ return print(io, "EnzymeMutabilityException: ", msg, '\n')
end
struct EnzymeRuntimeActivityError{MT,WT} <: EnzymeError
@@ -822,7 +821,7 @@ function InteractiveUtils.code_typed(ece::EnzymeRuntimeActivityError; interactiv
if mi === nothing
throw(AssertionError("code_typed(::EnzymeRuntimeActivityError; interactive::Bool=false, kwargs...) not supported for error without mi"))
end
- code_typed_helper(ece.mi, ece.world; kwargs...)
+ return code_typed_helper(ece.mi, ece.world; kwargs...)
end
struct EnzymeNoTypeError{MT,WT} <: EnzymeError
@@ -863,7 +862,7 @@ function InteractiveUtils.code_typed(ece::EnzymeNoTypeError; interactive::Bool=f
if mi === nothing
throw(AssertionError("code_typed(::EnzymeNoTypeError; interactive::Bool=false, kwargs...) not supported for error without mi"))
end
- code_typed_helper(ece.mi, ece.world; kwargs...)
+ return code_typed_helper(ece.mi, ece.world; kwargs...)
end
struct EnzymeNoShadowError <: EnzymeError
@@ -879,18 +878,18 @@ function Base.showerror(io::IO, ece::EnzymeNoShadowError)
print(io, msg, '\n')
end
-struct EnzymeNoDerivativeError{MT,WT} <: EnzymeError
+struct EnzymeNoDerivativeError{MT, WT} <: EnzymeError
msg::Cstring
mi::MT
world::WT
end
-function InteractiveUtils.code_typed(ece::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...)
+function InteractiveUtils.code_typed(ece::EnzymeNoDerivativeError; interactive::Bool = false, kwargs...)
mi = ece.mi
if mi === nothing
throw(AssertionError("code_typed(::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...) not supported for error without mi"))
end
- code_typed_helper(ece.mi, ece.world; kwargs...)
+ return code_typed_helper(ece.mi, ece.world; kwargs...)
end
function Base.showerror(io::IO, ece::EnzymeNoDerivativeError)
@@ -900,7 +899,7 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError)
msg = Base.unsafe_string(ece.msg)
print(io, "EnzymeNoDerivativeError: ", msg, '\n')
- if ece.mi !== nothing
+ return if ece.mi !== nothing
print(io, "Failure within method:\n")
println(io)
pretty_print_mi(ece.mi, io)
@@ -983,8 +982,8 @@ function julia_error(
if occursin("No create nofree of empty function", msg) ||
occursin("No forward mode derivative found for", msg) ||
occursin("No augmented forward pass", msg) ||
- occursin("No reverse pass found", msg) ||
- occursin("Runtime Activity not yet implemented for Forward-Mode BLAS", msg)
+ occursin("No reverse pass found", msg) ||
+ occursin("Runtime Activity not yet implemented for Forward-Mode BLAS", msg)
ir = nothing
end
if B != C_NULL
diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl
index eccd8f5d..d1bf77ab 100644
--- a/src/rules/activityrules.jl
+++ b/src/rules/activityrules.jl
@@ -66,7 +66,6 @@ function julia_activity_rule(f::LLVM.Function, method_table)
end
-
if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1]
any_active = false
for arg in jlargs
@@ -84,8 +83,8 @@ function julia_activity_rule(f::LLVM.Function, method_table)
parameter_attributes(f, arg.codegen.i),
StringAttribute("enzyme_inactive"),
)
- else
- any_active = true
+ else
+ any_active = true
end
end
if sret !== nothing
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index 680746e8..0e710633 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -2,7 +2,7 @@ import LinearAlgebra
@inline add_fwd(prev, post) = recursive_add(prev, post)
-@generated function EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial::Union{AbstractArray,Number}, dx::Union{AbstractArray,Number})
+@generated function EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial::Union{AbstractArray, Number}, dx::Union{AbstractArray, Number})
if partial <: Number || dx isa Number
if !(prev <: Type)
return quote
@@ -19,17 +19,17 @@ import LinearAlgebra
@assert partial <: AbstractArray
if dx <: Number
if !(prev <: Type)
- return quote
- Base.@_inline_meta
- LinearAlgebra.axpy!(dx, partial, prev)
- prev
- end
- else
- return quote
- Base.@_inline_meta
- prev(partial * dx)
- end
- end
+ return quote
+ Base.@_inline_meta
+ LinearAlgebra.axpy!(dx, partial, prev)
+ prev
+ end
+ else
+ return quote
+ Base.@_inline_meta
+ prev(partial * dx)
+ end
+ end
end
@assert dx <: AbstractArray
N = ndims(partial)
@@ -66,7 +66,7 @@ import LinearAlgebra
end
init = if prev <: Type
- :(prev = similar(prev, size(partial)[1:$(N-M)]...))
+ :(prev = similar(prev, size(partial)[1:$(N - M)]...))
end
idxs = Symbol[]
@@ -344,7 +344,7 @@ function enzyme_custom_setup_args(
B,
orig,
"Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " *
- "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n"*msg2,
+ "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n" * msg2,
)
end
if arty == eltype(value_type(val))
@@ -356,7 +356,7 @@ function enzyme_custom_setup_args(
emit_error(
B,
orig,
- "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2,
+ "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n" * msg2,
)
end
end
@@ -372,7 +372,7 @@ function enzyme_custom_setup_args(
emit_error(
B,
orig,
- "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2,
+ "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n" * msg2,
)
end
@@ -579,9 +579,9 @@ function enzyme_custom_setup_ret(
mode != API.DEM_ForwardMode &&
!guaranteed_nonactive(RealRt, world)
)
- if active_reg(RealRt, world) == MixedState && B !== nothing
+ if active_reg(RealRt, world) == MixedState && B !== nothing
bt = GPUCompiler.backtrace(orig)
- msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
+ msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
mi, _ = enzyme_custom_extract_mi(orig)
emit_error(
B,
@@ -742,7 +742,7 @@ end
ExpRT = EnzymeRules.forward_rule_return_type(C, RT)
if ExpRT != fwd_RT
bt = GPUCompiler.backtrace(orig)
- msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
+ msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
emit_error(
B,
orig,
@@ -1475,9 +1475,12 @@ function enzyme_custom_common_rev(
if ST != EnzymeRules.augmented_rule_return_type(C, RT, TapeT)
throw(AssertionError("Unexpected augmented rule return computation\nST = $ST\nER = $(EnzymeRules.augmented_rule_return_type(C, RT, TapeT))\nC = $C\nRT = $RT\nTapeT = $TapeT"))
end
- if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT <: EnzymeRules.AugmentedReturn{
- needsPrimal ? RealRt : Nothing,
- needsShadowJL ? ShadT : Nothing})
+ if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(
+ aug_RT <: EnzymeRules.AugmentedReturn{
+ needsPrimal ? RealRt : Nothing,
+ needsShadowJL ? ShadT : Nothing,
+ }
+ )
bt = GPUCompiler.backtrace(orig)
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
@@ -1603,7 +1606,7 @@ function enzyme_custom_common_rev(
if rev_RT != ST
bt = GPUCompiler.backtrace(orig)
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
- emit_error(B, orig, (msg2, rmi, world), ReverseRuleReturnError{C, Tuple{activity[2+isKWCall:end]...,}, rev_RT})
+ emit_error(B, orig, (msg2, rmi, world), ReverseRuleReturnError{C, Tuple{activity[(2 + isKWCall):end]...}, rev_RT})
return tapeV
end
if length(actives) >= 1 &&
diff --git a/test/rules/inactive_kwrules.jl b/test/rules/inactive_kwrules.jl
index fc0c731c..62092abe 100644
--- a/test/rules/inactive_kwrules.jl
+++ b/test/rules/inactive_kwrules.jl
@@ -6,10 +6,10 @@ using Test
import .EnzymeRules: forward, augmented_primal, reverse
-function f_kw(out; tmp=[2.0, 0.0])
+function f_kw(out; tmp = [2.0, 0.0])
out[1] *= tmp[1]
tmp[2] += 1
- nothing
+ return nothing
end
function forward(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, x::Duplicated; kwargs...)
@@ -29,14 +29,14 @@ function reverse(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, tape, x::Duplic
end
function g_kw(out)
- tmp=[2.0, 0.0]
+ tmp = [2.0, 0.0]
f_kw(out; tmp)
- nothing
+ return nothing
end
function h_kw(out, tmp)
f_kw(out; tmp)
- nothing
+ return nothing
end
@testset "Forward Inactive allocated kwarg error" begin
@@ -69,7 +69,7 @@ end
@test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp))
end
-Enzyme.EnzymeRules.inactive_kwarg(::typeof(f_kw), out; tmp=[2.0]) = nothing
+Enzyme.EnzymeRules.inactive_kwarg(::typeof(f_kw), out; tmp = [2.0]) = nothing
@testset "Forward Inactive allocated kwarg success" begin
x = [2.7]
diff --git a/test/rules/mixederror.jl b/test/rules/mixederror.jl
index 128f4403..dfb71b07 100644
--- a/test/rules/mixederror.jl
+++ b/test/rules/mixederror.jl
@@ -30,9 +30,11 @@ function handle_infinities(workfunc, f, s)
den = 1 / (1 - t)
return f(s0 - oneunit(s1) * t * den) * den * den * oneunit(s1)
end,
- reverse(map(s) do x
- 1 / (1 + oneunit(x) / (s0 - x))
- end),
+ reverse(
+ map(s) do x
+ 1 / (1 + oneunit(x) / (s0 - x))
+ end
+ ),
t -> s0 - oneunit(s1) * t / (1 - t),
)
else # x = s0 + t / (1 - t)
@@ -60,8 +62,8 @@ function inner(f::F, xs) where {F} # remove type annotation => problem solved
end
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig, ::Const{typeof(inner)}, ::Type, f, xs
-)
+ config::EnzymeRules.RevConfig, ::Const{typeof(inner)}, ::Type, f, xs
+ )
true_primal = inner(f.val, xs.val)
primal = EnzymeRules.needs_primal(config) ? true_primal : nothing
shadow = if EnzymeRules.needs_shadow(config)
@@ -77,8 +79,8 @@ function EnzymeRules.augmented_primal(
end
function EnzymeRules.reverse(
- ::EnzymeRules.RevConfig, ::Const{typeof(inner)}, shadow::Active, tape, f, xs
-)
+ ::EnzymeRules.RevConfig, ::Const{typeof(inner)}, shadow::Active, tape, f, xs
+ )
return ((f isa Active) ? f : nothing, (xs isa Active) ? xs : nothing)
end
@@ -90,4 +92,4 @@ F_bad(x) = outer(y -> [cos(y)], 0.0, x)[1][1]
@test_throws Enzyme.Compiler.MixedReturnException autodiff(Reverse, F_bad, Active(0.3))
end
-end # MixedRuleError
\ No newline at end of file
+end # MixedRuleError
diff --git a/test/rules/rules.jl b/test/rules/rules.jl
index 4ed0f483..51a1a24e 100644
--- a/test/rules/rules.jl
+++ b/test/rules/rules.jl
@@ -131,7 +131,7 @@ end
@test Enzyme.autodiff(Forward, h, Duplicated(3.0, 1.0)) == (6000.0,)
@test Enzyme.autodiff(ForwardWithPrimal, h, Duplicated(3.0, 1.0)) == (60.0, 9.0)
@test Enzyme.autodiff(Forward, h2, Duplicated(3.0, 1.0)) == (1080.0,)
- @test_throws Enzyme.Compiler.ForwardRuleReturnError Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0))
+ @test_throws Enzyme.Compiler.ForwardRuleReturnError Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0))
end
foo(x) = 2x; |
This was referenced Nov 10, 2025
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2761 +/- ##
==========================================
- Coverage 70.30% 68.89% -1.42%
==========================================
Files 58 58
Lines 19391 19861 +470
==========================================
+ Hits 13633 13683 +50
- Misses 5758 6178 +420 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.