Skip to content

Improve dup kwarg error and add utility#2761

Merged
wsmoses merged 20 commits intomainfrom
kwarge
Nov 10, 2025
Merged

Improve dup kwarg error and add utility#2761
wsmoses merged 20 commits intomainfrom
kwarge

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Nov 9, 2025

No description provided.

@wsmoses wsmoses requested a review from vchuravy November 9, 2025 23:37
@github-actions
Copy link
Contributor

github-actions bot commented Nov 9, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl
index af8d60e4..219eb3e4 100644
--- a/ext/EnzymeChainRulesCoreExt.jl
+++ b/ext/EnzymeChainRulesCoreExt.jl
@@ -204,23 +204,27 @@ function Enzyme._import_rrule(fn, tys...)
             shadow, byref = if !EnzymeRules.needs_shadow(config)
                 nothing, Val(false)
             elseif !Enzyme.Compiler.guaranteed_nonactive(Core.Typeof(res))
-                (if EnzymeRules.width(config) == 1
-                    Ref(Enzyme.make_zero(res))
-                else
-                    ntuple(Val(EnzymeRules.width(config))) do j
-                        Base.@_inline_meta
+                (
+                    if EnzymeRules.width(config) == 1
                         Ref(Enzyme.make_zero(res))
-                    end
-                end, Val(true))
-            else 
-                (if EnzymeRules.width(config) == 1
+                    else
+                        ntuple(Val(EnzymeRules.width(config))) do j
+                            Base.@_inline_meta
+                            Ref(Enzyme.make_zero(res))
+                        end
+                    end, Val(true),
+                )
+            else
+                (
+                    if EnzymeRules.width(config) == 1
                     Enzyme.make_zero(res)
                 else
                     ntuple(Val(EnzymeRules.width(config))) do j
                         Base.@_inline_meta
                         Enzyme.make_zero(res)
                     end
-                end, Val(false))
+                    end, Val(false),
+                )
             end
 
             cache = (shadow, pullback, byref)
diff --git a/lib/EnzymeCore/src/easyrules.jl b/lib/EnzymeCore/src/easyrules.jl
index a8ddeada..e6941b54 100644
--- a/lib/EnzymeCore/src/easyrules.jl
+++ b/lib/EnzymeCore/src/easyrules.jl
@@ -479,7 +479,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
 
             PT = EnzymeRules.primal_type(config, ($(esc(:RTA))).parameters[1])
             ST = EnzymeRules.shadow_type(config, ($(esc(:RTA))).parameters[1])
-            AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT,$ST,typeof(cache)})
+            AugmentedReturnType = :(EnzymeRules.AugmentedReturn{$PT, $ST, typeof(cache)})
 
             genres = if needs_primal(config)
                 if needs_shadow(config)
diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl
index 2d5259e0..7284abab 100644
--- a/lib/EnzymeCore/src/rules.jl
+++ b/lib/EnzymeCore/src/rules.jl
@@ -133,8 +133,8 @@ Compute the expected primal return type given a reverse mode config and return a
 """
 @inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
 @inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
-@inline primal_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
-@inline primal_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
+@inline primal_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_primal(config) ? RT : Nothing
+@inline primal_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_primal(config) ? RT : Nothing
 
 """
     shadow_type(::FwdConfig, ::Type{<:Annotation{RT}})
@@ -146,8 +146,8 @@ Compute the expected shadow return type given a reverse mode config and return a
 """
 @inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
 @inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
-@inline shadow_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
-@inline shadow_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
+@inline shadow_type(config::Type{<:FwdConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
+@inline shadow_type(config::Type{<:RevConfig}, ::Type{<:Annotation{RT}}) where {RT} = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
 
 
 """
@@ -173,7 +173,7 @@ Otherwise, just return the shadows.
     needsPrimal = EnzymeRules.needs_primal(C)
     needsShadow = EnzymeRules.needs_shadow(C)
     width = EnzymeRules.width(C)
-    if !needsShadow
+    return if !needsShadow
         if needsPrimal
             return RealRt
         else
@@ -184,14 +184,14 @@ Otherwise, just return the shadows.
         if !needsPrimal
             ST = RealRt
             if width != 1
-                ST = NTuple{Int(width),ST}
+                ST = NTuple{Int(width), ST}
             end
             return ST
         else
             ST = if width == 1
                 Duplicated{RealRt}
             else
-                BatchDuplicated{RealRt,Int(width)}
+                BatchDuplicated{RealRt, Int(width)}
             end
             return ST
         end
@@ -218,17 +218,17 @@ struct AugmentedReturn{PrimalType,ShadowType,TapeType}
     tape::TapeType
 end
 
-@inline function AugmentedReturn{PrimalType,ShadowType}(primal, shadow, cache) where {PrimalType, ShadowType}
-    AT = AugmentedReturn{PrimalType,ShadowType, typeof(cache)}
+@inline function AugmentedReturn{PrimalType, ShadowType}(primal, shadow, cache) where {PrimalType, ShadowType}
+    AT = AugmentedReturn{PrimalType, ShadowType, typeof(cache)}
     return AT(primal, shadow, cache)
 end
 
 @inline primal_type(::Type{<:AugmentedReturn{PrimalType}}) where {PrimalType} = PrimalType
 @inline primal_type(::AugmentedReturn{PrimalType}) where {PrimalType} = PrimalType
-@inline shadow_type(::Type{<:AugmentedReturn{<:Any,ShadowType}}) where {ShadowType} = ShadowType
-@inline shadow_type(::AugmentedReturn{<:Any,ShadowType}) where {ShadowType} = ShadowType
-@inline tape_type(::Type{<:AugmentedReturn{<:Any,<:Any,TapeType}}) where {TapeType} = TapeType
-@inline tape_type(::AugmentedReturn{<:Any,<:Any,TapeType}) where {TapeType} = TapeType
+@inline shadow_type(::Type{<:AugmentedReturn{<:Any, ShadowType}}) where {ShadowType} = ShadowType
+@inline shadow_type(::AugmentedReturn{<:Any, ShadowType}) where {ShadowType} = ShadowType
+@inline tape_type(::Type{<:AugmentedReturn{<:Any, <:Any, TapeType}}) where {TapeType} = TapeType
+@inline tape_type(::AugmentedReturn{<:Any, <:Any, TapeType}) where {TapeType} = TapeType
 struct AugmentedReturnFlexShadow{PrimalType,ShadowType,TapeType}
     primal::PrimalType
     shadow::ShadowType
@@ -461,10 +461,12 @@ This function is currently considered internal/experimental and may not respect
 """
 function inactive_kwarg end
 
-function is_inactive_kwarg_from_sig(@nospecialize(TT);
-                              world::UInt=Base.get_world_counter(),
-                              method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
-                              caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
+function is_inactive_kwarg_from_sig(
+        @nospecialize(TT);
+        world::UInt = Base.get_world_counter(),
+        method_table::Union{Nothing, Core.Compiler.MethodTableView} = nothing,
+        caller::Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
+    )
     return isapplicable(inactive_kwarg, TT; world, method_table, caller)
 end
 
@@ -478,10 +480,12 @@ This function is currently considered internal/experimental and may not respect
 """
 function inactive_arg end
 
-function is_inactive_arg_from_sig(@nospecialize(TT);
-                              world::UInt=Base.get_world_counter(),
-                              method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
-                              caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
+function is_inactive_arg_from_sig(
+        @nospecialize(TT);
+        world::UInt = Base.get_world_counter(),
+        method_table::Union{Nothing, Core.Compiler.MethodTableView} = nothing,
+        caller::Union{Nothing, Core.MethodInstance, Core.Compiler.MethodLookupResult} = nothing
+    )
     return isapplicable(inactive_arg, TT; world, method_table, caller)
 end
 
diff --git a/src/compiler.jl b/src/compiler.jl
index cec43e9e..c6340bf9 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -410,7 +410,7 @@ const JuliaEnzymeNameMap = Dict{String,Any}(
     "enz_no_derivative_exc" => EnzymeNoDerivativeError{Nothing, Nothing},
     "enz_no_derivative_mi_exc" => EnzymeNoDerivativeError{Core.MethodInstance, UInt},
     "enz_non_const_kwarg_exc" => NonConstantKeywordArgException,
-    "enz_callconv_mismatch_exc"=> CallingConventionMismatchError,
+    "enz_callconv_mismatch_exc" => CallingConventionMismatchError,
     "enz_illegal_ta_exc" => IllegalTypeAnalysisException,
     "enz_illegal_first_pointer_exc" => IllegalFirstPointerException,
     "enz_internal_exc" => EnzymeInternalError,
@@ -6318,12 +6318,12 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no
         rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}}
         add_edge!(edges, rev_sig)
     end
-    
+
     for gen_sig in (
-        Tuple{typeof(EnzymeRules.inactive), Vararg{Any}},
+            Tuple{typeof(EnzymeRules.inactive), Vararg{Any}},
         Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
-        Tuple{typeof(EnzymeRules.inactive_arg), Vararg{Any}},
-        Tuple{typeof(EnzymeRules.inactive_kwarg), Vararg{Any}},
+            Tuple{typeof(EnzymeRules.inactive_arg), Vararg{Any}},
+            Tuple{typeof(EnzymeRules.inactive_kwarg), Vararg{Any}},
         Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
         Tuple{typeof(EnzymeRules.inactive_type), Type},
     )
diff --git a/src/errors.jl b/src/errors.jl
index cd6aba94..aec96f5a 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -11,7 +11,7 @@ abstract type EnzymeError <: Base.Exception end
 
 abstract type CompilationException <: EnzymeError end
 
-function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
+function pretty_print_mi(mi, io = stdout; digit_align_width = 1)
     spec = mi.specTypes.parameters
     ft = spec[1]
     arg_types_param = spec[2:end]
@@ -26,7 +26,7 @@ function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
     end
 
     Base.show_signature_function(io, ft)
-    Base.show_tuple_as_call(io, :function, Tuple{arg_types_param...}; hasfirst=false, kwargs = isempty(kwargs) ? nothing : kwargs)
+    Base.show_tuple_as_call(io, :function, Tuple{arg_types_param...}; hasfirst = false, kwargs = isempty(kwargs) ? nothing : kwargs)
 
     m = mi.def
 
@@ -43,12 +43,12 @@ function pretty_print_mi(mi, io=stdout; digit_align_width = 1)
     end
 
     # module & file, re-using function from errorshow.jl
-    Base.print_module_path_file(io, Base.parentmodule(m), string(file), line; modulecolor, digit_align_width)
+    return Base.print_module_path_file(io, Base.parentmodule(m), string(file), line; modulecolor, digit_align_width)
 end
 
 using InteractiveUtils
 
-function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool=false, kwargs...)
+function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool = false, kwargs...)
     CT = @static if VERSION >= v"1.11.0-DEV.1552"
         EnzymeCacheToken(
             typeof(DefaultCompilerTarget()),
@@ -71,10 +71,10 @@ function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.AP
     interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
 
     sig = mi.specTypes  # XXX: can we just use the method instance?
-    if interactive
+    return if interactive
         # call Cthulhu without introducing a dependency on Cthulhu
         mod = get(Base.loaded_modules, Cthulhu, nothing)
-        mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
+        mod === nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
         descend_code_typed = getfield(mod, :descend_code_typed)
         descend_code_typed(sig; interp, kwargs...)
     else
@@ -120,7 +120,7 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeExceptionMI)
     )
     println(io)
     msg = Base.unsafe_string(ece.msg)
-    print(io, msg, '\n')
+    return print(io, msg, '\n')
 end
 
 abstract type CustomRuleError <: Base.Exception end
@@ -148,7 +148,7 @@ function Base.showerror(io::IO, ece::NonConstantKeywordArgException)
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 struct CallingConventionMismatchError <: CustomRuleError
@@ -174,7 +174,7 @@ function Base.showerror(io::IO, ece::CallingConventionMismatchError)
     )
     println(io)
 
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 InteractiveUtils.code_typed(ece::CallingConventionMismatchError; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
@@ -236,7 +236,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
     elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C)
         if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
             hint = "Shadow was not requested, you should only return the primal"
-        elseif fwd_RT <: (NTuple{N, <:RealRt} where N)
+        elseif fwd_RT <: (NTuple{N, <:RealRt} where {N})
             hint = "You appear to be returning a tuple of shadows, but only the primal was requested"
         elseif fwd_RT <: RealRt
             hint = "Expected the abstract type $RealRt for primal, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -249,7 +249,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
         if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
             hint = "Primal was not requested, you should only return the shadow"
         elseif width == 1
-            if fwd_RT <: (NTuple{N, <:RealRt} where N)
+            if fwd_RT <: (NTuple{N, <:RealRt} where {N})
                 hint = "You look to be returning a tuple of shadows, when the batch size is 1"
             elseif fwd_RT <: RealRt
                 hint = "Expected the abstract type $RealRt for shadow, you returned $(fwd_RT). Even though $(fwd_RT) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -273,11 +273,11 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
 
         if fwd_RT <: BatchDuplicated || fwd_RT <: Duplicated
             hint = "Neither primal nor shadow were requested, you should return nothing, not both the primal and shadow"
-        elseif fwd_RT <: (NTuple{N, <:RealRt} where N)
+        elseif fwd_RT <: (NTuple{N, <:RealRt} where {N})
             hint = "You appear to be returning a tuple of shadows, but neither primal nor shadow were requested"
         elseif fwd_RT <: RealRt && width == 1
             hint = "You appear to be returning a primal or shadow, but neither were requested"
