Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ concurrency:
jobs:
test:
timeout-minutes: 120
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.libEnzyme }} libEnzyme - assertions=${{ matrix.assertions }} llvm_args=${{ matrix.llvm_args }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -47,6 +47,8 @@ jobs:
- ubuntu-24.04
- macOS-latest
- windows-latest
llvm_args:
- ''
arch:
- default
assertions:
Expand All @@ -72,6 +74,21 @@ jobs:
libEnzyme: packaged
version: '1.11'
assertions: true

- os: ubuntu-24.04
arch: default
libEnzyme: packaged
version: '1.10'
assertions: true
llvm_args: '--opaque-pointers'

- os: ubuntu-24.04
arch: default
libEnzyme: packaged
version: '1.11'
assertions: false
llvm_args: '--opaque-pointers'

- os: ubuntu-24.04
arch: default
libEnzyme: packaged
Expand Down Expand Up @@ -148,6 +165,7 @@ jobs:
test_args: ${{ env.runtest_test_args }}
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
JULIA_LLVM_ARGS: ${{ matrix.llvm_args }}
- uses: julia-actions/julia-processcoverage@v1
if: matrix.version != 'nightly' || steps.run_tests.outcome == 'success'
- uses: codecov/codecov-action@v5
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.101"
version = "0.13.102"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -42,7 +42,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.16"
Enzyme_jll = "0.0.211"
Enzyme_jll = "0.0.213"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
LLVM = "9.1"
Expand Down
13 changes: 13 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ function guess_activity end
mutable struct EnzymeContext
end

struct OpaquePointerError
msg::String
end

function Base.showerror(io::IO, ece::OpaquePointerError)
if isdefined(Base.Experimental, :show_error_hints)
Base.Experimental.show_error_hints(io, ece)
end
print(io, "OpaquePointerError: Enzyme execution failed to handle opaque pointers, with the following information:\n")
print(io, ece.msg, '\n')
end


include("logic.jl")
include("analyses/type.jl")
include("typetree.jl")
Expand Down
61 changes: 44 additions & 17 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import Enzyme:
add_edge!
using Enzyme

import Enzyme: OpaquePointerError

import EnzymeCore
import EnzymeCore: EnzymeRules, ABI, FFIABI, DefaultABI

Expand Down Expand Up @@ -511,8 +513,8 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta)

RT = return_type(interp, mi)

_, _, returnRoots = get_return_info(RT)
returnRoots = returnRoots !== nothing
_, _, returnRoots0 = get_return_info(RT)
returnRoots = returnRoots0 !== nothing

attributes = function_attributes(llvmfn)
push!(
Expand All @@ -526,8 +528,19 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta)
if EnzymeRules.has_easy_rule_from_sig(Interpreter.simplify_kw(mi.specTypes); job.world)
push!(attributes, LLVM.StringAttribute("enzyme_LocalReadOnlyOrThrow"))
end

if is_sret_union(RT)
attr = StringAttribute("enzymejl_sret_union_bytes", string(union_alloca_type(RT)))
push!(parameter_attributes(llvmfn, 1), attr)
for u in LLVM.uses(llvmfn)
u = LLVM.user(u)
@assert isa(u, LLVM.CallInst)
LLVM.API.LLVMAddCallSiteAttribute(u, LLVM.API.LLVMAttributeIndex(1), attr)
end
end

if returnRoots
attr = StringAttribute("enzymejl_returnRoots", "")
attr = StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1])))
push!(parameter_attributes(llvmfn, 2), attr)
for u in LLVM.uses(llvmfn)
u = LLVM.user(u)
Expand Down Expand Up @@ -3779,7 +3792,13 @@ function lower_convention(
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, eltype(typ))

elty = convert(LLVMType, arg.typ)
if !LLVM.is_opaque(typ)
@assert elty == eltype(typ)
end

