Skip to content

Commit dea1012

Browse files
authored
Custom rules: handle sret union return (#2871)
* Custom rules: handle sret union return * fix * more test reduction * fewer rt act err * more * bump jll * fix local build * libc free as free * cleanup
1 parent 89305f0 commit dea1012

8 files changed

Lines changed: 229 additions & 38 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
4242
CEnum = "0.4, 0.5"
4343
ChainRulesCore = "1"
4444
EnzymeCore = "0.8.16"
45-
Enzyme_jll = "0.0.232"
45+
Enzyme_jll = "0.0.234"
4646
GPUArraysCore = "0.1.6, 0.2"
4747
GPUCompiler = "1.6.2"
4848
LLVM = "9.1"

deps/build_local.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ isdir(scratch_dir) && rm(scratch_dir; recursive=true)
8888

8989
# Build!
9090
@info "Building" source_dir scratch_dir LLVM_DIR BUILD_TYPE
91-
run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -B$(scratch_dir) -S$(source_dir)`)
91+
run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -DENZYME_ENABLE_BENCHMARKS=OFF -B$(scratch_dir) -S$(source_dir)`)
9292

9393
if BCLoad
9494
run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR) EnzymeBCLoad-$(LLVM_VER_MAJOR)`)

src/llvm/attributes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ function annotate!(mod::LLVM.Module)
536536
end
537537
end
538538

539+
539540
for fname in (
540541
"jl_f_getfield",
541542
"ijl_f_getfield",

src/rules/customrules.jl

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,15 @@ end
13331333
return true
13341334
end
13351335

1336+
function sret_union_tape_type(@nospecialize(aug_RT))
1337+
InnerTypes = Type[]
1338+
for_each_uniontype_small(aug_RT) do T
1339+
TapeT = EnzymeRules.tape_type(T)
1340+
push!(InnerTypes, TapeT)
1341+
end
1342+
return Union{InnerTypes...}
1343+
end
1344+
13361345
function enzyme_custom_common_rev(
13371346
forward::Bool,
13381347
B::LLVM.IRBuilder,
@@ -1424,6 +1433,11 @@ function enzyme_custom_common_rev(
14241433
!(aug_RT isa Union) &&
14251434
!(aug_RT === Union{})
14261435
TapeT = EnzymeRules.tape_type(aug_RT)
1436+
elseif (
1437+
aug_RT <: EnzymeRules.AugmentedReturn ||
1438+
aug_RT <: EnzymeRules.AugmentedReturnFlexShadow
1439+
) && is_sret_union(aug_RT)
1440+
TapeT = sret_union_tape_type(aug_RT)
14271441
elseif (aug_RT isa UnionAll) &&
14281442
(aug_RT <: EnzymeRules.AugmentedReturn) && hasfield(typeof(aug_RT.body), :name) &&
14291443
aug_RT.body.name == EnzymeCore.EnzymeRules.AugmentedReturn.body.body.body.name
@@ -1557,10 +1571,8 @@ function enzyme_custom_common_rev(
15571571
sret_union = is_sret_union(miRT)
15581572

15591573
if sret_union
1560-
bt = GPUCompiler.backtrace(orig)
1561-
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
1562-
emit_error(B, orig, (msg2, final_mi, world), UnionSretReturnException{miRT})
1563-
return tapeV
1574+
@assert sret !== nothing
1575+
@assert returnRoots === nothing
15641576
end
15651577

15661578
if !forward
@@ -1749,11 +1761,15 @@ function enzyme_custom_common_rev(
17491761
end
17501762

17511763
if sret !== nothing
1752-
sret_lty = convert(LLVMType, eltype(sret))
1753-
if VERSION >= v"1.12" && returnRoots !== nothing
1754-
dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))
1755-
sret_lty = LLVM.ArrayType(LLVM.Int8Type(), LLVM.sizeof(dl, sret_lty))
1756-
end
1764+
sret_lty = if sret_union
1765+
LLVM.ArrayType(LLVM.Int8Type(), union_alloca_type(miRT))
1766+
else
1767+
convert(LLVMType, eltype(sret))
1768+
end
1769+
if VERSION >= v"1.12" && returnRoots !== nothing
1770+
dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))
1771+
sret_lty = LLVM.ArrayType(LLVM.Int8Type(), LLVM.sizeof(dl, sret_lty))
1772+
end
17571773
sret = alloca!(alloctx, sret_lty)
17581774
pushfirst!(args, sret)
17591775
if returnRoots !== nothing
@@ -1845,7 +1861,75 @@ function enzyme_custom_common_rev(
18451861
return tapeV
18461862
end
18471863

1848-
if sret !== nothing
1864+
sret_union_tape = nothing
1865+
1866+
if sret_union && forward
1867+
1868+
ShadT = RealRt
1869+
if width != 1
1870+
ShadT = NTuple{Int(width),RealRt}
1871+
end
1872+
ST = EnzymeRules.AugmentedReturn{
1873+
needsPrimal ? RealRt : Nothing,
1874+
needsShadowJL ? ShadT : Nothing,
1875+
TapeT,
1876+
}
1877+
if ST != EnzymeRules.augmented_rule_return_type(C, RT, TapeT)
1878+
throw(AssertionError("Unexpected augmented rule return computation\nST = $ST\nER = $(EnzymeRules.augmented_rule_return_type(C, RT, TapeT))\nC = $C\nRT = $RT\nTapeT = $TapeT"))
1879+
end
1880+
if !(aug_RT <: EnzymeRules.AugmentedReturnFlexShadow) && !(aug_RT <: EnzymeRules.AugmentedReturn{
1881+
needsPrimal ? RealRt : Nothing,
1882+
needsShadowJL ? ShadT : Nothing})
1883+
1884+
bt = GPUCompiler.backtrace(orig)
1885+
msg2 = sprint(Base.Fix2(Base.show_backtrace, bt))
1886+
emit_error(B, orig, (msg2, ami, world), AugmentedRuleReturnError{C, RT, aug_RT})
1887+
return tapeV
1888+
end
1889+
1890+
if ST != EnzymeRules.augmented_rule_return_type(C, RT, TapeT)
1891+
throw(AssertionError("Unexpected augmented rule return computation\nST = $ST\nER = $(EnzymeRules.augmented_rule_return_type(C, RT, TapeT))\nC = $C\nRT = $RT\nTapeT = $TapeT"))
1892+
end
1893+
1894+
cur = nothing
1895+
cur_size = nothing
1896+
cur_offset = nothing
1897+
1898+
counter = 1
1899+
1900+
idxv = extract_value!(B, res, 1)
1901+
1902+
function inner(@nospecialize(aug_RT::Type))
1903+
jlrettype = EnzymeRules.tape_type(aug_RT)
1904+
if cur_size == nothing
1905+
cur_size = sizeof(jlrettype)
1906+
elseif cur_size != sizeof(jlrettype)
1907+
same_size = false
1908+
end
1909+
1910+
if cur === nothing
1911+
cur = unsafe_to_llvm(B, jlrettype)
1912+
cur_size = LLVM.ConstantInt(sizeof(jlrettype))
1913+
cur_offset = LLVM.ConstantInt(fieldoffset(aug_RT, 3))
1914+
else
1915+
cmpv = icmp!(B, LLVM.API.LLVMIntEQ, idxv, LLVM.ConstantInt(value_type(idxv), counter))
1916+
cur = select!(B, cmpv, unsafe_to_llvm(B, jlrettype), cur)
1917+
cur_size = select!(B, cmpv, LLVM.ConstantInt(sizeof(jlrettype)), cur_size)
1918+
cur_offset = select!(B, cmpv, LLVM.ConstantInt(fieldoffset(aug_RT, 3)), cur_offset)
1919+
end
1920+
1921+
counter += 1
1922+
return
1923+
end
1924+
for_each_uniontype_small(inner, miRT)
1925+
1926+
sret_union_tape = emit_allocobj!(B, cur, cur_size, false)
1927+
T_int8 = LLVM.Int8Type()
1928+
memcpy!(B, bitcast!(B, sret_union_tape, LLVM.PointerType(T_int8, Tracked)), 0, gep!(B, T_int8, bitcast!(B, sret, LLVM.PointerType(T_int8)), LLVM.Value[cur_offset]), 0, cur_size)
1929+
1930+
res = sret
1931+
1932+
elseif sret !== nothing
18491933
sty = sret_ty(llvmf, 1+swiftself)
18501934
if LLVM.version().major >= 12
18511935
attr = TypeAttribute("sret", sty)
@@ -1857,16 +1941,17 @@ function enzyme_custom_common_rev(
18571941
LLVM.API.LLVMAttributeIndex(1 + swiftself),
18581942
attr,
18591943
)
1860-
if returnRoots !== nothing
1861-
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(2 + swiftself), StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1]))))
1862-
end
1863-
if returnRoots !== nothing && VERSION >= v"1.12"
1864-
res = recombine_value_ptr!(B, sty, sret, returnRoots; must_cache=true)
1865-
else
1866-
res = load!(B, sty, sret)
1867-
API.SetMustCache!(res)
1868-
end
1944+
if returnRoots !== nothing
1945+
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(2 + swiftself), StringAttribute("enzymejl_returnRoots", string(length(eltype(returnRoots0).parameters[1]))))
1946+
end
1947+
if returnRoots !== nothing && VERSION >= v"1.12"
1948+
res = recombine_value_ptr!(B, sty, sret, returnRoots; must_cache=true)
1949+
else
1950+
res = load!(B, sty, sret)
1951+
API.SetMustCache!(res)
1952+
end
18691953
end
1954+
18701955
if swiftself
18711956
attr = EnumAttribute("swiftself")
18721957
LLVM.API.LLVMAddCallSiteAttribute(
@@ -1953,7 +2038,7 @@ function enzyme_custom_common_rev(
19532038
},
19542039
)
19552040
if StructTy != LLVM.VoidType()
1956-
load!(
2041+
lresV = load!(
19572042
B,
19582043
StructTy,
19592044
bitcast!(
@@ -1962,6 +2047,8 @@ function enzyme_custom_common_rev(
19622047
LLVM.PointerType(StructTy, addrspace(value_type(res))),
19632048
),
19642049
)
2050+
API.SetMustCache!(lresV)
2051+
lresV
19652052
else
19662053
res
19672054
end
@@ -1973,19 +2060,19 @@ function enzyme_custom_common_rev(
19732060
if needsPrimal
19742061
@assert !isghostty(RealRt)
19752062
normalV = extract_value!(B, resV, idx)
1976-
_, prim_sret, prim_roots = get_return_info(RealRt)
2063+
_, prim_sret, prim_roots = get_return_info(RealRt)
19772064
if prim_sret !== nothing
19782065
val = new_from_original(gutils, operands(orig)[1])
19792066

1980-
if prim_roots !== nothing && VERSION >= v"1.12"
2067+
if prim_roots !== nothing && VERSION >= v"1.12"
19812068
extract_nonjlvalues_into!(B, value_type(normalV), val, normalV)
19822069

19832070
rval = new_from_original(gutils, operands(orig)[2])
19842071

1985-
extract_roots_from_value!(B, normalV, rval)
1986-
else
2072+
extract_roots_from_value!(B, normalV, rval)
2073+
else
19872074
store!(B, normalV, val)
1988-
end
2075+
end
19892076
else
19902077
@assert value_type(normalV) == value_type(orig)
19912078
normalV = normalV.ref
@@ -2030,7 +2117,10 @@ function enzyme_custom_common_rev(
20302117
end
20312118
end
20322119
if needsTape
2033-
tapeV0 = if abstract
2120+
2121+
tapeV0 = if sret_union
2122+
sret_union_tape
2123+
elseif abstract
20342124
emit_nthfield!(B, res, LLVM.ConstantInt(2))
20352125
else
20362126
extract_value!(B, res, idx)

src/rules/llvmrules.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ end
454454
real_ops = collect(operands(orig))[1:end-1]
455455
ops = [new_from_original(gutils, o) for o in real_ops]
456456

457+
shadowin = invert_pointer(gutils, real_ops[1], B)
458+
457459
batch_call_same_with_inverted_arg_if_active!(
458460
B,
459461
gutils,
@@ -1948,9 +1950,19 @@ end
19481950
if is_constant_value(gutils, orig)
19491951
return true
19501952
end
1951-
err = emit_error(B, orig, "Enzyme: unhandled forward for jl_get_binding_or_error")
1953+
19521954
newo = new_from_original(gutils, orig)
1953-
API.moveBefore(newo, err, B)
1955+
cmp = icmp!(B, LLVM.API.LLVMIntNE, newo, LLVM.null(value_type(newo)))
1956+
1957+
err = emit_error(
1958+
B,
1959+
orig,
1960+
"Enzyme: unhandled forward for jl_get_binding_or_error",
1961+
EnzymeRuntimeException,
1962+
cmp
1963+
)
1964+
1965+
API.moveBefore(newo, cmp, B)
19541966

19551967
if unsafe_load(shadowR) != C_NULL
19561968
valTys = API.CValueType[]
@@ -1980,13 +1992,20 @@ end
19801992
if is_constant_value(gutils, orig)
19811993
return true
19821994
end
1995+
1996+
newo = new_from_original(gutils, orig)
1997+
1998+
cmp = icmp!(B, LLVM.API.LLVMIntNE, newo, LLVM.null(value_type(newo)))
1999+
19832000
err = emit_error(
19842001
B,
19852002
orig,
19862003
"Enzyme: unhandled augmented forward for jl_get_binding_or_error",
2004+
EnzymeRuntimeException,
2005+
cmp
19872006
)
1988-
newo = new_from_original(gutils, orig)
1989-
API.moveBefore(newo, err, B)
2007+
API.moveBefore(newo, cmp, B)
2008+
19902009
if unsafe_load(shadowR) != C_NULL
19912010
valTys = API.CValueType[]
19922011
args = LLVM.Value[]

src/typeutils/jltypes.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,48 @@ function inline_roots_type(@nospecialize(T::Type))::Int
5252
end
5353
end
5454

55+
function non_rooted_types(@nospecialize(typ::DataType))
56+
lRT = convert(LLVMType, typ)
57+
tracked = CountTrackedPointers(lRT)
58+
@assert !tracked.derived
59+
@assert !tracked.all
60+
@assert tracked.count != 0
61+
62+
inners = Type[]
63+
64+
todo = DataType[typ]
65+
while length(todo) != 0
66+
cur = popfirst!(todo)
67+
68+
desc = Base.DataTypeFieldDesc(cur)
69+
70+
next = DataType[]
71+
for i in 1:fieldcount(cur)
72+
styp = typed_fieldtype(cur, i)
73+
if isghostty(styp)
74+
push!(inners, styp)
75+
continue
76+
end
77+
if desc[i].isptr
78+
continue
79+
end
80+
if styp isa Union
81+
push!(inners, styp)
82+
continue
83+
end
84+
if !(styp isa DataType)
85+
throw(AssertionError("Non inner datatype: styp=$styp cur=$cur, typ=$typ lRT=$(string(lRT))"))
86+
end
87+
push!(next, styp)
88+
end
89+
90+
for styp in reverse(next)
91+
pushfirst!(todo, styp)
92+
end
93+
end
94+
return inners
95+
end
96+
5597
function equivalent_rooted_type(@nospecialize(typ::DataType))
5698
lRT = convert(LLVMType, typ)
5799
tracked = CountTrackedPointers(lRT)

test/advanced.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ end
740740
@inbounds w[1] * x[1]
741741
end
742742

743-
@static if VERSION < v"1.11-"
743+
@static if VERSION < v"1.11-" || VERSION >= v"1.12"
744744
Enzyme.autodiff(Reverse, inactiveArg, Active, Duplicated(w, dw), Const(x), Const(false))
745745

746746
@test x [3.0]
@@ -762,7 +762,7 @@ end
762762
res
763763
end
764764

765-
@static if VERSION < v"1.11-"
765+
@static if VERSION < v"1.11-" || VERSION >= v"1.12"
766766
dw = Enzyme.autodiff(Reverse, loss, Active, Active(1.0), Const(x), Const(false))[1]
767767

768768
else
@@ -1321,11 +1321,16 @@ end
13211321

13221322
f_union(cond, x) = cond ? x : 0
13231323
g_union(cond, x) = f_union(cond, x) * x
1324-
if sizeof(Int) == sizeof(Int64)
1325-
@test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0))
1326-
else
1327-
@test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0f0))
1324+
1325+
# This only works as a test in < 1.12 as we actually optimize away the issue in later LLVM's
1326+
if VERSION < v"1.12"
1327+
if sizeof(Int) == sizeof(Int64)
1328+
@test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0))
1329+
else
1330+
@test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0f0))
1331+
end
13281332
end
1333+
13291334
# TODO: Add test for NoShadowException
13301335
end
13311336

0 commit comments

Comments
 (0)