Skip to content

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Nov 14, 2025

didn't test yet if worked, just did a sweep through parts of the code.

custom derivatives also need a sweep

@github-actions
Copy link
Contributor

github-actions bot commented Nov 14, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/absint.jl b/src/absint.jl
index 6e098fba..cce53e79 100644
--- a/src/absint.jl
+++ b/src/absint.jl
@@ -1,7 +1,7 @@
 # Abstractly interpret julia from LLVM
 
 # Return (bool if could interpret, julia object interpreted to)
-function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false)::Tuple{Bool, Any}
+function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool = false)::Tuple{Bool, Any}
     if (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked)) || (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Derived)) || istracked
         ce, _ = get_base_and_offset(arg; offsetAllowed = false, inttoptr = true)
         if isa(ce, GlobalVariable)
@@ -443,8 +443,8 @@ function abs_typeof(
         end
         # Type tag is arg 3
         if nm == "jl_alloc_genericmemory_unchecked" ||
-		nm == "ijl_alloc_genericmemory_unchecked"
-	    vals = absint(operands(arg)[3], partial, true)
+                nm == "ijl_alloc_genericmemory_unchecked"
+            vals = absint(operands(arg)[3], partial, true)
             return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
         end
         # Type tag is arg 1
diff --git a/src/compiler.jl b/src/compiler.jl
index 1eeb1438..59db45cf 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1071,7 +1071,7 @@ end
     return
 end
 
-function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)::Tuple{Dict{String,LLVM.API.LLVMLinkage}, HandlerState}
+function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)::Tuple{Dict{String, LLVM.API.LLVMLinkage}, HandlerState}
 
     for f in functions(mod)
         if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1")
@@ -1088,7 +1088,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
         dl = string(LLVM.datalayout(LLVM.parent(f)))
 
         expectLen = (sret !== nothing) + (returnRoots !== nothing)
-	for (source_typ, _) in rooted_argument_list(mi.specTypes.parameters)
+        for (source_typ, _) in rooted_argument_list(mi.specTypes.parameters)
             if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
                 continue
             end
@@ -1111,19 +1111,19 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
         world = enzyme_extract_world(f)
 
         if expectLen != length(parameters(f))
-		msg = sprint() do io::IO
-		    println(io, "expectLen != length(parameters(f))")
-		    println(io, string(f))
-		    println(io, "expectLen=", string(expectLen))
-		    println(io, "swiftself=", string(swiftself))
-		    println(io, "sret=", string(sret))
-		    println(io, "returnRoots=", string(returnRoots))
-		    println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters))
-		    println(io, "retRemoved=", string(retRemoved))
-		    println(io, "parmsRemoved=", string(parmsRemoved))
-		    println(io, "rooted_argument_list=", string(rooted_argument_list(mi.specTypes.parameters)))
-		end
-		throw(CallingConventionMismatchError{String}(msg, mi, world))
+            msg = sprint() do io::IO
+                println(io, "expectLen != length(parameters(f))")
+                println(io, string(f))
+                println(io, "expectLen=", string(expectLen))
+                println(io, "swiftself=", string(swiftself))
+                println(io, "sret=", string(sret))
+                println(io, "returnRoots=", string(returnRoots))
+                println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters))
+                println(io, "retRemoved=", string(retRemoved))
+                println(io, "parmsRemoved=", string(parmsRemoved))
+                println(io, "rooted_argument_list=", string(rooted_argument_list(mi.specTypes.parameters)))
+            end
+            throw(CallingConventionMismatchError{String}(msg, mi, world))
         end
 
         jlargs = classify_arguments(
@@ -1230,7 +1230,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
         #=lowerConvention=#true,
         #=loweredArgs=#Set{Int}(),
         #=boxedArgs=#Set{Int}(),
-	#=removedRoots=#Set{Int}(),
+        #=removedRoots=# Set{Int}(),
         #=fnsToInject=#Tuple{Symbol,Type}[],
     )
 
@@ -1713,10 +1713,10 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
 	return
     end
     @static if VERSION >= v"1.11"
-	if Ty <: GenericMemory
-	    # TODO throw(AssertionError("What the heck is happening, why are we gc.alloca'ing memory, $(string(V)) $Ty"))
-	    return
-	end
+        if Ty <: GenericMemory
+            # TODO throw(AssertionError("What the heck is happening, why are we gc.alloca'ing memory, $(string(V)) $Ty"))
+            return
+        end
     end
 
     if mode == API.DEM_ForwardMode && (used || idx != 0)
@@ -2410,7 +2410,7 @@ end
 function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = true)
     ty = nothing
     byref = nothing
-    for fattr in collect(parameter_attributes(fn, idx)   )
+    for fattr in collect(parameter_attributes(fn, idx))
         if isa(fattr, LLVM.StringAttribute)
             if kind(fattr) == "enzymejl_parmtype"
                 ptr = reinterpret(Ptr{Cvoid}, parse(UInt, LLVM.value(fattr)))
@@ -2451,7 +2451,7 @@ function enzyme!(
     @nospecialize(expectedTapeType::Type),
     loweredArgs::Set{Int},
     boxedArgs::Set{Int},
-    removedRoots::Set{Int},
+        removedRoots::Set{Int},
 )
     if DumpPreEnzyme[]
         API.EnzymeDumpModuleRef(mod.ref)
@@ -2485,7 +2485,7 @@ function enzyme!(
     end
 
     seen = TypeTreeTable()
-    
+
     seen_roots = 0
 
     for (i, T) in enumerate(TT.parameters)
@@ -2500,32 +2500,32 @@ function enzyme!(
             end
             continue
         end
-	isboxed = (i + seen_roots) in boxedArgs
-	inline_root = false
-	
+        isboxed = (i + seen_roots) in boxedArgs
+        inline_root = false
 
-	if inline_roots_type(eltype(T)) != 0
-	   # This is already after lower_convention
-	   seen_roots += 1
-	   if false
-	       inline_root = true
-	   end
-	end
+
+        if inline_roots_type(eltype(T)) != 0
+            # This is already after lower_convention
+            seen_roots += 1
+            if false
+                inline_root = true
+            end
+        end
 
         if T <: Const
             push!(args_activity, API.DFT_CONSTANT)
-	    if inline_root
-               push!(args_activity, API.DFT_CONSTANT)
-	    end
+            if inline_root
+                push!(args_activity, API.DFT_CONSTANT)
+            end
         elseif T <: Active
             if isboxed
-	    	@assert !inline_root
+                @assert !inline_root
                 push!(args_activity, API.DFT_DUP_ARG)
             else
                 push!(args_activity, API.DFT_OUT_DIFF)
-	        if inline_root
-                   push!(args_activity, API.DFT_CONSTANT)
-	        end
+                if inline_root
+                    push!(args_activity, API.DFT_CONSTANT)
+                end
             end
         elseif T <: Duplicated ||
                T <: BatchDuplicated ||
@@ -2533,14 +2533,14 @@ function enzyme!(
                T <: MixedDuplicated ||
                T <: BatchMixedDuplicated
             push!(args_activity, API.DFT_DUP_ARG)
-	    if inline_root
-               push!(args_activity, API.DFT_DUP_ARG)
-	    end
+            if inline_root
+                push!(args_activity, API.DFT_DUP_ARG)
+            end
         elseif T <: DuplicatedNoNeed || T <: BatchDuplicatedNoNeed
             push!(args_activity, API.DFT_DUP_NONEED)
-	    if inline_root
-               push!(args_activity, API.DFT_DUP_ARG)
-	    end
+            if inline_root
+                push!(args_activity, API.DFT_DUP_ARG)
+            end
         else
             error("illegal annotation type $T")
         end
@@ -2553,16 +2553,16 @@ function enzyme!(
         push!(args_typeInfo, typeTree)
         push!(uncacheable_args, modifiedBetween[i])
         push!(args_known_values, API.IntList())
-	if inline_root
-           typeTree = typetree(Any, ctx, dl, seen)
-           push!(args_typeInfo, typeTree)
-           push!(uncacheable_args, modifiedBetween[i])
-           push!(args_known_values, API.IntList())
-	end
+        if inline_root
+            typeTree = typetree(Any, ctx, dl, seen)
+            push!(args_typeInfo, typeTree)
+            push!(uncacheable_args, modifiedBetween[i])
+            push!(args_known_values, API.IntList())
+        end
     end
     if length(uncacheable_args) != length(collect(parameters(primalf)))
                 msg = sprint() do io
-		    println(io, "length(uncacheable_args) != length(collect(parameters(primalf))) ")
+            println(io, "length(uncacheable_args) != length(collect(parameters(primalf))) ")
 		    println(io, "TT=", TT)
                     println(io, "modifiedBetween=", modifiedBetween)
 		    println(io, "uncacheable_args=", uncacheable_args)
@@ -2930,10 +2930,10 @@ function create_abi_wrapper(
         isboxed = GPUCompiler.deserves_argbox(source_typ)
         llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)
         push!(T_wrapperargs, llvmT)
-	arg_roots = inline_roots_type(source_typ)
-	if arg_rooting && arg_roots != 0
-	   push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
-	end
+        arg_roots = inline_roots_type(source_typ)
+        if arg_rooting && arg_roots != 0
+            push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
+        end
 
         if T <: Const || T <: BatchDuplicatedFunc
             if is_adjoint && i != 1
@@ -2952,19 +2952,19 @@ function create_abi_wrapper(
             end
         elseif T <: Duplicated || T <: DuplicatedNoNeed || T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
             push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmT)))
-	    arg_roots = inline_roots_type(source_typ)
-	    if arg_rooting && arg_roots != 0
-	       push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
-	    end
+            arg_roots = inline_roots_type(source_typ)
+            if arg_rooting && arg_roots != 0
+                push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+            end
             if is_adjoint && i != 1
                 push!(ActiveRetTypes, Nothing)
             end
         elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
             push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
-	    arg_roots = inline_roots_type(source_typ)
-	    if arg_rooting && arg_roots != 0
-	       push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
-	    end
+            arg_roots = inline_roots_type(source_typ)
+            if arg_rooting && arg_roots != 0
+                push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+            end
             if is_adjoint && i != 1
                 push!(ActiveRetTypes, Nothing)
             end
@@ -3016,10 +3016,10 @@ function create_abi_wrapper(
                 ),
             )
             push!(T_wrapperargs, dretTy)
-	    arg_roots = inline_roots_type(actualRetType)
-	    if arg_rooting && arg_roots != 0
-	       push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
-	    end
+            arg_roots = inline_roots_type(actualRetType)
+            if arg_rooting && arg_roots != 0
+                push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+            end
         end
     end
 
@@ -3151,10 +3151,10 @@ function create_abi_wrapper(
             tape = LLVM.LLVMType(tape)
             jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed = true)
             push!(T_wrapperargs, jltape)
-	    arg_roots = inline_roots_type(tape)
-	    if arg_rooting && arg_roots != 0
-	       push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
-	    end
+            arg_roots = inline_roots_type(tape)
+            if arg_rooting && arg_roots != 0
+                push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
+            end
         else
             needs_tape = false
         end
@@ -3217,16 +3217,16 @@ function create_abi_wrapper(
 
         convty = convert(LLVMType, T′; allow_boxed = true)
 
-	arg_roots = inline_roots_type(T′)
+        arg_roots = inline_roots_type(T′)
 
         if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
             @assert Base.isconcretetype(T′)
             al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter")
-	    parm = params[i]
-	    if arg_rooting && arg_roots != 0
-		parm = recombine_value!(builder, parm, params[i+1])
-		i += 1
-	    end
+            parm = params[i]
+            if arg_rooting && arg_roots != 0
+                parm = recombine_value!(builder, parm, params[i + 1])
+                i += 1
+            end
             al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
             store!(builder, parm, al)
             emit_writebarrier!(builder, get_julia_inner_types(builder, al0, parm))
@@ -3238,14 +3238,14 @@ function create_abi_wrapper(
 
         i += 1
         if T <: Const
-	    if arg_rooting && arg_roots != 0
-		 push(realparms, params[i])
-		 i += 1
-	    end
+            if arg_rooting && arg_roots != 0
+                push(realparms, params[i])
+                i += 1
+            end
         elseif T <: Active
             isboxed = GPUCompiler.deserves_argbox(T′)
             if isboxed
-		@assert arg_roots == 0
+                @assert arg_roots == 0
                 if is_split
                     msg = sprint() do io
                         println(
@@ -3282,55 +3282,55 @@ function create_abi_wrapper(
                     0,
                 )                                            #=align=#
             end
-	    if arg_rooting &&arg_roots != 0
-		 push(realparms, params[i])
-		 i += 1
-	    end
+            if arg_rooting &&arg_roots != 0
+                push(realparms, params[i])
+                i += 1
+            end
             activeNum += 1
         elseif T <: Duplicated || T <: DuplicatedNoNeed || T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
-	    # Enzyme expects, arg, darg, root, droot
-	    # Julia expects   arg, root, darg, droot
-	    # We already pushed arg
-	    # now params[i] refers to root
-	    isboxed = (T <: BatchDuplicated || T <: BatchDuplicatedNoNeed) && GPUCompiler.deserves_argbox(NTuple{width,T′})
-	    darg = nothing
-	    root = nothing
-	    droot = nothing
-	    if arg_rooting &&arg_roots != 0
-		 root = params[i]
-		 darg = params[i+1]
-		 droot = params[i+2]
-		 i += 3
-	    else
-		 darg = params[i]
-		 i += 1
-	    end
-
-	    if isboxed
-	        darg = load!(builder, convert(LLVMType, NTuple{width,T′}), darg)
-	    end
-	    push!(realparms, darg)
-	    if arg_roots != 0
-		push!(realparms, root)
-		push!(realparms, droot)
-	    end
+            # Enzyme expects, arg, darg, root, droot
+            # Julia expects   arg, root, darg, droot
+            # We already pushed arg
+            # now params[i] refers to root
+            isboxed = (T <: BatchDuplicated || T <: BatchDuplicatedNoNeed) && GPUCompiler.deserves_argbox(NTuple{width, T′})
+            darg = nothing
+            root = nothing
+            droot = nothing
+            if arg_rooting &&arg_roots != 0
+                root = params[i]
+                darg = params[i + 1]
+                droot = params[i + 2]
+                i += 3
+            else
+                darg = params[i]
+                i += 1
+            end
+
+            if isboxed
+                darg = load!(builder, convert(LLVMType, NTuple{width, T′}), darg)
+            end
+            push!(realparms, darg)
+            if arg_roots != 0
+                push!(realparms, root)
+                push!(realparms, droot)
+            end
         elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
-	    # Enzyme expects, arg, [w x darg], root, droot
-	    # Julia expects   arg, root, darg, droot
-	    # We already pushed arg
-	    # now params[i] referrs to root
-	    darg = nothing
-	    root = nothing
-	    droot = nothing
-	    if arg_rooting && arg_roots != 0
-		 root = params[i]
-		 darg = params[i+1]
-		 droot = params[i+2]
-		 i += 3
-	    else
-		 darg = params[i]
-		 i += 1
-	    end
+            # Enzyme expects, arg, [w x darg], root, droot
+            # Julia expects   arg, root, darg, droot
+            # We already pushed arg
+            # now params[i] referrs to root
+            darg = nothing
+            root = nothing
+            droot = nothing
+            if arg_rooting && arg_roots != 0
+                root = params[i]
+                darg = params[i + 1]
+                droot = params[i + 2]
+                i += 3
+            else
+                darg = params[i]
+                i += 1
+            end
 
             if T <: BatchMixedDuplicated
                 @assert Base.isconcretetype(T′)
@@ -3362,16 +3362,16 @@ function create_abi_wrapper(
             end
 
             push!(realparms, ival)
-	    
-	    if arg_rooting && arg_roots != 0
-		push!(realparms, root)
-		push!(realparms, droot)
-	    end
+
+            if arg_rooting && arg_roots != 0
+                push!(realparms, root)
+                push!(realparms, droot)
+            end
         elseif T <: BatchDuplicatedFunc
-	    # TODO handle this
-	    if arg_rooting
-		 @assert arg_roots == 0
-	    end
+            # TODO handle this
+            if arg_rooting
+                @assert arg_roots == 0
+            end
             Func = get_func(T)
             funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
             llvmf = nested_codegen!(Mode, mod, funcspec, world)
@@ -3709,7 +3709,7 @@ function create_abi_wrapper(
     end
 
     if returnRoots
-       move_sret_tofrom_roots!(builder, jltype, sret, root_ty, rootRet, SRetPointerToRootPointer)
+        move_sret_tofrom_roots!(builder, jltype, sret, root_ty, rootRet, SRetPointerToRootPointer)
     end
     if T_ret != T_void
         ret!(builder, load!(builder, T_ret, sret))
@@ -3772,125 +3772,128 @@ function fixup_metadata!(f::LLVM.Function)
     end
 end
 
-@enum(SRetRootMovement,
+@enum(
+    SRetRootMovement,
     SRetPointerToRootPointer = 0,
     SRetValueToRootPointer = 1,
     RootPointerToSRetValue = 2,
     RootPointerToSRetPointer = 3
-   )
+)
 
 function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::LLVM.Value, direction::SRetRootMovement)
-        count = 0
-        todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
-	    Cuint[],
+    count = 0
+    todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+        (
+            Cuint[],
             jltype,
-        )]
-	function to_llvm(lst::Vector{Cuint})
-	    vals = LLVM.Value[]
-	    push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
-	    for i in lst
-	       push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
-	    end
-	    return vals
-	end
+        ),
+    ]
+    function to_llvm(lst::Vector{Cuint})
+        vals = LLVM.Value[]
+        push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
+        for i in lst
+            push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
+        end
+        return vals
+    end
 
-	extracted = LLVM.Value[]
-
-	val = sret
-	# TODO check that we perform this in the same order that extraction happens within julia
-	# aka bfs/etc
-        while length(todo) != 0
-            path, ty = popfirst!(todo)
-            if isa(ty, LLVM.PointerType)
-		if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
-                  loc = inbounds_gep!(
-                      builder,
-                      root_ty,
-                      rootRet,
-		      to_llvm(Cuint[count]),
-		     )
-		end
-                
-		if direction == SRetPointerToRootPointer
-		    outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
-		    outloc = load!(builder, ty, outloc)
-                    store!(builder, outloc, loc)
-		elseif direction == SRetValueToRootPointer
-		    outloc = Enzyme.API.e_extract_value!(builder, sret, path)
-                    store!(builder, outloc, loc)
-		elseif direction == RootPointerToSRetValue
-		    loc = load!(builder, ty, loc)
-		    sret = Enzyme.API.e_insert_value!(builder, sret, loc, path)
-		elseif direction == RootPointerToSRetPointer
-		    outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
-		    loc = load!(builder, ty, loc)
-		    push!(extracted, loc)
-                    store!(builder, loc, outloc)
-		else
-		    @assert false "Unhandled direction"
-		end
-                
-		count += 1
-                continue
+    extracted = LLVM.Value[]
+
+    val = sret
+    # TODO check that we perform this in the same order that extraction happens within julia
+    # aka bfs/etc
+    while length(todo) != 0
+        path, ty = popfirst!(todo)
+        if isa(ty, LLVM.PointerType)
+            if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
+                loc = inbounds_gep!(
+                    builder,
+                    root_ty,
+                    rootRet,
+                    to_llvm(Cuint[count]),
+                )
             end
-            if isa(ty, LLVM.ArrayType)
-                if any_jltypes(ty)
-                    for i = 1:length(ty)
-                        npath = copy(path)
-			push!(npath, i - 1)
-                        push!(todo, (npath, eltype(ty)))
-                    end
+
+            if direction == SRetPointerToRootPointer
+                outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+                outloc = load!(builder, ty, outloc)
+                store!(builder, outloc, loc)
+            elseif direction == SRetValueToRootPointer
+                outloc = Enzyme.API.e_extract_value!(builder, sret, path)
+                store!(builder, outloc, loc)
+            elseif direction == RootPointerToSRetValue
+                loc = load!(builder, ty, loc)
+                sret = Enzyme.API.e_insert_value!(builder, sret, loc, path)
+            elseif direction == RootPointerToSRetPointer
+                outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+                loc = load!(builder, ty, loc)
+                push!(extracted, loc)
+                store!(builder, loc, outloc)
+            else
+                @assert false "Unhandled direction"
+            end
+
+            count += 1
+            continue
+        end
+        if isa(ty, LLVM.ArrayType)
+            if any_jltypes(ty)
+                for i in 1:length(ty)
+                    npath = copy(path)
+                    push!(npath, i - 1)
+                    push!(todo, (npath, eltype(ty)))
                 end
-                continue
             end
-            if isa(ty, LLVM.VectorType)
-                if any_jltypes(ty)
-                    for i = 1:size(ty)
-                        npath = copy(path)
-			push!(npath, i - 1)
-                        push!(todo, (npath, eltype(ty)))
-                    end
+            continue
+        end
+        if isa(ty, LLVM.VectorType)
+            if any_jltypes(ty)
+                for i in 1:size(ty)
+                    npath = copy(path)
+                    push!(npath, i - 1)
+                    push!(todo, (npath, eltype(ty)))
                 end
-                continue
             end
-            if isa(ty, LLVM.StructType)
-                for (i, t) in enumerate(LLVM.elements(ty))
-                    if any_jltypes(t)
-                        npath = copy(path)
-			push!(npath, i - 1)
-                        push!(todo, (npath, t))
-                    end
+            continue
+        end
+        if isa(ty, LLVM.StructType)
+            for (i, t) in enumerate(LLVM.elements(ty))
+                if any_jltypes(t)
+                    npath = copy(path)
+                    push!(npath, i - 1)
+                    push!(todo, (npath, t))
                 end
-                continue
             end
+            continue
         end
+    end
 
-	if direction == RootPointerToSRetPointer	        
-	    obj = get_base_and_offset(sret)[1]
-	    @assert length(extracted) > 0
-	    emit_writebarrier!(builder, LLVM.Value[obj, extracted...])
-	end
-        tracked = CountTrackedPointers(jltype)
-        @assert count == tracked.count
-	return val
+    if direction == RootPointerToSRetPointer
+        obj = get_base_and_offset(sret)[1]
+        @assert length(extracted) > 0
+        emit_writebarrier!(builder, LLVM.Value[obj, extracted...])
+    end
+    tracked = CountTrackedPointers(jltype)
+    @assert count == tracked.count
+    return val
 end
 
 function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
-   jltype = value_type(sret)
-   tracked = CountTrackedPointers(jltype)
-   @assert tracked.count > 0
-   @assert !tracked.all
-   root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
-   move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
+    jltype = value_type(sret)
+    tracked = CountTrackedPointers(jltype)
+    @assert tracked.count > 0
+    @assert !tracked.all
+    root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+    return move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
 end
 
 function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
-   jltype = value_type(sret)
-   tracked = CountTrackedPointers(jltype)
-   @assert tracked.count > 0
-   @assert !tracked.all
-   root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
-   move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
+    jltype = value_type(sret)
+    tracked = CountTrackedPointers(jltype)
+    @assert tracked.count > 0
+    @assert !tracked.all
+    root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+    return move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
 end
 
 
@@ -3975,54 +3978,54 @@ function lower_convention(
     removedRoots = Set{Int}()
 
     function is_mixed(idx::Int)
-	if TT === nothing
-	   return false
-	end
-	if idx > length(TT.parameters)
-	   throw(AssertionError("TT=$TT, args=$args idx=$idx"))
-	end
-	return (
-                   TT.parameters[idx] <: MixedDuplicated ||
-                   TT.parameters[idx] <: BatchMixedDuplicated
-               ) &&
-               run_enzyme
+        if TT === nothing
+            return false
+        end
+        if idx > length(TT.parameters)
+            throw(AssertionError("TT=$TT, args=$args idx=$idx"))
+        end
+        return (
+            TT.parameters[idx] <: MixedDuplicated ||
+                TT.parameters[idx] <: BatchMixedDuplicated
+        ) &&
+            run_enzyme
     end
 
     for arg in args
         typ = arg.codegen.typ
-	
-	if arg.rooted_typ !== nothing
-
-	   # There cannot exist a root arg if the original arg was boxed
-	   @assert !GPUCompiler.deserves_argbox(arg.rooted_typ)
-	   
-	   # There only can exist a rooting if the original argument was a bits_ref
-	   @assert arg.rooted_cc == GPUCompiler.BITS_REF
-	   
-	   # If the original arg exists and was lowered to be a bits_ref, we will destroy
-	   # the extra rooted arg and recombine with the bits_ref
-	   if (arg.arg_i - 1) in loweredArgs
-	        push!(removedRoots, arg.arg_i)
-		continue
-	   end
-	   
-	   # If we are raising an argument to mixed, we will still destroy the extra rooted
-	   # arg and recombine with the bits ref
-	   if (arg.arg_i - 1) in boxedArgs
-		@assert is_mixed(arg.arg_jl_i)
-	        push!(removedRoots, arg.arg_i)
-		continue
-	   end
-
-	   @assert false "Unhandled rooted arg condition"
-	end
 
-	if GPUCompiler.deserves_argbox(arg.typ)
+        if arg.rooted_typ !== nothing
+
+            # There cannot exist a root arg if the original arg was boxed
+            @assert !GPUCompiler.deserves_argbox(arg.rooted_typ)
+
+            # There only can exist a rooting if the original argument was a bits_ref
+            @assert arg.rooted_cc == GPUCompiler.BITS_REF
+
+            # If the original arg exists and was lowered to be a bits_ref, we will destroy
+            # the extra rooted arg and recombine with the bits_ref
+            if (arg.arg_i - 1) in loweredArgs
+                push!(removedRoots, arg.arg_i)
+                continue
+            end
+
+            # If we are raising an argument to mixed, we will still destroy the extra rooted
+            # arg and recombine with the bits ref
+            if (arg.arg_i - 1) in boxedArgs
+                @assert is_mixed(arg.arg_jl_i)
+                push!(removedRoots, arg.arg_i)
+                continue
+            end
+
+            @assert false "Unhandled rooted arg condition"
+        end
+
+        if GPUCompiler.deserves_argbox(arg.typ)
             push!(boxedArgs, arg.arg_i)
             push!(wrapper_types, typ)
             push!(wrapper_attrs, LLVM.Attribute[])
         elseif arg.cc != GPUCompiler.BITS_REF
-	    if is_mixed(arg.arg_jl_i)
+            if is_mixed(arg.arg_jl_i)
                 push!(boxedArgs, arg.arg_i)
                 push!(raisedArgs, arg.arg_i)
                 push!(wrapper_types, LLVM.PointerType(typ, Derived))
@@ -4033,7 +4036,7 @@ function lower_convention(
             end
         else
             # bits ref, and not boxed
-	    if is_mixed(arg.arg_jl_i)
+            if is_mixed(arg.arg_jl_i)
                 push!(boxedArgs, arg.arg_i)
                 push!(wrapper_types, typ)
                 push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
@@ -4106,22 +4109,22 @@ function lower_convention(
             end
             for arg in args
                 parm = ops[arg.codegen.i]
-		if arg.arg_i in removedRoots
-		    if arg.rooted_arg_i in loweredArgs
-		        nops[end] = recombine_value!(builder, nops[end], parm)
-		    elseif arg.rooted_arg_i in raisedArgs
-			jltype = convert(LLVMType, arg.rooted_typ)
-			tracked = CountTrackedPointers(jltype)
-			@assert tracked.count > 0
-			@assert !tracked.all
-			root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
-			move_sret_tofrom_roots!(builder, jltype, nops[end], root_ty, parm, RootPointerToSRetPointer)
-		    else
-			@assert false
-		    end
-		elseif (arg.arg_i) in removedRoots && (arg.rooted_arg_i in loweredArgs || arg)
-		    continue
-		elseif arg.arg_i in loweredArgs
+                if arg.arg_i in removedRoots
+                    if arg.rooted_arg_i in loweredArgs
+                        nops[end] = recombine_value!(builder, nops[end], parm)
+                    elseif arg.rooted_arg_i in raisedArgs
+                        jltype = convert(LLVMType, arg.rooted_typ)
+                        tracked = CountTrackedPointers(jltype)
+                        @assert tracked.count > 0
+                        @assert !tracked.all
+                        root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+                        move_sret_tofrom_roots!(builder, jltype, nops[end], root_ty, parm, RootPointerToSRetPointer)
+                    else
+                        @assert false
+                    end
+                elseif (arg.arg_i) in removedRoots && (arg.rooted_arg_i in loweredArgs || arg)
+                    continue
+                elseif arg.arg_i in loweredArgs
                     push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
                 elseif arg.arg_i in raisedArgs
                     obj = emit_allocobj!(builder, arg.typ, "raisedArg")
@@ -4131,10 +4134,10 @@ function lower_convention(
                         LLVM.PointerType(value_type(parm), addrspace(value_type(obj))),
                     )
                     store!(builder, parm, bc)
-		    if !(arg.arg_i in removedRoots)
+                    if !(arg.arg_i in removedRoots)
                         emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm))
-		    end
-		    addr = addrspacecast!(
+                    end
+                    addr = addrspacecast!(
                         builder,
                         bc,
                         LLVM.PointerType(value_type(parm), Derived),
@@ -4178,7 +4181,7 @@ function lower_convention(
         wrapper_args = Vector{LLVM.Value}()
 
         sretPtr = nothing
-	retRootPtr = nothing
+        retRootPtr = nothing
         dl = string(LLVM.datalayout(LLVM.parent(entry_f)))
         if sret
             if !in(0, parmsRemoved)
@@ -4213,39 +4216,39 @@ function lower_convention(
         end
 
         # perform argument conversions
-	wrapper_idx = 1
+        wrapper_idx = 1
         for arg in args
             parm = parameters(entry_f)[arg.codegen.i]
-	    if arg.arg_i in removedRoots
-	    	wrapparm = parameters(wrapper_f)[wrapper_idx - 1]
-		root_ty = convert(LLVMType, arg.typ)
-		ptr = alloca!(builder, root_ty, LLVM.name(parm)*".innerparm")
+            if arg.arg_i in removedRoots
+                wrapparm = parameters(wrapper_f)[wrapper_idx - 1]
+                root_ty = convert(LLVMType, arg.typ)
+                ptr = alloca!(builder, root_ty, LLVM.name(parm) * ".innerparm")
                 if TT !== nothing && TT.parameters[arg.arg_jl_i] <: Const
                     metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
                 end
-                
+
                 ctx = LLVM.context(entry_f)
-		typeTree = copy(typetree(arg.typ, ctx, dl, seen))
+                typeTree = copy(typetree(arg.typ, ctx, dl, seen))
                 merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
                 only!(typeTree, -1)
                 metadata(ptr)["enzyme_type"] = to_md(typeTree, ctx)
-	
-		if arg.arg_i-1 in loweredArgs
-		   extract_roots_from_value!(builder, wrapparm, ptr)
-		else
-	           @assert (arg.arg_i - 1) in boxedArgs
-		   @assert is_mixed(arg.arg_jl_i) 
-		   jltype = convert(LLVMType, arg.rooted_typ)
-		   move_sret_tofrom_roots!(builder, jltype, wrapparm, root_ty, ptr, SRetPointerToRootPointer)
-	        end
+
+                if arg.arg_i - 1 in loweredArgs
+                    extract_roots_from_value!(builder, wrapparm, ptr)
+                else
+                    @assert (arg.arg_i - 1) in boxedArgs
+                    @assert is_mixed(arg.arg_jl_i)
+                    jltype = convert(LLVMType, arg.rooted_typ)
+                    move_sret_tofrom_roots!(builder, jltype, wrapparm, root_ty, ptr, SRetPointerToRootPointer)
+                end
 
                 push!(wrapper_args, ptr)
-		continue
-	    end
+                continue
+            end
 
-	    wrapparm = parameters(wrapper_f)[wrapper_idx]
-	    wrapper_idx += 1
-	    if arg.arg_i in loweredArgs
+            wrapparm = parameters(wrapper_f)[wrapper_idx]
+            wrapper_idx += 1
+            if arg.arg_i in loweredArgs
                 # copy the argument value to a stack slot, and reference it.
                 ty = value_type(parm)
                 if !isa(ty, LLVM.PointerType)
@@ -4285,7 +4288,7 @@ function lower_convention(
                     ),
                 )
                 push!(
-		    parameter_attributes(wrapper_f, wrapper_idx - 1),
+                    parameter_attributes(wrapper_f, wrapper_idx - 1),
                     StringAttribute(
                         "enzymejl_parmtype",
                         string(convert(UInt, unsafe_to_pointer(arg.typ))),
@@ -4306,7 +4309,7 @@ function lower_convention(
                 merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
                 only!(typeTree, -1)
                 push!(
-		    parameter_attributes(wrapper_f, wrapper_idx - 1),
+                    parameter_attributes(wrapper_f, wrapper_idx - 1),
                     StringAttribute(
                         "enzyme_type",
                         string(typeTree),
@@ -4330,7 +4333,7 @@ function lower_convention(
                 push!(wrapper_args, wrapparm)
                 for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
                     push!(
-			  parameter_attributes(wrapper_f, wrapper_idx - 1),
+                        parameter_attributes(wrapper_f, wrapper_idx - 1),
                         attr,
                     )
                 end
@@ -4467,13 +4470,13 @@ function lower_convention(
                         string(UInt(GPUCompiler.BITS_REF)),
                     ),
                 )
-		res = load!(builder, RT, sretPtr)
-		@static if VERSION >= v"1.12"
-            	   if returnRoots
-		     res = recombine_value!(builder, res, retRootPtr)
-		   end
-		end
-		ret!(builder, res)
+                res = load!(builder, RT, sretPtr)
+                @static if VERSION >= v"1.12"
+                    if returnRoots
+                        res = recombine_value!(builder, res, retRootPtr)
+                    end
+                end
+                ret!(builder, res)
             end
         elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
             ret!(builder)
@@ -4627,25 +4630,25 @@ function lower_convention(
             println(io, string(wrapper_f))
             println(
                 io,
-		"TT=$TT\n",
+                "TT=$TT\n",
                 "parmsRemoved=",
                 parmsRemoved,
                 "\nretRemoved=",
                 retRemoved,
                 "\nprargs=",
                 prargs,
-		"\nreturnRoots=",
-		returnRoots,
-		"\nboxedArgs=",
-		boxedArgs,
-		"\nloweredArgs=",
-		loweredArgs,
-		"\nraisedArgs=",
-		raisedArgs,
-		"\nremovedRoots=",
-		removedRoots,
-		"\nloweredReturn=",
-		loweredReturn
+                "\nreturnRoots=",
+                returnRoots,
+                "\nboxedArgs=",
+                boxedArgs,
+                "\nloweredArgs=",
+                loweredArgs,
+                "\nraisedArgs=",
+                raisedArgs,
+                "\nremovedRoots=",
+                removedRoots,
+                "\nloweredReturn=",
+                loweredReturn
             )
             println(io, "Broken lower convention")
         end
@@ -4736,8 +4739,8 @@ function lower_convention(
     LLVM.@dispose pb = NewPMPassBuilder() begin
         add!(pb, NewPMModulePassManager()) do mpm
             # Kill the temporary staging function
-	    add!(mpm, GlobalDCEPass())
-	    add!(mpm, GlobalOptPass())
+            add!(mpm, GlobalDCEPass())
+            add!(mpm, GlobalOptPass())
         end
         LLVM.run!(pb, mod)
     end
@@ -4963,7 +4966,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
             end
         end
         GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded"
-	run!(GlobalOptPass(), mod)
+        run!(GlobalOptPass(), mod)
     end
 
     custom, state = set_module_types!(interp, mod, primalf, job, edges, params.run_enzyme, mode)
@@ -5267,16 +5270,16 @@ end
                         )
                     )
 
-			 size = Compiler.datatype_layoutsize(jTy)
+                        size = Compiler.datatype_layoutsize(jTy)
                         if offset < size && isa(sz, LLVM.ConstantInt) && size - offset >= convert(Int, sz)
                             lim = convert(Int, sz)
                             md = to_fullmd(jTy, offset, lim)
                             @assert byref == GPUCompiler.BITS_REF ||
                                     byref == GPUCompiler.MUT_REF
                             metadata(inst)["enzyme_truetype"] = md
-			elseif byref == GPUCompiler.BITS_VALUE && jTy <: Ptr && eltype(jTy) == Any
-			    # Todo generalize this
-			    md = to_fullmd(jTy, 0, sizeof(Ptr{Cvoid}))
+                        elseif byref == GPUCompiler.BITS_VALUE && jTy <: Ptr && eltype(jTy) == Any
+                            # Todo generalize this
+                            md = to_fullmd(jTy, 0, sizeof(Ptr{Cvoid}))
                             metadata(inst)["enzyme_truetype"] = md
                         end
                     end
@@ -5394,9 +5397,9 @@ end
                                nm == "ijl_new_array" ||
                                nm == "jl_new_array" ||
                                nm == "jl_alloc_genericmemory" ||
-                               nm == "ijl_alloc_genericmemory" ||
-			       nm == "jl_alloc_genericmemory_unchecked" ||
-			       nm == "ijl_alloc_genericmemory_unchecked"
+                                    nm == "ijl_alloc_genericmemory" ||
+                                    nm == "jl_alloc_genericmemory_unchecked" ||
+                                    nm == "ijl_alloc_genericmemory_unchecked"
                                 continue
                             end
                             if is_readonly(called)
@@ -5470,7 +5473,7 @@ end
             expectedTapeType,
             loweredArgs,
             boxedArgs,
-	    removedRoots,
+            removedRoots,
         )
         toremove = String[]
         # Inline the wrapper
@@ -5871,7 +5874,7 @@ const DumpLLVMCall = Ref(false)
             error("Return type `$rrt` not marked Const, but is ghost or const type.")
         end
 
-	needs_rooting = false
+    needs_rooting = false
 
         sret_types = Type[]  # Julia types of all returned variables
         # By ref values we create and need to preserve
@@ -6139,14 +6142,14 @@ const DumpLLVMCall = Ref(false)
         end
 
         # calls fptr
-	llvmtys = LLVMType[]
-	for x in types
-	   push!(llvmtys, convert(LLVMType, x; allow_boxed = true))
-	   arg_roots = inline_roots_type(x)
-	   if needs_rooting && arg_roots != 0
-	       push!(llvmtys, convert(LLVMType, AnyArray(3)))
-	   end
-	end
+        llvmtys = LLVMType[]
+        for x in types
+            push!(llvmtys, convert(LLVMType, x; allow_boxed = true))
+            arg_roots = inline_roots_type(x)
+            if needs_rooting && arg_roots != 0
+                push!(llvmtys, convert(LLVMType, AnyArray(3)))
+            end
+        end
 
         T_void = convert(LLVMType, Nothing)
 
@@ -6209,11 +6212,11 @@ const DumpLLVMCall = Ref(false)
             tape = callparams[end]
             if TapeType <: EnzymeTapeToLoad
                 llty = Compiler.from_tape_type(eltype(TapeType))
-	        
-		arg_roots = inline_roots_type(llty)
-	        if needs_rooting && arg_roots != 0
-		   throw(AssertionError("Should check about rooted tape calling conv"))
-	        end
+
+                arg_roots = inline_roots_type(llty)
+                if needs_rooting && arg_roots != 0
+                    throw(AssertionError("Should check about rooted tape calling conv"))
+                end
 
                 tape = bitcast!(
                     builder,
@@ -6226,13 +6229,13 @@ const DumpLLVMCall = Ref(false)
 
             else
                 llty = Compiler.from_tape_type(TapeType)
-	        arg_roots = inline_roots_type(llty)
-	        if needs_rooting && arg_roots != 0
-		   tape = callparams[end-1]
-	        end
-		if value_type(tape) != llty
-		   throw(AssertionError("MisMatched Tape type, expected $(string(value_type(tape))) found $(string(llty)) from $TapeType arg_roots=$arg_roots"))
-		end
+                arg_roots = inline_roots_type(llty)
+                if needs_rooting && arg_roots != 0
+                    tape = callparams[end - 1]
+                end
+                if value_type(tape) != llty
+                    throw(AssertionError("MisMatched Tape type, expected $(string(value_type(tape))) found $(string(llty)) from $TapeType arg_roots=$arg_roots"))
+                end
             end
         end
 
@@ -6275,9 +6278,9 @@ const DumpLLVMCall = Ref(false)
         end
         reinsert_gcmarker!(llvm_f)
 
-	if DumpLLVMCall[]
-	   API.EnzymeDumpModuleRef(mod.ref)
-	end
+        if DumpLLVMCall[]
+            API.EnzymeDumpModuleRef(mod.ref)
+        end
 
         ir = string(mod)
         fn = LLVM.name(llvm_f)
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 13518d4a..744a2d10 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -158,16 +158,16 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
         run!(pb, mod, tm)
     end
     end # middle_optimize!
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
 
     middle_optimize!()
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
-    
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
+
     middle_optimize!(true)
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
 
     # Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to
     # known functions
@@ -185,20 +185,20 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
         end
         run!(pb, mod, tm)
     end
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
-    
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
+
     removeDeadArgs!(mod, tm)
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
 
     detect_writeonly!(mod)
-    
-    run!(GCInvariantVerifierPass(strong=false), mod)
-    
+
+    run!(GCInvariantVerifierPass(strong = false), mod)
+
     nodecayed_phis!(mod)
-                
-    run!(GCInvariantVerifierPass(strong=false), mod)
+
+    return run!(GCInvariantVerifierPass(strong = false), mod)
 end
 
 function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)
diff --git a/src/errors.jl b/src/errors.jl
index 81a39046..22f553cf 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -184,11 +184,11 @@ function Base.showerror(io::IO, ece::CallingConventionMismatchError)
     println(io)
 
 
-    if VERBOSE_ERRORS[]
+    return if VERBOSE_ERRORS[]
         if ece.backtrace isa Cstring
-	   Base.println(io, Base.unsafe_string(ece.backtrace))
+            Base.println(io, Base.unsafe_string(ece.backtrace))
         else
-	   Base.println(io, ece.backtrace)
+            Base.println(io, ece.backtrace)
         end
     else
         print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
@@ -1252,41 +1252,41 @@ else
 		    end
                 end
 end
-                
-		if isa(cur, LLVM.LoadInst)
-                    larg, off = get_base_and_offset(operands(cur)[1])
-		    if off == 0 && isa(larg, LLVM.AllocaInst)
-			 legal = true
-			 for u in LLVM.uses(larg)
-			    u = LLVM.user(u)
-			    if isa(u, LLVM.LoadInst)
-				continue
-			    end
-			    if isa(u, LLVM.CallInst) && isa(called_operand(u), LLVM.Function)
-			       intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(u))
-			       if intr == LLVM.Intrinsic("llvm.lifetime.start").id || intr == LLVM.Intrinsic("llvm.lifetime.end").id || LLVM.name(called_operand(u)) == "llvm.enzyme.lifetime_end" || LLVM.name(called_operand(u)) ==
- "llvm.enzyme.lifetime_start"
-				    continue
-			       end
-			    end
-			    if isa(u, LLVM.StoreInst)
-				 v = operands(u)[1]
-				 if v == larg
-				    legal = false;
-				    break
-				 end
-				 if v isa ConstantInt && convert(Int, v) == -1
-				    continue
-				 end
-			    end
-			    legal = false
-			    break
-			 end
-			 if legal
-			    return make_batched(ncur, prevbb)
-			 end
-		    end
-		end
+
+            if isa(cur, LLVM.LoadInst)
+                larg, off = get_base_and_offset(operands(cur)[1])
+                if off == 0 && isa(larg, LLVM.AllocaInst)
+                    legal = true
+                    for u in LLVM.uses(larg)
+                        u = LLVM.user(u)
+                        if isa(u, LLVM.LoadInst)
+                            continue
+                        end
+                        if isa(u, LLVM.CallInst) && isa(called_operand(u), LLVM.Function)
+                            intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(u))
+                            if intr == LLVM.Intrinsic("llvm.lifetime.start").id || intr == LLVM.Intrinsic("llvm.lifetime.end").id || LLVM.name(called_operand(u)) == "llvm.enzyme.lifetime_end" || LLVM.name(called_operand(u)) ==
+                                    "llvm.enzyme.lifetime_start"
+                                continue
+                            end
+                        end
+                        if isa(u, LLVM.StoreInst)
+                            v = operands(u)[1]
+                            if v == larg
+                                legal = false
+                                break
+                            end
+                            if v isa ConstantInt && convert(Int, v) == -1
+                                continue
+                            end
+                        end
+                        legal = false
+                        break
+                    end
+                    if legal
+                        return make_batched(ncur, prevbb)
+                    end
+                end
+            end
 
             legal, TT, byref = abs_typeof(cur, true)
 
diff --git a/src/gradientutils.jl b/src/gradientutils.jl
index dec9f7d5..8a9d4fe5 100644
--- a/src/gradientutils.jl
+++ b/src/gradientutils.jl
@@ -317,7 +317,7 @@ function batch_call_same_with_inverted_arg_if_active!(
     args::Vector{<:LLVM.Value},
     valTys::Vector{API.CValueType},
     lookup::Bool;
-    need_result = true,
+        need_result = true,
     kwargs...
 )
 
@@ -340,7 +340,7 @@ function batch_call_same_with_inverted_arg_if_active!(
                 end
             end
         end
-        res = call_same_with_inverted_arg_if_active!(B, gutils, orig, args2, valTys, lookup; need_result, kwargs..., movebefore=idx == 1)
+        res = call_same_with_inverted_arg_if_active!(B, gutils, orig, args2, valTys, lookup; need_result, kwargs..., movebefore = idx == 1)
         if shadow === nothing
             continue
         end
diff --git a/src/jlrt.jl b/src/jlrt.jl
index b806ac50..ebf12393 100644
--- a/src/jlrt.jl
+++ b/src/jlrt.jl
@@ -889,9 +889,9 @@ end
 function emit_type_layout_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
 	legal, JTy = absint(ty)
 	if legal
-	    @assert JTy isa Type
-	    res = Compiler.datatype_layoutsize(JTy)
-	    return LLVM.ConstantInt(res)
+        @assert JTy isa Type
+        res = Compiler.datatype_layoutsize(JTy)
+        return LLVM.ConstantInt(res)
 	end
 
 	ty = emit_layout_of_type!(B, ty)
@@ -957,20 +957,20 @@ function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
         end
 
         if nm in (
-            "jl_alloc_genericmemory",
-            "ijl_alloc_genericmemory",
-        )
-                res = operands(array)[2]
+                "jl_alloc_genericmemory",
+                "ijl_alloc_genericmemory",
+            )
+            res = operands(array)[2]
                 return res
         end
         if nm in (
-	     "jl_alloc_genericmemory_unchecked",
-	     "ijl_alloc_genericmemory_unchecked",
-	    )
-	        # This is number of bytes not number of elements
-		res = get_memory_size(B, array)
-		es = get_memory_elsz(B, array)
-		return udiv!(B, res, es)
+                "jl_alloc_genericmemory_unchecked",
+                "ijl_alloc_genericmemory_unchecked",
+            )
+            # This is number of bytes not number of elements
+            res = get_memory_size(B, array)
+            es = get_memory_elsz(B, array)
+            return udiv!(B, res, es)
         end
     end
     ST = get_memory_struct()
@@ -992,23 +992,23 @@ end
 
 # nel - number of elements
 #
-@static if VERSION >= v"1.11" 
-function get_memory_nbytes(B::LLVM.IRBuilder, memty::Type{<:Memory}, nel::LLVM.Value)
-    elsz = LLVM.ConstantInt(Compiler.datatype_layoutsize(memty))
-    isboxed = Base.datatype_arrayelem(memty) == 1
-    isunion = Base.datatype_arrayelem(memty) == 2
-
-    if isboxed
-        elsz = LLVM.ConstantInt(sizeof(Ptr{Cvoid}))
-    end
-    nbytes = LLVM.mul!(B, nel, elsz)
+@static if VERSION >= v"1.11"
+    function get_memory_nbytes(B::LLVM.IRBuilder, memty::Type{<:Memory}, nel::LLVM.Value)
+        elsz = LLVM.ConstantInt(Compiler.datatype_layoutsize(memty))
+        isboxed = Base.datatype_arrayelem(memty) == 1
+        isunion = Base.datatype_arrayelem(memty) == 2
+
+        if isboxed
+            elsz = LLVM.ConstantInt(sizeof(Ptr{Cvoid}))
+        end
+        nbytes = LLVM.mul!(B, nel, elsz)
 
-    if isunion
-        # an extra byte for each isbits union memory element, stored at m->ptr + m->length
-	nbytes = LLVM.add!(B, nbytes, nel)
+        if isunion
+            # an extra byte for each isbits union memory element, stored at m->ptr + m->length
+            nbytes = LLVM.add!(B, nbytes, nel)
+        end
+        return nbytes
     end
-    return nbytes
-end
 end
 
 function get_memory_nbytes(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
@@ -1019,12 +1019,12 @@ function get_memory_nbytes(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
             nm = LLVM.name(fn)
         end
         if nm in (
-	     "jl_alloc_genericmemory_unchecked",
-	     "ijl_alloc_genericmemory_unchecked",
-	    )
-	        # This is number of bytes not number of elements
-                res = operands(array)[2]
-		return res
+                "jl_alloc_genericmemory_unchecked",
+                "ijl_alloc_genericmemory_unchecked",
+            )
+            # This is number of bytes not number of elements
+            res = operands(array)[2]
+            return res
         end
     end
     nel = get_memory_len(B, array)
diff --git a/src/llvm/attributes.jl b/src/llvm/attributes.jl
index c384a2b6..b0eab84b 100644
--- a/src/llvm/attributes.jl
+++ b/src/llvm/attributes.jl
@@ -747,14 +747,14 @@ function annotate!(mod::LLVM.Module)
         "ijl_gc_alloc_typed",
         "jl_alloc_genericmemory",
         "ijl_alloc_genericmemory",
-	"jl_alloc_genericmemory_unchecked",
-	"ijl_alloc_genericmemory_unchecked",
+            "jl_alloc_genericmemory_unchecked",
+            "ijl_alloc_genericmemory_unchecked",
         "jl_alloc_array_1d",
         "jl_alloc_array_2d",
         "jl_alloc_array_3d",
         "ijl_alloc_array_1d",
         "ijl_alloc_array_2d",
-        "ijl_alloc_array_3d",
+            "ijl_alloc_array_3d",
         "ijl_new_array",
         "jl_new_array"
     )
@@ -808,8 +808,8 @@ function annotate!(mod::LLVM.Module)
         "ijl_box_int64",
         "jl_alloc_genericmemory",
         "ijl_alloc_genericmemory",
-	"jl_alloc_genericmemory_unchecked",
-	"ijl_alloc_genericmemory_unchecked",
+            "jl_alloc_genericmemory_unchecked",
+            "ijl_alloc_genericmemory_unchecked",
         "jl_alloc_array_1d",
         "jl_alloc_array_2d",
         "jl_alloc_array_3d",
@@ -821,7 +821,7 @@ function annotate!(mod::LLVM.Module)
         "jl_genericmemory_slice",
         "ijl_genericmemory_slice",
         "jl_genericmemory_copy_slice",
-        "ijl_genericmemory_copy_slice",
+            "ijl_genericmemory_copy_slice",
         "jl_idtable_rehash",
         "ijl_idtable_rehash",
         "jl_f_tuple",
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 5bd0124a..bf7416b5 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2623,9 +2623,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
                 add!(fpm, AllocOptPass())
                 add!(fpm, SROAPass())
             end
-	    if RunAttributor[]
+            if RunAttributor[]
                 add!(mpm, EnzymeAttributorPass())
-	    end
+            end
             add!(mpm, NewPMFunctionPassManager()) do fpm
                 add!(fpm, EarlyCSEPass())
             end
diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl
index b0094dc9..7a87f26b 100644
--- a/src/rules/activityrules.jl
+++ b/src/rules/activityrules.jl
@@ -79,7 +79,7 @@ function julia_activity_rule(f::LLVM.Function, method_table)
             typ, _ = enzyme_extract_parm_type(f, arg.codegen.i)
             @assert typ == arg.typ
 
-	    if (kwarg_inactive && arg.arg_i == 2) || guaranteed_const_nongen(arg.typ, world) || (arg.rooted_typ !== nothing && guaranteed_const_nongen(arg.rooted_typ, world))
+            if (kwarg_inactive && arg.arg_i == 2) || guaranteed_const_nongen(arg.typ, world) || (arg.rooted_typ !== nothing && guaranteed_const_nongen(arg.rooted_typ, world))
                 push!(
                     parameter_attributes(f, arg.codegen.i),
                     StringAttribute("enzyme_inactive"),
diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl
index c2220f99..3b6e2955 100644
--- a/src/rules/allocrules.jl
+++ b/src/rules/allocrules.jl
@@ -20,26 +20,26 @@ function array_shadow_handler(
             ),
         )
     end
-    
+
 
     b = LLVM.IRBuilder(B)
     orig = LLVM.Value(OrigCI)::LLVM.CallInst
 
     nm = LLVM.name(LLVM.called_operand(orig)::LLVM.Function)
-    
+
     if iszeroinit(typ)
-	# If already zero init we should not need to perform the initial memset.
-	# However as I have yet to actually see such a type exist in the wild, I want to see
-	# what triggers it.
+        # If already zero init we should not need to perform the initial memset.
+        # However as I have yet to actually see such a type exist in the wild, I want to see
+        # what triggers it.
         throw(
             AssertionError(
-	        "THERE IS A TYPE WHICH IS ZERO INIT ($typ)",
+                "THERE IS A TYPE WHICH IS ZERO INIT ($typ)",
             ),
         )
-	# Only the regular, checked version does the zero.
-    	if nm == "jl_alloc_genericmemory" || nm == "ijl_alloc_genericmemory"
-	   return C_NULL
-	end
+        # Only the regular, checked version does the zero.
+        if nm == "jl_alloc_genericmemory" || nm == "ijl_alloc_genericmemory"
+            return C_NULL
+        end
     end
 
     typ = eltype(typ)
@@ -54,7 +54,7 @@ function array_shadow_handler(
     end
 
     anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, false) #=lookup=#
-    
+
     isunboxed, elsz, al = Base.uniontype_layout(typ)
 
     isunion = typ isa Union
@@ -73,17 +73,17 @@ function array_shadow_handler(
         get_memory_nbytes(b, anti)
     else
         arlen = get_array_len(b, anti)
-    	tot = LLVM.mul!(b, arlen, LLVM.ConstantInt(LLVM.value_type(arlen), elsz, false))
-    
-	if elsz == 1 && !isunion
-	   # extra byte for all julia allocated byte arrays
-	   tot = LLVM.add!(b, tot, LLVM.ConstantInt(LLVM.value_type(tot), 1, false))
-	end
-	if isunion
-	    # an extra byte for each isbits union array element, stored after a->maxsize
-	    tot = LLVM.add!(b, tot, prod)
-	end
-	tot
+        tot = LLVM.mul!(b, arlen, LLVM.ConstantInt(LLVM.value_type(arlen), elsz, false))
+
+        if elsz == 1 && !isunion
+            # extra byte for all julia allocated byte arrays
+            tot = LLVM.add!(b, tot, LLVM.ConstantInt(LLVM.value_type(tot), 1, false))
+        end
+        if isunion
+            # an extra byte for each isbits union array element, stored after a->maxsize
+            tot = LLVM.add!(b, tot, prod)
+        end
+        tot
     end
 
 
@@ -114,7 +114,7 @@ end
          "jl_alloc_array_3d", "ijl_alloc_array_3d",
          "jl_new_array", "ijl_new_array",
          "jl_alloc_genericmemory", "ijl_alloc_genericmemory",
-	 "jl_alloc_genericmemory_unchecked", "ijl_alloc_genericmemory_unchecked"
+            "jl_alloc_genericmemory_unchecked", "ijl_alloc_genericmemory_unchecked",
         ),
         @cfunction(
             array_shadow_handler,
diff --git a/src/typeutils/jltypes.jl b/src/typeutils/jltypes.jl
index 72146f2e..3026ea65 100644
--- a/src/typeutils/jltypes.jl
+++ b/src/typeutils/jltypes.jl
@@ -2,14 +2,14 @@
 iszeroinit(Base.@nospecialize t) = (Base.@_total_meta; isa(t, DataType) && (t.flags & 0x0004) == 0x0004)
 
 @static if VERSION >= v"1.11"
-const datatype_layoutsize = Base.datatype_layoutsize
+    const datatype_layoutsize = Base.datatype_layoutsize
 else
-function datatype_layoutsize(dt::Base.DataType)
-    Base.@_foldable_meta
-    dt.layout == C_NULL && throw(Base.UndefRefError())
-    size = unsafe_load(convert(Ptr{Base.DataTypeLayout}, dt.layout)).size
-    return size % Int
-end
+    function datatype_layoutsize(dt::Base.DataType)
+        Base.@_foldable_meta
+        dt.layout == C_NULL && throw(Base.UndefRefError())
+        size = unsafe_load(convert(Ptr{Base.DataTypeLayout}, dt.layout)).size
+        return size % Int
+    end
 end
 
 # On 1.12+, there was a change to the calling convention where
@@ -17,38 +17,38 @@ end
 # return the number of roots in the corresponding convention, or
 # 0 if it does not apply https://github.com/JuliaLang/julia/pull/55767/files#diff-62cfb2606c6a323a7f26a3eddfa0bf2b819fa33e094561fee09daeb328e3a1e7
 function inline_roots_type(@nospecialize(LT::LLVM.LLVMType))::Int
-   @static if VERSION <= v"1.12-"
-	return 0
-   else
-	   if !(LT isa LLVM.ArrayType || LT isa LLVM.StructType)
-		return 0
-	   end
-	   tracked = CountTrackedPointers(LT)
-	   if tracked.count > 0 && !tracked.all
-		return tracked.count
-	   end
-	   return 0
-   end
+    @static if VERSION <= v"1.12-"
+        return 0
+    else
+        if !(LT isa LLVM.ArrayType || LT isa LLVM.StructType)
+            return 0
+        end
+        tracked = CountTrackedPointers(LT)
+        if tracked.count > 0 && !tracked.all
+            return tracked.count
+        end
+        return 0
+    end
 end
 
 function inline_roots_type(@nospecialize(T::Type))::Int
-   @static if VERSION <= v"1.12-"
-	return 0
-   else
-	   if T === Union{}
-		return 0
-	   end
-	   if GPUCompiler.deserves_argbox(T)
-		return 0
-	   end
-	   if Base.isabstracttype(T)
-		return 0
-	   end
-	   if isghostty(T) || Core.Compiler.isconstType(T)
-		return 0
-	   end
-	   LT = convert(LLVM.LLVMType, T)
-	   return inline_roots_type(LT)
+    @static if VERSION <= v"1.12-"
+        return 0
+    else
+        if T === Union{}
+            return 0
+        end
+        if GPUCompiler.deserves_argbox(T)
+            return 0
+        end
+        if Base.isabstracttype(T)
+            return 0
+        end
+        if isghostty(T) || Core.Compiler.isconstType(T)
+            return 0
+        end
+        LT = convert(LLVM.LLVMType, T)
+        return inline_roots_type(LT)
     end
 end
 
@@ -56,35 +56,34 @@ end
 # with the AnyArray's as requisite for the new roots for the calling convention
 # on 1.12
 function rooted_argument_list(iterable)
-	results = Tuple{Type, Union{Nothing, Type}}[]
-	for T in iterable
-	    roots = inline_roots_type(T)
-	    push!(results, (T, nothing))
-	    if roots != 0
-	        push!(results, (AnyArray(roots), T))
-	    end
-	end
-	return results
+    results = Tuple{Type, Union{Nothing, Type}}[]
+    for T in iterable
+        roots = inline_roots_type(T)
+        push!(results, (T, nothing))
+        if roots != 0
+            push!(results, (AnyArray(roots), T))
+        end
+    end
+    return results
 end
 
 function split_value_into(B::LLVM.IRBuilder, val::LLVM.Value)
-   LT = value_type(val)
-   tracked = CountTrackedPointers(LT)
-   @assert tracked.count > 0
-   @assert !tracked.all
-   RT = convert(LLVM.LLVMType, AnyArray(tracked.count))
-   al = alloca!(B, RT)
-   fdsafdsa 
-   return (val, al)
+    LT = value_type(val)
+    tracked = CountTrackedPointers(LT)
+    @assert tracked.count > 0
+    @assert !tracked.all
+    RT = convert(LLVM.LLVMType, AnyArray(tracked.count))
+    al = alloca!(B, RT)
+    fdsafdsa
+    return (val, al)
 end
 
 function recombine_value(B::LLVM.IRBuilder, val::LLVM.Value, roots::LLVM.Value)
-	TODO
+    TODO
     return val
 end
 
 
-
 struct RemovedParam end
 
 # Modified from GPUCompiler classify_arguments
@@ -123,30 +122,36 @@ function classify_arguments(
     last_cc = nothing
     arg_jl_i = 1
     for (source_i, (source_typ, rooted_typ)) in enumerate(rooted_argument_list(source_sig.parameters))
-	if rooted_typ !== nothing
-	   arg_jl_i -= 1
-	end
+        if rooted_typ !== nothing
+            arg_jl_i -= 1
+        end
         if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
-            push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i,
-			rooted_typ = rooted_typ,
-			rooted_arg_i = rooted_typ === nothing ? nothing : (source_i - 1),
-		        rooted_cc = rooted_typ === nothing ? nothing : last_cc,
-			arg_jl_i = arg_jl_i,
-		...*[Comment body truncated]*

@github-actions
Copy link
Contributor

github-actions bot commented Nov 14, 2025

Benchmark Results

main e872b7e... main / e872b7e...
basics/make_zero/namedtuple 0.0533 ± 0.0019 μs 0.0531 ± 0.003 μs 1 ± 0.066
basics/make_zero/struct 0.254 ± 0.0061 μs 0.255 ± 0.0063 μs 0.996 ± 0.034
basics/overhead 4.34 ± 0.05 ns 4.64 ± 0.011 ns 0.935 ± 0.011
basics/remake_zero!/namedtuple 0.241 ± 0.0075 μs 0.242 ± 0.0096 μs 0.996 ± 0.05
basics/remake_zero!/struct 0.234 ± 0.0086 μs 0.234 ± 0.009 μs 1 ± 0.053
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.24 μs 10.4 ± 1.7 μs 0.986 ± 0.16
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.25 μs 12.3 ± 0.28 μs 0.996 ± 0.031
time_to_load 1.26 ± 0.0063 s 1.27 ± 0.0075 s 0.997 ± 0.0078

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19386158803/artifacts/4576044989.

@giordano giordano added the Julia v1.12 Related to compatibility with Julia v1.12 label Nov 14, 2025
@codecov
Copy link

codecov bot commented Nov 15, 2025

Codecov Report

❌ Patch coverage is 51.69492% with 228 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.52%. Comparing base (b910049) to head (e872b7e).

Files with missing lines Patch % Lines
src/compiler.jl 49.67% 153 Missing ⚠️
src/typeutils/jltypes.jl 48.52% 35 Missing ⚠️
src/errors.jl 10.34% 26 Missing ⚠️
src/jlrt.jl 78.78% 7 Missing ⚠️
src/rules/allocrules.jl 70.58% 5 Missing ⚠️
src/absint.jl 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2782      +/-   ##
==========================================
- Coverage   69.01%   68.52%   -0.50%     
==========================================
  Files          58       58              
  Lines       19996    20335     +339     
==========================================
+ Hits        13800    13934     +134     
- Misses       6196     6401     +205     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Julia v1.12 Related to compatibility with Julia v1.12

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants