Skip to content

Improve handling of jlroots#2978

Open
wsmoses wants to merge 8 commits intomainfrom
ihjlroot
Open

Improve handling of jlroots#2978
wsmoses wants to merge 8 commits intomainfrom
ihjlroot

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Feb 21, 2026

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Feb 21, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/absint.jl b/src/absint.jl
index 4afd14a7..d959e1fe 100644
--- a/src/absint.jl
+++ b/src/absint.jl
@@ -943,7 +943,7 @@ function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool, String}
 
         if larg !== nothing
             if (isa(larg, LLVM.ConstantArray) || isa(larg, LLVM.ConstantDataArray)) && eltype(value_type(larg)) == LLVM.IntType(8)
-	        return (true, String(map(Base.Fix1(convert, UInt8), collect(larg)[1:(end-1)])))
+            return (true, String(map(Base.Fix1(convert, UInt8), collect(larg)[1:(end - 1)])))
             end
 
         end
diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl
index ce1f5cc1..ed330d26 100644
--- a/src/compiler/validation.jl
+++ b/src/compiler/validation.jl
@@ -849,7 +849,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
             ofn = LLVM.parent(LLVM.parent(inst))
             mod = LLVM.parent(ofn)
 
-	    ops = collect(operands(inst))[1:LLVM.API.LLVMGetNumArgOperands(inst)]
+            ops = collect(operands(inst))[1:LLVM.API.LLVMGetNumArgOperands(inst)]
             @assert length(ops) == 2
             flib = ops[1]
             fname = ops[2]
@@ -1090,7 +1090,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
                 legal, flibty, byref = abs_typeof(operands(inst)[offset+1])
                 if legal
                     tys = Union{Type, Core.TypeofVararg}[flibty]
-		    for op in collect(operands(inst))[start+1:LLVM.API.LLVMGetNumArgOperands(inst)]
+                    for op in collect(operands(inst))[(start + 1):LLVM.API.LLVMGetNumArgOperands(inst)]
                         legal, typ, byref2 = abs_typeof(op, true)
                         if !legal
                             typ = Any
@@ -1245,7 +1245,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
             legal, flibty, byref = abs_typeof(operands(inst)[offset])
             if legal
                 tys = Union{Type, Core.TypeofVararg}[flibty]
-		for op in collect(operands(inst))[start:LLVM.API.LLVMGetNumArgOperands(inst)]
+                for op in collect(operands(inst))[start:LLVM.API.LLVMGetNumArgOperands(inst)]
                     legal, typ, byref2 = abs_typeof(op, true)
                     if !legal
                         typ = Any
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 5e503865..bf16f552 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -477,7 +477,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
                     if isa(cur, LLVM.CallInst) &&
                        isa(LLVM.called_operand(cur), LLVM.Function)
                         legalc = true
-			for (i, ci) in enumerate(operands(cur)[1:LLVM.API.LLVMGetNumArgOperands(cur)])
+                        for (i, ci) in enumerate(operands(cur)[1:LLVM.API.LLVMGetNumArgOperands(cur)])
                             if ci == prev
                                 nocapture = false
                                 readonly = false
@@ -1432,7 +1432,7 @@ function fix_decayaddr!(mod::LLVM.Module)
                    intr == LLVM.Intrinsic("llvm.memmove").id ||
                    intr == LLVM.Intrinsic("llvm.memset").id
                     newvs = LLVM.Value[]
-		    for (i, v) in enumerate(operands(st)[1:LLVM.API.LLVMGetNumArgOperands(st)])
+                    for (i, v) in enumerate(operands(st)[1:LLVM.API.LLVMGetNumArgOperands(st)])
                         if v == inst
                             LLVM.API.LLVMSetOperand(st, i - 1, operands(inst)[1])
                             push!(newvs, operands(inst)[1])
@@ -1491,7 +1491,7 @@ function fix_decayaddr!(mod::LLVM.Module)
                 else
                     EnumAttribute("sret")
                 end)
