diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl index ce3b1abd61..af8d60e4b4 100644 --- a/ext/EnzymeChainRulesCoreExt.jl +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -159,11 +159,13 @@ function Enzyme._import_rrule(fn, tys...) # ) : nothing # end) + ptys = [] for (i, ty) in enumerate(tys) push!(nothings, :(nothing)) val = Symbol("arg_$i") TA = Symbol("AN_$i") e = :($val::$TA) + push!(ptys, :(::$(esc(ty)))) push!(anns, :($TA <: Annotation{<:$(esc(ty))})) push!(vals, val) push!(exprs, e) @@ -184,16 +186,14 @@ function Enzyme._import_rrule(fn, tys...) end) end - quote + EnzymeRules.has_easy_rule(::$(esc(fn)), $(ptys...)) = true + function EnzymeRules.augmented_primal(config, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} $(valtys...) - - res, pullback = if RetAnnotation <: Const - (fn.val($(primals...); kwargs...), nothing) - else - $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...) - end + + @assert !(RetAnnotation <: Const) + res, pullback = $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...) primal = if EnzymeRules.needs_primal(config) res @@ -201,29 +201,42 @@ function Enzyme._import_rrule(fn, tys...) nothing end - shadow = if !EnzymeRules.needs_shadow(config) - nothing - else - if EnzymeRules.width(config) == 1 + shadow, byref = if !EnzymeRules.needs_shadow(config) + nothing, Val(false) + elseif !Enzyme.Compiler.guaranteed_nonactive(Core.Typeof(res)) + (if EnzymeRules.width(config) == 1 + Ref(Enzyme.make_zero(res)) + else + ntuple(Val(EnzymeRules.width(config))) do j + Base.@_inline_meta + Ref(Enzyme.make_zero(res)) + end + end, Val(true)) + else + (if EnzymeRules.width(config) == 1 Enzyme.make_zero(res) else ntuple(Val(EnzymeRules.width(config))) do j Base.@_inline_meta Enzyme.make_zero(res) end - end + end, Val(false)) end - return EnzymeRules.AugmentedReturn(primal, shadow, (shadow, pullback)) + cache = (shadow, pullback, byref) + return EnzymeRules.augmented_rule_return_type(config, RetAnnotation){typeof(cache)}(primal, shadow, cache) end function EnzymeRules.reverse(config, fn::FA, ::Type{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {RetAnnotation, TapeTy, FA<:Annotation{<:$(esc(fn))}, $(anns...)} if !(RetAnnotation <: Const) - shadow, pullback = tape + shadow, pullback, byref = tape tcomb = ntuple(Val(EnzymeRules.width(config))) do batch_i Base.@_inline_meta shad = EnzymeRules.width(config) == 1 ? shadow : shadow[batch_i] + if byref === Val(true) + shad = shad[] + end res = pullback(shad) for (cr, en) in zip(res, (fn, $(vals...),)) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index a599346b09..2810105f0a 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.15" +version = "0.8.16" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/easyrules.jl b/lib/EnzymeCore/src/easyrules.jl index e93f2283ac..a8ddeada22 100644 --- a/lib/EnzymeCore/src/easyrules.jl +++ b/lib/EnzymeCore/src/easyrules.jl @@ -236,7 +236,11 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, input_names end if !seen - push!(gensetup, Expr(:(=), outexpr, nothing)) + ST = $(esc(:RT)) + if ST <: Tuple + ST = ST.parameters[o] + end + push!(gensetup, Expr(:(=), outexpr, ST)) seen = true end @@ -473,21 +477,25 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names end push!(gensetup, Expr(:(=), :cache, Expr(:tuple, caches...))) + PT = EnzymeRules.primal_type(config, ($(esc(:RTA))).parameters[1]) + ST = EnzymeRules.shadow_type(config, ($(esc(:RTA))).parameters[1]) + AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT,$ST,typeof(cache)}) + genres = if needs_primal(config) if needs_shadow(config) if width(config) == 1 - Expr(:call, EnzymeRules.AugmentedReturn, :Ω, :dΩ, :cache) + Expr(:call, AugmentedReturnType, :Ω, :dΩ, :cache) else - Expr(:call, EnzymeRules.AugmentedReturn, :Ω, :dΩ, :cache) + Expr(:call, AugmentedReturnType, :Ω, :dΩ, :cache) end else - Expr(:call, EnzymeRules.AugmentedReturn, :Ω, nothing, :cache) + Expr(:call, AugmentedReturnType, :Ω, nothing, :cache) end else if needs_shadow(config) - Expr(:call, EnzymeRules.AugmentedReturn, nothing, :dΩ, :cache) + Expr(:call, AugmentedReturnType, nothing, :dΩ, :cache) else - Expr(:call, EnzymeRules.AugmentedReturn, nothing, nothing, :cache) + Expr(:call, AugmentedReturnType, nothing, nothing, :cache) end end @@ -613,7 +621,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names if !seen if inp_types[inum] <: Active - push!(gensetup, Expr(:(=), inexpr, nothing)) + push!(gensetup, Expr(:(=), inexpr, eltype(inp_types[inum]))) else dexpr = Expr(:call, getfield, Symbol(inp_names[inum]), 2) if W != 1 diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 174a6456d5..2d5259e026 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -7,7 +7,7 @@ export FwdConfig, FwdConfigWidth export AugmentedReturn import ..EnzymeCore: needs_primal export needs_primal, needs_shadow, width, overwritten, runtime_activity -export primal_type, shadow_type, tape_type, easy_scalar_rule +export primal_type, shadow_type, tape_type, easy_scalar_rule, forward_rule_return_type, augmented_rule_return_type import Base: unwrapva, isvarargtype, unwrap_unionall, rewrap_unionall @@ -25,6 +25,8 @@ Valid types for `RT` are: - [`EnzymeCore.Duplicated`](@ref) - [`EnzymeCore.DuplicatedNoNeed`](@ref) - [`EnzymeCore.Const`](@ref) + +The return from this function must be a type matching [`forward_rule_return_type`](@ref) when given the `fwdconfig` and `RT`. """ function forward end @@ -124,20 +126,79 @@ between). """ primal_type(::FwdConfig, ::Type{<:Annotation{RT}}) primal_type(::RevConfig, ::Type{<:Annotation{RT}}) + primal_type(::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) + primal_type(::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) Compute the expected primal return type given a reverse mode config and return activity """ @inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing @inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing +@inline primal_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing +@inline primal_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing """ shadow_type(::FwdConfig, ::Type{<:Annotation{RT}}) shadow_type(::RevConfig, ::Type{<:Annotation{RT}}) + shadow_type(::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) + shadow_type(::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) Compute the expected shadow return type given a reverse mode config and return activity """ @inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing @inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing +@inline shadow_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing +@inline shadow_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing + + +""" + forward_rule_return_type(C::FwdConfig, RT::Type{<:Annotation}) + forward_rule_return_type(::Type{<:FwdConfig}, RT::Type{<:Annotation}) + +Compute the expected result type of a custom forward rule, given the configuration `C` and return activity and type `RT`. + +Consider `RealRt` as the original return type of the rule, accessible as `eltype(RT)`. The return type can be computed as follows: + +If the shadow isn't needed, return the original result (of type `RealRt`) if requested by the config ([`needs_primal`](@ref)), otherwise nothing. + +Otherwise, first construct a shadow return. + If the [`width`](@ref) is one, the shadow is the same type as the primal (`RealRt`). + If the [`width`](@ref) is not one, the shadow is a tuple containing `width` of the original return types (`NTuple{width,RealRt}`). + +Finally, if both the primal and shadow are requested, return a [`EnzymeCore.Duplicated`](@ref) or [`EnzymeCore.BatchDuplicated`](@ref) of primal and shadows. +Otherwise, just return the shadows. + +""" +@inline function forward_rule_return_type(C::Type{<:FwdConfig}, RT::Type{<:Annotation}) + RealRt = eltype(RT) + needsPrimal = EnzymeRules.needs_primal(C) + needsShadow = EnzymeRules.needs_shadow(C) + width = EnzymeRules.width(C) + if !needsShadow + if needsPrimal + return RealRt + else + return Nothing + end + else + @assert !(RT <: Const) + if !needsPrimal + ST = RealRt + if width != 1 + ST = NTuple{Int(width),ST} + end + return ST + else + ST = if width == 1 + Duplicated{RealRt} + else + BatchDuplicated{RealRt,Int(width)} + end + return ST + end + end +end +@inline forward_rule_return_type(::FCT, RT::Type{<:Annotation}) where {FCT <: FwdConfig} = forward_rule_return_type(FCT, RT) + """ AugmentedReturn(primal, shadow, tape) @@ -156,12 +217,18 @@ struct AugmentedReturn{PrimalType,ShadowType,TapeType} shadow::ShadowType tape::TapeType end -@inline primal_type(::Type{AugmentedReturn{PrimalType,ShadowType,TapeType}}) where {PrimalType,ShadowType,TapeType} = PrimalType -@inline primal_type(::AugmentedReturn{PrimalType,ShadowType,TapeType}) where {PrimalType,ShadowType,TapeType} = PrimalType -@inline shadow_type(::Type{AugmentedReturn{PrimalType,ShadowType,TapeType}}) where {PrimalType,ShadowType,TapeType} = ShadowType -@inline shadow_type(::AugmentedReturn{PrimalType,ShadowType,TapeType}) where {PrimalType,ShadowType,TapeType} = ShadowType -@inline tape_type(::Type{AugmentedReturn{PrimalType,ShadowType,TapeType}}) where {PrimalType,ShadowType,TapeType} = TapeType -@inline tape_type(::AugmentedReturn{PrimalType,ShadowType,TapeType}) where {PrimalType,ShadowType,TapeType} = TapeType + +@inline function AugmentedReturn{PrimalType,ShadowType}(primal, shadow, cache) where {PrimalType, ShadowType} + AT = AugmentedReturn{PrimalType,ShadowType, typeof(cache)} + return AT(primal, shadow, cache) +end + +@inline primal_type(::Type{<:AugmentedReturn{PrimalType}}) where {PrimalType} = PrimalType +@inline primal_type(::AugmentedReturn{PrimalType}) where {PrimalType} = PrimalType +@inline shadow_type(::Type{<:AugmentedReturn{<:Any,ShadowType}}) where {ShadowType} = ShadowType +@inline shadow_type(::AugmentedReturn{<:Any,ShadowType}) where {ShadowType} = ShadowType +@inline tape_type(::Type{<:AugmentedReturn{<:Any,<:Any,TapeType}}) where {TapeType} = TapeType +@inline tape_type(::AugmentedReturn{<:Any,<:Any,TapeType}) where {TapeType} = TapeType struct AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType} primal::PrimalType shadow::ShadowType @@ -176,6 +243,13 @@ end """ augmented_primal(::RevConfig, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) +Code to run during the original forward pass through the code. Any additional data can be saved from forward to +reverse pass. This may be required as arguments might be overwritten before the reverse pass is run. + +This should compute and mutate the same values as the original function (if requested). + +It should also return a shadow data structure to hold derivatives (if requested). + Must return an [`AugmentedReturn`](@ref) type. * The primal must be the same type of the original return if [`needs_primal(config)`](@ref needs_primal), otherwise nothing. @@ -183,6 +257,8 @@ Must return an [`AugmentedReturn`](@ref) type. If width is 1, the shadow should be the same type of the original return. If the width is greater than 1, the shadow should be `NTuple{original return, width}`. * The tape can be any type (including `Nothing`), and is preserved for the reverse call. + +See [`augmented_rule_return_type`](@ref) for more information. """ function augmented_primal end @@ -196,6 +272,66 @@ as `Type{Duplicated{T}}`, etc. """ function reverse end + +""" + augmented_rule_return_type(C::RevConfig, RT::Type{<:Annotation}, cache::Any) + augmented_rule_return_type(::Type{<:RevConfig}, RT::Type{<:Annotation}, CacheType::Type) + augmented_rule_return_type(C::RevConfig, RT::Type{<:Annotation}) + augmented_rule_return_type(::Type{<:RevConfig}, RT::Type{<:Annotation}) + +Compute the expected result type of a custom augmented forward pass rule, given the configuration `C` return activity and type `RT`, and cache `cache`. +Alternatively, this can be called with the configuration type, return activity, and cache type. + +Consider `RealRt` as the original return type of the rule, accessible as `eltype(RT)`. The return type can be computed as follows: + +We must return a struct of type [`AugmentedReturn`](@ref), which has three elements (and corresponding type parameter). + +The first element is the primal type, which is the original result (of type `RealRt`) if requested by the config ([`needs_primal`](@ref)), otherwise nothing. + +The second element is the shadow type, if requested by the config ([`needs_shadow`](@ref), otherwise nothing. If requested, the shadow is of type: + If the [`width`](@ref) is one, the shadow is the same type as the primal (`RealRt`). + If the [`width`](@ref) is not one, the shadow is a tuple containing `width` of the original return types (`NTuple{width,RealRt}`). + +The third element is user defined, whatever type the cache is you want to save from forward to reverse pass. In this case, it +will be determined by `cache`, or `CacheType`. + +If a cache type is not provided a unionall will be returned + +""" +@inline function augmented_rule_return_type(C::Type{<:RevConfig}, RT::Type{<:Annotation}) + RealRt = eltype(RT) + + PrimalType = if EnzymeRules.needs_primal(C) + RealRt + else + Nothing + end + + ShadowType = if EnzymeRules.needs_shadow(C) + if EnzymeRules.width(C) == 1 + RealRt + else + NTuple{EnzymeRules.width(C), RealRt} + end + else + Nothing + end + + return AugmentedReturn{PrimalType, ShadowType} +end + +@inline function augmented_rule_return_type(C::Type{<:RevConfig}, RT::Type{<:Annotation}, CacheType::Type) + return augmented_rule_return_type(C, RT){CacheType} +end + +@generated function augmented_rule_return_type(rct::RevConfig, RT::Type{<:Annotation}, cache) + return augmented_rule_return_type(rct, RT.parameters[1], cache) +end + +@generated function augmented_rule_return_type(rct::RevConfig, RT::Type{<:Annotation}) + return augmented_rule_return_type(rct, RT.parameters[1]) +end + function _annotate(@nospecialize(T)) if isvarargtype(T) VA = T @@ -316,6 +452,39 @@ function is_inactive_noinl_from_sig(@nospecialize(TT); return isapplicable(inactive_noinl, TT; world, method_table, caller) end +""" + inactive_kwarg(func::typeof(f), args...; kwargs...) + +Mark a particular function as always having inactive keyword arguments. The return does not matter, merely its declaration. + +This function is currently considered internal/experimental and may not respect semver. +""" +function inactive_kwarg end + +function is_inactive_kwarg_from_sig(@nospecialize(TT); + world::UInt=Base.get_world_counter(), + method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + return isapplicable(inactive_kwarg, TT; world, method_table, caller) +end + +""" + inactive_arg(func::typeof(f), args...; kwargs...) + +Mark a particular function as always having inactive non-keyword arguments. The return type must be a tuple of Val's whose +value is the argument marked inactive. + +This function is currently considered internal/experimental and may not respect semver. +""" +function inactive_arg end + +function is_inactive_arg_from_sig(@nospecialize(TT); + world::UInt=Base.get_world_counter(), + method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, + caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + return isapplicable(inactive_arg, TT; world, method_table, caller) +end + """ noalias(func::typeof(f), args...) diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index fe4e5f5fb2..bac295c026 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -474,13 +474,13 @@ end $(Expr(:meta, :generated, active_reg_nothrow_generator)) end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const(::Type{T})::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const(@nospecialize(T::Type))::Bool rt = active_reg_nothrow(T) res = rt == AnyState return res end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world::UInt)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(@nospecialize(T::Type), world::UInt)::Bool rt = active_reg(T, world) res = rt == AnyState return res @@ -488,7 +488,7 @@ end # check if a value is guaranteed to be not contain active[register] data # (aka not either mixed or active) -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive(::Type{T})::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive(@nospecialize(T::Type))::Bool rt = active_reg_nothrow(T) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end diff --git a/src/compiler.jl b/src/compiler.jl index 08b3ec34d5..cec43e9e54 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -400,11 +400,21 @@ const JuliaEnzymeNameMap = Dict{String,Any}( "enz_any_array_2" => AnyArray(2), "enz_any_array_3" => AnyArray(3), "enz_runtime_exc" => EnzymeRuntimeException, + "enz_runtime_mi_exc" => EnzymeRuntimeExceptionMI, "enz_mut_exc" => EnzymeMutabilityException, - "enz_runtime_activity_exc" => EnzymeRuntimeActivityError, - "enz_no_type_exc" => EnzymeNoTypeError, + "enz_runtime_activity_exc" => EnzymeRuntimeActivityError{Nothing, Nothing}, + "enz_runtime_activity_mi_exc" => EnzymeRuntimeActivityError{Core.MethodInstance, UInt}, + "enz_no_type_exc" => EnzymeNoTypeError{Nothing, Nothing}, + "enz_no_type_mi_exc" => EnzymeNoTypeError{Core.MethodInstance, UInt}, "enz_no_shadow_exc" => EnzymeNoShadowError, - "enz_no_derivative_exc" => EnzymeNoDerivativeError, + "enz_no_derivative_exc" => EnzymeNoDerivativeError{Nothing, Nothing}, + "enz_no_derivative_mi_exc" => EnzymeNoDerivativeError{Core.MethodInstance, UInt}, + "enz_non_const_kwarg_exc" => NonConstantKeywordArgException, + "enz_callconv_mismatch_exc"=> CallingConventionMismatchError, + "enz_illegal_ta_exc" => IllegalTypeAnalysisException, + "enz_illegal_first_pointer_exc" => IllegalFirstPointerException, + "enz_internal_exc" => EnzymeInternalError, + "enz_non_scalar_return_exc" => EnzymeNonScalarReturnException, ) const JuliaGlobalNameMap = Dict{String,Any}( @@ -625,7 +635,7 @@ end name = meth.name jlmod = meth.module - julia_activity_rule(llvmfn) + julia_activity_rule(llvmfn, method_table) if has_custom_rule handleCustom( state, @@ -1379,7 +1389,7 @@ function julia_sanitize( position!(builder, bad) - emit_error(builder, nothing, sval, EnzymeNoDerivativeError) + emit_error(builder, nothing, sval, EnzymeNoDerivativeError{Nothing, Nothing}) unreachable!(builder) dispose(builder) end @@ -6309,11 +6319,11 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no add_edge!(edges, rev_sig) end - ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} - add_edge!(edges, ina_sig) - for gen_sig in ( + Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}, Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_arg), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_kwarg), Vararg{Any}}, Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, Tuple{typeof(EnzymeRules.inactive_type), Type}, ) diff --git a/src/errors.jl b/src/errors.jl index 1e64c4b675..cd6aba9459 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -11,6 +11,77 @@ abstract type EnzymeError <: Base.Exception end abstract type CompilationException <: EnzymeError end +function pretty_print_mi(mi, io=stdout; digit_align_width = 1) + spec = mi.specTypes.parameters + ft = spec[1] + arg_types_param = spec[2:end] + f_is_function = false + kwargs = [] + if ft === typeof(Core.kwcall) && length(arg_types_param) >= 2 && arg_types_param[1] <: NamedTuple + ft = arg_types_param[2] + kwt = arg_types_param[1] + arg_types_param = arg_types_param[3:end] + keys = kwt.parameters[1]::Tuple + kwargs = Any[(keys[i], fieldtype(kwt, i)) for i in eachindex(keys)] + end + + Base.show_signature_function(io, ft) + Base.show_tuple_as_call(io, :function, Tuple{arg_types_param...}; hasfirst=false, kwargs = isempty(kwargs) ? nothing : kwargs) + + m = mi.def + + modulecolor = :light_black + tv, decls, file, line = Base.arg_decl_parts(m) + #if m.sig <: Tuple{Core.Builtin, Vararg} + # file = "none" + # line = 0 + #end + + if !(get(io, :compact, false)::Bool) # single-line mode + println(io) + digit_align_width += 4 + end + + # module & file, re-using function from errorshow.jl + Base.print_module_path_file(io, Base.parentmodule(m), string(file), line; modulecolor, digit_align_width) +end + +using InteractiveUtils + +function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool=false, kwargs...) + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + world, + mode == API.DEM_ForwardMode, + mode != API.DEM_ForwardMode, + true + ) + else + if mode == API.DEM_ForwardMode + GLOBAL_FWD_CACHE + else + GLOBAL_REV_CACHE + end + end + + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true) + + sig = mi.specTypes # XXX: can we just use the method instance? + if interactive + # call Cthulhu without introducing a dependency on Cthulhu + mod = get(Base.loaded_modules, Cthulhu, nothing) + mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") + descend_code_typed = getfield(mod, :descend_code_typed) + descend_code_typed(sig; interp, kwargs...) + else + Base.code_typed_by_type(sig; interp, kwargs...) + end +end + struct EnzymeRuntimeException <: EnzymeError msg::Cstring end @@ -19,11 +90,527 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeException) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme execution failed.\n") + print(io, "EnzymeRuntimeException: Enzyme execution failed.\n") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end +struct EnzymeRuntimeExceptionMI <: EnzymeError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::EnzymeRuntimeExceptionMI; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::EnzymeRuntimeExceptionMI) + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + print(io, "EnzymeRuntimeException: Enzyme execution failed within\n") + println(io) + pretty_print_mi(ece.mi, io) + println(io) + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": catch this exception as `err` and call `code_typed(err)` to inspect the surrounding code.\n"; + color = :cyan, + ) + println(io) + msg = Base.unsafe_string(ece.msg) + print(io, msg, '\n') +end + +abstract type CustomRuleError <: Base.Exception end + +struct NonConstantKeywordArgException <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::NonConstantKeywordArgException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::NonConstantKeywordArgException) + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + print(io, "NonConstantKeywordArgException: Custom Rule for method was passed a differentiable keyword argument. Differentiable kwargs cannot currently be specified from within the rule system.\n") + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": Experimental utility Enzyme.EnzymeRules.inactive_kwarg will enable you to mark the keyword arguments as non-differentiable, if that is correct."; + color = :cyan, + ) + println(io) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + +struct CallingConventionMismatchError <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +function Base.showerror(io::IO, ece::CallingConventionMismatchError) + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + print(io, "CallingConventionMismatchError: Enzyme hit an internal error trying to parse the julia calling convention from a custom rule definition:\n") + println(io) + pretty_print_mi(ece.mi, io) + println(io) + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + +InteractiveUtils.code_typed(ece::CallingConventionMismatchError; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +struct ForwardRuleReturnError{C, RT, fwd_RT} <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::ForwardRuleReturnError; kwargs...) = code_typed_helper(ece.mi, ece.world, Enzyme.API.DEM_ForwardMode; kwargs...) + +function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) where {C, RT, fwd_RT} + ExpRT = EnzymeRules.forward_rule_return_type(C, RT) + @assert ExpRT != fwd_RT + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + + RealRt = eltype(RT) + + hint = nothing + + width = EnzymeRules.width(C) + + desc = if EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C) + if width == 1 + if !(fwd_RT <: Duplicated) + if fwd_RT <: BatchDuplicated + hint = "For width 1, the return type should be a Duplicated, not BatchDuplicated" + elseif fwd_RT <: RealRt + hint = "Both primal and shadow need to be returned" + else + hint = "Return type should be a Duplicated" + end + elseif eltype(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for primal/shadow, you returned $(eltype(fwd_RT)). Even though $(eltype(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "The type within your Duplicated $(eltype(fwd_RT)) does not match the primal type $RealRt" + end + else + if !(fwd_RT <: BatchDuplicated) + if fwd_RT <: BatchDuplicated && EnzymeCore.batch_size(fwd_RT) != width + hint = "Mismatched batch size, expected batch size $width, found a BatchDuplicated of width $(EnzymeCore.batch_size(fwd_RT))" + elseif fwd_RT <: Duplicated + hint = "For width $width, the return type should be a BatchDuplicated, not a Duplicated" + elseif fwd_RT <: RealRt + hint = "Both primal and shadow need to be returned" + else + hint = "Return type should be a BatchDuplicated" + end + elseif eltype(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for primal/shadow, you returned $(eltype(fwd_RT)). Even though $(eltype(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "The type within your BatchDuplicated $(eltype(fwd_RT)) does not match the primal type $RealRt" + end + end + "primal and shadow configuration" + elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C) + if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated + hint = "Shadow was not requested, you should only return the primal" + elseif fwd_RT <: (NTuple{N, <:RealRt} where N) + hint = "You appear to be returning a tuple of shadows, but only the primal was requested" + elseif fwd_RT <: RealRt + hint = "Expected the abstract type $RealRt for primal, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Your return type does not match the primal type $RealRt" + end + + "primal-only configuration" + elseif !EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C) + if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated + hint = "Primal was not requested, you should only return the shadow" + elseif width == 1 + if fwd_RT <: (NTuple{N, <:RealRt} where N) + hint = "You look to be returning a tuple of shadows, when the batch size is 1" + elseif fwd_RT <: RealRt + hint = "Expected the abstract type $RealRt for shadow, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Your return type does not match the shadow type $RealRt" + end + else + if !(fwd_RT <: NTuple) + hint = "Configuration required batch size $width, which requires returning a tuple of shadows" + elseif !(fwd_RT <: NTuple{width, <:Any}) + hint = "Did not return a tuple of shadows of the right size, expected a tuple of size $width" + elseif eltype(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for each shadow in the tuple to create $ExpRT, you returned $(eltype(fwd_RT)) as the eltype of your tuple ($fwd_RT). Even though $(eltype(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Your return type does not match the batched shadow type $ExpRT" + end + end + "shadow-only configuration" + else + @assert !EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C) + + if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated + hint = "Neither primal nor shadow were requested, you should return nothing, not both the primal and shadow" + elseif fwd_RT <: (NTuple{N, <:RealRt} where N) + hint = "You appear to be returning a tuple of shadows, but neither primal nor shadow were requested" + elseif fwd_RT <: RealRt && width == 1 + hint = "You appear to be returning a primal or shadow, but neither were requested" + elseif fwd_RT <: RealRt + hint = "You appear to be returning a primal, but it was not requested" + else + hint = "You should return nothing" + end + + "neither primal nor shadow configuration" + end + + print(io, "ForwardRuleReturnError: Incorrect return type for $desc of forward custom rule with width $width of a function which returned $(eltype(RealRt)):\n") + print(io, " found : ", fwd_RT, "\n") + print(io, " expected : ", ExpRT, "\n") + println(io) + print(io, "For more information see `EnzymeRules.forward_rule_return_type`\n") + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": ", hint; + color = :cyan, + ) + println(io) + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": if the reason for the return type is unclear, you can catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + + +struct AugmentedRuleReturnError{C, RT, aug_RT} <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::AugmentedRuleReturnError; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) where {C, RT, fwd_RT} + ExpRT = EnzymeRules.augmented_rule_return_type(C, RT, Any) + @assert ExpRT != fwd_RT + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + + width = EnzymeRules.width(C) + + RealRt = eltype(RT) + + primal_found = nothing + shadow_found = nothing + + hint = nothing + + desc = if EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C) + if !(fwd_RT <: EnzymeRules.AugmentedReturn) + hint = "Return should be a struct of type EnzymeRules.AugmentedReturn" + elseif fwd_RT isa UnionAll && (fwd_RT.body isa UnionAll || fwd_RT.body.parameters[1] isa TypeVar || fwd_RT.body.parameters[2] isa TypeVar) + hint = "Return is a UnionAll, not a concrete type, try explicitly returning a single value of type EnzymeRules.AugmentedReturn{PrimalType, ShadowType, CacheType} as follows\n return EnzymeRules.augmented_rule_return_type(config, RA)(primal, shadow, cache)" + elseif EnzymeRules.primal_type(fwd_RT) == Nothing + hint = "Missing primal return" + elseif EnzymeRules.shadow_type(fwd_RT) == Nothing + hint = "Missing shadow return" + elseif EnzymeRules.primal_type(fwd_RT) != RealRt + if EnzymeRules.primal_type(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Mismatched primal type $(EnzymeRules.sprimal_type(fwd_RT)), expected $RealRt" + end + elseif EnzymeRules.shadow_type(fwd_RT) != RealRt + if width == 1 + if EnzymeRules.shadow_type(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + elseif shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N) + hint = "Batch size was 1, expected a single shadow, not a tuple of shadows." + else + hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))." + end + else + if EnzymeRules.shadow_type(fwd_RT) <: RealRt + hint = "Batch size was $width, expected a tuple of shadows, not a single shadow." + elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N) + hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))." + end + end + end + + "primal and shadow configuration" + elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C) + if !(fwd_RT <: EnzymeRules.AugmentedReturn) + hint = "Return should be a struct of type EnzymeRules.AugmentedReturn" + elseif EnzymeRules.primal_type(fwd_RT) == Nothing + hint = "Missing primal return" + elseif EnzymeRules.shadow_type(fwd_RT) != Nothing + hint = "Shadow return was not requested" + elseif EnzymeRules.primal_type(fwd_RT) != RealRt + if EnzymeRules.primal_type(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Mismatched primal type $(EnzymeRules.primal_type(fwd_RT)), expected $RealRt" + end + end + + "primal-only configuration" + elseif !EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C) + + if !(fwd_RT <: EnzymeRules.AugmentedReturn) + hint = "Return should be a struct of type EnzymeRules.AugmentedReturn" + elseif EnzymeRules.primal_type(fwd_RT) != Nothing + hint = "Primal was not requested" + elseif EnzymeRules.shadow_type(fwd_RT) != RealRt + if width == 1 + if EnzymeRules.shadow_type(fwd_RT) <: RealRt + hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N) + hint = "Batch size was 1, expected a single shadow, not a tuple of shadows." + else + hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))." + end + else + if EnzymeRules.shadow_type(fwd_RT) <: RealRt + hint = "Batch size was $width, expected a tuple of shadows, not a single shadow." + elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N) + hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + else + hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))." + end + end + end + + "shadow-only configuration" + else + if !(fwd_RT <: EnzymeRules.AugmentedReturn) + hint = "Return should be a struct of type EnzymeRules.AugmentedReturn" + elseif EnzymeRules.primal_type(fwd_RT) != Nothing + hint = "Primal was not requested" + elseif EnzymeRules.shadow_type(fwd_RT) != Nothing + hint = "Shadow return was not requested" + end + + @assert !EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C) + "neither primal nor shadow configuration" + end + + print(io, "AugmentedRuleReturnError: Incorrect return type for $desc of augmented_primal custom rule with width $width of a function which returned $(eltype(RealRt)):\n") + print(io, " found : ", fwd_RT, "\n") + print(io, " expected : ", ExpRT, "\n") + println(io) + print(io, "For more information see `EnzymeRules.augmented_rule_return_type`\n") + println(io) + if hint !== nothing + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": ", hint; + color = :cyan, + ) + println(io) + end + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": if the reason for the return type is unclear, you can catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + + +struct ReverseRuleReturnError{C, ArgAct, rev_RT} <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::ReverseRuleReturnError; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT}) where {C, ArgAct, rev_RT} + width = EnzymeRules.width(C) + Tys = ( + A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width),eltype(A)}) : Nothing for A in ArgAct.parameters + ) + ExpRT = Tuple{Tys...} + @assert ExpRT != rev_RT + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + + hint = nothing + + if !(rev_RT <: Tuple) + hint = "Return type should be a tuple with one element for each argument" + elseif length(rev_RT.parameters) != length(ExpRT.parameters) + hint = "Returned tuple should have one result for each argument, had $(length(rev_RT.parameters)) elements, expected $(length(ExpRT.parameters))" + else + for i in 1:length(ArgAct.parameters) + if ExpRT.parameters[i] == rev_RT.parameters[i] + continue + end + if ExpRT.parameters[i] === Nothing + hint = "Tuple return mismatch at index $i, argument of type $(ArgAct.parameters[i]) corresponds to return of nothing (only Active inputs have returns)" + break + end + + if rev_RT.parameters[i] === Nothing + hint = "Tuple return mismatch at index $i, argument of type $(ArgAct.parameters[i]) corresponds to return of $(ExpRT.parameters[i]), found nothing (Active inputs have returns)" + break + end + + if width == 1 + + if rev_RT.parameters[i] <: (NTuple{N, ExpRT.parameters[i]} where N) + hint = "Tuple return mismatch at index $i, returned a tuple of results when expected just one of type $(ExpRT.parameters[i])." + break + end + + if rev_RT.parameters[i] <: ExpRT.parameters[i] + hint = "Tuple return mismatch at index $i, expected the abstract type $(ExpRT.parameters[i]), you returned $(rev_RT.parameters[i]). Even though $(rev_RT.parameters[i]) <: $(ExpRT.parameters[i]), rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + break + end + + else + + if !(rev_RT.parameters[i] <: NTuple) + hint = "Tuple return mismatch at index $i, returned a single result of type $(rev_RT.parameters[i]) for a batched configuration of width $width, expected an inner tuple for each batch element." + break + end + + if eltype(rev_RT.parameters[i]) <: eltype(ExpRT.parameters[i]) + hint = "Tuple return mismatch at index $i, expected the abstract type $(eltype(ExpRT.parameters[i])) (here batched to form $(ExpRT.parameters[i])), you returned $(eltype(rev_RT.parameters[i])) (batched to form $(eltype(rev_RT.parameters[i]))). Even though $(eltype(rev_RT.parameters[i])) <: $(eltype(ExpRT.parameters[i])), rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})." + break + end + end + + hint = "Tuple return mismatch at index $i, argument of type $(ArgAct.parameters[i]) corresponds to returning type $(ExpRT.parameters[i]), you returned $(rev_RT.parameters[i])." + break + end + end + @assert hint !== nothing + + print(io, "ReverseRuleReturnError: Incorrect return type for reverse custom rule with width $(EnzymeRules.width(C)):\n") + print(io, " found : ", rev_RT, "\n") + print(io, " expected : ", ExpRT, "\n") + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": ", hint; + color = :cyan, + ) + println(io) + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": if the reason for the return type is unclear, you can catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + +struct MixedReturnException{RT} <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::MixedReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::MixedReturnException{RT}) where RT + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + print(io, "MixedReturnException: Custom Rule for method returns type $(RT), which has mixed internal activity types. This is not presently supported.\n") + print(io, "See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information.\n") + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": if the reason for the return type is unclear, you can catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + + +struct UnionSretReturnException{RT} <: CustomRuleError + backtrace::Cstring + mi::Core.MethodInstance + world::UInt +end + +InteractiveUtils.code_typed(ece::UnionSretReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...) + +function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where RT + if isdefined(Base.Experimental, :show_error_hints) + Base.Experimental.show_error_hints(io, ece) + end + print(io, "UnionSretReturnException: Custom Rule for method returns type $(RT), which is a union has an sret layout calling convention. This is not presently supported.\n") + print(io, "Please open an issue if you hit this.") + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": if the reason for the return type is unclear, you can catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; + color = :cyan, + ) + println(io) + println(io) + pretty_print_mi(ece.mi, io) + println(io) + Base.println(io, Base.unsafe_string(ece.backtrace)) +end + + + struct NoDerivativeException <: CompilationException msg::String ir::Union{Nothing,String} @@ -34,7 +621,7 @@ function Base.showerror(io::IO, ece::NoDerivativeException) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme compilation failed.\n") + print(io, "NoDerivativeException: Enzyme compilation failed.\n") if ece.ir !== nothing if VERBOSE_ERRORS[] print(io, "Current scope: \n") @@ -71,7 +658,7 @@ function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme compilation failed due to illegal type analysis.\n") + print(io, "IllegalTypeAnalysisException: Enzyme compilation failed due to illegal type analysis.\n") print(io, " This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].\n") print(io, " Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.\n") print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") @@ -100,43 +687,14 @@ function Base.showerror(io::IO, ece::IllegalTypeAnalysisException) end end -using InteractiveUtils - -function InteractiveUtils.code_typed(ece::IllegalTypeAnalysisException; interactive::Bool=false, kwargs...) +function InteractiveUtils.code_typed(ece::IllegalTypeAnalysisException; kwargs...) mi = ece.mi if mi === nothing throw(AssertionError("code_typed(::IllegalTypeAnalysisException; interactive::Bool=false, kwargs...) not supported for error without mi")) end world = ece.world::UInt mode = Enzyme.API.DEM_ReverseModeCombined - - CT = @static if VERSION >= v"1.11.0-DEV.1552" - EnzymeCacheToken( - typeof(DefaultCompilerTarget()), - false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# - EnzymeCompilerParams, - world, - false, - true, - true - ) - else - Enzyme.Compiler.GLOBAL_REV_CACHE - end - - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true) - - sig = mi.specTypes # XXX: can we just use the method instance? - if interactive - # call Cthulhu without introducing a dependency on Cthulhu - mod = get(Base.loaded_modules, Cthulhu, nothing) - mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") - descend_code_typed = getfield(mod, :descend_code_typed) - descend_code_typed(sig; interp, kwargs...) - else - Base.code_typed_by_type(sig; interp, kwargs...) - end + code_typed_helper(ece.mi, ece.world; kwargs...) end struct IllegalFirstPointerException <: CompilationException @@ -149,7 +707,7 @@ function Base.showerror(io::IO, ece::IllegalFirstPointerException) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme compilation failed due to an internal error (first pointer exception).\n") + print(io, "IllegalFirstPointerException: Enzyme compilation failed due to an internal error (first pointer exception).\n") print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl\n") print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") if VERBOSE_ERRORS[] @@ -175,7 +733,7 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme compilation failed due to an internal error.\n") + print(io, "EnzymeInternalError: Enzyme compilation failed due to an internal error.\n") print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl\n") print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n") if VERBOSE_ERRORS[] @@ -206,7 +764,7 @@ function Base.showerror(io::IO, ece::EnzymeMutabilityException) Base.Experimental.show_error_hints(io, ece) end msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') + print(io, "EnzymeMutabilityException: ", msg, '\n') end struct EnzymeRuntimeActivityError{MT,WT} <: EnzymeError @@ -219,11 +777,8 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - println(io, "Constant memory is stored (or returned) to a differentiable variable.") - println( - io, - "As a result, Enzyme cannot provably ensure correctness and throws this error.", - ) + println(io, "EnzymeRuntimeActivityError: Detected potential need for runtime activity.\n") + println(io, "Constant memory is stored (or returned) to a differentiable variable and correctness cannot be guaranteed with static activity analysis.") println( io, "This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).", @@ -232,24 +787,32 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) io, "If Enzyme should be able to prove this use non-differentable, open an issue!", ) + println(io) println(io, "To work around this issue, either:") println( io, - " a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or", + " a) rewrite this variable to not be conditionally active (fastest performance, slower to setup), or", ) println( io, - " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", + " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", ) + println(io) if ece.mi !== nothing - print(io, " Failure within method: ", ece.mi, "\n") + print(io, "Failure within method:\n") + println(io) + pretty_print_mi(ece.mi, io) + println(io) + println(io) + printstyled(io, "Hint"; bold = true, color = :cyan) printstyled( io, - ": catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\nIf you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.\n"; + ": catch this exception as `err` and call `code_typed(err)` to inspect the surrounding code.\n"; color = :cyan, ) end + println(io) msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end @@ -259,36 +822,7 @@ function InteractiveUtils.code_typed(ece::EnzymeRuntimeActivityError; interactiv if mi === nothing throw(AssertionError("code_typed(::EnzymeRuntimeActivityError; interactive::Bool=false, kwargs...) not supported for error without mi")) end - world = ece.world::UInt - mode = Enzyme.API.DEM_ReverseModeCombined - - CT = @static if VERSION >= v"1.11.0-DEV.1552" - EnzymeCacheToken( - typeof(DefaultCompilerTarget()), - false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# - EnzymeCompilerParams, - world, - false, - true, - true - ) - else - Enzyme.Compiler.GLOBAL_REV_CACHE - end - - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true) - - sig = mi.specTypes # XXX: can we just use the method instance? - if interactive - # call Cthulhu without introducing a dependency on Cthulhu - mod = get(Base.loaded_modules, Cthulhu, nothing) - mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") - descend_code_typed = getfield(mod, :descend_code_typed) - descend_code_typed(sig; interp, kwargs...) - else - Base.code_typed_by_type(sig; interp, kwargs...) - end + code_typed_helper(ece.mi, ece.world; kwargs...) end struct EnzymeNoTypeError{MT,WT} <: EnzymeError @@ -301,7 +835,7 @@ function Base.showerror(io::IO, ece::EnzymeNoTypeError) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme cannot statically prove the type of a value being differentiated and risks a correctness error if it gets it wrong.\n") + print(io, "EnzymeNoTypeError: Enzyme cannot statically prove the type of a value being differentiated and risks a correctness error if it gets it wrong.\n") print(io, " Generally this shouldn't occur as Enzyme records type information from julia, but may be expected if you, for example copy untyped data.\n") print(io, " or alternatively emit very large sized registers that exceed the maximum size of Enzyme's type analysis. If it seems reasonable to differentiate\n") print(io, " this code, open an issue! If the cause of the error is too large of a register, you can request Enzyme increase the size (https://enzyme.mit.edu/julia/dev/api/#Enzyme.API.maxtypeoffset!-Tuple{Any})\n") @@ -318,7 +852,7 @@ function Base.showerror(io::IO, ece::EnzymeNoTypeError) printstyled(io, "Hint"; bold = true, color = :cyan) printstyled( io, - ": catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\nIf you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.\n"; + ": catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\n"; color = :cyan, ) end @@ -329,36 +863,7 @@ function InteractiveUtils.code_typed(ece::EnzymeNoTypeError; interactive::Bool=f if mi === nothing throw(AssertionError("code_typed(::EnzymeNoTypeError; interactive::Bool=false, kwargs...) not supported for error without mi")) end - world = ece.world::UInt - mode = Enzyme.API.DEM_ReverseModeCombined - - CT = @static if VERSION >= v"1.11.0-DEV.1552" - EnzymeCacheToken( - typeof(DefaultCompilerTarget()), - false, - GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# - EnzymeCompilerParams, - world, - false, - true, - true - ) - else - Enzyme.Compiler.GLOBAL_REV_CACHE - end - - interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true) - - sig = mi.specTypes # XXX: can we just use the method instance? - if interactive - # call Cthulhu without introducing a dependency on Cthulhu - mod = get(Base.loaded_modules, Cthulhu, nothing) - mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") - descend_code_typed = getfield(mod, :descend_code_typed) - descend_code_typed(sig; interp, kwargs...) - else - Base.code_typed_by_type(sig; interp, kwargs...) - end + code_typed_helper(ece.mi, ece.world; kwargs...) end struct EnzymeNoShadowError <: EnzymeError @@ -369,13 +874,23 @@ function Base.showerror(io::IO, ece::EnzymeNoShadowError) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - print(io, "Enzyme could not find shadow for value\n") + print(io, "EnzymeNoShadowError: Enzyme could not find shadow for value\n") msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end -struct EnzymeNoDerivativeError <: EnzymeError +struct EnzymeNoDerivativeError{MT,WT} <: EnzymeError msg::Cstring + mi::MT + world::WT +end + +function InteractiveUtils.code_typed(ece::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...) + mi = ece.mi + if mi === nothing + throw(AssertionError("code_typed(::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...) not supported for error without mi")) + end + code_typed_helper(ece.mi, ece.world; kwargs...) end function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) @@ -383,7 +898,22 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError) Base.Experimental.show_error_hints(io, ece) end msg = Base.unsafe_string(ece.msg) - print(io, msg, '\n') + print(io, "EnzymeNoDerivativeError: ", msg, '\n') + + if ece.mi !== nothing + print(io, "Failure within method:\n") + println(io) + pretty_print_mi(ece.mi, io) + println(io) + println(io) + + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": catch this exception as `err` and call `code_typed(err)` to inspect the surrounding code.\n"; + color = :cyan, + ) + end end parent_scope(val::LLVM.Function, depth = 0) = depth == 0 ? LLVM.parent(val) : val @@ -453,7 +983,8 @@ function julia_error( if occursin("No create nofree of empty function", msg) || occursin("No forward mode derivative found for", msg) || occursin("No augmented forward pass", msg) || - occursin("No reverse pass found", msg) + occursin("No reverse pass found", msg) || + occursin("Runtime Activity not yet implemented for Forward-Mode BLAS", msg) ir = nothing end if B != C_NULL @@ -477,7 +1008,31 @@ function julia_error( else data2 = nothing end - emit_error(B, nothing, msg2, EnzymeNoDerivativeError, data2) + + mi = nothing + world = nothing + + if isa(val, LLVM.Instruction) + f = LLVM.parent(LLVM.parent(val))::LLVM.Function + mi, rt = enzyme_custom_extract_mi( + f, + false, + ) #=error=# + world = enzyme_extract_world(f) + elseif isa(val, LLVM.Argument) + f = parent_scope(val)::LLVM.Function + mi, rt = enzyme_custom_extract_mi( + f, + false, + ) #=error=# + world = enzyme_extract_world(f) + end + if mi !== nothing + emit_error(B, nothing, (msg2, mi, world), EnzymeNoDerivativeError{Core.MethodInstance, UInt}, data2) + else + emit_error(B, nothing, msg2, EnzymeNoDerivativeError{Nothing, Nothing}, data2) + end + return C_NULL end throw(NoDerivativeException(msg, ir, bt)) @@ -1003,7 +1558,7 @@ end print(io, msg) println(io) if badval !== nothing - println(io, " value=" * badval) + println(io, " Julia value causing error: " * badval) else ttval = val if isa(ttval, LLVM.StoreInst) @@ -1016,7 +1571,7 @@ end API.EnzymeStringFree(st) end if illegalVal !== nothing - println(io, " llvalue=" * string(illegalVal)) + println(io, " LLVM view of erring value: " * string(illegalVal)) end if bt !== nothing Base.show_backtrace(io, bt) @@ -1076,9 +1631,14 @@ function Base.showerror(io::IO, ece::EnzymeNonScalarReturnException) if isdefined(Base.Experimental, :show_error_hints) Base.Experimental.show_error_hints(io, ece) end - println(io, "Return type of differentiated function was not a scalar as required, found ", ece.object) - println(io, "If calling Enzyme.autodiff(Reverse, f, Active, ...), try Enzyme.autodiff_thunk(Reverse, f, Duplicated, ....)") - println(io, "If calling Enzyme.gradient, try Enzyme.jacobian") + if Enzyme.Compiler.guaranteed_const(Core.Typeof(ece.object)) + println(io, "EnzymeNonScalarReturnException: Return type of active-returning differentiated function was not differentiable, found ", ece.object, " of type ", Core.Typeof(ece.object)) + println(io, "Either rewrite the autodiff call to return Const, or the function being differentiated to return an active type") + else + println(io, "EnzymeNonScalarReturnException: Return type of differentiated function was not a scalar as required, found ", ece.object, " of type ", Core.Typeof(ece.object)) + println(io, "If calling Enzyme.autodiff(Reverse, f, Active, ...), try Enzyme.autodiff_thunk(Reverse, f, Duplicated, ....)") + println(io, "If calling Enzyme.gradient, try Enzyme.jacobian") + end if length(ece.extra) != 0 print(io, ece.extra) end diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 1ce499d91c..eccd8f5d25 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,5 +1,5 @@ -function julia_activity_rule(f::LLVM.Function) +function julia_activity_rule(f::LLVM.Function, method_table) if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1") return end @@ -57,6 +57,16 @@ function julia_activity_rule(f::LLVM.Function) parmsRemoved, ) + kwarg_inactive = false + + if isKWCallSignature(mi.specTypes) + if EnzymeRules.is_inactive_kwarg_from_sig(Interpreter.simplify_kw(mi.specTypes); world, method_table) + kwarg_inactive = true + end + end + + + if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1] any_active = false for arg in jlargs @@ -69,13 +79,13 @@ function julia_activity_rule(f::LLVM.Function) typ, _ = enzyme_extract_parm_type(f, arg.codegen.i) @assert typ == arg.typ - if guaranteed_const_nongen(arg.typ, world) + if (kwarg_inactive && arg.arg_i == 2) || guaranteed_const_nongen(arg.typ, world) push!( parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_inactive"), ) - else - any_active = true + else + any_active = true end end if sret !== nothing diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 6969cd1370..680746e86e 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -3,44 +3,43 @@ import LinearAlgebra @inline add_fwd(prev, post) = recursive_add(prev, post) @generated function EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial::Union{AbstractArray,Number}, dx::Union{AbstractArray,Number}) - if partial <: Number || dx isa Number - if prev !== Nothing + if !(prev <: Type) return quote Base.@_inline_meta - add_fwd(prev, EnzymeCore.EnzymeRules.multiply_fwd_into(nothing, partial, dx)) + add_fwd(prev, EnzymeCore.EnzymeRules.multiply_fwd_into(Core.Typeof(prev), partial, dx)) end end return quote Base.@_inline_meta - partial * dx + prev(partial * dx) end end @assert partial <: AbstractArray if dx <: Number - if prev !== Nothing - return quote - Base.@_inline_meta - LinearAlgebra.axpy!(dx, partial, prev) - prev - end - else - return quote - Base.@_inline_meta - partial * dx - end - end + if !(prev <: Type) + return quote + Base.@_inline_meta + LinearAlgebra.axpy!(dx, partial, prev) + prev + end + else + return quote + Base.@_inline_meta + prev(partial * dx) + end + end end @assert dx <: AbstractArray N = ndims(partial) M = ndims(dx) if N == M - if prev !== Nothing + if !(prev <: Type) return quote Base.@_inline_meta - add_fwd(prev, EnzymeCore.EnzymeRules.multiply_fwd_into(nothing, partial, dx)) + add_fwd(prev, EnzymeCore.EnzymeRules.multiply_fwd_into(typeof(prev), partial, dx)) end end @@ -55,7 +54,7 @@ import LinearAlgebra end return quote Base.@_inline_meta - $res + prev($res) end end @@ -66,8 +65,8 @@ import LinearAlgebra end end - init = if prev === Nothing - :(prev = similar(partial, size(partial)[1:$(N-M)]...)) + init = if prev <: Type + :(prev = similar(prev, size(partial)[1:$(N-M)]...)) end idxs = Symbol[] @@ -93,7 +92,7 @@ import LinearAlgebra matp = Expr(:call, Base.reshape, matp, Expr(:call, Base.length, outp), Expr(:call, Base.length, inp)) end - outexpr = if prev === Nothing + outexpr = if prev <: Type Expr(:call, LinearAlgebra.mul!, outp, matp, inp) else Expr(:call, LinearAlgebra.mul!, outp, matp, inp, true, true) @@ -267,6 +266,10 @@ function enzyme_custom_setup_args( if isKWCall && arg.arg_i == 2 Ty = arg.typ + if EnzymeRules.is_inactive_kwarg_from_sig(Interpreter.simplify_kw(mi.specTypes); world) + activep = API.DFT_CONSTANT + end + push!(args, val) # Only constant kw arg tuple's are currently supported @@ -335,21 +338,25 @@ function enzyme_custom_setup_args( ) if value_type(val) != eltype(value_type(ptr)) if overwritten[end] + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) emit_error( B, orig, "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " * - "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.", + "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n"*msg2, ) end if arty == eltype(value_type(val)) val = load!(B, arty, val) else + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) val = LLVM.UndefValue(arty) emit_error( B, orig, - "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2, ) end end @@ -360,10 +367,12 @@ function enzyme_custom_setup_args( emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end else + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) emit_error( B, orig, - "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))", + "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2, ) end @@ -570,11 +579,15 @@ function enzyme_custom_setup_ret( mode != API.DEM_ForwardMode && !guaranteed_nonactive(RealRt, world) ) - if active_reg(RealRt, world) == MixedState && B !== nothing + if active_reg(RealRt, world) == MixedState && B !== nothing + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + mi, _ = enzyme_custom_extract_mi(orig) emit_error( B, orig, - "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information", + (msg2, mi, world), + MixedReturnException{RealRt} ) end RT = Active{RealRt} @@ -616,13 +629,20 @@ end ) end + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = enzyme_extract_world(fn) + # TODO: don't inject the code multiple times for multiple calls fmi, (args, TT, fwd_RT, kwtup, RT, needsPrimal, RealRt, origNeedsPrimal, activity, C) = fwd_mi(orig, gutils, B) if kwtup !== nothing && kwtup <: Duplicated - @safe_debug "Non-constant keyword argument found for " TT - emit_error(B, orig, "Enzyme: Non-constant keyword argument found for " * string(TT)) + mi, _ = enzyme_custom_extract_mi(orig) + + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error(B, orig, (msg2, mi, world), NonConstantKeywordArgException) return false end @@ -632,9 +652,6 @@ end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) width = get_width(gutils) - curent_bb = position(B) - fn = LLVM.parent(curent_bb) - world = enzyme_extract_world(fn) llvmf = nested_codegen!(mode, mod, fmi, world) @@ -659,16 +676,24 @@ end end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", - args, - llvmf, - string(value_type(llvmf)), - orig, - isKWCall, - kwtup, - TT, - sret, - returnRoots + bt = GPUCompiler.backtrace(orig) + msg2 = sprint() do io + if startswith(LLVM.name(llvmf), "japi3") || startswith(LLVM.name(llvmf), "japi1") + Base.println(io, "Function uses the japi convention, which is not supported yet: ", LLVM.name(llvmf)) + else + Base.println(io, "args = ", args) + Base.println(io, "llvmf = ", string(llvmf)) + Base.println(io, "value_type(llvmf) = ", string(value_type(llvmf))) + Base.println(io, "orig = ", string(orig)) + Base.println(io, "isKWCall = ", string(isKWCall)) + Base.println(io, "kwtup = ", string(kwtup)) + Base.println(io, "TT = ", string(TT)) + Base.println(io, "sret = ", string(sret)) + Base.println(io, "returnRoots = ", string(returnRoots)) + end + Base.show_backtrace(io, bt) + end + emit_error(B, orig, (msg2, fmi, world), CallingConventionMismatchError) return false end @@ -714,23 +739,22 @@ end shadowV = C_NULL normalV = C_NULL + ExpRT = EnzymeRules.forward_rule_return_type(C, RT) + if ExpRT != fwd_RT + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error( + B, + orig, + (msg2, fmi, world), + ForwardRuleReturnError{C, RT, fwd_RT} + ) + return false + end + if RT <: Const if needsPrimal - if RealRt != fwd_RT - emit_error( - B, - orig, - "Enzyme: incorrect return type of const primal-only forward custom rule - $C " * - (string(RT)) * - " " * - string(activity) * - " want just return type " * - string(RealRt) * - " found " * - string(fwd_RT), - ) - return false - end + @assert RealRt == fwd_RT if get_return_info(RealRt)[2] !== nothing val = new_from_original(gutils, operands(orig)[1]) store!(B, res, val) @@ -738,19 +762,7 @@ end normalV = res.ref end else - if Nothing != fwd_RT - emit_error( - B, - orig, - "Enzyme: incorrect return type of const no-primal forward custom rule - $C " * - (string(RT)) * - " " * - string(activity) * - " want just return type Nothing found " * - string(fwd_RT), - ) - return false - end + @assert Nothing == fwd_RT end else if !needsPrimal @@ -758,21 +770,7 @@ end if width != 1 ST = NTuple{Int(width),ST} end - if ST != fwd_RT - emit_error( - B, - orig, - "Enzyme: incorrect return type of shadow-only forward custom rule - $C " * - (string(RT)) * - " " * - string(activity) * - " want just shadow type " * - string(ST) * - " found " * - string(fwd_RT), - ) - return false - end + @assert ST == fwd_RT if get_return_info(RealRt)[2] !== nothing dval_ptr = invert_pointer(gutils, operands(orig)[1], B) for idx = 1:width @@ -789,21 +787,7 @@ end else BatchDuplicated{RealRt,Int(width)} end - if ST != fwd_RT - emit_error( - B, - orig, - "Enzyme: incorrect return type of prima/shadow forward custom rule - $C " * - (string(RT)) * - " " * - string(activity) * - " want just shadow type " * - string(ST) * - " found " * - string(fwd_RT), - ) - return false - end + @assert ST == fwd_RT if get_return_info(RealRt)[2] !== nothing val = new_from_original(gutils, operands(orig)[1]) store!(B, extract_value!(B, res, 0), val) @@ -1104,12 +1088,10 @@ function enzyme_custom_common_rev( ) aug_RT = return_type(interp, ami) if kwtup !== nothing && kwtup <: Duplicated - @safe_debug "Non-constant keyword argument found for " augprimal_TT - emit_error( - B, - orig, - "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT), - ) + mi, _ = enzyme_custom_extract_mi(orig) + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error(B, orig, (msg2, mi, world), NonConstantKeywordArgException) return C_NULL end @@ -1152,10 +1134,13 @@ function enzyme_custom_common_rev( llvmf = nothing applicablefn = true + final_mi = nothing + if forward llvmf = nested_codegen!(mode, mod, ami, world) @assert llvmf !== nothing rev_RT = nothing + final_mi = ami else tt = copy(activity) if isKWCall @@ -1196,6 +1181,7 @@ function enzyme_custom_common_rev( rmi = rmi::Core.MethodInstance rev_RT = rev_RT::Type llvmf = nested_codegen!(mode, mod, rmi, world) + final_mi = rmi end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -1227,15 +1213,9 @@ function enzyme_custom_common_rev( sret_union = is_sret_union(miRT) if sret_union - emit_error( - B, - orig, - "Enzyme: Augmented forward pass custom rule " * - string(augprimal_TT) * - " had a union sret of type " * - string(miRT) * - " which is not currently supported", - ) + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error(B, orig, (msg2, final_mi, world), UnionSretReturnException{miRT}) return tapeV end @@ -1314,6 +1294,7 @@ function enzyme_custom_common_rev( else llety = convert(LLVMType, eltype(RT); allow_boxed = true) ptr_val = invert_pointer(gutils, operands(orig)[1+!isghostty(funcTy)], B) + ptr_val = lookup_value(gutils, ptr_val, B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) for idx = 1:width ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1) @@ -1376,17 +1357,26 @@ function enzyme_custom_common_rev( end if length(args) != length(parameters(llvmf)) - GPUCompiler.@safe_error "Calling convention mismatch", - args, - llvmf, - orig, - isKWCall, - kwtup, - augprimal_TT, - rev_TT, - fn, - sret, - returnRoots + bt = GPUCompiler.backtrace(orig) + msg2 = sprint() do io + if startswith(LLVM.name(llvmf), "japi3") || startswith(LLVM.name(llvmf), "japi1") + Base.println(io, "Function uses the japi convention, which is not supported yet: ", LLVM.name(llvmf)) + else + Base.println(io, "args = ", args) + Base.println(io, "llvmf = ", string(llvmf)) + Base.println(io, "value_type(llvmf) = ", string(value_type(llvmf))) + Base.println(io, "orig = ", string(orig)) + Base.println(io, "isKWCall = ", string(isKWCall)) + Base.println(io, "kwtup = ", string(kwtup)) + Base.println(io, "augprimal_TT = ", string(augprimal_TT)) + Base.println(io, "rev_TT = ", string(rev_TT)) + Base.println(io, "fn = ", string(fn)) + Base.println(io, "sret = ", string(sret)) + Base.println(io, "returnRoots = ", string(returnRoots)) + end + Base.show_backtrace(io, bt) + end + emit_error(B, orig, (msg2, final_mi, world), CallingConventionMismatchError) return tapeV end @@ -1482,6 +1472,20 @@ function enzyme_custom_common_rev( needsShadowJL ? ShadT : Nothing, TapeT, } + if ST != EnzymeRules.augmented_rule_return_type(C, RT, TapeT) + throw(AssertionError("Unexpected augmented rule return computation\nST = $ST\nER = $(EnzymeRules.augmented_rule_return_type(C, RT, TapeT))\nC = $C\nRT = $RT\nTapeT = $TapeT")) + end + if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT <: EnzymeRules.AugmentedReturn{ + needsPrimal ? RealRt : Nothing, + needsShadowJL ? ShadT : Nothing}) + + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error(B, orig, (msg2, ami, world), AugmentedRuleReturnError{C, RT, aug_RT}) + return tapeV + end + + if aug_RT != ST if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed = true) != @@ -1517,22 +1521,7 @@ function enzyme_custom_common_rev( if aug_RT <: abs abstract = true else - ST = EnzymeRules.AugmentedReturn{ - needsPrimal ? RealRt : Nothing, - needsShadowJL ? ShadT : Nothing, - Any, - } - emit_error( - B, - orig, - "Enzyme: Augmented forward pass custom rule " * - string(augprimal_TT) * - " return type mismatch, expected " * - string(ST) * - " found " * - string(aug_RT), - ) - return tapeV + @assert false end end @@ -1612,16 +1601,9 @@ function enzyme_custom_common_rev( ) ST = Tuple{Tys...} if rev_RT != ST - emit_error( - B, - orig, - "Enzyme: Reverse pass custom rule " * - string(rev_TT) * - " return type mismatch, expected " * - string(ST) * - " found " * - string(rev_RT), - ) + bt = GPUCompiler.backtrace(orig) + msg2 = sprint(Base.Fix2(Base.show_backtrace, bt)) + emit_error(B, orig, (msg2, rmi, world), ReverseRuleReturnError{C, Tuple{activity[2+isKWCall:end]...,}, rev_RT}) return tapeV end if length(actives) >= 1 && diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 93cac3818b..79b1e1ffdf 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -165,11 +165,25 @@ include("parallelrules.jl") end end - err = emit_error( - B, - orig, - "Enzyme: jl_call calling convention not implemented in forward for " * string(orig), - ) + pf = LLVM.parent(LLVM.parent(orig))::LLVM.Function + mi, _ = enzyme_custom_extract_mi(pf, false) #=error=# + world = enzyme_extract_world(pf) + + if mi !== nothing + err = emit_error( + B, + orig, + ("Enzyme: jl_call calling convention not implemented in forward for " * string(orig), mi, world), + EnzymeRuntimeExceptionMI + ) + else + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in forward for " * string(orig), + EnzymeRuntimeException + ) + end newo = new_from_original(gutils, orig) @@ -244,12 +258,26 @@ end end end - err = emit_error( - B, - orig, - "Enzyme: jl_call calling convention not implemented in aug_forward for " * - string(orig), - ) + pf = LLVM.parent(LLVM.parent(orig))::LLVM.Function + mi, _ = enzyme_custom_extract_mi(pf, false) #=error=# + world = enzyme_extract_world(pf) + + if mi !== nothing + err = emit_error( + B, + orig, + ("Enzyme: jl_call calling convention not implemented in aug_forward for " * string(orig), mi, world), + EnzymeRuntimeExceptionMI + ) + else + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in aug_forward for " * string(orig), + EnzymeRuntimeException + ) + end + newo = new_from_original(gutils, orig) API.moveBefore(newo, err, B) @@ -330,11 +358,25 @@ end end end - emit_error( - B, - orig, - "Enzyme: jl_call calling convention not implemented in reverse for " * string(orig), - ) + pf = LLVM.parent(LLVM.parent(orig))::LLVM.Function + mi, _ = enzyme_custom_extract_mi(pf, false) #=error=# + world = enzyme_extract_world(pf) + + if mi !== nothing + err = emit_error( + B, + orig, + ("Enzyme: jl_call calling convention not implemented in reverse for " * string(orig), mi, world), + EnzymeRuntimeExceptionMI + ) + else + err = emit_error( + B, + orig, + "Enzyme: jl_call calling convention not implemented in reverse for " * string(orig), + EnzymeRuntimeException + ) + end return nothing end diff --git a/test/rules/inactive_kwrules.jl b/test/rules/inactive_kwrules.jl new file mode 100644 index 0000000000..fc0c731c3b --- /dev/null +++ b/test/rules/inactive_kwrules.jl @@ -0,0 +1,115 @@ +module InactiveKWRules + +using Enzyme +using Enzyme.EnzymeRules +using Test + +import .EnzymeRules: forward, augmented_primal, reverse + +function f_kw(out; tmp=[2.0, 0.0]) + out[1] *= tmp[1] + tmp[2] += 1 + nothing +end + +function forward(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, x::Duplicated; kwargs...) + f_kw(x.val; kwargs...) + f_kw(x.dval; kwargs...) + return nothing +end + +function augmented_primal(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, x::Duplicated; kwargs...) + f_kw(x.val; kwargs...) + return EnzymeRules.AugmentedReturn(nothing, nothing, nothing) +end + +function reverse(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, tape, x::Duplicated; kwargs...) + f_kw(x.dval; kwargs...) + return (nothing,) +end + +function g_kw(out) + tmp=[2.0, 0.0] + f_kw(out; tmp) + nothing +end + +function h_kw(out, tmp) + f_kw(out; tmp) + nothing +end + +@testset "Forward Inactive allocated kwarg error" begin + x = [2.7] + dx = [3.1] + @test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, g_kw, Duplicated(x, dx)) +end + +@testset "Reverse Inactive allocated kwarg error" begin + x = [2.7] + dx = [3.1] + @test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Reverse, g_kw, Duplicated(x, dx)) +end + +@testset "Forward Inactive arg kwarg error" begin + x = [2.7] + dx = [3.1] + + tmp = [2.0, 0.0] + dtmp = [7.1, 9.4] + @test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp)) +end + +@testset "Reverse Inactive arg kwarg error" begin + x = [2.7] + dx = [3.1] + + tmp = [2.0, 0.0] + dtmp = [7.1, 9.4] + @test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp)) +end + +Enzyme.EnzymeRules.inactive_kwarg(::typeof(f_kw), out; tmp=[2.0]) = nothing + +@testset "Forward Inactive allocated kwarg success" begin + x = [2.7] + dx = [3.1] + autodiff(Forward, g_kw, Duplicated(x, dx)) + @test x ≈ [2.7 * 2.0] + @test dx ≈ [3.1 * 2.0] +end + +@testset "Reverse Inactive allocated kwarg success" begin + x = [2.7] + dx = [3.1] + autodiff(Reverse, g_kw, Duplicated(x, dx)) + @test x ≈ [2.7 * 2.0] + @test dx ≈ [3.1 * 2.0] +end + +@testset "Forward Inactive arg kwarg success" begin + x = [2.7] + dx = [3.1] + + tmp = [2.0, 0.0] + dtmp = [7.1, 9.4] + autodiff(Forward, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp)) + + @test x ≈ [2.7 * 2.0] + @test dx ≈ [3.1 * 2.0] +end + +@testset "Reverse Inactive arg kwarg success" begin + x = [2.7] + dx = [3.1] + + tmp = [2.0, 0.0] + dtmp = [7.1, 9.4] + autodiff(Reverse, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp)) + + @test x ≈ [2.7 * 2.0] + @test dx ≈ [3.1 * 2.0] +end + +end # InactiveKWRules + diff --git a/test/rules/kwrrules.jl b/test/rules/kwrrules.jl index 044273d44f..7e6af86505 100644 --- a/test/rules/kwrrules.jl +++ b/test/rules/kwrrules.jl @@ -108,7 +108,7 @@ end # Test that this errors due to missing kwargs in rule definition g4(x, y) = f_kw4(x; y) @test autodiff(Reverse, g4, Active(2.0), Const(42.0))[1][1] ≈ 42004.0 -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, g4, Active(2.0), Active(42.0))[1] +@test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Reverse, g4, Active(2.0), Active(42.0))[1] struct Closure2 v::Vector{Float64} diff --git a/test/rules/kwrules.jl b/test/rules/kwrules.jl index 9761c23510..a005c65fd7 100644 --- a/test/rules/kwrules.jl +++ b/test/rules/kwrules.jl @@ -56,7 +56,7 @@ end # Test that this errors due to missing kwargs in rule definition g4(x, y) = f_kw4(x; y) @test autodiff(Forward, g4, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 42004.0 -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Forward, g4, Duplicated(2.0, 1.0), Duplicated(42.0, 1.0))[1] +@test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, g4, Duplicated(2.0, 1.0), Duplicated(42.0, 1.0))[1] end # KWForwardRules diff --git a/test/rules/mixederror.jl b/test/rules/mixederror.jl new file mode 100644 index 0000000000..128f4403f6 --- /dev/null +++ b/test/rules/mixederror.jl @@ -0,0 +1,93 @@ +module MixedRuleError + +using Enzyme +using Enzyme.EnzymeRules +using Test + +using Enzyme, LinearAlgebra + +function handle_infinities(workfunc, f, s) + s1, s2 = first(s), last(s) + inf1, inf2 = isinf(s1), isinf(s2) + if inf1 || inf2 + if inf1 && inf2 # x = t / (1 - t^2) + return workfunc( + function (t) + t2 = t * t + den = 1 / (1 - t2) + return f(oneunit(s1) * t * den) * (1 + t2) * den * den * oneunit(s1) + end, + map(s) do x + isinf(x) ? copysign(one(x), x) : 2x / (oneunit(x) + hypot(oneunit(x), 2x)) + end, + t -> oneunit(s1) * t / (1 - t^2), + ) + else + (s0, si) = inf1 ? (s2, s1) : (s1, s2) + if si < zero(si) # x = s0 - t / (1 - t) + return workfunc( + function (t) + den = 1 / (1 - t) + return f(s0 - oneunit(s1) * t * den) * den * den * oneunit(s1) + end, + reverse(map(s) do x + 1 / (1 + oneunit(x) / (s0 - x)) + end), + t -> s0 - oneunit(s1) * t / (1 - t), + ) + else # x = s0 + t / (1 - t) + return workfunc( + function (t) + den = 1 / (1 - t) + return f(s0 + oneunit(s1) * t * den) * den * den * oneunit(s1) + end, + map(s) do x + 1 / (1 + oneunit(x) / (x - s0)) + end, + t -> s0 + oneunit(s1) * t / (1 - t), + ) + end + end + end + return workfunc(f, s, identity) +end + +outer(f, xs...) = handle_infinities((f_, xs_, _) -> inner(f_, xs_), f, xs) + +function inner(f::F, xs) where {F} # remove type annotation => problem solved + s = sum(f, xs) + return (s, norm(s)) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, ::Const{typeof(inner)}, ::Type, f, xs +) + true_primal = inner(f.val, xs.val) + primal = EnzymeRules.needs_primal(config) ? true_primal : nothing + shadow = if EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + make_zero(true_primal) + else + ntuple(_ -> make_zero(true_primal), Val(EnzymeRules.width(config))) + end + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + ::EnzymeRules.RevConfig, ::Const{typeof(inner)}, shadow::Active, tape, f, xs +) + return ((f isa Active) ? f : nothing, (xs isa Active) ? xs : nothing) +end + +F_good(x) = outer(y -> [cos(x * y)], 0.0, 1.0)[1][1] +F_bad(x) = outer(y -> [cos(y)], 0.0, x)[1][1] + +@testset "Mixed Return Rule Error" begin + @test_throws Enzyme.Compiler.MixedReturnException autodiff(Reverse, F_good, Active(0.3)) + @test_throws Enzyme.Compiler.MixedReturnException autodiff(Reverse, F_bad, Active(0.3)) +end + +end # MixedRuleError \ No newline at end of file diff --git a/test/rules/rules.jl b/test/rules/rules.jl index b306c353fb..4ed0f483f8 100644 --- a/test/rules/rules.jl +++ b/test/rules/rules.jl @@ -131,7 +131,7 @@ end @test Enzyme.autodiff(Forward, h, Duplicated(3.0, 1.0)) == (6000.0,) @test Enzyme.autodiff(ForwardWithPrimal, h, Duplicated(3.0, 1.0)) == (60.0, 9.0) @test Enzyme.autodiff(Forward, h2, Duplicated(3.0, 1.0)) == (1080.0,) - @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0)) + @test_throws Enzyme.Compiler.ForwardRuleReturnError Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0)) end foo(x) = 2x;