Skip to content

Commit 5dede02

Browse files
authored
If the lowered return of active really can't be inferred, just assume… (#2390)
* If the lowered return of active really can't be inferred, just assume float64 and emit a cast error * set art * fix * fixup * fix
1 parent 55c039a commit 5dede02

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

src/Enzyme.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
344344
@inline function autodiff(
345345
mode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten},
346346
f::FA,
347-
::Type{A},
347+
::Type{A0},
348348
args::Vararg{Annotation,Nargs},
349349
) where {
350350
FA<:Annotation,
351-
A<:Annotation,
351+
A0<:Annotation,
352352
ReturnPrimal,
353353
RuntimeActivity,
354354
RABI<:ABI,
@@ -369,13 +369,14 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
369369

370370
FTy = Core.Typeof(f.val)
371371

372-
rt = if A isa UnionAll
373-
Compiler.primal_return_type(Reverse, FTy, tt)
372+
rt, A = if A0 isa UnionAll
373+
rt0 = Compiler.primal_return_type(Reverse, FTy, tt)
374+
rt0, A0{rt0}
374375
else
375-
eltype(A)
376+
eltype(A0), A0
376377
end
377378

378-
if A <: Active
379+
if A0 <: Active
379380
if (!allocatedinline(rt) || rt isa Union) && rt != Union{}
380381
forward, adjoint = autodiff_thunk(
381382
ReverseModeSplit{
@@ -401,11 +402,11 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
401402
return adjoint(f, args..., tape)
402403
end
403404
end
404-
elseif A <: Duplicated ||
405-
A <: DuplicatedNoNeed ||
406-
A <: BatchDuplicated ||
407-
A <: BatchDuplicatedNoNeed ||
408-
A <: BatchDuplicatedFunc
405+
elseif A0 <: Duplicated ||
406+
A0 <: DuplicatedNoNeed ||
407+
A0 <: BatchDuplicated ||
408+
A0 <: BatchDuplicatedNoNeed ||
409+
A0 <: BatchDuplicatedFunc
409410
throw(ErrorException("Duplicated Returns not yet handled"))
410411
end
411412

@@ -415,7 +416,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
415416
Val(0)
416417
end
417418

418-
if (A <: Active && rt <: Complex) && rt != Union{}
419+
if (A0 <: Active && rt <: Complex) && rt != Union{}
419420
if Holomorphic
420421
seen = IdDict()
421422
seen2 = IdDict()
@@ -497,7 +498,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
497498
Val(RuntimeActivity),
498499
) #=ShadowInit=#
499500

500-
if A <: Active
501+
if A0 <: Active
501502
args = (args..., Compiler.default_adjoint(rt))
502503
end
503504
thunk(f, args...)

src/compiler.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2714,10 +2714,15 @@ function lower_convention(
27142714
returnRoots = returnRoots !== nothing
27152715

27162716
loweredReturn = RetActivity <: Active && (actualRetType === Any)
2717+
expected_RT = Nothing
27172718
if loweredReturn
27182719
@assert !sret
27192720
@assert !returnRoots
2720-
RT = convert(LLVMType, eltype(RetActivity))
2721+
expected_RT = eltype(RetActivity)
2722+
if expected_RT === Any
2723+
expected_RT = Float64
2724+
end
2725+
RT = convert(LLVMType, expected_RT)
27212726
end
27222727

27232728
# TODO removed implications
@@ -3170,7 +3175,7 @@ function lower_convention(
31703175
return_attributes(wrapper_f),
31713176
StringAttribute(
31723177
"enzymejl_parmtype",
3173-
string(convert(UInt, unsafe_to_pointer(eltype(RetActivity)))),
3178+
string(convert(UInt, unsafe_to_pointer(expected_RT))),
31743179
),
31753180
)
31763181
push!(
@@ -3181,7 +3186,7 @@ function lower_convention(
31813186
),
31823187
)
31833188
ty = emit_jltypeof!(builder, res)
3184-
cmp = icmp!(builder, LLVM.API.LLVMIntEQ, ty, unsafe_to_llvm(builder, eltype(RetActivity)))
3189+
cmp = icmp!(builder, LLVM.API.LLVMIntEQ, ty, unsafe_to_llvm(builder, expected_RT))
31853190
cmpret = BasicBlock(wrapper_f, "ret")
31863191
failure = BasicBlock(wrapper_f, "fail")
31873192
br!(builder, cmp, cmpret, failure)
@@ -3194,7 +3199,7 @@ function lower_convention(
31943199

31953200
position!(builder, failure)
31963201

3197-
emit_error(builder, nothing, "Expected return type of primal to be "*string(eltype(RetActivity))*" but did not find a value of that type")
3202+
emit_error(builder, nothing, "Expected return type of primal to be "*string(expected_RT)*" but did not find a value of that type")
31983203
unreachable!(builder)
31993204
else
32003205
push!(
@@ -3428,7 +3433,7 @@ function lower_convention(
34283433
end
34293434
throw(LLVM.LLVMException(msg))
34303435
end
3431-
return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? eltype(RetActivity) : actualRetType
3436+
return wrapper_f, returnRoots, boxedArgs, loweredArgs, loweredReturn ? expected_RT : actualRetType
34323437
end
34333438

34343439
using Random

0 commit comments

Comments
 (0)