Skip to content

Commit af53c37

Browse files
committed
unbind info
1 parent b0241b0 commit af53c37

File tree

7 files changed

+49
-12
lines changed

7 files changed

+49
-12
lines changed

src/absint.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
const JL_MAX_TAGS = 64 # see `enum jl_small_typeof_tags` in julia.h
77

8+
function unbind(@nospecialize(val))
9+
if val isa Core.Binding
10+
return val.value
11+
else
12+
return val
13+
end
14+
end
15+
816
function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false, typetag::Bool=false)::Tuple{Bool, Any}
917
if (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked)) || (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Derived)) || istracked
1018
ce, _ = get_base_and_offset(arg; offsetAllowed = false, inttoptr = true)
@@ -455,12 +463,14 @@ function abs_typeof(
455463
nm == "jl_gc_alloc_typed" ||
456464
nm == "ijl_gc_alloc_typed"
457465
vals = absint(operands(arg)[3], partial, false, #=typetag=#true)
466+
@assert !(vals[2] isa Core.Binding)
458467
return (vals[1], vals[2], vals[1] ? GPUCompiler.BITS_REF : nothing)
459468
end
460469
# Type tag is arg 3
461470
if nm == "jl_alloc_genericmemory_unchecked" ||
462471
nm == "ijl_alloc_genericmemory_unchecked"
463472
vals = absint(operands(arg)[3], partial, true, #=typetag=#true)
473+
@assert !(vals[2] isa Core.Binding)
464474
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
465475
end
466476
# Type tag is arg 1
@@ -475,11 +485,13 @@ function abs_typeof(
475485
nm == "jl_alloc_genericmemory" ||
476486
nm == "ijl_alloc_genericmemory"
477487
vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
488+
@assert !(vals[2] isa Core.Binding)
478489
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
479490
end
480491

481492
if nm == "jl_new_structt" || nm == "ijl_new_structt"
482493
vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
494+
@assert !(vals[2] isa Core.Binding)
483495
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
484496
end
485497

@@ -498,6 +510,7 @@ function abs_typeof(
498510
if nm == "jl_new_structv" || nm == "ijl_new_structv"
499511
@assert index == 2
500512
vals = absint(operands(arg)[index], partial, false, #=typetag=#true)
513+
@assert !(vals[2] isa Core.Binding)
501514
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
502515
end
503516

@@ -531,9 +544,11 @@ function abs_typeof(
531544
if nm == "jl_f__apply_iterate" || nm == "ijl_f__apply_iterate"
532545
index += 1
533546
legal, iterfn = absint(operands(arg)[index])
547+
iterfn = unbind(iterfn)
534548
index += 1
535549
if legal && iterfn == Base.iterate
536550
legal0, combfn = absint(operands(arg)[index])
551+
combfn = unbind(combfn)
537552
index += 1
538553
if legal0 && combfn == Core.apply_type && partial
539554
return (true, Type, GPUCompiler.BITS_REF)
@@ -871,6 +886,7 @@ function abs_typeof(
871886

872887
legal, val = absint(arg, partial)
873888
if legal
889+
val = unbind(val)
874890
return (true, Core.Typeof(val), GPUCompiler.BITS_REF)
875891
end
876892
return (false, nothing, nothing)

src/compiler.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,8 +1694,10 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
16941694
index += 1
16951695
found = Any[]
16961696
legal, Ty = absint(operands(arg)[index], partial)
1697+
Ty = unbind(Ty)
16971698
if legal && Ty == NTuple
16981699
legal, Ty = absint(operands(arg)[index+2])
1700+
Ty = unbind(Ty)
16991701
if legal
17001702
# count should represent {the total size in bytes, the aligned size of each element}
17011703
B = LLVM.IRBuilder()
@@ -5400,6 +5402,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
54005402
@static if VERSION < v"1.11-"
54015403
else
54025404
legal2, obj = absint(inst)
5405+
obj = unbind(obj)
54035406
if legal2 && is_memory_instance(obj)
54045407
metadata(inst)["nonnull"] = MDNode(LLVM.Metadata[])
54055408
end
@@ -5629,6 +5632,7 @@ end
56295632
string(cur)
56305633
slegal, foundv = absint(cur)
56315634
if slegal
5635+
foundv = unbind(foundv)
56325636
resstr *= "of type " * string(foundv)
56335637
end
56345638
emit_error(builder, user, resstr, EnzymeMutabilityException)
@@ -6464,6 +6468,7 @@ const DumpLLVMCall = Ref(false)
64646468
end
64656469
reinsert_gcmarker!(llvm_f)
64666470

6471+
Enzyme.Compiler.JIT.prepare!(mod)
64676472
if DumpLLVMCall[]
64686473
API.EnzymeDumpModuleRef(mod.ref)
64696474
end

src/compiler/validation.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
278278
b = IRBuilder()
279279
position!(b, inst)
280280
ccall(:jl_, Cvoid, (Any,), (gname, obj, obj0, ptr, string(addr)))
281-
newf = unsafe_to_llvm(b, obj; insert_name_if_not_exists=gname)
281+
newf = unsafe_to_llvm(b, obj0; insert_name_if_not_exists=gname)
282282
replace_uses!(inst, newf)
283283
LLVM.API.LLVMInstructionEraseFromParent(inst)
284284
continue
@@ -835,7 +835,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
835835
if isa(flib, LLVM.ConstantExpr)
836836
legal, flib2 = absint(flib)
837837
if legal
838-
flib = flib2
838+
flib = unbind(flib2)
839839
end
840840
end
841841
if isa(flib, GlobalRef) && isdefined(flib.mod, flib.name)
@@ -986,6 +986,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
986986
iteroff = 2
987987

988988
legal, iterlib = absint(operands(inst)[iteroff+1])
989+
iterlib = unbind(iterlib)
989990
if legal && iterlib == Base.iterate
990991
legal, GT, byref = abs_typeof(operands(inst)[4+1], true)
991992
funcoff = 3
@@ -1075,6 +1076,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
10751076
push!(tys, typ)
10761077
end
10771078
legal, flib = absint(operands(inst)[offset+1])
1079+
flib = unbind(flib)
10781080
if legal && isa(flib, Core.MethodInstance)
10791081
if !Base.isvarargtype(flib.specTypes.parameters[end])
10801082
@assert length(tys) == length(flib.specTypes.parameters)
@@ -1229,6 +1231,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
12291231
push!(tys, typ)
12301232
end
12311233
legal, flib = absint(operands(inst)[offset+1])
1234+
flib = unbind(flib)
12321235
if legal && isa(flib, Core.MethodInstance)
12331236
if !Base.isvarargtype(flib.specTypes.parameters[end])
12341237
if length(tys) != length(flib.specTypes.parameters)

src/errors.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,13 @@ function julia_error(
10701070
end
10711071
legal, obj = absint(val)
10721072
if legal
1073-
println(io, "\nValue of type: ", Core.Typeof(val))
1074-
println(io , " of value : ", val)
1073+
obj0 = obj
1074+
obj = unbind(obj)
1075+
println(io, "\nValue of type: ", Core.Typeof(obj))
1076+
println(io , " of value : ", obj)
1077+
if obj0 isa Core.Binding
1078+
println(io , " binding : ", obj0)
1079+
end
10751080
println(io)
10761081
end
10771082
if !isa(val, LLVM.Argument) && !isa(val, LLVM.GlobalVariable)
@@ -1297,9 +1302,10 @@ function julia_error(
12971302
end
12981303

12991304
legal2, obj = absint(cur)
1300-
1305+
obj0 = obj
13011306
# Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple
13021307
if legal2
1308+
obj = unbind(obj)
13031309
if is_memory_instance(obj)
13041310
return make_batched(ncur, prevbb)
13051311
end
@@ -1337,9 +1343,6 @@ function julia_error(
13371343
end
13381344
end
13391345

1340-
if is_memory_instance(obj)
1341-
return make_batched(ncur, prevbb)
1342-
end
13431346
end
13441347

13451348
@static if VERSION < v"1.11-"
@@ -1348,6 +1351,7 @@ else
13481351
larg, off = get_base_and_offset(operands(cur)[1])
13491352
if isa(larg, LLVM.LoadInst)
13501353
legal2, obj = absint(larg)
1354+
obj = unbind(obj)
13511355
if legal2 && is_memory_instance(obj)
13521356
return make_batched(ncur, prevbb)
13531357
end
@@ -1356,7 +1360,10 @@ else
13561360
end
13571361

13581362
badval = if legal2
1359-
string(obj) * " of type" * " " * string(TT)
1363+
sv = string(obj) * " of type" * " " * string(TT)
1364+
if obj0 isa Core.Binding
1365+
sv = sv *" binded at "*string(obj0)
1366+
end
13601367
else
13611368
"Unknown object of type" * " " * string(TT)
13621369
end

src/jlrt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ function emit_apply_type!(B::LLVM.IRBuilder, @nospecialize(Ty::Type), args::Vect
454454
for arg in args
455455
slegal, foundv = absint(arg)
456456
if slegal
457-
push!(found, foundv)
457+
push!(found, unbind(foundv))
458458
else
459459
legal = false
460460
break
@@ -509,7 +509,7 @@ function emit_tuple!(B::LLVM.IRBuilder, args::Vector{LLVM.Value})::LLVM.Value
509509
for arg in args
510510
slegal, foundv = absint(arg)
511511
if slegal
512-
push!(found, foundv)
512+
push!(found, unbind(foundv))
513513
else
514514
legal = false
515515
break
@@ -871,6 +871,7 @@ function emit_layout_of_type!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
871871
ls = get_layout_struct()
872872
lptr = LLVM.PointerType(ls, 10)
873873
if legal
874+
JTy = unbind(JTy)
874875
return LLVM.const_inttoptr(LLVM.ConstantInt(Base.reinterpret(UInt, JTy.layout)), lptr)
875876
end
876877
@assert !isa(ty, LLVM.ConstantExpr)
@@ -889,6 +890,7 @@ end
889890
function emit_type_layout_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
890891
legal, JTy = absint(ty)
891892
if legal
893+
JTy = unbind(JTy)
892894
@assert JTy isa Type
893895
res = Compiler.datatype_layoutsize(JTy)
894896
return LLVM.ConstantInt(res)

src/llvm/transforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
442442
@static if VERSION < v"1.11-"
443443
else
444444
legal2, obj = absint(src)
445-
if legal2 && is_memory_instance(obj)
445+
if legal2 && is_memory_instance(unbind(obj))
446446
metadata(src)["nonnull"] = MDNode(LLVM.Metadata[])
447447
end
448448
end

src/rules/jitrules.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,9 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR)
23612361
end
23622362

23632363
v, isiter = absint(operands(orig)[offset+1])
2364+
isiter = unbind(isiter)
23642365
v2, istup = absint(operands(orig)[offset+2])
2366+
istup = unbind(istup)
23652367

23662368
width = get_width(gutils)
23672369

@@ -2513,6 +2515,8 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR,
25132515

25142516
v, isiter = absint(operands(orig)[offset+1])
25152517
v2, istup = absint(operands(orig)[offset+2])
2518+
isiter = unbind(isiter)
2519+
istup = unbind(istup)
25162520

25172521
width = get_width(gutils)
25182522

0 commit comments

Comments
 (0)