Skip to content

1.12 the adventure continues#2764

Merged
wsmoses merged 26 commits intomainfrom
ac
Nov 12, 2025
Merged

1.12 the adventure continues#2764
wsmoses merged 26 commits intomainfrom
ac

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Nov 10, 2025

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 10, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/Enzyme.jl b/src/Enzyme.jl
index d93bc0b6..17ec9f1a 100644
--- a/src/Enzyme.jl
+++ b/src/Enzyme.jl
@@ -144,7 +144,7 @@ function Base.showerror(io::IO, ece::OpaquePointerError)
         Base.Experimental.show_error_hints(io, ece)
     end
     print(io, "OpaquePointerError: Enzyme execution failed to handle opaque pointers, with the following information:\n")
-    print(io, ece.msg, '\n')
+    return print(io, ece.msg, '\n')
 end
 
 
diff --git a/src/compiler.jl b/src/compiler.jl
index 432a6c5c..66d46360 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3937,7 +3937,7 @@ function lower_convention(
             if returnRoots && !in(1, parmsRemoved)
                 retRootPtr = alloca!(
                     builder,
-                    sret_ty(entry_f, 1+sret),
+                    sret_ty(entry_f, 1 + sret),
                     "innerreturnroots",
                 )
                 # retRootPtr = alloca!(builder, parameters(wrapper_f)[1])
diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl
index 9b120731..c4d8832b 100644
--- a/src/compiler/validation.jl
+++ b/src/compiler/validation.jl
@@ -616,9 +616,9 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
     method_table = Core.Compiler.method_table(interp)
     bt = backtrace(inst)
     dest = called_operand(inst)
-    
+
     if isa(dest, LLVM.PHIInst) && all(Base.Fix1(==, operands(dest)[1]), operands(dest))
-       dest = operands(dest)[1]
+        dest = operands(dest)[1]
     end
     if isa(dest, LLVM.ConstantExpr) && opcode(dest) == LLVM.API.LLVMIntToPtr && isa(operands(dest)[1], LLVM.ConstantExpr) && opcode(operands(dest)[1]) == LLVM.API.LLVMPtrToInt
        dest = operands(operands(dest)[1])[1] 
@@ -1148,7 +1148,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
 		    else
 			false, nothing
 		    end
-	
+
                     lfn = nothing
                     if found 
                         lfn = replaceWith
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index a41033c5..69799cf1 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -753,11 +753,11 @@ function nodecayed_phis!(mod::LLVM.Module)
                                     v2 = operands(v)[1]
                                     if addrspace(value_type(v2)) == 0
                                         if addr == 13 && isa(v, LLVM.ConstantExpr)
-					    PT = if LLVM.is_opaque(value_type(v))
-						LLVM.PointerType(10)
-					    else
-						LLVM.PointerType(eltype(value_type(v)), 10)
-					    end
+                                            PT = if LLVM.is_opaque(value_type(v))
+                                                LLVM.PointerType(10)
+                                            else
+                                                LLVM.PointerType(eltype(value_type(v)), 10)
+                                            end
                                             v2 = const_addrspacecast(
                                                 operands(v)[1],
                                                 PT
@@ -917,12 +917,12 @@ function nodecayed_phis!(mod::LLVM.Module)
                                 undeforpoison |= isa(v, LLVM.PoisonValue)
                             end
                             if undeforpoison
-				PT = if LLVM.is_opaque(value_type(v))
-				   LLVM.PointerType(10)
-				else
-				   LLVM.PointerType(eltype(value_type(v)), 10)
-				end
-				return LLVM.UndefValue(PT), offset, addr == 13
+                                PT = if LLVM.is_opaque(value_type(v))
+                                    LLVM.PointerType(10)
+                                else
+                                    LLVM.PointerType(eltype(value_type(v)), 10)
+                                end
+                                return LLVM.UndefValue(PT), offset, addr == 13
                             end
 
                             if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v)
@@ -1241,7 +1241,7 @@ function fix_decayaddr!(mod::LLVM.Module)
                 mayread = false
                 maywrite = false
                 sret = true
-		sret_elty = nothing
+                sret_elty = nothing
                 sretkind = kind(if LLVM.version().major >= 12
                     TypeAttribute("sret", LLVM.Int32Type())
                 else
@@ -1255,11 +1255,11 @@ function fix_decayaddr!(mod::LLVM.Module)
                         t_sret = false
                         for a in collect(parameter_attributes(fop, i))
                             if kind(a) == sretkind
-				sret_elty = sret_ty(fop, i)
+                                sret_elty = sret_ty(fop, i)
                                 t_sret = true
                             end
                             if kind(a) == kind(StringAttribute("enzyme_sret"))
-				sret_elty = sret_ty(fop, i)
+                                sret_elty = sret_ty(fop, i)
                                 t_sret = true
                             end
                             # if kind(a) == kind(StringAttribute("enzyme_sret_v"))
@@ -1300,7 +1300,7 @@ function fix_decayaddr!(mod::LLVM.Module)
                     throw(AssertionError(msg))
                 end
 
-		@assert sret_elty !== nothing
+                @assert sret_elty !== nothing
                 if temp === nothing
                     nb = IRBuilder()
                     position!(nb, first(instructions(first(blocks(f)))))
@@ -1447,11 +1447,11 @@ function prop_global!(g::LLVM.GlobalVariable)
                     end
                 end
             end
-	    if value_type(var) != value_type(res)
-		al = alloca!(B, value_type(res))
-		store!(B, res, al)
-		res = load!(B, value_type(var), al)
-	    end
+            if value_type(var) != value_type(res)
+                al = alloca!(B, value_type(res))
+                store!(B, res, al)
+                res = load!(B, value_type(var), al)
+            end
             replace_uses!(var, res)
             eraseInst(LLVM.parent(var), var)
             continue
@@ -1667,10 +1667,10 @@ function propagate_returned!(mod::LLVM.Module)
                 changed = true
             end
             has_user = false
-	    for u in LLVM.uses(fn)
-		has_user = true
-		break
-	    end
+            for u in LLVM.uses(fn)
+                has_user = true
+                break
+            end
             attrs = collect(function_attributes(fn))
             prevent = any(
                 kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for
@@ -1681,8 +1681,8 @@ function propagate_returned!(mod::LLVM.Module)
             # end
             argn = nothing
             toremove = Int64[]
-	    # Don't bother with functions we're about to delete anyways
-	    if has_user
+            # Don't bother with functions we're about to delete anyways
+            if has_user
             for (i, arg) in enumerate(parameters(fn))
                 if any(
                     kind(attr) == kind(EnumAttribute("returned")) for
@@ -1726,7 +1726,7 @@ function propagate_returned!(mod::LLVM.Module)
                         if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue)
                             illegalUse = true
                             break
-                        end
+                            end
                         seenfn = false
                         todo = LLVM.Instruction[]
                         if isa(ops[i], LLVM.AllocaInst)
@@ -1793,20 +1793,20 @@ function propagate_returned!(mod::LLVM.Module)
 
                         position!(B, first(instructions(first(blocks(fn)))))
 
-                        has_use = false
-                        for _ in LLVM.uses(arg)
-                            has_use = true
-                            break
-                        end
+                            has_use = false
+                            for _ in LLVM.uses(arg)
+                                has_use = true
+                                break
+                            end
 
-                        if has_use
-                            argeltype = sret_ty(fn, i)
-                            al = alloca!(B, argeltype)
-                            if value_type(al) != value_type(arg)
-                                al = addrspacecast!(B, al, value_type(arg))
+                            if has_use
+                                argeltype = sret_ty(fn, i)
+                                al = alloca!(B, argeltype)
+                                if value_type(al) != value_type(arg)
+                                    al = addrspacecast!(B, al, value_type(arg))
+                                end
+                                LLVM.replace_uses!(arg, al)
                             end
-                            LLVM.replace_uses!(arg, al)
-                        end
                     end
                 end
 
@@ -1907,7 +1907,7 @@ function propagate_returned!(mod::LLVM.Module)
 			end
 		end
             end
-	    end
+            end
             illegalUse = !(
                 linkage(fn) == LLVM.API.LLVMInternalLinkage ||
                 linkage(fn) == LLVM.API.LLVMPrivateLinkage
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index d568fa2a..8a9541a0 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -301,7 +301,7 @@ function enzyme_custom_setup_args(
                         LLVM.ConstantInt(LLVM.IntType(32), 0),
                     ],
                 )
-                
+
                 if !is_opaque(value_type(ptr))
                     @assert eltype(value_type(ptr)) == arty
                 end
@@ -1452,7 +1452,7 @@ function enzyme_custom_common_rev(
     end
 
     if sret !== nothing
-        sty = sret_ty(llvmf, 1+swiftself)
+        sty = sret_ty(llvmf, 1 + swiftself)
         if LLVM.version().major >= 12
             attr = TypeAttribute("sret", sty)
         else
diff --git a/src/typeutils/conversion.jl b/src/typeutils/conversion.jl
index 1dbcb893..b090667a 100644
--- a/src/typeutils/conversion.jl
+++ b/src/typeutils/conversion.jl
@@ -26,7 +26,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool}
         if 10 <= addrspace <= 12
             return Any, true
         elseif LLVM.is_opaque(LLVM.PointerType(Type))
-            return Core.LLVMPtr{Cvoid,Int(addrspace)}, false
+            return Core.LLVMPtr{Cvoid, Int(addrspace)}, false
         else
             e = LLVM.API.LLVMGetElementType(Type)
             tkind2 = LLVM.API.LLVMGetTypeKind(e)
diff --git a/src/utils.jl b/src/utils.jl
index 137d7460..206fafd9 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -503,11 +503,13 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
 
     vt = LLVM.value_type(LLVM.parameters(fn)[idx])
 
-    sretkind = LLVM.kind(if LLVM.version().major >= 12
-        LLVM.TypeAttribute("sret", LLVM.Int32Type())
-    else
-        LLVM.EnumAttribute("sret")
-    end)
+    sretkind = LLVM.kind(
+        if LLVM.version().major >= 12
+            LLVM.TypeAttribute("sret", LLVM.Int32Type())
+        else
+            LLVM.EnumAttribute("sret")
+        end
+    )
 
 
     enzymejl_parmtype_ref = nothing
@@ -537,7 +539,7 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
 
         if ekind == "enzymejl_returnRoots"
             nroots = parse(Int, LLVM.value(attr))
-    
+
             T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
             T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
 
@@ -549,13 +551,13 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
         end
 
         if ekind == "enzyme_sret"
-	    ety = parse(UInt, LLVM.value(attr))
-	    ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
-	    ety = LLVM.LLVMType(ety)
+            ety = parse(UInt, LLVM.value(attr))
+            ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
+            ety = LLVM.LLVMType(ety)
             if !LLVM.is_opaque(vt)
-		@assert ety == eltype(vt)
+                @assert ety == eltype(vt)
             end
-        
+
             return ety
         end
 
diff --git a/test/rules/internal_rules.jl b/test/rules/internal_rules.jl
index 93c2fc53..51d7ec0f 100644
--- a/test/rules/internal_rules.jl
+++ b/test/rules/internal_rules.jl
@@ -108,7 +108,7 @@ end
     res = Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0)))
     @test res[1][1] ≈ 375.0
     @test res[1][2] ≈ 750.0
-    
+
     @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
         ((var"1" = 25.0, var"2" = 50.0),)
     @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) ==

@giordano giordano added the Julia v1.12 Related to compatibility with Julia v1.12 label Nov 10, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Nov 10, 2025

Benchmark Results

main 7b851c3... main / 7b851c3...
basics/make_zero/namedtuple 0.0549 ± 0.004 μs 0.0566 ± 0.0049 μs 0.969 ± 0.11
basics/make_zero/struct 0.248 ± 0.0064 μs 0.26 ± 0.012 μs 0.955 ± 0.049
basics/overhead 4.88 ± 0.01 ns 4.89 ± 0.01 ns 0.998 ± 0.0029
basics/remake_zero!/namedtuple 0.231 ± 0.011 μs 0.231 ± 0.011 μs 1 ± 0.068
basics/remake_zero!/struct 0.232 ± 0.01 μs 0.233 ± 0.013 μs 0.993 ± 0.07
fold_broadcast/multidim_sum_bcast/1D 10.6 ± 0.28 μs 10.6 ± 0.46 μs 0.994 ± 0.051
fold_broadcast/multidim_sum_bcast/2D 12.9 ± 0.34 μs 12.9 ± 0.35 μs 1 ± 0.038
time_to_load 1.28 ± 0.012 s 1.3 ± 0.0044 s 0.986 ± 0.0096

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19282144045/artifacts/4537588880.

@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

❌ Patch coverage is 80.13245% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.95%. Comparing base (6b30dda) to head (7b851c3).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/utils.jl 79.16% 10 Missing ⚠️
src/llvm/transforms.jl 79.06% 9 Missing ⚠️
src/Enzyme.jl 0.00% 5 Missing ⚠️
src/rules/customrules.jl 86.36% 3 Missing ⚠️
src/compiler.jl 96.55% 1 Missing ⚠️
src/compiler/validation.jl 50.00% 1 Missing ⚠️
src/typeutils/conversion.jl 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2764      +/-   ##
==========================================
+ Coverage   68.91%   68.95%   +0.04%     
==========================================
  Files          58       58              
  Lines       19861    19961     +100     
==========================================
+ Hits        13688    13765      +77     
- Misses       6173     6196      +23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wsmoses wsmoses merged commit f7ec291 into main Nov 12, 2025
51 of 54 checks passed
@wsmoses wsmoses deleted the ac branch November 12, 2025 02:04
This was referenced Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Julia v1.12 Related to compatibility with Julia v1.12

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants