Skip to content

Commit 805538e

Browse files
committed
parmtype
1 parent 9f7cb7d commit 805538e

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/utils.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
510510
end)
511511

512512

513+
enzymejl_parmtype_ref = nothing
514+
enzymejl_parmtype = nothing
513515

514516
for attr in collect(LLVM.parameter_attributes(fn, idx))
515517
ekind = LLVM.kind(attr)
@@ -559,9 +561,28 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
559561

560562
return eltype(vt)
561563
end
564+
565+
566+
if ekind == "enzymejl_parmtype_ref"
567+
enzymejl_parmtype_ref = GPUCompiler.ArgumentCC(parse(UInt, LLVM.value(fattr)))
568+
continue
569+
end
570+
571+
if ekind == "enzymejl_parmtype"
572+
ptr = reinterpret(Ptr{Cvoid}, parse(UInt, LLVM.value(fattr)))
573+
enzymejl_parmtype = Base.unsafe_pointer_to_objref(ptr)::Type
574+
end
575+
end
576+
577+
if enzymejl_parmtype_ref == GPUCompiler.BITS_REF && enzymejl_parmtype !== nothing
578+
res = convert(LLVMType, enzymejl_parmtype)
579+
if !LLVM.is_opaque(vt)
580+
@assert eltype(vt) == res
581+
end
582+
return res
562583
end
563584

564-
throw(AssertionError("Function requesting sret type was not an sret\nidx=$idx\nfn=$(string(fn))"))
585+
throw(AssertionError("Function requesting sret type was not an sret\nidx=$idx\nfn=$(string(fn)) enzymejl_parmtype=$enzymejl_parmtype enzymejl_parmtype_ref=$enzymejl_parmtype_ref"))
565586
end
566587

567588
export sret_ty

0 commit comments

Comments
 (0)