Skip to content

Commit 0c3bcb8

Browse files
committed
continue
1 parent cdc4f86 commit 0c3bcb8

File tree

2 files changed

+44
-41
lines changed

2 files changed

+44
-41
lines changed

src/llvm/transforms.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ function nodecayed_phis!(mod::LLVM.Module)
634634
push!(todo, inst)
635635
nb = IRBuilder()
636636
position!(nb, inst)
637-
el_ty = if addr == 11
637+
el_ty = if addr == 11 && !LLVM.is_opaque(ty)
638638
eltype(ty)
639639
else
640640
LLVM.StructType(LLVM.LLVMType[])
@@ -656,7 +656,7 @@ function nodecayed_phis!(mod::LLVM.Module)
656656

657657
for inst in todo
658658
ty = value_type(inst)
659-
el_ty = if addr == 11
659+
el_ty = if addr == 11 && !LLVM.is_opaque(ty)
660660
eltype(ty)
661661
else
662662
LLVM.StructType(LLVM.LLVMType[])
@@ -893,15 +893,17 @@ function nodecayed_phis!(mod::LLVM.Module)
893893
offset,
894894
API.EnzymeComputeByteOffsetOfGEP(b, v, offty),
895895
)
896-
v2 = bitcast!(
897-
b,
898-
v2,
899-
LLVM.PointerType(
900-
eltype(value_type(v)),
901-
addrspace(value_type(v2)),
902-
),
903-
)
904-
@assert eltype(value_type(v2)) == eltype(value_type(v))
896+
if !LLVM.is_opaque(value_type(v2))
897+
v2 = bitcast!(
898+
b,
899+
v2,
900+
LLVM.PointerType(
901+
eltype(value_type(v)),
902+
addrspace(value_type(v2)),
903+
),
904+
)
905+
@assert eltype(value_type(v2)) == eltype(value_type(v))
906+
end
905907
return v2, offset, skipload
906908
end
907909

@@ -1024,7 +1026,7 @@ function nodecayed_phis!(mod::LLVM.Module)
10241026
@assert hadload
10251027
end
10261028

1027-
if eltype(value_type(v)) != el_ty
1029+
if !LLVM.is_opaque(value_type(v)) && eltype(value_type(v)) != el_ty
10281030
v = bitcast!(
10291031
b,
10301032
v,
@@ -1772,14 +1774,23 @@ function propagate_returned!(mod::LLVM.Module)
17721774
eraseInst(LLVM.parent(c), c)
17731775
end
17741776
B = IRBuilder()
1777+
17751778
position!(B, first(instructions(first(blocks(fn)))))
17761779

1777-
argeltype = sret_ty(fn, i)
1778-
al = alloca!(B, argeltype)
1779-
if value_type(al) != value_type(arg)
1780-
al = addrspacecast!(B, al, value_type(arg))
1780+
has_use = false
1781+
for _ in LLVM.uses(arg)
1782+
has_use = true
1783+
break
1784+
end
1785+
1786+
if has_use
1787+
argeltype = sret_ty(fn, i)
1788+
al = alloca!(B, argeltype)
1789+
if value_type(al) != value_type(arg)
1790+
al = addrspacecast!(B, al, value_type(arg))
1791+
end
1792+
LLVM.replace_uses!(arg, al)
17811793
end
1782-
LLVM.replace_uses!(arg, al)
17831794
end
17841795
end
17851796

src/rules/customrules.jl

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,12 @@ function enzyme_custom_setup_args(
301301
LLVM.ConstantInt(LLVM.IntType(32), 0),
302302
],
303303
)
304-
if value_type(val) != eltype(value_type(ptr))
304+
305+
if !is_opaque(value_type(ptr))
306+
@assert eltype(value_type(ptr)) == arty
307+
end
308+
309+
if value_type(val) != arty
305310
val = load!(B, arty, val)
306311
end
307312
store!(B, val, ptr)
@@ -353,33 +358,20 @@ function enzyme_custom_setup_args(
353358
)
354359
end
355360

356-
msg = sprint() do io
357-
print(io, "custom rule lower failure fwd\n")
358-
print(io, "Ty = $Ty\n")
359-
print(io, "llty = $llty\n")
360-
print(io, "arty = $arty\n")
361-
print(io, "al0 = $al0\n")
362-
print(io, "ptr = $ptr\n")
363-
print(io, "val = $val\n")
364-
print(io, "arg = $arg\n")
365-
end
366-
throw(OpaquePointerError(msg))
361+
if !LLVM.is_opaque(value_type(val))
362+
if arty != eltype(value_type(val))
363+
msg = sprint() do io
364+
println(io, "Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr)) arty=$arty")
365+
end
367366

368-
if arty == eltype(value_type(val))
369-
val = load!(B, arty, val)
370-
else
371-
bt = GPUCompiler.backtrace(orig)
372-
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
373-
val = LLVM.UndefValue(arty)
374-
emit_error(
375-
B,
376-
orig,
377-
"Enzyme: active by ref type $Ty is wrong type in application of custom rule for $mi val=$(string(val)) ptr=$(string(ptr))\n"*msg2,
378-
)
367+
EnzymeInternalError(msg, ir, bt)
368+
end
379369
end
370+
371+
val = load!(B, arty, val)
380372
end
381373

382-
if eltype(value_type(ptr)) == value_type(val)
374+
if arty == value_type(val)
383375
store!(B, val, ptr)
384376
if any_jltypes(llty)
385377
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))

0 commit comments

Comments
 (0)