Skip to content

Ensure relocatable global variables#2810

Merged
wsmoses merged 40 commits intomainfrom
relocg
Nov 30, 2025
Merged

Ensure relocatable global variables#2810
wsmoses merged 40 commits intomainfrom
relocg

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Nov 25, 2025

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 25, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/absint.jl b/src/absint.jl
index 06a06d3d..3484f35e 100644
--- a/src/absint.jl
+++ b/src/absint.jl
@@ -6,11 +6,11 @@
 const JL_MAX_TAGS = 64 # see `enum jl_small_typeof_tags` in julia.h
 
 function unbind(@nospecialize(val))
-   if val isa Core.Binding
-       return val.value
-   else
-       return val
-   end
+    if val isa Core.Binding
+        return val.value
+    else
+        return val
+    end
 end
 
 function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false, typetag::Bool=false)::Tuple{Bool, Any}
@@ -28,7 +28,7 @@ function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked
                     return (true, v)
                 end
             end
-	    @assert !startswith(gname, "ejl_inserted") "Could not find ejl_inserted variable in map $gname"
+            @assert !startswith(gname, "ejl_inserted") "Could not find ejl_inserted variable in map $gname"
         end
         if isa(ce, LLVM.LoadInst)
             gv = operands(ce)[1]
@@ -373,7 +373,7 @@ function abs_typeof(
             end
             for (k, v) in JuliaEnzymeNameMap
                 if gname == "ejl_" * k
-		    return (true, Core.Typeof(unbind(v)), GPUCompiler.BITS_REF)
+                    return (true, Core.Typeof(unbind(v)), GPUCompiler.BITS_REF)
                 end
             end
         end
@@ -464,14 +464,14 @@ function abs_typeof(
                 nm == "jl_gc_alloc_typed" ||
                 nm == "ijl_gc_alloc_typed"
             vals = absint(operands(arg)[3], partial, false, #=typetag=#true)
-	    @assert !(vals[2] isa Core.Binding)
+            @assert !(vals[2] isa Core.Binding)
             return (vals[1], vals[2], vals[1] ? GPUCompiler.BITS_REF : nothing)
         end
         # Type tag is arg 3
         if nm == "jl_alloc_genericmemory_unchecked" ||
 		nm == "ijl_alloc_genericmemory_unchecked"
 	    vals = absint(operands(arg)[3], partial, true, #=typetag=#true)
-	    @assert !(vals[2] isa Core.Binding)
+            @assert !(vals[2] isa Core.Binding)
             return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
         end
         # Type tag is arg 1
@@ -486,13 +486,13 @@ function abs_typeof(
                 nm == "jl_alloc_genericmemory" ||
                 nm == "ijl_alloc_genericmemory"
             vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
-	    @assert !(vals[2] isa Core.Binding)
+            @assert !(vals[2] isa Core.Binding)
             return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
         end
 
         if nm == "jl_new_structt" || nm == "ijl_new_structt"
             vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
-	    @assert !(vals[2] isa Core.Binding)
+            @assert !(vals[2] isa Core.Binding)
             return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
         end
 
@@ -511,7 +511,7 @@ function abs_typeof(
             if nm == "jl_new_structv" || nm == "ijl_new_structv"
                 @assert index == 2
                 vals = absint(operands(arg)[index], partial, false, #=typetag=#true)
-	    	@assert !(vals[2] isa Core.Binding)
+                @assert !(vals[2] isa Core.Binding)
                 return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
             end
 
@@ -545,11 +545,11 @@ function abs_typeof(
             if nm == "jl_f__apply_iterate" || nm == "ijl_f__apply_iterate"
                 index += 1
                 legal, iterfn = absint(operands(arg)[index])
-	    	iterfn = unbind(iterfn)
+                iterfn = unbind(iterfn)
                 index += 1
                 if legal && iterfn == Base.iterate
                     legal0, combfn = absint(operands(arg)[index])
-		    combfn = unbind(combfn)
+                    combfn = unbind(combfn)
                     index += 1
                     if legal0 && combfn == Core.apply_type && partial
                         return (true, Type, GPUCompiler.BITS_REF)
@@ -887,7 +887,7 @@ function abs_typeof(
 
     legal, val = absint(arg, partial)
     if legal
-	val = unbind(val)
+        val = unbind(val)
         return (true, Core.Typeof(val), GPUCompiler.BITS_REF)
     end
     return (false, nothing, nothing)
diff --git a/src/api.jl b/src/api.jl
index 75d7b0a8..93ccfeac 100644
--- a/src/api.jl
+++ b/src/api.jl
@@ -1502,11 +1502,11 @@ function EnzymeDumpModuleRef(mod)
 end
 
 function EnzymeDumpValueRef(mod)
-    ccall((:EnzymeDumpValueRef, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), mod)
+    return ccall((:EnzymeDumpValueRef, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), mod)
 end
 
 function EnzymeDumpTypeRef(mod)
-    ccall((:EnzymeDumpTypeRef, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), mod)
+    return ccall((:EnzymeDumpTypeRef, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef,), mod)
 end
 
 EnzymeComputeByteOffsetOfGEP(B, V, T) = LLVM.Value(
diff --git a/src/compiler.jl b/src/compiler.jl
index 78b160d0..ee1b2cd7 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1694,10 +1694,10 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
 				index += 1
 				found = Any[]
 				legal, Ty = absint(operands(arg)[index], partial)
-				Ty = unbind(Ty)
+                            Ty = unbind(Ty)
 				if legal && Ty == NTuple
 				   legal, Ty = absint(operands(arg)[index+2])
-				   Ty = unbind(Ty)
+                                Ty = unbind(Ty)
 				   if legal
 					# count should represent {the total size in bytes, the aligned size of each element}
 					B = LLVM.IRBuilder()
@@ -5402,8 +5402,8 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
 @static if VERSION < v"1.11-"
 else    
                     legal2, obj = absint(inst)
-		    obj = unbind(obj)
-		    if legal2 && is_memory_instance(obj)
+                        obj = unbind(obj)
+                        if legal2 && is_memory_instance(obj)
                         metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[])
                     end
 end
@@ -5632,7 +5632,7 @@ end
                         string(cur)
                     slegal, foundv = absint(cur)
                     if slegal
-		    	foundv = unbind(foundv)
+                        foundv = unbind(foundv)
                         resstr *= "of type " * string(foundv)
                     end
                     emit_error(builder, user, resstr, EnzymeMutabilityException)
@@ -6468,7 +6468,7 @@ const DumpLLVMCall = Ref(false)
         end
         reinsert_gcmarker!(llvm_f)
 
-	Enzyme.Compiler.JIT.prepare!(mod)
+        Enzyme.Compiler.JIT.prepare!(mod)
 	if DumpLLVMCall[]
 	   API.EnzymeDumpModuleRef(mod.ref)
 	end
@@ -6594,7 +6594,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
             end
         else
             propagate_returned!(mod)
-	    Compiler.JIT.prepare!(mod)
+            Compiler.JIT.prepare!(mod)
         end
         mstr
     else
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 0ccb9dfe..c77e7788 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -74,13 +74,13 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
 
     function middle_optimize!(second_stage=false)
     @dispose pb = NewPMPassBuilder() begin
-        registerEnzymeAndPassPipeline!(pb)
+            registerEnzymeAndPassPipeline!(pb)
         add!(pb, NewPMAAManager()) do aam
             add!(aam, ScopedNoAliasAA())
             add!(aam, TypeBasedAA())
             add!(aam, BasicAA())
         end
-        add!(pb, NewPMModulePassManager()) do mpm
+            add!(pb, NewPMModulePassManager()) do mpm
             add!(mpm, CPUFeaturesPass()) # why is this duplicated?
 
             add!(mpm, NewPMFunctionPassManager()) do fpm
diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl
index da9ec7c7..df747497 100644
--- a/src/compiler/orcv2.jl
+++ b/src/compiler/orcv2.jl
@@ -187,22 +187,22 @@ function prepare!(mod)
     end
     for g in collect(globals(mod))
         if !startswith(LLVM.name(g), "ejl_inserted\$")
-           continue
+            continue
         end
         _, ogname, load1, initaddr = split(LLVM.name(g), "\$")
 
         load1 = load1 == "true"
-            initaddr = parse(UInt, initaddr)
+        initaddr = parse(UInt, initaddr)
         ptr = Base.reinterpret(Ptr{Ptr{Cvoid}}, initaddr)
         if load1
-           ptr = Base.unsafe_load(ptr, :unordered)
+            ptr = Base.unsafe_load(ptr, :unordered)
         end
-                
+
         obj = Base.unsafe_pointer_to_objref(ptr)
-	
+
         # Let's try a de-bind for 1.10 lux
         if isa(obj, Core.Binding)
-           ptr = Compiler.unsafe_to_ptr(obj.value)
+            ptr = Compiler.unsafe_to_ptr(obj.value)
         end
 
         ptr = reinterpret(UInt, ptr)
@@ -212,6 +212,7 @@ function prepare!(mod)
         replace_uses!(g, ptr)
         Compiler.eraseInst(mod, g)
     end
+    return
 end
 
 function get_trampoline(job)
@@ -258,7 +259,7 @@ function get_trampoline(job)
                 Compiler.eraseInst(mod, other_func)
             end
 
-	    prepare!(mod)
+            prepare!(mod)
             tsm = move_to_threadsafe(mod)
 
             il = LLVM.IRCompileLayer(lljit)
diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl
index 701e8e4d..96fb7ac5 100644
--- a/src/compiler/utils.jl
+++ b/src/compiler/utils.jl
@@ -317,7 +317,7 @@ function declare_pgcstack!(mod::LLVM.Module)
     )
 end
 
-function emit_pgcstack(B::LLVM.IRBuilder, name::String="")
+function emit_pgcstack(B::LLVM.IRBuilder, name::String = "")
     curent_bb = position(B)
     fn = LLVM.parent(curent_bb)
     mod = LLVM.parent(fn)
@@ -353,20 +353,20 @@ function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing
         context(LLVM.parent(func))
         B = IRBuilder()
         entry_bb = first(blocks(func))
-	if PB !== nothing && LLVM.name(Base.position(PB)) == "allocsForInversion"
-	    B = PB
-	elseif !isempty(instructions(entry_bb))
-	    if PB === nothing || Base.position(PB) != entry_bb 
-		    position!(B, first(instructions(entry_bb)))
-	    else
-		    B = PB
-	    end
+        if PB !== nothing && LLVM.name(Base.position(PB)) == "allocsForInversion"
+            B = PB
+        elseif !isempty(instructions(entry_bb))
+            if PB === nothing || Base.position(PB) != entry_bb
+                position!(B, first(instructions(entry_bb)))
+            else
+                B = PB
+            end
         else
-	    if PB === nothing || Base.position(PB) != entry_bb 
-               position!(B, entry_bb)
-	    else
-	       B = PB
-	    end
+            if PB === nothing || Base.position(PB) != entry_bb
+                position!(B, entry_bb)
+            else
+                B = PB
+            end
         end
         emit_pgcstack(B, "newly_emitted_pgc_stack")
     else
diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl
index 1138d8b1..ed7da86f 100644
--- a/src/compiler/validation.jl
+++ b/src/compiler/validation.jl
@@ -239,77 +239,77 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
         inst = LLVM.Instruction(iter)
         iter = LLVM.API.LLVMGetNextInstruction(iter)
 
-	if isa(value_type(inst), LLVM.PointerType) && addrspace(value_type(inst)) == Tracked
-	    	inst0, _ = get_base_and_offset(inst; offsetAllowed=false, inttoptr=true)
-		if isa(inst0, LLVM.LoadInst) && addrspace(value_type(operands(inst0)[1])) == 0
-		   addr = operands(inst0)[1]
-	    	   addr, off = get_base_and_offset(addr; offsetAllowed=true, inttoptr=true)
-			gname = nothing
-			load1 = false
-			if isa(addr, LLVM.GlobalVariable) && haskey(metadata(addr), "julia.constgv")
-				paddr = addr
-				addr = LLVM.initializer(paddr)
-				gname = LLVM.name(paddr)*"\$false"
-				addr, _ = get_base_and_offset(addr; offsetAllowed=false, inttoptr=true)
-			elseif isa(addr, LLVM.LoadInst)
-			   paddr = operands(addr)[1]
-			   if isa(paddr, LLVM.GlobalVariable) && haskey(metadata(paddr), "julia.constgv")
-				addr = LLVM.initializer(paddr)
-				gname = LLVM.name(paddr)*"\$true"
-				addr, _ = get_base_and_offset(addr; offsetAllowed=false, inttoptr=true)
-				load1 = true
-			   end
-			elseif isa(addr, LLVM.ConstantInt)
-			    gname = string(convert(UInt, addr))*"\$true"
-			    load1 = true
-			end
+            if isa(value_type(inst), LLVM.PointerType) && addrspace(value_type(inst)) == Tracked
+                inst0, _ = get_base_and_offset(inst; offsetAllowed = false, inttoptr = true)
+                if isa(inst0, LLVM.LoadInst) && addrspace(value_type(operands(inst0)[1])) == 0
+                    addr = operands(inst0)[1]
+                    addr, off = get_base_and_offset(addr; offsetAllowed = true, inttoptr = true)
+                    gname = nothing
+                    load1 = false
+                    if isa(addr, LLVM.GlobalVariable) && haskey(metadata(addr), "julia.constgv")
+                        paddr = addr
+                        addr = LLVM.initializer(paddr)
+                        gname = LLVM.name(paddr) * "\$false"
+                        addr, _ = get_base_and_offset(addr; offsetAllowed = false, inttoptr = true)
+                    elseif isa(addr, LLVM.LoadInst)
+                        paddr = operands(addr)[1]
+                        if isa(paddr, LLVM.GlobalVariable) && haskey(metadata(paddr), "julia.constgv")
+                            addr = LLVM.initializer(paddr)
+                            gname = LLVM.name(paddr) * "\$true"
+                            addr, _ = get_base_and_offset(addr; offsetAllowed = false, inttoptr = true)
+                            load1 = true
+                        end
+                    elseif isa(addr, LLVM.ConstantInt)
+                        gname = string(convert(UInt, addr)) * "\$true"
+                        load1 = true
+                    end
 
-			if isa(addr, LLVM.ConstantInt)
-			
-			initaddr = convert(UInt, addr) + off
-			if gname isa String
-			    gname = gname *"\$$initaddr"
-			end
-			ptr = Base.reinterpret(Ptr{Ptr{Cvoid}}, initaddr)
-			if load1
-			ptr = Base.unsafe_load(ptr, :unordered)
-			if ptr == C_NULL
-				continue
-			end
-			end
-			obj = Base.unsafe_pointer_to_objref(ptr)
-	    		if obj === nothing
-				continue
-			end
-			obj0 = obj
-
-			# TODO we can use this to make it properly relocatable
-			if isa(obj, Core.Binding)
-			   obj = obj.value
-			   if gname === nothing
-				obj0 = obj
-			   end
-			end
+                    if isa(addr, LLVM.ConstantInt)
 
-			# We really don't want to mess with the atomic baked in loads here
-			#if obj isa Base.ReentrantLock
-			#   continue
-			#end
-
-			b = IRBuilder()
-			position!(b, inst)
-			newf = unsafe_to_llvm(b, obj0; insert_name_if_not_exists=gname) 
-			replace_uses!(inst, newf)
-			LLVM.API.LLVMInstructionEraseFromParent(inst)
-			continue
-		end
-		end
-	    end
+                        initaddr = convert(UInt, addr) + off
+                        if gname isa String
+                            gname = gname * "\$$initaddr"
+                        end
+                        ptr = Base.reinterpret(Ptr{Ptr{Cvoid}}, initaddr)
+                        if load1
+                            ptr = Base.unsafe_load(ptr, :unordered)
+                            if ptr == C_NULL
+                                continue
+                            end
+                        end
+                        obj = Base.unsafe_pointer_to_objref(ptr)
+                        if obj === nothing
+                            continue
+                        end
+                        obj0 = obj
+
+                        # TODO we can use this to make it properly relocatable
+                        if isa(obj, Core.Binding)
+                            obj = obj.value
+                            if gname === nothing
+                                obj0 = obj
+                            end
+                        end
+
+                        # We really don't want to mess with the atomic baked in loads here
+                        #if obj isa Base.ReentrantLock
+                        #   continue
+                        #end
+
+                        b = IRBuilder()
+                        position!(b, inst)
+                        newf = unsafe_to_llvm(b, obj0; insert_name_if_not_exists = gname)
+                        replace_uses!(inst, newf)
+                        LLVM.API.LLVMInstructionEraseFromParent(inst)
+                        continue
+                    end
+                end
+            end
         if isa(inst, LLVM.CallInst)
             push!(calls, inst)
             # remove illegal invariant.load and jtbaa_const invariants
         elseif isa(inst, LLVM.LoadInst) 
-	    fn_got, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed=false, inttoptr=false)
+                fn_got, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed = false, inttoptr = false)
             fname = String(name(fn_got))
             match_ = match(r"^jlplt_(.*)_\d+_got$", fname)
 
@@ -852,11 +852,11 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
             flib = ops[1]
             fname = ops[2]
 
-	    if isa(flib, LLVM.ConstantExpr) || isa(flib, LLVM.GlobalVariable)
-		legal, flib2 = absint(flib)
-		if legal
-		    flib = unbind(flib2)
-		end
+            if isa(flib, LLVM.ConstantExpr) || isa(flib, LLVM.GlobalVariable)
+                legal, flib2 = absint(flib)
+                if legal
+                    flib = unbind(flib2)
+                end
             end
             if isa(flib, GlobalRef) && isdefined(flib.mod, flib.name)
                 flib = getfield(flib.mod, flib.name)
@@ -1006,7 +1006,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
                 iteroff = 2
 
                 legal, iterlib = absint(operands(inst)[iteroff+1])
-		iterlib = unbind(iterlib)
+                iterlib = unbind(iterlib)
                 if legal && iterlib == Base.iterate
                     legal, GT, byref = abs_typeof(operands(inst)[4+1], true)
                     funcoff = 3
@@ -1096,7 +1096,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
                         push!(tys, typ)
                     end
                     legal, flib = absint(operands(inst)[offset+1])
-		    flib = unbind(flib)
+                    flib = unbind(flib)
                     if legal && isa(flib, Core.MethodInstance)
                         if !Base.isvarargtype(flib.specTypes.parameters[end])
                             @assert length(tys) == length(flib.specTypes.parameters)
@@ -1251,7 +1251,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
                     push!(tys, typ)
                 end
                 legal, flib = absint(operands(inst)[offset+1])
-		flib = unbind(flib)
+                flib = unbind(flib)
                 if legal && isa(flib, Core.MethodInstance)
                     if !Base.isvarargtype(flib.specTypes.parameters[end])
                         if length(tys) != length(flib.specTypes.parameters)
diff --git a/src/errors.jl b/src/errors.jl
index e89da601..027bcfd0 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -1068,18 +1068,18 @@ function julia_error(
                 print(io, "Current scope: \n")
                 print(io, ir)
             end
-	    legal, obj = absint(val)
-	    if legal
-		obj0 = obj
-		obj = unbind(obj)
-	        println(io, "\nValue of type: ", Core.Typeof(obj))
-		println(io ,  " of value    : ", obj)
-		if obj0 isa Core.Binding
-		println(io ,  " binding     : ", obj0)	    
-		end
-		println(io)
-	    end
-	    if !isa(val, LLVM.Argument) && !isa(val, LLVM.GlobalVariable) 
+            legal, obj = absint(val)
+            if legal
+                obj0 = obj
+                obj = unbind(obj)
+                println(io, "\nValue of type: ", Core.Typeof(obj))
+                println(io, " of value    : ", obj)
+                if obj0 isa Core.Binding
+                    println(io, " binding     : ", obj0)
+                end
+                println(io)
+            end
+            if !isa(val, LLVM.Argument) && !isa(val, LLVM.GlobalVariable)
                 print(io, "\n Inverted pointers: \n")
                 ip = API.EnzymeGradientUtilsInvertedPointersToString(gutils)
                 sval = Base.unsafe_string(ip)
@@ -1302,15 +1302,15 @@ function julia_error(
                 end
 
                 legal2, obj = absint(cur)
-		obj0 = obj
+                obj0 = obj
                 # Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple
                 if legal2
-		   obj = unbind(obj)
-		   if is_memory_instance(obj) || (obj isa Core.SimpleVector && length(obj) == 0)
-			return make_batched(ncur, prevbb)
-		   end
+                    obj = unbind(obj)
+                    if is_memory_instance(obj) || (obj isa Core.SimpleVector && length(obj) == 0)
+                        return make_batched(ncur, prevbb)
+                    end
                    if active_reg(TT, world) == ActiveState &&
-		     ( isa(cur, LLVM.ConstantExpr) || isa(cur, LLVM.GlobalVariable)) &&
+                            (isa(cur, LLVM.ConstantExpr) || isa(cur, LLVM.GlobalVariable)) &&
                    cur == data2
                     if width == 1
                         if mode == API.DEM_ForwardMode
@@ -1351,8 +1351,8 @@ else
                     larg, off = get_base_and_offset(operands(cur)[1])
                     if isa(larg, LLVM.LoadInst)
                         legal2, obj = absint(larg)
-			obj = unbind(obj)
-			if legal2 && is_memory_instance(obj)
+                            obj = unbind(obj)
+                            if legal2 && is_memory_instance(obj)
                             return make_batched(ncur, prevbb)
                         end
                     end
@@ -1361,10 +1361,10 @@ end
 
                 badval = if legal2
                     sv = string(obj) * " of type" * " " * string(TT)
-		    if obj0 isa Core.Binding
-			sv = sv *" binded at "*string(obj0)
-		    end
-		    sv
+                    if obj0 isa Core.Binding
+                        sv = sv * " binded at " * string(obj0)
+                    end
+                    sv
                 else
                     "Unknown object of type" * " " * string(TT)
                 end
@@ -1525,7 +1525,7 @@ end
                 end
             end
            
-	    if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(Base.Fix2(isa, LLVM.ConstantInt), operands(cur)[2:end])) || (isa(cur,LLVM.ConstantExpr) &&  opcode(cur) in (LLVM.API.LLVMBitCast, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMGetElementPtr))
+            if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(Base.Fix2(isa, LLVM.ConstantInt), operands(cur)[2:end])) || (isa(cur, LLVM.ConstantExpr) &&  opcode(cur) in (LLVM.API.LLVMBitCast, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMGetElementPtr))
                 lhs = make_replacement(operands(cur)[1], prevbb)
                 if illegal
                     return ncur
diff --git a/src/jlrt.jl b/src/jlrt.jl
index ad18cc90..00a4e324 100644
--- a/src/jlrt.jl
+++ b/src/jlrt.jl
@@ -20,12 +20,12 @@ function emit_allocobj!(
     T_pint8 = LLVM.PointerType(T_int8)
 
     pgcstack = reinsert_gcmarker!(fn, B)
-    bc = bitcast!(B, pgcstack, T_ppjlvalue, LLVM.name(pgcstack)*"_bc")
-    
+    bc = bitcast!(B, pgcstack, T_ppjlvalue, LLVM.name(pgcstack) * "_bc")
+
     ct = inbounds_gep!(
         B,
         T_pjlvalue,
-	bc,
+        bc,
         [LLVM.ConstantInt(current_task_offset())],
     )
 
@@ -457,7 +457,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vect
     for arg in args
         slegal, foundv = absint(arg)
         if slegal
-	    push!(found, unbind(foundv))
+            push!(found, unbind(foundv))
         else
             legal = false
             break
@@ -512,7 +512,7 @@ function emit_tuple!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value
     for arg in args
         slegal, foundv = absint(arg)
         if slegal
-	    push!(found, unbind(foundv))
+            push!(found, unbind(foundv))
         else
             legal = false
             break
@@ -874,7 +874,7 @@ function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
 	ls = get_layout_struct()
 	lptr = LLVM.PointerType(ls, 10)
 	if legal
-		JTy = unbind(JTy)
+        JTy = unbind(JTy)
 		return LLVM.const_inttoptr(LLVM.ConstantInt(Base.reinterpret(UInt, JTy.layout)), lptr)
 	end
 	@assert !isa(ty, LLVM.ConstantExpr)
@@ -893,7 +893,7 @@ end
 function emit_type_layout_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
 	legal, JTy = absint(ty)
 	if legal
-	    JTy = unbind(JTy)
+        JTy = unbind(JTy)
 	    @assert JTy isa Type
 	    res = Compiler.datatype_layoutsize(JTy)
 	    return LLVM.ConstantInt(res)
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 5cecd431..7725027a 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -442,7 +442,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
 	@static if VERSION < v"1.11-"
 	else    
 			    legal2, obj = absint(src)
-			    if legal2 && is_memory_instance(unbind(obj)) 
+                            if legal2 && is_memory_instance(unbind(obj))
 				metadata(src)["nonnull"] = MDNode(LLVM.Metadata[])
 			    end
 	end
diff --git a/src/sugar.jl b/src/sugar.jl
index f244d1be..5ede8791 100644
--- a/src/sugar.jl
+++ b/src/sugar.jl
@@ -107,9 +107,9 @@ end
 
         LLVM.position!(builder, exit)
         LLVM.ret!(builder, obj)
-	
+
         Compiler.reinsert_gcmarker!(llvm_f)
-	Compiler.JIT.prepare!(mod)
+        Compiler.JIT.prepare!(mod)
 
         string(mod)
     end
diff --git a/src/typeutils/jltypes.jl b/src/typeutils/jltypes.jl
index c2877660..7beaf3e4 100644
--- a/src/typeutils/jltypes.jl
+++ b/src/typeutils/jltypes.jl
@@ -412,13 +412,13 @@ end
 @inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated
 
 @inline function is_memory_instance(@nospecialize(obj))
-   @static if VERSION < v"1.11"
-	return false
-   else
-	if obj isa Memory
-	   return obj == typeof(obj).instance
+    @static if VERSION < v"1.11"
+        return false
+    else
+        if obj isa Memory
+            return obj == typeof(obj).instance
         end
-	return false
-   end
+        return false
+    end
 end
 
diff --git a/src/utils.jl b/src/utils.jl
index c0ca94de..1a05e158 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -92,7 +92,7 @@ end
 export unsafe_to_ptr
 
 # This mimicks literal_pointer_val / literal_pointer_val_slot
-function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val); insert_name_if_not_exists::Union{String, Nothing}=nothing)::LLVM.Value
+function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val); insert_name_if_not_exists::Union{String, Nothing} = nothing)::LLVM.Value
     T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
     T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
     T_prjlvalue_UT = LLVM.PointerType(T_jlvalue)
@@ -106,65 +106,65 @@ function unsafe_to_llvm(B::LLVM.IRBuilder, @nospecialize(val); insert_name_if_no
             end
         end
     end
-    
+
     function setup_global(k, v)
-	    k0 = k
+        k0 = k
             mod = LLVM.parent(LLVM.parent(LLVM.position(B)))
             globs = LLVM.globals(mod)
             if Base.haskey(globs, "ejl_" * k)
                 return globs["ejl_"*k]
             end
-        
-	force_inactive = false
-	if insert_name_if_not_exists isa String
-	    k = "inserted\$"*insert_name_if_not_exists
-            if !haskey(Compiler.JuliaEnzymeNameMap, k)
-		 Compiler.JuliaEnzymeNameMap[k] = val
-	    end
-	    # Since the legacy behavior was to force inactive for global constants, we retain that here (for now)
-	    force_inactive = true
-	end
 
-            if Base.haskey(globs, "ejl_" * k)
-                return globs["ejl_"*k]
+        force_inactive = false
+        if insert_name_if_not_exists isa String
+            k = "inserted\$" * insert_name_if_not_exists
+            if !haskey(Compiler.JuliaEnzymeNameMap, k)
+                Compiler.JuliaEnzymeNameMap[k] = val
             end
+            # Since the legacy behavior was to force inactive for global constants, we retain that here (for now)
+            force_inactive = true
+        end
+
+        if Base.haskey(globs, "ejl_" * k)
+            return globs["ejl_" * k]
+        end
 
             gv = LLVM.GlobalVariable(mod, T_jlvalue, "ejl_" * k, Tracked)
 
             API.SetMD(gv, "enzyme_ta_norecur", LLVM.MDNode(LLVM.Metadata[]))
-            inactive = force_inactive || Enzyme.Compiler.is_memory_instance(v)
-	    if !inactive && v isa Core.SimpleVector && length(v) == 0
-		inactive = true
-	    end
-	    if !inactive && world isa UInt
+        inactive = force_inactive || Enzyme.Compiler.is_memory_instance(v)
+        if !inactive && v isa Core.SimpleVector && length(v) == 0
+            inactive = true
+        end
+        if !inactive && world isa UInt
                 legal, jTy, byref = Compiler.abs_typeof(gv, true)
                 if legal
                     curent_bb = position(B)
                     fn = LLVM.parent(curent_bb)
-		    state = Enzyme.Compiler.active_reg(jTy, world)
-		    inactive = state == Enzyme.Compiler.AnyState ||state == Enzyme.Compiler.ActiveState
+                state = Enzyme.Compiler.active_reg(jTy, world)
+                inactive = state == Enzyme.Compiler.AnyState ||state == Enzyme.Compiler.ActiveState
                 end
             end
-	    if inactive
-		API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[]))
-	    end
+        if inactive
+            API.SetMD(gv, "enzyme_inactive", LLVM.MDNode(LLVM.Metadata[]))
+        end
             return gv
     end
 
     for (k, v) in Compiler.JuliaGlobalNameMap
         if v === val
-	    return setup_global(k, v)
+            return setup_global(k, v)
         end
     end
 
     for (k, v) in Compiler.JuliaEnzymeNameMap
         if v === val
-	    return setup_global(k, v)
+            return setup_global(k, v)
         end
     end
 
     if insert_name_if_not_exists !== nothing
-	return setup_global(insert_name_if_not_exists, val)
+        return setup_global(insert_name_if_not_exists, val)
     end
 
     # XXX: This prevents code from being runtime relocatable
diff --git a/test/locks.jl b/test/locks.jl
index 9e5daf67..8106a1f8 100644
--- a/test/locks.jl
+++ b/test/locks.jl
@@ -4,9 +4,9 @@ using Test
 const my_cache_lock = ReentrantLock()
 
 function my_lock()
-       lock(my_cache_lock);
-       unlock(my_cache_lock);
-       return nothing
+    lock(my_cache_lock)
+    unlock(my_cache_lock)
+    return nothing
 end
 
 @testset "Lock forward" begin

@github-actions
Copy link
Contributor

github-actions bot commented Nov 25, 2025

Benchmark Results

main 3aae7c6... main / 3aae7c6...
basics/make_zero/namedtuple 0.0528 ± 0.0026 μs 0.0527 ± 0.0027 μs 1 ± 0.071
basics/make_zero/struct 0.265 ± 0.0059 μs 0.265 ± 0.0065 μs 1 ± 0.033
basics/overhead 4.65 ± 0.01 ns 5.57 ± 0.021 ns 0.834 ± 0.0036
basics/remake_zero!/namedtuple 0.237 ± 0.0082 μs 0.239 ± 0.0095 μs 0.99 ± 0.052
basics/remake_zero!/struct 0.238 ± 0.0092 μs 0.24 ± 0.0087 μs 0.995 ± 0.053
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.48 μs 10.3 ± 0.27 μs 1 ± 0.054
fold_broadcast/multidim_sum_bcast/2D 12.1 ± 0.27 μs 12.2 ± 0.25 μs 0.998 ± 0.03
time_to_load 1.03 ± 0.0099 s 1.04 ± 0.021 s 0.989 ± 0.022

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19791982609/artifacts/4714998493.


get_tm() = tm[]
get_jit() = jit[].jit
get_dylib() = dylib[]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just define it as:

get_dylib() = JITDylib(lljit)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this actually used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was used [I tried to do the llvm.define when we added to our list, but that led to issues], so now just doing a pass over ir doing the rewrite.

removed since

@wsmoses
Copy link
Member Author

wsmoses commented Nov 25, 2025

@vchuravy interestingly [on a debug build] I can't repro the ci failures. any luck on your side?

continue
end
_, ogname, load1, initaddr = split(LLVM.name(g), "\$")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy here we have the original name of the global variable that julia emitted it as [as well as the initial address/etc]. If there's any way we can lookup a portable version of this from the name, we have our relocatable code. E.g. if I could cglobal(:jl_global_7353) [where just the name was consistent but the runtime pointer differed, we'd be done]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can't rely on names. Names are currently session local and in the future compilation local

On the Julia side this rougly looks like:

We are building a list of pointers (GOT like) that will be initialized during the loading process, we create a side table of Julia Object => Index in GOT, which will be serialized during the saving of a package image, upon loading we deserialize that table (basically an IDDict) and populate the GOT.

Now this might be something GPUCompiler doesn't yet fully expose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if we can somehow get access to that dict directly we could replace this with the correct object per session.

We actually store all the objects right now in our own dict, but it then begs the question of how we keep that dict relocatable to different sessions

@wsmoses
Copy link
Member Author

wsmoses commented Nov 26, 2025

@vchuravy @gbaraldi @oscardssmith are there any jl_global's passed as a ptr{ptr{jlvaluet}} or ptr{jlaluet} that are different to load dependeing on what thread you're on?

@codecov
Copy link

codecov bot commented Nov 26, 2025

Codecov Report

❌ Patch coverage is 86.09626% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 67.83%. Comparing base (30e0519) to head (3aae7c6).
⚠️ Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
src/errors.jl 43.47% 13 Missing ⚠️
src/api.jl 0.00% 4 Missing ⚠️
src/compiler/validation.jl 94.33% 3 Missing ⚠️
src/absint.jl 85.71% 2 Missing ⚠️
src/compiler/utils.jl 83.33% 2 Missing ⚠️
src/jlrt.jl 80.00% 1 Missing ⚠️
src/utils.jl 96.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2810      +/-   ##
==========================================
+ Coverage   67.79%   67.83%   +0.03%     
==========================================
  Files          58       58              
  Lines       20723    20801      +78     
==========================================
+ Hits        14050    14111      +61     
- Misses       6673     6690      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@vchuravy
Copy link
Member

jl_global's passed as a ptr{ptr{jlvaluet}} or ptr{jlaluet} that are different to load dependeing on what thread you're on?

I don't understand the question? The only object that is different to load is the ptls

@wsmoses
Copy link
Member Author

wsmoses commented Nov 26, 2025

Resolved ish, but a question I have is some jlvaluet are passed a s a pointer to the value others as a pointer to the pointer.

Sometimes these aren't initialized and the inner ptr is null and or the final value is nothing. I've reverse engineered that behavior and we don't load those such values but is good to confirm

@vchuravy
Copy link
Member

@vchuravy
Copy link
Member

Sometimes these aren't initialized and the inner ptr is null and or the final value is nothing. I've reverse engineered that behavior and we don't load those such values but is good to confirm

Could those be jl_binding_t? https://github.com/JuliaLang/julia/blob/ee6bb20bc86ec6fc9905e2769293ce1da4c0c4cf/src/aotcompile.cpp#L128

@wsmoses
Copy link
Member Author

wsmoses commented Nov 26, 2025

No they are not and we explicitly check for those (which appear to only happen on 1.10)

@wsmoses
Copy link
Member Author

wsmoses commented Nov 29, 2025

okay figured it out, we needed to add an atomic to the load

@wsmoses wsmoses merged commit 84f9728 into main Nov 30, 2025
55 of 56 checks passed
@wsmoses wsmoses deleted the relocg branch November 30, 2025 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants