Skip to content

Commit 294904d

Browse files
committed
opaque ptr err
1 parent 6a573ae commit 294904d

File tree

4 files changed

+56
-15
lines changed

4 files changed

+56
-15
lines changed

src/Enzyme.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,19 @@ function guess_activity end
135135
mutable struct EnzymeContext
136136
end
137137

138+
struct OpaquePointerError
139+
msg::String
140+
end
141+
142+
function Base.showerror(io::IO, ece::OpaquePointerError)
143+
if isdefined(Base.Experimental, :show_error_hints)
144+
Base.Experimental.show_error_hints(io, ece)
145+
end
146+
print(io, "OpaquePointerError: Enzyme execution failed to handle opaque pointers, with the following information:\n")
147+
print(io, ece.msg, '\n')
148+
end
149+
150+
138151
include("logic.jl")
139152
include("analyses/type.jl")
140153
include("typetree.jl")

src/compiler.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import Enzyme:
3232
add_edge!
3333
using Enzyme
3434

35+
import Enzyme: OpaquePointerError
36+
3537
import EnzymeCore
3638
import EnzymeCore: EnzymeRules, ABI, FFIABI, DefaultABI
3739

@@ -3779,7 +3781,13 @@ function lower_convention(
37793781
push!(wrapper_types, typ)
37803782
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
37813783
else
3782-
push!(wrapper_types, eltype(typ))
3784+
3785+
elty = convert(LLVMType, arg.typ)
3786+
if !LLVM.is_opaque(typ)
3787+
@assert elty == eltype(typ)
3788+
end
3789+
3790+
push!(wrapper_types, elty)
37833791
push!(wrapper_attrs, LLVM.Attribute[])
37843792
push!(loweredArgs, arg.arg_i)
37853793
end
@@ -3941,7 +3949,13 @@ function lower_convention(
39413949
),
39423950
)
39433951
end
3944-
ptr = alloca!(builder, eltype(ty), LLVM.name(parm) * ".innerparm")
3952+
3953+
elty = convert(LLVMType, arg.typ)
3954+
if !LLVM.is_opaque(ty)
3955+
@assert elty == eltype(ty)
3956+
end
3957+
3958+
ptr = alloca!(builder, elty, LLVM.name(parm) * ".innerparm")
39453959
if TT !== nothing && TT.parameters[arg.arg_i] <: Const
39463960
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
39473961
end

src/errors.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,6 @@ function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.AP
8282
end
8383
end
8484

85-
struct OpaquePointerError <: EnzymeError
86-
msg::String
87-
end
88-
89-
function Base.showerror(io::IO, ece::OpaquePointerError)
90-
if isdefined(Base.Experimental, :show_error_hints)
91-
Base.Experimental.show_error_hints(io, ece)
92-
end
93-
print(io, "OpaquePointerError: Enzyme execution failed to handle opaque pointers, with the following information:\n")
94-
print(io, ece.msg, '\n')
95-
end
96-
9785
struct EnzymeRuntimeException <: EnzymeError
9886
msg::Cstring
9987
end

src/utils.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,33 @@ export typed_fieldoffset
500500

501501
# returns the inner type of an sret/enzyme_sret/enzyme_sret_v
502502
function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
503-
return eltype(LLVM.value_type(LLVM.parameters(fn)[idx]))
503+
504+
vt = LLVM.value_type(LLVM.parameters(fn)[idx])
505+
506+
507+
for attr in collect(parameter_attributes(fn, idx))
508+
ekind = LLVM.kind(attr)
509+
510+
if ekind == "sret"
511+
return value(attr)
512+
end
513+
514+
if ekind == "enzyme_sret" || ekind == "enzyme_sret_v"
515+
if LLVM.is_opaque(vt)
516+
msg = sprint() do io
517+
println(io, "Failed to get sret type of function\n")
518+
println(io, "idx = ", string(idx))
519+
println(io, "vt = ", string(vt))
520+
println(io, "fn = ", string(fn))
521+
end
522+
throw(OpaquePointerError(msg))
523+
end
524+
525+
return eltype(vt)
526+
end
527+
end
528+
529+
throw(AssertionError("Function requesting sret type was not an sret"))
504530
end
505531

506532
export sret_ty

0 commit comments

Comments
 (0)