@@ -1333,6 +1333,15 @@ end
13331333 return true
13341334end
13351335
1336+ function sret_union_tape_type (@nospecialize (aug_RT))
1337+ InnerTypes = Type[]
1338+ for_each_uniontype_small (aug_RT) do T
1339+ TapeT = EnzymeRules. tape_type (T)
1340+ push! (InnerTypes, TapeT)
1341+ end
1342+ return Union{InnerTypes... }
1343+ end
1344+
13361345function enzyme_custom_common_rev (
13371346 forward:: Bool ,
13381347 B:: LLVM.IRBuilder ,
@@ -1424,6 +1433,11 @@ function enzyme_custom_common_rev(
14241433 ! (aug_RT isa Union) &&
14251434 ! (aug_RT === Union{})
14261435 TapeT = EnzymeRules. tape_type (aug_RT)
1436+ elseif (
1437+ aug_RT <: EnzymeRules.AugmentedReturn ||
1438+ aug_RT <: EnzymeRules.AugmentedReturnFlexShadow
1439+ ) && is_sret_union (aug_RT)
1440+ TapeT = sret_union_tape_type (aug_RT)
14271441 elseif (aug_RT isa UnionAll) &&
14281442 (aug_RT <: EnzymeRules.AugmentedReturn ) && hasfield (typeof (aug_RT. body), :name ) &&
14291443 aug_RT. body. name == EnzymeCore. EnzymeRules. AugmentedReturn. body. body. body. name
@@ -1557,10 +1571,8 @@ function enzyme_custom_common_rev(
15571571 sret_union = is_sret_union (miRT)
15581572
15591573 if sret_union
1560- bt = GPUCompiler. backtrace (orig)
1561- msg2 = sprint (Base. Fix2 (Base. show_backtrace, bt))
1562- emit_error (B, orig, (msg2, final_mi, world), UnionSretReturnException{miRT})
1563- return tapeV
1574+ @assert sret != = nothing
1575+ @assert returnRoots === nothing
15641576 end
15651577
15661578 if ! forward
@@ -1749,11 +1761,15 @@ function enzyme_custom_common_rev(
17491761 end
17501762
17511763 if sret != = nothing
1752- sret_lty = convert (LLVMType, eltype (sret))
1753- if VERSION >= v " 1.12" && returnRoots != = nothing
1754- dl = LLVM. datalayout (LLVM. parent (LLVM. parent (LLVM. parent (orig))))
1755- sret_lty = LLVM. ArrayType (LLVM. Int8Type (), LLVM. sizeof (dl, sret_lty))
1756- end
1764+ sret_lty = if sret_union
1765+ LLVM. ArrayType (LLVM. Int8Type (), union_alloca_type (miRT))
1766+ else
1767+ convert (LLVMType, eltype (sret))
1768+ end
1769+ if VERSION >= v " 1.12" && returnRoots != = nothing
1770+ dl = LLVM. datalayout (LLVM. parent (LLVM. parent (LLVM. parent (orig))))
1771+ sret_lty = LLVM. ArrayType (LLVM. Int8Type (), LLVM. sizeof (dl, sret_lty))
1772+ end
17571773 sret = alloca! (alloctx, sret_lty)
17581774 pushfirst! (args, sret)
17591775 if returnRoots != = nothing
@@ -1845,7 +1861,75 @@ function enzyme_custom_common_rev(
18451861 return tapeV
18461862 end
18471863
1848- if sret != = nothing
1864+ sret_union_tape = nothing
1865+
1866+ if sret_union && forward
1867+
1868+ ShadT = RealRt
1869+ if width != 1
1870+ ShadT = NTuple{Int (width),RealRt}
1871+ end
1872+ ST = EnzymeRules. AugmentedReturn{
1873+ needsPrimal ? RealRt : Nothing,
1874+ needsShadowJL ? ShadT : Nothing,
1875+ TapeT,
1876+ }
1877+ if ST != EnzymeRules. augmented_rule_return_type (C, RT, TapeT)
1878+ throw (AssertionError (" Unexpected augmented rule return computation\n ST = $ST \n ER = $(EnzymeRules. augmented_rule_return_type (C, RT, TapeT)) \n C = $C \n RT = $RT \n TapeT = $TapeT " ))
1879+ end
1880+ if ! (aug_RT <: EnzymeRules.AugmentedReturnFlexShadow ) && ! (aug_RT <: EnzymeRules.AugmentedReturn {
1881+ needsPrimal ? RealRt : Nothing,
1882+ needsShadowJL ? ShadT : Nothing})
1883+
1884+ bt = GPUCompiler. backtrace (orig)
1885+ msg2 = sprint (Base. Fix2 (Base. show_backtrace, bt))
1886+ emit_error (B, orig, (msg2, ami, world), AugmentedRuleReturnError{C, RT, aug_RT})
1887+ return tapeV
1888+ end
1889+
1890+ if ST != EnzymeRules. augmented_rule_return_type (C, RT, TapeT)
1891+ throw (AssertionError (" Unexpected augmented rule return computation\n ST = $ST \n ER = $(EnzymeRules. augmented_rule_return_type (C, RT, TapeT)) \n C = $C \n RT = $RT \n TapeT = $TapeT " ))
1892+ end
1893+
1894+ cur = nothing
1895+ cur_size = nothing
1896+ cur_offset = nothing
1897+
1898+ counter = 1
1899+
1900+ idxv = extract_value! (B, res, 1 )
1901+
1902+ function inner (@nospecialize (aug_RT:: Type ))
1903+ jlrettype = EnzymeRules. tape_type (aug_RT)
1904+ if cur_size == nothing
1905+ cur_size = sizeof (jlrettype)
1906+ elseif cur_size != sizeof (jlrettype)
1907+ same_size = false
1908+ end
1909+
1910+ if cur === nothing
1911+ cur = unsafe_to_llvm (B, jlrettype)
1912+ cur_size = LLVM. ConstantInt (sizeof (jlrettype))
1913+ cur_offset = LLVM. ConstantInt (fieldoffset (aug_RT, 3 ))
1914+ else
1915+ cmpv = icmp! (B, LLVM. API. LLVMIntEQ, idxv, LLVM. ConstantInt (value_type (idxv), counter))
1916+ cur = select! (B, cmpv, unsafe_to_llvm (B, jlrettype), cur)
1917+ cur_size = select! (B, cmpv, LLVM. ConstantInt (sizeof (jlrettype)), cur_size)
1918+ cur_offset = select! (B, cmpv, LLVM. ConstantInt (fieldoffset (aug_RT, 3 )), cur_offset)
1919+ end
1920+
1921+ counter += 1
1922+ return
1923+ end
1924+ for_each_uniontype_small (inner, miRT)
1925+
1926+ sret_union_tape = emit_allocobj! (B, cur, cur_size, false )
1927+ T_int8 = LLVM. Int8Type ()
1928+ memcpy! (B, bitcast! (B, sret_union_tape, LLVM. PointerType (T_int8, Tracked)), 0 , gep! (B, T_int8, bitcast! (B, sret, LLVM. PointerType (T_int8)), LLVM. Value[cur_offset]), 0 , cur_size)
1929+
1930+ res = sret
1931+
1932+ elseif sret != = nothing
18491933 sty = sret_ty (llvmf, 1 + swiftself)
18501934 if LLVM. version (). major >= 12
18511935 attr = TypeAttribute (" sret" , sty)
@@ -1857,16 +1941,17 @@ function enzyme_custom_common_rev(
18571941 LLVM. API. LLVMAttributeIndex (1 + swiftself),
18581942 attr,
18591943 )
1860- if returnRoots != = nothing
1861- LLVM. API. LLVMAddCallSiteAttribute (res, LLVM. API. LLVMAttributeIndex (2 + swiftself), StringAttribute (" enzymejl_returnRoots" , string (length (eltype (returnRoots0). parameters[1 ]))))
1862- end
1863- if returnRoots != = nothing && VERSION >= v " 1.12"
1864- res = recombine_value_ptr! (B, sty, sret, returnRoots; must_cache= true )
1865- else
1866- res = load! (B, sty, sret)
1867- API. SetMustCache! (res)
1868- end
1944+ if returnRoots != = nothing
1945+ LLVM. API. LLVMAddCallSiteAttribute (res, LLVM. API. LLVMAttributeIndex (2 + swiftself), StringAttribute (" enzymejl_returnRoots" , string (length (eltype (returnRoots0). parameters[1 ]))))
1946+ end
1947+ if returnRoots != = nothing && VERSION >= v " 1.12"
1948+ res = recombine_value_ptr! (B, sty, sret, returnRoots; must_cache= true )
1949+ else
1950+ res = load! (B, sty, sret)
1951+ API. SetMustCache! (res)
1952+ end
18691953 end
1954+
18701955 if swiftself
18711956 attr = EnumAttribute (" swiftself" )
18721957 LLVM. API. LLVMAddCallSiteAttribute (
@@ -1953,7 +2038,7 @@ function enzyme_custom_common_rev(
19532038 },
19542039 )
19552040 if StructTy != LLVM. VoidType ()
1956- load! (
2041+ lresV = load! (
19572042 B,
19582043 StructTy,
19592044 bitcast! (
@@ -1962,6 +2047,8 @@ function enzyme_custom_common_rev(
19622047 LLVM. PointerType (StructTy, addrspace (value_type (res))),
19632048 ),
19642049 )
2050+ API. SetMustCache! (lresV)
2051+ lresV
19652052 else
19662053 res
19672054 end
@@ -1973,19 +2060,19 @@ function enzyme_custom_common_rev(
19732060 if needsPrimal
19742061 @assert ! isghostty (RealRt)
19752062 normalV = extract_value! (B, resV, idx)
1976- _, prim_sret, prim_roots = get_return_info (RealRt)
2063+ _, prim_sret, prim_roots = get_return_info (RealRt)
19772064 if prim_sret != = nothing
19782065 val = new_from_original (gutils, operands (orig)[1 ])
19792066
1980- if prim_roots != = nothing && VERSION >= v " 1.12"
2067+ if prim_roots != = nothing && VERSION >= v " 1.12"
19812068 extract_nonjlvalues_into! (B, value_type (normalV), val, normalV)
19822069
19832070 rval = new_from_original (gutils, operands (orig)[2 ])
19842071
1985- extract_roots_from_value! (B, normalV, rval)
1986- else
2072+ extract_roots_from_value! (B, normalV, rval)
2073+ else
19872074 store! (B, normalV, val)
1988- end
2075+ end
19892076 else
19902077 @assert value_type (normalV) == value_type (orig)
19912078 normalV = normalV. ref
@@ -2030,7 +2117,10 @@ function enzyme_custom_common_rev(
20302117 end
20312118 end
20322119 if needsTape
2033- tapeV0 = if abstract
2120+
2121+ tapeV0 = if sret_union
2122+ sret_union_tape
2123+ elseif abstract
20342124 emit_nthfield! (B, res, LLVM. ConstantInt (2 ))
20352125 else
20362126 extract_value! (B, res, idx)
0 commit comments