-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Description
When new thunks are generated, Enzyme fails to merge the constants with the same value in the extended module.
I expect to merge those constants with the same value. If constants with different values are encountered an error must be raised.
using Enzyme, Clang_jll, Libdl
const FUNC_LLVM_IR = raw"""
; ModuleID = '<stdin>'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"
@.const.array.data.5 = hidden unnamed_addr constant [48 x i8] c"\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D0?\00\00\00\00\00\00\D0?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\E8?\00\00\00\00\00\00\F0?", align 8
; Function Attrs: argmemonly nofree norecurse nosync nounwind
define i32 @func(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.t, i8* nocapture readnone %arg.arr.0, i8* nocapture readnone %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* nocapture readonly %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) local_unnamed_addr #0 {
B0.endif.endif:
%.449 = fcmp oeq double %arg.t, 1.000000e+00
%.241 = fcmp ult double %arg.t, 0.000000e+00
%.338 = fcmp uge double %arg.t, 2.500000e-01
%not.or.cond = or i1 %.241, %.338
%_ind.5.1 = sext i1 %not.or.cond to i64
%_ind.4.1 = select i1 %.449, i64 2, i64 %_ind.5.1
%.241.1 = fcmp oge double %arg.t, 2.500000e-01
%.338.1 = fcmp olt double %arg.t, 7.500000e-01
%or.cond6 = and i1 %.241.1, %.338.1
%_ind.5.1.1 = select i1 %or.cond6, i64 1, i64 %_ind.4.1
%_ind.4.1.1 = select i1 %.449, i64 2, i64 %_ind.5.1.1
%.466 = icmp eq i64 %_ind.4.1.1, -1
br i1 %.466, label %common.ret, label %B162
common.ret: ; preds = %B162, %B0.endif.endif
%.842.sink = phi double [ 0.000000e+00, %B0.endif.endif ], [ %.842, %B162 ]
store double %.842.sink, double* %retptr, align 8
ret i32 0
B162: ; preds = %B0.endif.endif
%0 = shl nuw nsw i64 %_ind.4.1.1, 1
%.560 = or i64 %0, 1
%.561 = getelementptr double, double* bitcast ([48 x i8]* @.const.array.data.5 to double*), i64 %.560
%.562 = load double, double* %.561, align 8
%.650 = getelementptr double, double* bitcast ([48 x i8]* @.const.array.data.5 to double*), i64 %0
%.651 = load double, double* %.650, align 8
%.657 = fsub double %.562, %.651
%.740 = getelementptr double, double* bitcast ([48 x i8]* @.const.array.data.5 to double*), i64 %0
%.741 = load double, double* %.740, align 8
%.748 = fsub double %arg.t, %.741
%.752 = fmul double %.748, 1.000000e+02
%.756 = fdiv double %.752, %.657
%.757 = fptosi double %.756 to i64
%.762 = mul nuw nsw i64 %_ind.4.1.1, 100
%.765 = add nsw i64 %.762, %.757
%.771 = icmp eq i64 %.765, 300
%.826 = icmp slt i64 %.765, 0
%.827 = select i1 %.826, i64 %arg.arr.5.0, i64 0
%.828 = add i64 %.827, %.765
%.789 = add i64 %arg.arr.5.0, -1
%.828.sink = select i1 %.771, i64 %.789, i64 %.828
%.841 = getelementptr double, double* %arg.arr.4, i64 %.828.sink
%.842 = load double, double* %.841, align 8
br label %common.ret
}
attributes #0 = { argmemonly nofree norecurse nosync nounwind }
"""
sopath = "./func.so"
run(pipeline(
`$(clang()) -x ir - -Xclang -no-opaque-pointers -O3 -fPIC -fembed-bitcode -shared -o $(sopath)`;
stdin=IOBuffer(FUNC_LLVM_IR)
)
)
# load the function pointer
lib = Libdl.dlopen(sopath)
const fptr = Libdl.dlsym(lib, :func)
function func_ccall(t::Float64, arr::AbstractVector{Float64})
nitems = length(arr)
bitsize = Base.elsize(arr)
GC.@preserve arr begin
excinfo = Ptr{Ptr{Nothing}}()
res = Ref{Float64}()
status = ccall(fptr, Cint,
(Ref{Cdouble}, Ptr{Ptr{Cvoid}},
Cdouble, Ptr{Cvoid}, Ptr{Cvoid},
Clong, Clong, Ptr{Cdouble}, Clong, Clong),
res, excinfo, t, C_NULL, C_NULL, nitems, bitsize,
Base.unsafe_convert(Ptr{Cdouble}, arr), nitems, bitsize)
status == 0 || error("returned non-zero status: $status")
res[]
end
end
# .const.array.data.5 corresponds to the following
# tspans = [[0.0, 0.25], [0.25, 0.75], [0.75, 1.0]]
const GRID_SIZE = 100
const SEG_NUM = 3
const a = rand(SEG_NUM * GRID_SIZE)
ad = similar(a)
autodiff(Reverse, func_ccall, Active, Active(1.0), Const(a))
# returns ((0.0, nothing),)
autodiff(Reverse, func_ccall, Active, Const(1.0), Duplicated(a, ad))
# ERROR: LLVM error: Duplicate definition of symbol '.const.array.data.5'
# Stacktrace:
# [1] macro expansion
# @ ~/.julia/packages/LLVM/iza6e/src/executionengine/utils.jl:28 [inlined]
# [2] add!
# @ ~/.julia/packages/LLVM/iza6e/src/orc.jl:434 [inlined]
# [3] add!(mod::LLVM.Module)
# @ Enzyme.Compiler.JIT ~/projects/qruise/Enzyme.jl/src/compiler/orcv2.jl:264
# [4] _link(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, edges::Vector{…}, adjoint_name::String, primal_name::Union{…}, TapeType::Any, prepost::String)
# @ Enzyme.Compiler ~/projects/qruise/Enzyme.jl/src/compiler.jl:5919
# [5] cached_compilation
# @ ~/projects/qruise/Enzyme.jl/src/compiler.jl:6012 [inlined]
# [6] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
# @ Enzyme.Compiler ~/projects/qruise/Enzyme.jl/src/compiler.jl:6127
# [7] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
# @ Enzyme.Compiler ~/projects/qruise/Enzyme.jl/src/compiler.jl:6271
# [8] autodiff
# @ ~/projects/qruise/Enzyme.jl/src/Enzyme.jl:502 [inlined]
# [9] autodiff(::ReverseMode{…}, ::typeof(func_ccall), ::Type{…}, ::Const{…}, ::Duplicated{…})
# @ Enzyme ~/projects/qruise/Enzyme.jl/src/Enzyme.jl:542
# [10] top-level scope
# @ REPL[14]:1
# Some type information was truncated. Use `show(err)` to see complete types.Package information
julia> versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × AMD Ryzen 5 7640U w/ Radeon 760M Graphics
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver4)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)
julia> pkgversion(Enzyme)
v"0.13.93"Metadata
Metadata
Assignees
Labels
No labels