-        elseif fwd_RT <: RealRt 
+        elseif fwd_RT <: RealRt
             hint = "You appear to be returning a primal, but it was not requested"
         else
             hint = "You should return nothing"
@@ -309,7 +309,7 @@ function Base.showerror(io::IO, ece::ForwardRuleReturnError{C, RT, fwd_RT}) wher
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 
@@ -345,7 +345,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
         elseif EnzymeRules.primal_type(fwd_RT) == Nothing
             hint = "Missing primal return"
         elseif EnzymeRules.shadow_type(fwd_RT) == Nothing
-            hint = "Missing shadow return"      
+            hint = "Missing shadow return"
         elseif EnzymeRules.primal_type(fwd_RT) != RealRt
             if EnzymeRules.primal_type(fwd_RT) <: RealRt
                 hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -356,7 +356,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
             if width == 1
                 if EnzymeRules.shadow_type(fwd_RT) <: RealRt
                     hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
-                elseif shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+                elseif shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
                     hint = "Batch size was 1, expected a single shadow, not a tuple of shadows."
                 else
                     hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -364,7 +364,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
             else
                 if EnzymeRules.shadow_type(fwd_RT) <: RealRt
                     hint = "Batch size was $width, expected a tuple of shadows, not a single shadow."
-                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
                     hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
                 else
                     hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -375,11 +375,11 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
         "primal and shadow configuration"
     elseif EnzymeRules.needs_primal(C) && !EnzymeRules.needs_shadow(C)
         if !(fwd_RT <: EnzymeRules.AugmentedReturn)
-            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"        
+            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
         elseif EnzymeRules.primal_type(fwd_RT) == Nothing
             hint = "Missing primal return"
         elseif EnzymeRules.shadow_type(fwd_RT) != Nothing
-            hint = "Shadow return was not requested"      
+            hint = "Shadow return was not requested"
         elseif EnzymeRules.primal_type(fwd_RT) != RealRt
             if EnzymeRules.primal_type(fwd_RT) <: RealRt
                 hint = "Expected the abstract type $RealRt for primal, you returned $(EnzymeRules.primal_type(fwd_RT)). Even though $(EnzymeRules.primal_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
@@ -392,14 +392,14 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
     elseif !EnzymeRules.needs_primal(C) && EnzymeRules.needs_shadow(C)
 
         if !(fwd_RT <: EnzymeRules.AugmentedReturn)
-            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"        
+            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
         elseif EnzymeRules.primal_type(fwd_RT) != Nothing
             hint = "Primal was not requested"
         elseif EnzymeRules.shadow_type(fwd_RT) != RealRt
             if width == 1
                 if EnzymeRules.shadow_type(fwd_RT) <: RealRt
                     hint = "Expected the abstract type $RealRt for shadow, you returned $(EnzymeRules.shadow_type(fwd_RT)). Even though $(EnzymeRules.shadow_type(fwd_RT)) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
-                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
                     hint = "Batch size was 1, expected a single shadow, not a tuple of shadows."
                 else
                     hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -407,7 +407,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
             else
                 if EnzymeRules.shadow_type(fwd_RT) <: RealRt
                     hint = "Batch size was $width, expected a tuple of shadows, not a single shadow."
-                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where N)
+                elseif EnzymeRules.shadow_type(fwd_RT) <: (NTuple{N, <:RealRt} where {N})
                     hint = "Expected the abstract type $RealRt for the element shadow type (for a batched shadow type $(EnzymeRules.shadow_type(ExpRT))), you returned $(eltype(EnzymeRules.shadow_type(fwd_RT))) as the element shadow type (batched to become $(EnzymeRules.shadow_type(fwd_RT)). Even though $(eltype(EnzymeRules.shadow_type(fwd_RT))) <: $RealRt, rules require an exact match (akin to how you cannot substitute Vector{Float64} in a method that takes a Vector{Real})."
                 else
                     hint = "Mismatched shadow type $(EnzymeRules.shadow_type(fwd_RT)), expected $(EnzymeRules.shadow_type(ExpRT))."
@@ -418,7 +418,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
         "shadow-only configuration"
     else
         if !(fwd_RT <: EnzymeRules.AugmentedReturn)
-            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"        
+            hint = "Return should be a struct of type EnzymeRules.AugmentedReturn"
         elseif EnzymeRules.primal_type(fwd_RT) != Nothing
             hint = "Primal was not requested"
         elseif EnzymeRules.shadow_type(fwd_RT) != Nothing