push!(wrapper_types, elty)
push!(wrapper_attrs, LLVM.Attribute[])
push!(loweredArgs, arg.arg_i)
end
Expand Down Expand Up @@ -3865,7 +3884,9 @@ function lower_convention(
res = call!(builder, LLVM.function_type(wrapper_f), wrapper_f, nops)
callconv!(res, callconv(wrapper_f))
if sret
@assert value_type(res) == eltype(value_type(ops[1]))
if !LLVM.is_opaque(value_type(ops[1]))
@assert value_type(res) == eltype(value_type(ops[1]))
end
store!(builder, res, ops[1])
else
LLVM.replace_uses!(ci, res)
Expand Down Expand Up @@ -3899,7 +3920,7 @@ function lower_convention(
if !in(0, parmsRemoved)
sretPtr = alloca!(
builder,
eltype(value_type(parameters(entry_f)[1])),
sret_ty(entry_f, 1),
"innersret",
)
ctx = LLVM.context(entry_f)
Expand All @@ -3916,7 +3937,7 @@ function lower_convention(
if returnRoots && !in(1, parmsRemoved)
retRootPtr = alloca!(
builder,
eltype(value_type(parameters(entry_f)[1+sret])),
sret_ty(entry_f, 1+sret),
"innerreturnroots",
)
# retRootPtr = alloca!(builder, parameters(wrapper_f)[1])
Expand All @@ -3941,7 +3962,13 @@ function lower_convention(
),
)
end
ptr = alloca!(builder, eltype(ty), LLVM.name(parm) * ".innerparm")

elty = convert(LLVMType, arg.typ)
if !LLVM.is_opaque(ty)
@assert elty == eltype(ty)
end

ptr = alloca!(builder, elty, LLVM.name(parm) * ".innerparm")
if TT !== nothing && TT.parameters[arg.arg_i] <: Const
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
Expand All @@ -3954,7 +3981,7 @@ function lower_convention(
if LLVM.addrspace(ty) != 0
ptr = addrspacecast!(builder, ptr, ty)
end
@assert eltype(ty) == value_type(wrapparm)
@assert elty == value_type(wrapparm)
store!(builder, wrapparm, ptr)
push!(wrapper_args, ptr)
push!(
Expand Down Expand Up @@ -4566,10 +4593,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
found = String[]
if bitcode_replacement() &&
API.EnzymeBitcodeReplacement(mod, disableFallback, found) != 0
ModulePassManager() do pm
instruction_combining!(pm)
LLVM.run!(pm, mod)
end
run!(InstCombinePass(), mod)
toremove = String[]
for f in functions(mod)
if !has_fn_attr(f, EnumAttribute("alwaysinline"))
Expand Down Expand Up @@ -4696,10 +4720,10 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
end
end

_, _, returnRoots = get_return_info(rt)
returnRoots = returnRoots !== nothing
_, _, returnRoots0 = get_return_info(rt)
returnRoots = returnRoots0 !== nothing
if returnRoots
attr = StringAttribute("enzymejl_returnRoots", "")
attr = StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1])))
push!(parameter_attributes(wrapper_f, 2), attr)
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(2), attr)
end
Expand Down Expand Up @@ -5869,7 +5893,10 @@ end
EnumAttribute("sret")
end
LLVM.API.LLVMAddCallSiteAttribute(r, LLVM.API.LLVMAttributeIndex(1), attr)
r = load!(builder, eltype(value_type(callparams[1])), callparams[1])
if !LLVM.is_opaque(value_type(callparams[1]))
@assert eltype(value_type(callparams[1])) == jltype
end
r = load!(builder, jltype, callparams[1])
end

if T_ret != T_void
Expand Down
6 changes: 5 additions & 1 deletion src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
method_table = Core.Compiler.method_table(interp)
bt = backtrace(inst)
dest = called_operand(inst)

if isa(dest, LLVM.PHIInst) && all(Base.Fix1(==, operands(dest)[1]), operands(dest))
dest = operands(dest)[1]
end
if isa(dest, LLVM.ConstantExpr) && opcode(dest) == LLVM.API.LLVMIntToPtr && isa(operands(dest)[1], LLVM.ConstantExpr) && opcode(operands(dest)[1]) == LLVM.API.LLVMPtrToInt
dest = operands(operands(dest)[1])[1]
end
Expand Down Expand Up @@ -1144,7 +1148,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
else
false, nothing
end

lfn = nothing
if found
lfn = replaceWith
Expand Down
Loading
Loading