Skip to content

Commit 0edeb21

Browse files
authored
Merge branch 'main' into vc/string2
2 parents 11792ec + 4656760 commit 0edeb21

File tree

6 files changed

+143
-59
lines changed

6 files changed

+143
-59
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Enzyme"
22
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.13.30"
4+
version = "0.13.35"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -39,9 +39,9 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
3939
CEnum = "0.4, 0.5"
4040
ChainRulesCore = "1"
4141
EnzymeCore = "0.8.8"
42-
Enzyme_jll = "0.0.172"
42+
Enzyme_jll = "0.0.173"
4343
GPUArraysCore = "0.1.6, 0.2"
44-
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
44+
GPUCompiler = "1.3"
4545
LLVM = "6.1, 7, 8, 9"
4646
LogExpFunctions = "0.3"
4747
ObjectFile = "0.4"

src/api.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ function EnzymeCreatePrimalAndGradient(
265265
atomicAdd,
266266
)
267267
freeMemory = true
268+
subsequent_calls_may_write = mode != DEM_ReverseModeCombined
268269
ccall(
269270
(:EnzymeCreatePrimalAndGradient, libEnzyme),
270271
LLVMValueRef,
@@ -286,6 +287,7 @@ function EnzymeCreatePrimalAndGradient(
286287
LLVMTypeRef,
287288
UInt8,
288289
CFnTypeInfo,
290+
UInt8,
289291
Ptr{UInt8},
290292
Csize_t,
291293
EnzymeAugmentedReturnPtr,
@@ -308,6 +310,7 @@ function EnzymeCreatePrimalAndGradient(
308310
additionalArg,
309311
forceAnonymousTape,
310312
typeInfo,
313+
subsequent_calls_may_write,
311314
uncacheable_args,
312315
length(uncacheable_args),
313316
augmented,
@@ -331,6 +334,7 @@ function EnzymeCreateForwardDiff(
331334
)
332335
freeMemory = true
333336
aug = C_NULL
337+
subsequent_calls_may_write = false
334338
ccall(
335339
(:EnzymeCreateForwardDiff, libEnzyme),
336340
LLVMValueRef,
@@ -350,6 +354,7 @@ function EnzymeCreateForwardDiff(
350354
Cuint,
351355
LLVMTypeRef,
352356
CFnTypeInfo,
357+
UInt8,
353358
Ptr{UInt8},
354359
Csize_t,
355360
EnzymeAugmentedReturnPtr,
@@ -369,6 +374,7 @@ function EnzymeCreateForwardDiff(
369374
width,
370375
additionalArg,
371376
typeInfo,
377+
subsequent_calls_may_write,
372378
uncacheable_args,
373379
length(uncacheable_args),
374380
aug,
@@ -401,6 +407,7 @@ function EnzymeCreateAugmentedPrimal(
401407
width,
402408
atomicAdd,
403409
)
410+
subsequent_calls_may_write = true
404411
ccall(
405412
(:EnzymeCreateAugmentedPrimal, libEnzyme),
406413
EnzymeAugmentedReturnPtr,
@@ -416,6 +423,7 @@ function EnzymeCreateAugmentedPrimal(
416423
UInt8,
417424
UInt8,
418425
CFnTypeInfo,
426+
UInt8,
419427
Ptr{UInt8},
420428
Csize_t,
421429
UInt8,
@@ -434,6 +442,7 @@ function EnzymeCreateAugmentedPrimal(
434442
returnUsed,
435443
shadowReturnUsed,
436444
typeInfo,
445+
subsequent_calls_may_write,
437446
uncacheable_args,
438447
length(uncacheable_args),
439448
forceAnonymousTape,

src/compiler.jl

Lines changed: 111 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,10 @@ function nested_codegen!(
379379

380380
target = DefaultCompilerTarget()
381381
params = PrimalCompilerParams(mode)
382-
job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false), world)
382+
job = CompilerJob(funcspec, CompilerConfig(target, params; kernel = false, libraries = true, toplevel = true, optimize = false, cleanup = false, only_entry = false, validate = false), world)
383383

384384
GPUCompiler.prepare_job!(job)
385-
otherMod, meta = GPUCompiler.emit_llvm(job; libraries=true, toplevel=true, optimize=false, cleanup=false, only_entry=false, validate=false)
385+
otherMod, meta = GPUCompiler.emit_llvm(job)
386386

387387
prepare_llvm(otherMod, job, meta)
388388

@@ -618,7 +618,9 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType),
618618
prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived))
619619
zero_single_allocation(B, Ty, LLVMType, prev, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true)
620620
else
621-
@assert fieldcount(Ty) != 0
621+
if fieldcount(Ty) == 0
622+
error("Error handling recursive stores for $Ty which has a fieldcount of 0")
623+
end
622624

623625
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
624626
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
@@ -839,16 +841,16 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
839841
continue
840842
end
841843
if isa(ty, LLVM.ArrayType)
842-
subTy = if jlty isa DataType
843-
eltype(jlty)
844-
elseif !(jlty isa DataType)
845-
if eltype(ty) isa LLVM.PointerType && LLVM.addrspace(eltype(ty)) == 10
846-
Any
847-
else
848-
throw(AssertionError("jlty=$jlty ty=$ty"))
849-
end
850-
end
851844
for i = 1:length(ty)
845+
subTy = if jlty isa DataType
846+
typed_fieldtype(jlty, i)
847+
elseif !(jlty isa DataType)
848+
if eltype(ty) isa LLVM.PointerType && LLVM.addrspace(eltype(ty)) == 10
849+
Any
850+
else
851+
throw(AssertionError("jlty=$jlty ty=$ty"))
852+
end
853+
end
852854
npath = copy(path)
853855
push!(npath, LLVM.ConstantInt(LLVM.IntType(32), i - 1))
854856
push!(todo, (npath, eltype(ty), subTy))
@@ -866,7 +868,9 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
866868
end
867869
if isa(ty, LLVM.StructType)
868870
i = 1
869-
@assert jlty isa DataType
871+
if !(jlty isa DataType)
872+
throw(AssertionError("Could not handle non datatype $jlty in zero_single_allocation $ty"))
873+
end
870874
for ii = 1:fieldcount(jlty)
871875
jlet = typed_fieldtype(jlty, ii)
872876
if isghostty(jlet) || Core.Compiler.isconstType(jlet)
@@ -2683,6 +2687,7 @@ function lower_convention(
26832687

26842688
RT = LLVM.return_type(entry_ft)
26852689

2690+
26862691
# generate the wrapper function type & definition
26872692
wrapper_types = LLVM.LLVMType[]
26882693
wrapper_attrs = Vector{LLVM.Attribute}[]
@@ -2699,6 +2704,13 @@ function lower_convention(
26992704
sret = sret !== nothing
27002705
returnRoots = returnRoots !== nothing
27012706

2707+
loweredReturn = RetActivity <: Active && (actualRetType === Any)
2708+
if loweredReturn
2709+
@assert !sret
2710+
@assert !returnRoots
2711+
RT = convert(LLVMType, eltype(RetActivity))
2712+
end
2713+
27022714
# TODO removed implications
27032715
retRemoved, parmsRemoved = removed_ret_parms(entry_f)
27042716
swiftself = has_swiftself(entry_f)
@@ -2769,8 +2781,8 @@ function lower_convention(
27692781
end
27702782
end
27712783

2772-
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
2773-
return entry_f, returnRoots, boxedArgs, loweredArgs
2784+
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union && !loweredReturn
2785+
return entry_f, returnRoots, boxedArgs, loweredArgs, actualRetType
27742786
end
27752787

27762788
wrapper_fn = LLVM.name(entry_f)
@@ -3136,28 +3148,69 @@ function lower_convention(
31363148
ret!(builder)
31373149
else
31383150
ctx = LLVM.context(wrapper_f)
3139-
push!(
3140-
return_attributes(wrapper_f),
3141-
StringAttribute(
3142-
"enzyme_type",
3143-
string(typetree(actualRetType, ctx, dl, seen)),
3144-
),
3145-
)
3146-
push!(
3147-
return_attributes(wrapper_f),
3148-
StringAttribute(
3149-
"enzymejl_parmtype",
3150-
string(convert(UInt, unsafe_to_pointer(actualRetType))),
3151-
),
3152-
)
3153-
push!(
3154-
return_attributes(wrapper_f),
3155-
StringAttribute(
3156-
"enzymejl_parmtype_ref",
3157-
string(UInt(GPUCompiler.BITS_REF)),
3158-
),
3159-
)
3160-
ret!(builder, res)
3151+
3152+
if loweredReturn
3153+
push!(
3154+
return_attributes(wrapper_f),
3155+
StringAttribute(
3156+
"enzyme_type",
3157+
string(typetree(eltype(RetActivity), ctx, dl, seen)),
3158+
),
3159+
)
3160+
push!(
3161+
return_attributes(wrapper_f),
3162+
StringAttribute(
3163+
"enzymejl_parmtype",
3164+
string(convert(UInt, unsafe_to_pointer(eltype(RetActivity)))),
3165+
),
3166+
)
3167+
push!(
3168+
return_attributes(wrapper_f),
3169+
StringAttribute(
3170+
"enzymejl_parmtype_ref",
3171+
string(UInt(GPUCompiler.BITS_VALUE)),
3172+
),
3173+
)
3174+
ty = emit_jltypeof!(builder, res)
3175+
cmp = icmp!(builder, LLVM.API.LLVMIntEQ, ty, unsafe_to_llvm(builder, eltype(RetActivity)))
3176+
cmpret = BasicBlock(wrapper_f, "ret")
3177+
failure = BasicBlock(wrapper_f, "fail")
3178+
br!(builder, cmp, cmpret, failure)
3179+
3180+
position!(builder, cmpret)
3181+
res = bitcast!(builder, res, LLVM.PointerType(RT, addrspace(value_type(res))))
3182+
res = addrspacecast!(builder, res, LLVM.PointerType(RT, Derived))
3183+
res = load!(builder, RT, res)
3184+
ret!(builder, res)
3185+
3186+
position!(builder, failure)
3187+
3188+
emit_error(builder, nothing, "Expected return type of primal to be "*string(eltype(RetActivity))*" but did not find a value of that type")
3189+
unreachable!(builder)
3190+
else
3191+
push!(
3192+
return_attributes(wrapper_f),
3193+
StringAttribute(
3194+
"enzyme_type",
3195+
string(typetree(actualRetType, ctx, dl, seen)),
3196+
),
3197+
)
3198+
push!(
3199+
return_attributes(wrapper_f),
3200+
StringAttribute(
3201+
"enzymejl_parmtype",
3202+
string(convert(UInt, unsafe_to_pointer(actualRetType))),
3203+
),
3204+
)
3205+
push!(
3206+
return_attributes(wrapper_f),
3207+
StringAttribute(
3208+
"enzymejl_parmtype_ref",
3209+
string(UInt(GPUCompiler.BITS_REF)),
3210+
),
3211+
)
3212+
ret!(builder, res)
3213+
end
31613214
end
31623215
dispose(builder)
31633216
end
@@ -3366,7 +3419,7 @@ function lower_convention(
33663419
end
33673420
throw(LLVM.LLVMException(msg))
33683421
end
3369-
return wrapper_f, returnRoots, boxedArgs, loweredArgs
3422+
return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? eltype(RetActivity) : actualRetType
33703423
end
33713424

33723425
using Random
@@ -3424,9 +3477,20 @@ function GPUCompiler.codegen(
34243477
if parent_job === nothing
34253478
primal_target = DefaultCompilerTarget()
34263479
primal_params = PrimalCompilerParams(mode)
3480+
config2 = CompilerConfig(
3481+
primal_target,
3482+
primal_params;
3483+
kernel = false,
3484+
libraries = true,
3485+
toplevel = toplevel,
3486+
optimize = false,
3487+
cleanup = false,
3488+
only_entry = false,
3489+
validate = false
3490+
)
34273491
primal_job = CompilerJob(
34283492
primal,
3429-
CompilerConfig(primal_target, primal_params; kernel = false),
3493+
config2,
34303494
job.world,
34313495
)
34323496
else
@@ -3437,12 +3501,18 @@ function GPUCompiler.codegen(
34373501
parent_job.config.entry_abi,
34383502
parent_job.config.name,
34393503
parent_job.config.always_inline,
3504+
libraries = true,
3505+
toplevel = toplevel,
3506+
optimize = false,
3507+
cleanup = false,
3508+
only_entry = false,
3509+
validate = false,
34403510
)
34413511
primal_job = CompilerJob(primal, config2, job.world) # TODO EnzymeInterp params, etc
34423512
end
34433513

34443514
GPUCompiler.prepare_job!(primal_job)
3445-
mod, meta = GPUCompiler.emit_llvm(primal_job; libraries=true, toplevel=toplevel, optimize=false, cleanup=false, only_entry=false, validate=false)
3515+
mod, meta = GPUCompiler.emit_llvm(primal_job)
34463516
edges = Any[]
34473517
mod_to_edges[mod] = edges
34483518

@@ -4204,7 +4274,7 @@ end
42044274
primalf, returnRoots = primalf, false
42054275

42064276
if lowerConvention
4207-
primalf, returnRoots, boxedArgs, loweredArgs = lower_convention(
4277+
primalf, returnRoots, boxedArgs, loweredArgs, actualRetType = lower_convention(
42084278
source_sig,
42094279
mod,
42104280
primalf,

src/jlrt.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,18 @@ function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nosp
366366
end
367367
end
368368

369-
function ref_if_mixed(val::VT) where {VT}
370-
if active_reg_inner(Core.Typeof(val), (), nothing, Val(true)) == ActiveState
371-
return Ref(val)
369+
@generated function ref_if_mixed(val::VT) where VT
370+
areg = active_reg_inner(VT, (), nothing, Val(true))
371+
if areg == ActiveState || areg == MixedState
372+
quote
373+
Base.@_inline_meta
374+
Ref(val)
375+
end
372376
else
373-
return val
377+
quote
378+
Base.@_inline_meta
379+
val
380+
end
374381
end
375382
end
376383

@@ -380,15 +387,13 @@ function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Valu
380387
if !legal
381388
legal, TT, _ = abs_typeof(val, true)
382389
act = active_reg_inner(TT, (), world)
383-
if act == AnyState
390+
if legal && act == AnyState
384391
return val
385392
end
386-
if !legal
387-
return emit_apply_generic!(B, [unsafe_to_llvm(B, ref_if_mixed), val])
388-
end
393+
return emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(B, ref_if_mixed), val])
389394
end
390395
act = active_reg_inner(TT, (), world)
391-
396+
392397
if act == ActiveState || act == MixedState
393398
obj = emit_allocobj!(B, Base.RefValue{TT})
394399
lty = convert(LLVMType, TT)

0 commit comments

Comments
 (0)