@@ -454,7 +454,7 @@ function Base.showerror(io::IO, ece::AugmentedRuleReturnError{C, RT, fwd_RT}) wh
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 
@@ -469,7 +469,7 @@ InteractiveUtils.code_typed(ece::ReverseRuleReturnError; kwargs...) = code_typed
 function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT}) where {C, ArgAct, rev_RT}
     width = EnzymeRules.width(C)
     Tys = (
-        A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width),eltype(A)}) : Nothing for A in ArgAct.parameters
+        A <: Active ? (width == 1 ? eltype(A) : NTuple{Int(width), eltype(A)}) : Nothing for A in ArgAct.parameters
     )
     ExpRT = Tuple{Tys...}
     @assert ExpRT != rev_RT
@@ -500,7 +500,7 @@ function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT})
 
             if width == 1
 
-                if rev_RT.parameters[i] <: (NTuple{N, ExpRT.parameters[i]} where N)
+                if rev_RT.parameters[i] <: (NTuple{N, ExpRT.parameters[i]} where {N})
                     hint = "Tuple return mismatch at index $i, returned a tuple of results when expected just one of type $(ExpRT.parameters[i])."
                     break
                 end
@@ -550,7 +550,7 @@ function Base.showerror(io::IO, ece::ReverseRuleReturnError{C, ArgAct, rev_RT})
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 struct MixedReturnException{RT} <: CustomRuleError
@@ -561,7 +561,7 @@ end
 
 InteractiveUtils.code_typed(ece::MixedReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
 
-function Base.showerror(io::IO, ece::MixedReturnException{RT}) where RT
+function Base.showerror(io::IO, ece::MixedReturnException{RT}) where {RT}
     if isdefined(Base.Experimental, :show_error_hints)
         Base.Experimental.show_error_hints(io, ece)
     end
@@ -577,7 +577,7 @@ function Base.showerror(io::IO, ece::MixedReturnException{RT}) where RT
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 
@@ -589,7 +589,7 @@ end
 
 InteractiveUtils.code_typed(ece::UnionSretReturnException; kwargs...) = code_typed_helper(ece.mi, ece.world; kwargs...)
 
-function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where RT
+function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where {RT}
     if isdefined(Base.Experimental, :show_error_hints)
         Base.Experimental.show_error_hints(io, ece)
     end
@@ -606,11 +606,10 @@ function Base.showerror(io::IO, ece::UnionSretReturnException{RT}) where RT
     println(io)
     pretty_print_mi(ece.mi, io)
     println(io)
-    Base.println(io, Base.unsafe_string(ece.backtrace))
+    return Base.println(io, Base.unsafe_string(ece.backtrace))
 end
 
 
-
 struct NoDerivativeException <: CompilationException
     msg::String
     ir::Union{Nothing,String}
@@ -694,7 +693,7 @@ function InteractiveUtils.code_typed(ece::IllegalTypeAnalysisException; kwargs..
     end
     world = ece.world::UInt
     mode = Enzyme.API.DEM_ReverseModeCombined
-    code_typed_helper(ece.mi, ece.world; kwargs...)
+    return code_typed_helper(ece.mi, ece.world; kwargs...)
 end
 
 struct IllegalFirstPointerException <: CompilationException
@@ -764,7 +763,7 @@ function Base.showerror(io::IO, ece::EnzymeMutabilityException)
         Base.Experimental.show_error_hints(io, ece)
     end
     msg = Base.unsafe_string(ece.msg)
-    print(io, "EnzymeMutabilityException: ", msg, '\n')
+    return print(io, "EnzymeMutabilityException: ", msg, '\n')
 end
 
 struct EnzymeRuntimeActivityError{MT,WT} <: EnzymeError
@@ -822,7 +821,7 @@ function InteractiveUtils.code_typed(ece::EnzymeRuntimeActivityError; interactiv
     if mi === nothing
         throw(AssertionError("code_typed(::EnzymeRuntimeActivityError; interactive::Bool=false, kwargs...) not supported for error without mi"))
     end
-    code_typed_helper(ece.mi, ece.world; kwargs...)
+    return code_typed_helper(ece.mi, ece.world; kwargs...)
 end
 
 struct EnzymeNoTypeError{MT,WT} <: EnzymeError
@@ -863,7 +862,7 @@ function InteractiveUtils.code_typed(ece::EnzymeNoTypeError; interactive::Bool=f
     if mi === nothing
         throw(AssertionError("code_typed(::EnzymeNoTypeError; interactive::Bool=false, kwargs...) not supported for error without mi"))
     end
-    code_typed_helper(ece.mi, ece.world; kwargs...)
+    return code_typed_helper(ece.mi, ece.world; kwargs...)
 end
 
 struct EnzymeNoShadowError <: EnzymeError
@@ -879,18 +878,18 @@ function Base.showerror(io::IO, ece::EnzymeNoShadowError)
     print(io, msg, '\n')
 end
 
-struct EnzymeNoDerivativeError{MT,WT} <: EnzymeError
+struct EnzymeNoDerivativeError{MT, WT} <: EnzymeError
     msg::Cstring
     mi::MT
     world::WT
 end
 
-function InteractiveUtils.code_typed(ece::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...)
+function InteractiveUtils.code_typed(ece::EnzymeNoDerivativeError; interactive::Bool = false, kwargs...)
     mi = ece.mi
     if mi === nothing
         throw(AssertionError("code_typed(::EnzymeNoDerivativeError; interactive::Bool=false, kwargs...) not supported for error without mi"))
     end
-    code_typed_helper(ece.mi, ece.world; kwargs...)
+    return code_typed_helper(ece.mi, ece.world; kwargs...)
 end
 
 function Base.showerror(io::IO, ece::EnzymeNoDerivativeError)
@@ -900,7 +899,7 @@ function Base.showerror(io::IO, ece::EnzymeNoDerivativeError)
     msg = Base.unsafe_string(ece.msg)
     print(io, "EnzymeNoDerivativeError: ", msg, '\n')
 
-    if ece.mi !== nothing
+    return if ece.mi !== nothing
         print(io, "Failure within method:\n")
         println(io)
         pretty_print_mi(ece.mi, io)
@@ -983,8 +982,8 @@ function julia_error(
         if occursin("No create nofree of empty function", msg) ||
            occursin("No forward mode derivative found for", msg) ||
            occursin("No augmented forward pass", msg) ||
-           occursin("No reverse pass found", msg) ||
-           occursin("Runtime Activity not yet implemented for Forward-Mode BLAS", msg)
+                occursin("No reverse pass found", msg) ||
+                occursin("Runtime Activity not yet implemented for Forward-Mode BLAS", msg)
             ir = nothing
         end
         if B != C_NULL
diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl
index eccd8f5d..d1bf77ab 100644
--- a/src/rules/activityrules.jl
+++ b/src/rules/activityrules.jl
@@ -66,7 +66,6 @@ function julia_activity_rule(f::LLVM.Function, method_table)
     end
 
 
-
     if !Enzyme.Compiler.no_type_setting(mi.specTypes; world)[1]
         any_active = false
         for arg in jlargs
@@ -84,8 +83,8 @@ function julia_activity_rule(f::LLVM.Function, method_table)
                     parameter_attributes(f, arg.codegen.i),
                     StringAttribute("enzyme_inactive"),
                 )
-    	    else
-        		any_active = true
+            else
+                any_active = true
             end
         end
         if sret !== nothing
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index 680746e8..0e710633 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -2,7 +2,7 @@ import LinearAlgebra
 
 @inline add_fwd(prev, post) = recursive_add(prev, post)
 
-@generated function EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial::Union{AbstractArray,Number}, dx::Union{AbstractArray,Number})
+@generated function EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial::Union{AbstractArray, Number}, dx::Union{AbstractArray, Number})
     if partial <: Number || dx isa Number
         if !(prev <: Type)
             return quote
@@ -19,17 +19,17 @@ import LinearAlgebra
     @assert partial <: AbstractArray
     if dx <: Number
         if !(prev <: Type)
-    	    return quote
-    		    Base.@_inline_meta
-    		    LinearAlgebra.axpy!(dx, partial, prev)
-    		    prev
-    	    end
-    	else
-    	    return quote
-    		    Base.@_inline_meta
-    		    prev(partial * dx)
-    	    end
-    	end
+            return quote
+                Base.@_inline_meta
+                LinearAlgebra.axpy!(dx, partial, prev)
+                prev
+            end
+        else
+            return quote
+                Base.@_inline_meta
+                prev(partial * dx)
+            end
+        end
     end
     @assert dx <: AbstractArray
     N = ndims(partial)
@@ -66,7 +66,7 @@ import LinearAlgebra
     end
 
     init = if prev <: Type
-        :(prev = similar(prev, size(partial)[1:$(N-M)]...))
+        :(prev = similar(prev, size(partial)[1:$(N - M)]...))
     end
 
     idxs = Symbol[]
@@ -344,7 +344,7 @@ function enzyme_custom_setup_args(
                             B,
                             orig,
                             "Enzyme: active by ref type $Ty is overwritten in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)). " *
-                            "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n"*msg2,
+                                "As a workaround until support for this is added, try passing values as separate arguments rather than as an aggregate of type $Ty.\n" * msg2,
                         )
                     end
                     if arty == eltype(value_type(val))
@@ -356,7 +356,7 @@ function enzyme_custom_setup_args(
                         emit_error(
                             B,
                             orig,
-                            "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2,
+                            "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n" * msg2,
                         )
                     end
                 end
@@ -372,7 +372,7 @@ function enzyme_custom_setup_args(
                     emit_error(
                         B,
                         orig,
-                        "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2,
+                        "Enzyme: active by ref type $Ty is wrong store type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n" * msg2,
                     )
                 end
 
@@ -579,9 +579,9 @@ function enzyme_custom_setup_ret(
         mode != API.DEM_ForwardMode &&
         !guaranteed_nonactive(RealRt, world)
     )
-        if active_reg(RealRt, world) == MixedState && B !== nothing        
+        if active_reg(RealRt, world) == MixedState && B !== nothing
             bt = GPUCompiler.backtrace(orig)
-            msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))            
+            msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
             mi, _ = enzyme_custom_extract_mi(orig)
             emit_error(
                 B,
@@ -742,7 +742,7 @@ end
     ExpRT = EnzymeRules.forward_rule_return_type(C, RT)
     if ExpRT != fwd_RT
         bt = GPUCompiler.backtrace(orig)
-        msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))            
+        msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
         emit_error(
             B,
             orig,
@@ -1475,9 +1475,12 @@ function enzyme_custom_common_rev(
         if ST != EnzymeRules.augmented_rule_return_type(C, RT, TapeT)
             throw(AssertionError("Unexpected augmented rule return computation\nST = $ST\nER = $(EnzymeRules.augmented_rule_return_type(C, RT, TapeT))\nC = $C\nRT = $RT\nTapeT = $TapeT"))
         end
-        if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT <: EnzymeRules.AugmentedReturn{
-            needsPrimal ? RealRt : Nothing,
-            needsShadowJL ? ShadT : Nothing})
+        if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(
+                aug_RT <: EnzymeRules.AugmentedReturn{
+                    needsPrimal ? RealRt : Nothing,
+                    needsShadowJL ? ShadT : Nothing,
+                }
+            )
 
             bt = GPUCompiler.backtrace(orig)
             msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
@@ -1603,7 +1606,7 @@ function enzyme_custom_common_rev(
         if rev_RT != ST
             bt = GPUCompiler.backtrace(orig)
             msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
-            emit_error(B, orig, (msg2, rmi, world), ReverseRuleReturnError{C, Tuple{activity[2+isKWCall:end]...,}, rev_RT})
+            emit_error(B, orig, (msg2, rmi, world), ReverseRuleReturnError{C, Tuple{activity[(2 + isKWCall):end]...}, rev_RT})
             return tapeV
         end
         if length(actives) >= 1 &&
diff --git a/test/rules/inactive_kwrules.jl b/test/rules/inactive_kwrules.jl
index fc0c731c..62092abe 100644
--- a/test/rules/inactive_kwrules.jl
+++ b/test/rules/inactive_kwrules.jl
@@ -6,10 +6,10 @@ using Test
 
 import .EnzymeRules: forward, augmented_primal, reverse
 
-function f_kw(out; tmp=[2.0, 0.0])
+function f_kw(out; tmp = [2.0, 0.0])
     out[1] *= tmp[1]
     tmp[2] += 1
-    nothing
+    return nothing
 end
 
 function forward(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, x::Duplicated; kwargs...)
@@ -29,14 +29,14 @@ function reverse(config, ::Const{typeof(f_kw)}, ::Type{<:Const}, tape, x::Duplic
 end
 
 function g_kw(out)
-    tmp=[2.0, 0.0]
+    tmp = [2.0, 0.0]
     f_kw(out; tmp)
-    nothing
+    return nothing
 end
 
 function h_kw(out, tmp)
     f_kw(out; tmp)
-    nothing
+    return nothing
 end
 
 @testset "Forward Inactive allocated kwarg error" begin
@@ -69,7 +69,7 @@ end
     @test_throws Enzyme.Compiler.NonConstantKeywordArgException autodiff(Forward, h_kw, Duplicated(x, dx), Duplicated(tmp, dtmp))
 end
 
-Enzyme.EnzymeRules.inactive_kwarg(::typeof(f_kw), out; tmp=[2.0]) = nothing
+Enzyme.EnzymeRules.inactive_kwarg(::typeof(f_kw), out; tmp = [2.0]) = nothing
 
 @testset "Forward Inactive allocated kwarg success" begin
     x = [2.7]
diff --git a/test/rules/mixederror.jl b/test/rules/mixederror.jl
index 128f4403..dfb71b07 100644
--- a/test/rules/mixederror.jl
+++ b/test/rules/mixederror.jl
@@ -30,9 +30,11 @@ function handle_infinities(workfunc, f, s)
                         den = 1 / (1 - t)
                         return f(s0 - oneunit(s1) * t * den) * den * den * oneunit(s1)
                     end,
-                    reverse(map(s) do x
-                        1 / (1 + oneunit(x) / (s0 - x))
-                    end),
+                    reverse(
+                        map(s) do x
+                            1 / (1 + oneunit(x) / (s0 - x))
+                        end
+                    ),
                     t -> s0 - oneunit(s1) * t / (1 - t),
                 )
             else # x = s0 + t / (1 - t)
@@ -60,8 +62,8 @@ function inner(f::F, xs) where {F}  # remove type annotation => problem solved
 end
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig, ::Const{typeof(inner)}, ::Type, f, xs
-)
+        config::EnzymeRules.RevConfig, ::Const{typeof(inner)}, ::Type, f, xs
+    )
     true_primal = inner(f.val, xs.val)
     primal = EnzymeRules.needs_primal(config) ? true_primal : nothing
     shadow = if EnzymeRules.needs_shadow(config)
@@ -77,8 +79,8 @@ function EnzymeRules.augmented_primal(
 end
 
 function EnzymeRules.reverse(
-    ::EnzymeRules.RevConfig, ::Const{typeof(inner)}, shadow::Active, tape, f, xs
-)
+        ::EnzymeRules.RevConfig, ::Const{typeof(inner)}, shadow::Active, tape, f, xs
+    )
     return ((f isa Active) ? f : nothing, (xs isa Active) ? xs : nothing)
 end
 
@@ -90,4 +92,4 @@ F_bad(x) = outer(y -> [cos(y)], 0.0, x)[1][1]
     @test_throws Enzyme.Compiler.MixedReturnException autodiff(Reverse, F_bad, Active(0.3))
 end
 
-end # MixedRuleError
\ No newline at end of file
+end # MixedRuleError
diff --git a/test/rules/rules.jl b/test/rules/rules.jl
index 4ed0f483..51a1a24e 100644
--- a/test/rules/rules.jl
+++ b/test/rules/rules.jl
@@ -131,7 +131,7 @@ end
     @test Enzyme.autodiff(Forward, h, Duplicated(3.0, 1.0)) == (6000.0,)
     @test Enzyme.autodiff(ForwardWithPrimal, h, Duplicated(3.0, 1.0))  == (60.0, 9.0)
     @test Enzyme.autodiff(Forward, h2, Duplicated(3.0, 1.0))  == (1080.0,)
-    @test_throws Enzyme.Compiler.ForwardRuleReturnError Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0)) 
+    @test_throws Enzyme.Compiler.ForwardRuleReturnError Enzyme.autodiff(Forward, h3, Duplicated(3.0, 1.0))
 end
 
 foo(x) = 2x;

@wsmoses wsmoses merged commit f518970 into main Nov 10, 2025
42 of 48 checks passed
@wsmoses wsmoses deleted the kwarge branch November 10, 2025 08:32
@codecov
Copy link

codecov bot commented Nov 10, 2025

Codecov Report

❌ Patch coverage is 12.74834% with 527 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.89%. Comparing base (8bbf178) to head (76241bc).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
src/errors.jl 3.09% 376 Missing ⚠️
src/rules/customrules.jl 43.00% 57 Missing ⚠️
lib/EnzymeCore/src/rules.jl 0.00% 54 Missing ⚠️
ext/EnzymeChainRulesCoreExt.jl 10.52% 17 Missing ⚠️
lib/EnzymeCore/src/easyrules.jl 0.00% 13 Missing ⚠️
src/rules/llvmrules.jl 55.55% 8 Missing ⚠️
src/analyses/activity.jl 66.66% 1 Missing ⚠️
src/compiler.jl 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2761      +/-   ##
==========================================
- Coverage   70.30%   68.89%   -1.42%     
==========================================
  Files          58       58              
  Lines       19391    19861     +470     
==========================================
+ Hits        13633    13683      +50     
- Misses       5758     6178     +420     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant