Skip to content

1.12: Fix custom rule return rooting#2804

Merged
wsmoses merged 2 commits intomainfrom
crulerooting
Nov 24, 2025
Merged

1.12: Fix custom rule return rooting#2804
wsmoses merged 2 commits intomainfrom
crulerooting

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Nov 24, 2025

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 24, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/compiler.jl b/src/compiler.jl
index 9a4e1296..d14269e6 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3845,18 +3845,18 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
         		if direction == SRetPointerToRootPointer
         		    outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
         		    outloc = load!(builder, ty, outloc)
-			    if must_cache
-		                API.SetMustCache!(outloc)
-			    end
+                if must_cache
+                    API.SetMustCache!(outloc)
+                end
                             store!(builder, outloc, loc)
         		elseif direction == SRetValueToRootPointer
         		    outloc = Enzyme.API.e_extract_value!(builder, sret, path)
                             store!(builder, outloc, loc)
         		elseif direction == RootPointerToSRetValue
         		    loc = load!(builder, ty, loc)
-			    if must_cache
-		                API.SetMustCache!(loc)
-			    end
+                if must_cache
+                    API.SetMustCache!(loc)
+                end
         		    val = Enzyme.API.e_insert_value!(builder, val, loc, path)
                 elseif direction == NullifySRetValue
                     loc = unsafe_to_llvm(builder, nothing)
@@ -3920,13 +3920,13 @@ function nullify_rooted_values!(builder::LLVM.IRBuilder, sret::LLVM.Value)
    move_sret_tofrom_roots!(builder, jltype, sret, root_ty, nothing, NullifySRetValue)
 end
 
-function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value; must_cache::Bool=false)::LLVM.Value
+function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value; must_cache::Bool = false)::LLVM.Value
    jltype = value_type(sret)
    tracked = CountTrackedPointers(jltype)
    @assert tracked.count > 0
    @assert !tracked.all "Not tracked.all, jltype ($(string(jltype)))"
    root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
-   move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue; must_cache)
+    return move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue; must_cache)
 end
 
 function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
@@ -3963,8 +3963,8 @@ function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::
             end
 
             if isa(ty, LLVM.FloatingPointType)
-		dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
-		srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
+            dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
+            srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
                 val = load!(builder, ty, srcloc)
                 st = store!(builder, val, dstloc)
                 continue
@@ -4003,63 +4003,65 @@ end
 
 function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
     count = 0
-    todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
-	    Cuint[],
-        jltype,
-    )]
-	function to_llvm(lst::Vector{Cuint})
-	    vals = LLVM.Value[]
-	    push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
-	    for i in lst
-	       push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
-	    end
-	    return vals
-	end
+    todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+        (
+            Cuint[],
+            jltype,
+        ),
+    ]
+    function to_llvm(lst::Vector{Cuint})
+        vals = LLVM.Value[]
+        push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
+        for i in lst
+            push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
+        end
+        return vals
+    end
 
-	extracted = LLVM.Value[]
+    extracted = LLVM.Value[]
 
     while length(todo) != 0
-            path, ty = popfirst!(todo)
+        path, ty = popfirst!(todo)
 
-            if isa(ty, LLVM.PointerType)
-                if any_jltypes(ty)
-			continue
-		end
+        if isa(ty, LLVM.PointerType)
+            if any_jltypes(ty)
+                continue
             end
+        end
 
-            if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
-                for i = 1:length(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.ArrayType) && any_jltypes(ty)
+            for i in 1:length(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.VectorType) && any_jltypes(ty)
-                for i = 1:size(ty)
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, eltype(ty)))
-                end
-                continue
+        if isa(ty, LLVM.VectorType) && any_jltypes(ty)
+            for i in 1:size(ty)
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, eltype(ty)))
             end
+            continue
+        end
 
-            if isa(ty, LLVM.StructType) && any_jltypes(ty)
-                for (i, t) in enumerate(LLVM.elements(ty))
-                    npath = copy(path)
-                    push!(npath, i - 1)
-                    push!(todo, (npath, t))
-                end
-                continue
+        if isa(ty, LLVM.StructType) && any_jltypes(ty)
+            for (i, t) in enumerate(LLVM.elements(ty))
+                npath = copy(path)
+                push!(npath, i - 1)
+                push!(todo, (npath, t))
             end
-		
-	    dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
-            val = Enzyme.API.e_extract_value!(builder, src, path)
-	    st = store!(builder, val, dstloc)
+            continue
         end
 
-	return nothing
+        dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
+        val = Enzyme.API.e_extract_value!(builder, src, path)
+        st = store!(builder, val, dstloc)
+    end
+
+    return nothing
 end
 
 
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index a4537e70..91ca676b 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -971,9 +971,9 @@ end
         end
         LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr)
         res = load!(B, sty, sret)
-	if returnRoots !== nothing && VERSION >= v"1.12"
-	   res = recombine_value!(B, res, returnRoots; must_cache=true)
-	end
+        if returnRoots !== nothing && VERSION >= v"1.12"
+            res = recombine_value!(B, res, returnRoots; must_cache = true)
+        end
     end
     if swiftself
         attr = EnumAttribute("swiftself")
@@ -1003,19 +1003,19 @@ end
     if RT <: Const
         if needsPrimal
             @assert RealRt == fwd_RT
-	    _, prim_sret, prim_roots = get_return_info(RealRt)
+            _, prim_sret, prim_roots = get_return_info(RealRt)
             if prim_sret !== nothing
                 val = new_from_original(gutils, operands(orig)[1])
-		
-		if prim_roots !== nothing && VERSION >= v"1.12"
+
+                if prim_roots !== nothing && VERSION >= v"1.12"
                     extract_nonjlvalues_into!(B, value_type(res), val, res)
 
                     rval = new_from_original(gutils, operands(orig)[2])
 
-		    extract_roots_from_value!(B, res, rval)
-		else
+                    extract_roots_from_value!(B, res, rval)
+                else
                     store!(B, res, val)
-		end
+                end
             else
                 normalV = res.ref
             end
@@ -1029,28 +1029,28 @@ end
                 ST = NTuple{Int(width),ST}
             end
             @assert ST == fwd_RT
-	    _, prim_sret, prim_roots = get_return_info(RealRt)
+            _, prim_sret, prim_roots = get_return_info(RealRt)
             if prim_sret !== nothing
                 dval_ptr = invert_pointer(gutils, operands(orig)[1], B)
-		
-		droots = if prim_roots !== nothing && VERSION >= v"1.12"
-		    @assert !is_constant_value(gutils, operands(orig)[2])
-		    invert_pointer(gutils, operands(orig)[2], B)
-	        end
-                
-		for idx = 1:width
+
+                droots = if prim_roots !== nothing && VERSION >= v"1.12"
+                    @assert !is_constant_value(gutils, operands(orig)[2])
+                    invert_pointer(gutils, operands(orig)[2], B)
+                end
+
+                for idx in 1:width
                     ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
                     pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1)
-			
-		    if prim_roots !== nothing && VERSION >= v"1.12"
-		        extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
 
-		        rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
+                    if prim_roots !== nothing && VERSION >= v"1.12"
+                        extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
 
-		        extract_roots_from_value!(B, ev, rval)
-		    else
+                        rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
+
+                        extract_roots_from_value!(B, ev, rval)
+                    else
                         store!(B, ev, pev)
-		    end
+                    end
                 end
             else
                 shadowV = res.ref
@@ -1062,42 +1062,42 @@ end
                 BatchDuplicated{RealRt,Int(width)}
             end
             @assert ST == fwd_RT
-	    
-	    _, prim_sret, prim_roots = get_return_info(RealRt)
+
+            _, prim_sret, prim_roots = get_return_info(RealRt)
             if prim_sret !== nothing
                 val = new_from_original(gutils, operands(orig)[1])
-                
-		res0 = extract_value!(B, res, 0)
-		if prim_roots !== nothing && VERSION >= v"1.12"
+
+                res0 = extract_value!(B, res, 0)
+                if prim_roots !== nothing && VERSION >= v"1.12"
                     extract_nonjlvalues_into!(B, value_type(res0), val, res0)
 
                     rval = new_from_original(gutils, operands(orig)[2])
 
-		    extract_roots_from_value!(B, res0, rval)
-		else
+                    extract_roots_from_value!(B, res0, rval)
+                else
                     store!(B, res0, val)
-		end
+                end
 
                 dval_ptr = invert_pointer(gutils, operands(orig)[1], B)
                 dval = extract_value!(B, res, 1)
-		
-		droots = if prim_roots !== nothing && VERSION >= v"1.12"
-		    @assert !is_constant_value(gutils, operands(orig)[2])
-		    invert_pointer(gutils, operands(orig)[2], B)
-	        end
-                
-		for idx = 1:width
+
+                droots = if prim_roots !== nothing && VERSION >= v"1.12"
+                    @assert !is_constant_value(gutils, operands(orig)[2])
+                    invert_pointer(gutils, operands(orig)[2], B)
+                end
+
+                for idx in 1:width
                     ev = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
                     pev = (width == 1) ? dval_ptr : extract_value!(B, dval_ptr, idx - 1)
-		    if prim_roots !== nothing && VERSION >= v"1.12"
-		        extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
+                    if prim_roots !== nothing && VERSION >= v"1.12"
+                        extract_nonjlvalues_into!(B, value_type(ev), pev, ev)
 
-		        rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
+                        rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
 
-		        extract_roots_from_value!(B, ev, rval)
-		    else
+                        extract_roots_from_value!(B, ev, rval)
+                    else
                         store!(B, ev, pev)
-		    end
+                    end
                 end
             else
                 normalV = extract_value!(B, res, 0).ref
@@ -1836,9 +1836,9 @@ function enzyme_custom_common_rev(
         )
         res = load!(B, sty, sret)
         API.SetMustCache!(res)
-	if returnRoots !== nothing && VERSION >= v"1.12"
-	   res = recombine_value!(B, res, returnRoots; must_cache=true)
-	end
+        if returnRoots !== nothing && VERSION >= v"1.12"
+            res = recombine_value!(B, res, returnRoots; must_cache = true)
+        end
     end
     if swiftself
         attr = EnumAttribute("swiftself")
@@ -1946,19 +1946,19 @@ function enzyme_custom_common_rev(
         if needsPrimal
             @assert !isghostty(RealRt)
             normalV = extract_value!(B, resV, idx)
-	    _, prim_sret, prim_roots = get_return_info(RealRt)
+            _, prim_sret, prim_roots = get_return_info(RealRt)
             if prim_sret !== nothing
                 val = new_from_original(gutils, operands(orig)[1])
-		
-		if prim_roots !== nothing && VERSION >= v"1.12"
+
+                if prim_roots !== nothing && VERSION >= v"1.12"
                     extract_nonjlvalues_into!(B, value_type(normalV), val, normalV)
 
                     rval = new_from_original(gutils, operands(orig)[2])
 
-		    extract_roots_from_value!(B, normalV, rval)
-		else
+                    extract_roots_from_value!(B, normalV, rval)
+                else
                     store!(B, normalV, val)
-		end
+                end
             else
                 @assert value_type(normalV) == value_type(orig)
                 normalV = normalV.ref
@@ -1969,30 +1969,30 @@ function enzyme_custom_common_rev(
             if needsShadowJL
                 @assert !isghostty(RealRt)
                 shadowV = extract_value!(B, resV, idx)
-	        _, prim_sret, prim_roots = get_return_info(RealRt)
+                _, prim_sret, prim_roots = get_return_info(RealRt)
                 if prim_sret !== nothing
                     dval = invert_pointer(gutils, operands(orig)[1], B)
 
-		    droots = if prim_roots !== nothing && VERSION >= v"1.12"
-			@assert !is_constant_value(gutils, operands(orig)[2])
-                    	invert_pointer(gutils, operands(orig)[2], B)
-		    end
+                    droots = if prim_roots !== nothing && VERSION >= v"1.12"
+                        @assert !is_constant_value(gutils, operands(orig)[2])
+                        invert_pointer(gutils, operands(orig)[2], B)
+                    end
 
-		    for idx = 1:width
+                    for idx in 1:width
                         to_store =
                             (width == 1) ? shadowV : extract_value!(B, shadowV, idx - 1)
 
                         store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx - 1)
 
-			if prim_roots !== nothing && VERSION >= v"1.12"
-			    extract_nonjlvalues_into!(B, value_type(to_store), store_ptr, to_store)
+                        if prim_roots !== nothing && VERSION >= v"1.12"
+                            extract_nonjlvalues_into!(B, value_type(to_store), store_ptr, to_store)
 
                             rval = (width == 1) ? droots : extract_value!(B, droots, idx - 1)
 
-			    extract_roots_from_value!(B, to_store, rval)
-			else
+                            extract_roots_from_value!(B, to_store, rval)
+                        else
                             store!(B, to_store, store_ptr)
-			end
+                        end
                     end
                     shadowV = C_NULL
                 else

@wsmoses wsmoses changed the title Crulerooting 1.12: Fix custom rule return rooting Nov 24, 2025
@codecov
Copy link

codecov bot commented Nov 24, 2025

Codecov Report

❌ Patch coverage is 26.85185% with 79 lines in your changes missing coverage. Please review.
✅ Project coverage is 67.81%. Comparing base (04781c8) to head (24cd291).
⚠️ Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
src/compiler.jl 7.84% 47 Missing ⚠️
src/rules/customrules.jl 43.85% 32 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2804      +/-   ##
==========================================
- Coverage   68.03%   67.81%   -0.22%     
==========================================
  Files          58       58              
  Lines       20628    20717      +89     
==========================================
+ Hits        14035    14050      +15     
- Misses       6593     6667      +74     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 24, 2025

Benchmark Results

main 24cd291... main / 24cd291...
basics/make_zero/namedtuple 0.0514 ± 0.003 μs 0.0514 ± 0.0021 μs 1 ± 0.07
basics/make_zero/struct 0.266 ± 0.0054 μs 0.256 ± 0.0055 μs 1.04 ± 0.031
basics/overhead 4.95 ± 0.011 ns 4.03 ± 0.001 ns 1.23 ± 0.0027
basics/remake_zero!/namedtuple 0.238 ± 0.007 μs 0.243 ± 0.0063 μs 0.979 ± 0.038
basics/remake_zero!/struct 0.239 ± 0.0089 μs 0.235 ± 0.0068 μs 1.02 ± 0.048
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.2 μs 10.4 ± 1.8 μs 0.992 ± 0.17
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.26 μs 12.2 ± 0.24 μs 0.997 ± 0.029
time_to_load 1.24 ± 0.0056 s 1.24 ± 0.0091 s 0.995 ± 0.0086

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19622388008/artifacts/4656161654.

@wsmoses wsmoses merged commit f9dd728 into main Nov 24, 2025
51 of 54 checks passed
@wsmoses wsmoses deleted the crulerooting branch November 24, 2025 04:58
@giordano giordano added the Julia v1.12 Related to compatibility with Julia v1.12 label Nov 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Julia v1.12 Related to compatibility with Julia v1.12

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants