Skip to content
2 changes: 1 addition & 1 deletion src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ function try_import_llvmbc(mod::LLVM.Module, flib::String, fname::String, import
end
end
for g in globals(inmod)
linkage!(g, LLVM.API.LLVMExternalLinkage)
linkage!(g, LLVM.API.LLVMInternalLinkage)
end
# override libdevice's triple and datalayout to avoid warnings
triple!(inmod, triple(mod))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Clang_jll = "0ee61d77-7f21-5576-8119-9fcc46b10100"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Expand All @@ -12,6 +13,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
Expand Down
72 changes: 72 additions & 0 deletions test/global_constants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Enzyme
using Libdl
using Test

const LLVM_IR = raw"""
; ModuleID = '<stdin>'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"


@A = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 8

define double @func(double %x, double %y, i64 %n) {
entry:
%ptr = getelementptr inbounds [3 x double], [3 x double]* @A, i64 0, i64 %n
%aval = load double, double* %ptr, align 8
%prod = fmul double %x, %aval
%sum = fadd double %prod, %y
ret double %sum
}
"""

tmp_dir = tempdir()
tmp_so_file = joinpath(tmp_dir, "func.so")

run(
pipeline(
`clang -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o $(tmp_so_file)`;
stdin=IOBuffer(LLVM_IR)
)
);
lib = Libdl.dlopen(tmp_so_file);
const fptr = Libdl.dlsym(lib, :func);


function func_llvm(x::Float64, y::Float64, n::Int)
n >= 0 && n <= 2 || throw("0 ≤ n ≤ 2")
Base.llvmcall((LLVM_IR, "func"), Cdouble,
Tuple{Cdouble,Cdouble,Clong},
x, y, n
)
end;


function func_ccall(x::Float64, y::Float64, n::Int)
n >= 0 && n <= 2 || throw("0 ≤ n ≤ 2")
ccall(fptr, Cdouble,
(Cdouble, Cdouble, Clong),
x, y, n
)
end;

@testset "Rename external global constants ccall" begin

x = 2.0
y = 1.0
n = 2
A = [1.0, 2.0, 3.0]

@test func_llvm(x, y, n) == func_ccall(x, y, n)
@test func_llvm(x, y, n) == x * A[n+1] + y
@test func_ccall(x, y, n) == x * A[n+1] + y



@test gradient(Reverse, func_llvm, Const(x), y, Const(n)) == (nothing, 1.0, nothing)
@test gradient(Reverse, func_llvm, x, Const(y), Const(n)) == (3.0, nothing, nothing)

@test gradient(Reverse, func_ccall, Const(x), y, Const(n)) == (nothing, 1.0, nothing)
@test gradient(Reverse, func_ccall, x, Const(y), Const(n)) == (3.0, nothing, nothing)
end
Loading