-		for (i, v) in enumerate(operands(st)[1:LLVM.API.LLVMGetNumArgOperands(st)])
+                for (i, v) in enumerate(operands(st)[1:LLVM.API.LLVMGetNumArgOperands(st)])
                     if v == inst
                         readnone = false
                         readonly = false
@@ -1817,7 +1817,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})
         un = un::LLVM.CallInst
 
         # Passing the fn as an argument is not permitted
-	for op in collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
+        for op in collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
             if op == fn
                 return false
             end
@@ -1963,7 +1963,7 @@ function propagate_returned!(mod::LLVM.Module)
                             illegalUse = true
                             break
                         end
-                        ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
+                            ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
                         bad = false
                         for op in ops
                             if op == fn
@@ -2075,7 +2075,7 @@ function propagate_returned!(mod::LLVM.Module)
                             illegalUse = true
                             break
                         end
-			ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
+                            ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
                         bad = false
                         for op in ops
                             if op == fn
@@ -2171,7 +2171,7 @@ function propagate_returned!(mod::LLVM.Module)
                     illegalUse = true
                     continue
                 end
-		ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
+                ops = collect(operands(un))[1:LLVM.API.LLVMGetNumArgOperands(un)]
                 bad = false
                 for op in ops
                     if op == fn
diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl
index 72233aac..a1bf75dd 100644
--- a/src/rules/jitrules.jl
+++ b/src/rules/jitrules.jl
@@ -1810,7 +1810,7 @@ function generic_setup(
     mode = get_mode(gutils)
     mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
 
-    ops = collect(operands(orig))[start+firstconst:LLVM.API.LLVMGetNumArgOperands(orig)]
+    ops = collect(operands(orig))[(start + firstconst):LLVM.API.LLVMGetNumArgOperands(orig)]
 
     T_int8 = LLVM.Int8Type()
     T_jlvalue = LLVM.StructType(LLVMType[])
@@ -1940,11 +1940,11 @@ function generic_setup(
     T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
 
     for v in vals
-       if value_type(v) != T_prjlvalue
-          throw(AssertionError("Illegal generic_setup, expected all arguments to by jlvaluet, found $(string(v)), within $(vals), orig=$(string(orig))"))
-       end
+        if value_type(v) != T_prjlvalue
+            throw(AssertionError("Illegal generic_setup, expected all arguments to by jlvaluet, found $(string(v)), within $(vals), orig=$(string(orig))"))
+        end
     end
-    
+
     cal = emit_apply_generic!(B, vals)
 
     debug_from_orig!(gutils, cal, orig)
diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl
index bdc7c532..da3051cd 100644
--- a/src/rules/llvmrules.jl
+++ b/src/rules/llvmrules.jl
@@ -1732,7 +1732,7 @@ end
     width = get_width(gutils)
 
     args = LLVM.Value[]
-    for a in origops[1:LLVM.API.LLVMGetNumArgOperands(orig)-1]
+    for a in origops[1:(LLVM.API.LLVMGetNumArgOperands(orig) - 1)]
         v = invert_pointer(gutils, a, B)
         push!(args, v)
     end
@@ -1842,7 +1842,7 @@ end
             UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
         for idx = 1:width
             vargs = LLVM.Value[]
-	    for a in args[1:end-1]
+            for a in args[1:(end - 1)]
                 push!(vargs, extract_value!(B, a, idx - 1))
             end
             push!(vargs, args[end])
diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl
index de614420..fbf7c36b 100644
--- a/src/rules/typeunstablerules.jl
+++ b/src/rules/typeunstablerules.jl
@@ -405,9 +405,9 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
     world = enzyme_extract_world(LLVM.parent(position(B)))
 
     @assert is_constant_value(gutils, origops[offset])
-    icvs = [is_constant_value(gutils, v) for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]
-    abs_partial = [abs_typeof(v, true) for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]
-    abs = [abs_typeof(v) for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]
+    icvs = [is_constant_value(gutils, v) for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]
+    abs_partial = [abs_typeof(v, true) for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]
+    abs = [abs_typeof(v) for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]
 
     @assert length(icvs) == length(abs)
     for (icv, (found_partial, typ_partial, byref_partial), (found, typ, byref)) in
@@ -441,10 +441,10 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
     shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:LLVM.API.LLVMGetNumArgOperands(orig)]]
     if offset != 1
         pushfirst!(shadowsin, origops[1])
-	pushfirst!(valTys, API.VT_Primal)
+        pushfirst!(valTys, API.VT_Primal)
     end
-    
-    shadowres = batch_call_same_with_inverted_arg_if_active!(B, gutils, orig, shadowsin, valTys, false; force_run=true)
+
+    shadowres = batch_call_same_with_inverted_arg_if_active!(B, gutils, orig, shadowsin, valTys, false; force_run = true)
     unsafe_store!(shadowR, shadowres.ref)
     return true
 end
@@ -468,8 +468,8 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
 
     if !newstruct_common(true, true, offset, B, orig, gutils, normalR, shadowR) #=run=#
         origops = collect(operands(orig))
-        abs_partial = [abs_typeof(v, true) for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]
-        icvs = [is_constant_value(gutils, v) for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]
+        abs_partial = [abs_typeof(v, true) for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]
+        icvs = [is_constant_value(gutils, v) for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]
         emit_error(
             B,
             orig,
@@ -480,7 +480,7 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
             " " *
             string(abs_partial) *
             " " *
-            string([v for v in origops[offset+1:LLVM.API.LLVMGetNumArgOperands(orig)]]),
+                string([v for v in origops[(offset + 1):LLVM.API.LLVMGetNumArgOperands(orig)]]),
         )
     end
 
@@ -1027,7 +1027,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR)
         shadowin = invert_pointer(gutils, origops[2], B)
         if width == 1
             args = LLVM.Value[new_from_original(gutils, origops[1]), shadowin]
-	    for a in origops[3:(LLVM.API.LLVMGetNumArgOperands(orig)-(offset-1))]
+            for a in origops[3:(LLVM.API.LLVMGetNumArgOperands(orig) - (offset - 1))]
                 push!(args, new_from_original(gutils, a))
             end
             if offset != 1
@@ -1050,7 +1050,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR)
                     new_from_original(gutils, origops[1]),
                     shadowin_idx,
                 ]
-		for a in origops[3:(LLVM.API.LLVMGetNumArgOperands(orig)-(offset-1))]
+                for a in origops[3:(LLVM.API.LLVMGetNumArgOperands(orig) - (offset - 1))]
                     push!(args, new_from_original(gutils, a))
                 end
                 if offset != 1

@github-actions
Copy link
Contributor

github-actions bot commented Feb 21, 2026

Benchmark Results

main d6362ad... main / d6362ad...
basics/make_zero/namedtuple 0.052 ± 0.0025 μs 0.0544 ± 0.002 μs 0.956 ± 0.059
basics/make_zero/struct 0.271 ± 0.0052 μs 0.274 ± 0.0049 μs 0.989 ± 0.026
basics/overhead 4.65 ± 0.06 ns 4.95 ± 0.01 ns 0.939 ± 0.012
basics/remake_zero!/namedtuple 0.225 ± 0.0086 μs 0.224 ± 0.009 μs 1 ± 0.055
basics/remake_zero!/struct 0.228 ± 0.0078 μs 0.225 ± 0.0096 μs 1.01 ± 0.056
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.27 μs 10.3 ± 1.7 μs 0.994 ± 0.17
fold_broadcast/multidim_sum_bcast/2D 12 ± 0.28 μs 12.1 ± 0.25 μs 0.998 ± 0.031
time_to_load 1.01 ± 0.008 s 1.03 ± 0.015 s 0.984 ± 0.017

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/22273960398/artifacts/5606129576.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant