Skip to content
37 changes: 36 additions & 1 deletion src/compiler/orcv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,35 @@ function absolute_symbol_materialization(name, ptr)
return LLVM.absolute_symbols(Ref(gv))
end

const hnd_string_map = Dict{String, Ref{Ptr{Cvoid}}}()
const hnd_string_map = Dict{String,Ref{Ptr{Cvoid}}}()

# These are special (external or private) global
# constants that should not be incremented (renamed)
const global_var_prefixes = ("ejl_enz_", "ejl_jl_", "enz_exception", "_j_const_", "jl_", "_j_str")

# store non-special external global constants
# They will be incremented to produce new names
const glob_vars_maps = Dict{String,Int}()

# TODO: may not be necessary once the core
# issue is found regarding linkage type of
# private constants is change to external
# when ccall is used.
function rename_global!(glob_var)
_name = LLVM.name(glob_var)
if any(startswith.(_name, global_var_prefixes))
else
if haskey(glob_vars_maps, _name)
glob_vars_maps[_name] += 1
_new_name = _name * ".$(glob_vars_maps[_name])"
LLVM.name!(glob_var, _new_name)
glob_vars_maps[_new_name] = 1
else
glob_vars_maps[_name] = 1
end

end
end

function fix_ptr_lookup(name)
if startswith(name, "ejlstr\$") || startswith(name, "ejlptr\$")
Expand Down Expand Up @@ -258,6 +286,13 @@ function add!(mod)
replace_uses!(f, ptr)
Compiler.eraseInst(mod, f)
end

# rename non-special global constants that
# have external linkage (modified by the ccall
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternate solution, can we just mark all defined symbols during the lto import phase as internal?

we should already do so for functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x/ref

linkage!(g, LLVM.API.LLVMExternalLinkage)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worked like a charm! All the tests passed locally. This also explains why something like

@.const.array.data.5 = private unnamed_addr constant [48 x i8] c"\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D0?\00\00\00\00\00\00\D0?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\F0?", align 8

turns into

@.const.array.data.5 = dso_local unnamed_addr constant [48 x i8] c"\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D0?\00\00\00\00\00\00\D0?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\F0?", align 8

What was the initial thought behind changing the linkage to external in the first place?

# execution path)
for glob_var in collect(globals(mod))
rename_global!(glob_var)
end
lljit = jit[].jit
jd = LLVM.JITDylib(lljit)
tsm = move_to_threadsafe(mod)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Expand All @@ -12,6 +13,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
Expand Down
72 changes: 72 additions & 0 deletions test/global_constants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Enzyme
using Libdl
using Test

const LLVM_IR = raw"""
; ModuleID = '<stdin>'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"


@A = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 8

define double @func(double %x, double %y, i64 %n) {
entry:
%ptr = getelementptr inbounds [3 x double], [3 x double]* @A, i64 0, i64 %n
%aval = load double, double* %ptr, align 8
%prod = fmul double %x, %aval
%sum = fadd double %prod, %y
ret double %sum
}
"""

tmp_dir = tempdir()
tmp_so_file = joinpath(tmp_dir, "func.so")

run(
pipeline(
`clang -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o $(tmp_so_file)`;
stdin=IOBuffer(LLVM_IR)
)
);
lib = Libdl.dlopen(tmp_so_file);
const fptr = Libdl.dlsym(lib, :func);


function func_llvm(x::Float64, y::Float64, n::Int)
n >= 0 && n <= 2 || throw("0 ≤ n ≤ 2")
Base.llvmcall((LLVM_IR, "func"), Cdouble,
Tuple{Cdouble,Cdouble,Clong},
x, y, n
)
end;


function func_ccall(x::Float64, y::Float64, n::Int)
n >= 0 && n <= 2 || throw("0 ≤ n ≤ 2")
ccall(fptr, Cdouble,
(Cdouble, Cdouble, Clong),
x, y, n
)
end;

@testset "Rename external global constants ccall" begin

x = 2.0
y = 1.0
n = 2
A = [1.0, 2.0, 3.0]

@test func_llvm(x, y, n) == func_ccall(x, y, n)
@test func_llvm(x, y, n) == x * A[n+1] + y
@test func_ccall(x, y, n) == x * A[n+1] + y



@test gradient(Reverse, func_llvm, Const(x), y, Const(n)) == (nothing, 1.0, nothing)
@test gradient(Reverse, func_llvm, x, Const(y), Const(n)) == (3.0, nothing, nothing)

@test gradient(Reverse, func_ccall, Const(x), y, Const(n)) == (nothing, 1.0, nothing)
@test gradient(Reverse, func_ccall, x, Const(y), Const(n)) == (3.0, nothing, nothing)
end