-
Notifications
You must be signed in to change notification settings - Fork 82
Calling conv part 1 #2782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Calling conv part 1 #2782
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/absint.jl b/src/absint.jl
index 6e098fba..cce53e79 100644
--- a/src/absint.jl
+++ b/src/absint.jl
@@ -1,7 +1,7 @@
# Abstractly interpret julia from LLVM
# Return (bool if could interpret, julia object interpreted to)
-function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false)::Tuple{Bool, Any}
+function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool = false)::Tuple{Bool, Any}
if (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked)) || (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Derived)) || istracked
ce, _ = get_base_and_offset(arg; offsetAllowed = false, inttoptr = true)
if isa(ce, GlobalVariable)
@@ -443,8 +443,8 @@ function abs_typeof(
end
# Type tag is arg 3
if nm == "jl_alloc_genericmemory_unchecked" ||
- nm == "ijl_alloc_genericmemory_unchecked"
- vals = absint(operands(arg)[3], partial, true)
+ nm == "ijl_alloc_genericmemory_unchecked"
+ vals = absint(operands(arg)[3], partial, true)
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
end
# Type tag is arg 1
diff --git a/src/compiler.jl b/src/compiler.jl
index 1eeb1438..59db45cf 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1071,7 +1071,7 @@ end
return
end
-function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)::Tuple{Dict{String,LLVM.API.LLVMLinkage}, HandlerState}
+function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLVM.Function}, job, edges, run_enzyme, mode::API.CDerivativeMode)::Tuple{Dict{String, LLVM.API.LLVMLinkage}, HandlerState}
for f in functions(mod)
if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1")
@@ -1088,7 +1088,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
dl = string(LLVM.datalayout(LLVM.parent(f)))
expectLen = (sret !== nothing) + (returnRoots !== nothing)
- for (source_typ, _) in rooted_argument_list(mi.specTypes.parameters)
+ for (source_typ, _) in rooted_argument_list(mi.specTypes.parameters)
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
continue
end
@@ -1111,19 +1111,19 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
world = enzyme_extract_world(f)
if expectLen != length(parameters(f))
- msg = sprint() do io::IO
- println(io, "expectLen != length(parameters(f))")
- println(io, string(f))
- println(io, "expectLen=", string(expectLen))
- println(io, "swiftself=", string(swiftself))
- println(io, "sret=", string(sret))
- println(io, "returnRoots=", string(returnRoots))
- println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters))
- println(io, "retRemoved=", string(retRemoved))
- println(io, "parmsRemoved=", string(parmsRemoved))
- println(io, "rooted_argument_list=", string(rooted_argument_list(mi.specTypes.parameters)))
- end
- throw(CallingConventionMismatchError{String}(msg, mi, world))
+ msg = sprint() do io::IO
+ println(io, "expectLen != length(parameters(f))")
+ println(io, string(f))
+ println(io, "expectLen=", string(expectLen))
+ println(io, "swiftself=", string(swiftself))
+ println(io, "sret=", string(sret))
+ println(io, "returnRoots=", string(returnRoots))
+ println(io, "mi.specTypes.parameters=", string(mi.specTypes.parameters))
+ println(io, "retRemoved=", string(retRemoved))
+ println(io, "parmsRemoved=", string(parmsRemoved))
+ println(io, "rooted_argument_list=", string(rooted_argument_list(mi.specTypes.parameters)))
+ end
+ throw(CallingConventionMismatchError{String}(msg, mi, world))
end
jlargs = classify_arguments(
@@ -1230,7 +1230,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
#=lowerConvention=#true,
#=loweredArgs=#Set{Int}(),
#=boxedArgs=#Set{Int}(),
- #=removedRoots=#Set{Int}(),
+ #=removedRoots=# Set{Int}(),
#=fnsToInject=#Tuple{Symbol,Type}[],
)
@@ -1713,10 +1713,10 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
return
end
@static if VERSION >= v"1.11"
- if Ty <: GenericMemory
- # TODO throw(AssertionError("What the heck is happening, why are we gc.alloca'ing memory, $(string(V)) $Ty"))
- return
- end
+ if Ty <: GenericMemory
+ # TODO throw(AssertionError("What the heck is happening, why are we gc.alloca'ing memory, $(string(V)) $Ty"))
+ return
+ end
end
if mode == API.DEM_ForwardMode && (used || idx != 0)
@@ -2410,7 +2410,7 @@ end
function enzyme_extract_parm_type(fn::LLVM.Function, idx::Int, error::Bool = true)
ty = nothing
byref = nothing
- for fattr in collect(parameter_attributes(fn, idx) )
+ for fattr in collect(parameter_attributes(fn, idx))
if isa(fattr, LLVM.StringAttribute)
if kind(fattr) == "enzymejl_parmtype"
ptr = reinterpret(Ptr{Cvoid}, parse(UInt, LLVM.value(fattr)))
@@ -2451,7 +2451,7 @@ function enzyme!(
@nospecialize(expectedTapeType::Type),
loweredArgs::Set{Int},
boxedArgs::Set{Int},
- removedRoots::Set{Int},
+ removedRoots::Set{Int},
)
if DumpPreEnzyme[]
API.EnzymeDumpModuleRef(mod.ref)
@@ -2485,7 +2485,7 @@ function enzyme!(
end
seen = TypeTreeTable()
-
+
seen_roots = 0
for (i, T) in enumerate(TT.parameters)
@@ -2500,32 +2500,32 @@ function enzyme!(
end
continue
end
- isboxed = (i + seen_roots) in boxedArgs
- inline_root = false
-
+ isboxed = (i + seen_roots) in boxedArgs
+ inline_root = false
- if inline_roots_type(eltype(T)) != 0
- # This is already after lower_convention
- seen_roots += 1
- if false
- inline_root = true
- end
- end
+
+ if inline_roots_type(eltype(T)) != 0
+ # This is already after lower_convention
+ seen_roots += 1
+ if false
+ inline_root = true
+ end
+ end
if T <: Const
push!(args_activity, API.DFT_CONSTANT)
- if inline_root
- push!(args_activity, API.DFT_CONSTANT)
- end
+ if inline_root
+ push!(args_activity, API.DFT_CONSTANT)
+ end
elseif T <: Active
if isboxed
- @assert !inline_root
+ @assert !inline_root
push!(args_activity, API.DFT_DUP_ARG)
else
push!(args_activity, API.DFT_OUT_DIFF)
- if inline_root
- push!(args_activity, API.DFT_CONSTANT)
- end
+ if inline_root
+ push!(args_activity, API.DFT_CONSTANT)
+ end
end
elseif T <: Duplicated ||
T <: BatchDuplicated ||
@@ -2533,14 +2533,14 @@ function enzyme!(
T <: MixedDuplicated ||
T <: BatchMixedDuplicated
push!(args_activity, API.DFT_DUP_ARG)
- if inline_root
- push!(args_activity, API.DFT_DUP_ARG)
- end
+ if inline_root
+ push!(args_activity, API.DFT_DUP_ARG)
+ end
elseif T <: DuplicatedNoNeed || T <: BatchDuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
- if inline_root
- push!(args_activity, API.DFT_DUP_ARG)
- end
+ if inline_root
+ push!(args_activity, API.DFT_DUP_ARG)
+ end
else
error("illegal annotation type $T")
end
@@ -2553,16 +2553,16 @@ function enzyme!(
push!(args_typeInfo, typeTree)
push!(uncacheable_args, modifiedBetween[i])
push!(args_known_values, API.IntList())
- if inline_root
- typeTree = typetree(Any, ctx, dl, seen)
- push!(args_typeInfo, typeTree)
- push!(uncacheable_args, modifiedBetween[i])
- push!(args_known_values, API.IntList())
- end
+ if inline_root
+ typeTree = typetree(Any, ctx, dl, seen)
+ push!(args_typeInfo, typeTree)
+ push!(uncacheable_args, modifiedBetween[i])
+ push!(args_known_values, API.IntList())
+ end
end
if length(uncacheable_args) != length(collect(parameters(primalf)))
msg = sprint() do io
- println(io, "length(uncacheable_args) != length(collect(parameters(primalf))) ")
+ println(io, "length(uncacheable_args) != length(collect(parameters(primalf))) ")
println(io, "TT=", TT)
println(io, "modifiedBetween=", modifiedBetween)
println(io, "uncacheable_args=", uncacheable_args)
@@ -2930,10 +2930,10 @@ function create_abi_wrapper(
isboxed = GPUCompiler.deserves_argbox(source_typ)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)
push!(T_wrapperargs, llvmT)
- arg_roots = inline_roots_type(source_typ)
- if arg_rooting && arg_roots != 0
- push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
- end
+ arg_roots = inline_roots_type(source_typ)
+ if arg_rooting && arg_roots != 0
+ push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
+ end
if T <: Const || T <: BatchDuplicatedFunc
if is_adjoint && i != 1
@@ -2952,19 +2952,19 @@ function create_abi_wrapper(
end
elseif T <: Duplicated || T <: DuplicatedNoNeed || T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmT)))
- arg_roots = inline_roots_type(source_typ)
- if arg_rooting && arg_roots != 0
- push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
- end
+ arg_roots = inline_roots_type(source_typ)
+ if arg_rooting && arg_roots != 0
+ push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+ end
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
- arg_roots = inline_roots_type(source_typ)
- if arg_rooting && arg_roots != 0
- push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
- end
+ arg_roots = inline_roots_type(source_typ)
+ if arg_rooting && arg_roots != 0
+ push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+ end
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
@@ -3016,10 +3016,10 @@ function create_abi_wrapper(
),
)
push!(T_wrapperargs, dretTy)
- arg_roots = inline_roots_type(actualRetType)
- if arg_rooting && arg_roots != 0
- push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
- end
+ arg_roots = inline_roots_type(actualRetType)
+ if arg_rooting && arg_roots != 0
+ push!(T_wrapperargs, convert(LLVMType, AnyArray(width * arg_roots)))
+ end
end
end
@@ -3151,10 +3151,10 @@ function create_abi_wrapper(
tape = LLVM.LLVMType(tape)
jltape = convert(LLVM.LLVMType, Compiler.tape_type(tape); allow_boxed = true)
push!(T_wrapperargs, jltape)
- arg_roots = inline_roots_type(tape)
- if arg_rooting && arg_roots != 0
- push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
- end
+ arg_roots = inline_roots_type(tape)
+ if arg_rooting && arg_roots != 0
+ push!(T_wrapperargs, convert(LLVMType, AnyArray(arg_roots)))
+ end
else
needs_tape = false
end
@@ -3217,16 +3217,16 @@ function create_abi_wrapper(
convty = convert(LLVMType, T′; allow_boxed = true)
- arg_roots = inline_roots_type(T′)
+ arg_roots = inline_roots_type(T′)
if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
@assert Base.isconcretetype(T′)
al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter")
- parm = params[i]
- if arg_rooting && arg_roots != 0
- parm = recombine_value!(builder, parm, params[i+1])
- i += 1
- end
+ parm = params[i]
+ if arg_rooting && arg_roots != 0
+ parm = recombine_value!(builder, parm, params[i + 1])
+ i += 1
+ end
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, parm, al)
emit_writebarrier!(builder, get_julia_inner_types(builder, al0, parm))
@@ -3238,14 +3238,14 @@ function create_abi_wrapper(
i += 1
if T <: Const
- if arg_rooting && arg_roots != 0
- push(realparms, params[i])
- i += 1
- end
+ if arg_rooting && arg_roots != 0
+ push(realparms, params[i])
+ i += 1
+ end
elseif T <: Active
isboxed = GPUCompiler.deserves_argbox(T′)
if isboxed
- @assert arg_roots == 0
+ @assert arg_roots == 0
if is_split
msg = sprint() do io
println(
@@ -3282,55 +3282,55 @@ function create_abi_wrapper(
0,
) #=align=#
end
- if arg_rooting &&arg_roots != 0
- push(realparms, params[i])
- i += 1
- end
+ if arg_rooting &&arg_roots != 0
+ push(realparms, params[i])
+ i += 1
+ end
activeNum += 1
elseif T <: Duplicated || T <: DuplicatedNoNeed || T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
- # Enzyme expects, arg, darg, root, droot
- # Julia expects arg, root, darg, droot
- # We already pushed arg
- # now params[i] refers to root
- isboxed = (T <: BatchDuplicated || T <: BatchDuplicatedNoNeed) && GPUCompiler.deserves_argbox(NTuple{width,T′})
- darg = nothing
- root = nothing
- droot = nothing
- if arg_rooting &&arg_roots != 0
- root = params[i]
- darg = params[i+1]
- droot = params[i+2]
- i += 3
- else
- darg = params[i]
- i += 1
- end
-
- if isboxed
- darg = load!(builder, convert(LLVMType, NTuple{width,T′}), darg)
- end
- push!(realparms, darg)
- if arg_roots != 0
- push!(realparms, root)
- push!(realparms, droot)
- end
+ # Enzyme expects, arg, darg, root, droot
+ # Julia expects arg, root, darg, droot
+ # We already pushed arg
+ # now params[i] refers to root
+ isboxed = (T <: BatchDuplicated || T <: BatchDuplicatedNoNeed) && GPUCompiler.deserves_argbox(NTuple{width, T′})
+ darg = nothing
+ root = nothing
+ droot = nothing
+ if arg_rooting &&arg_roots != 0
+ root = params[i]
+ darg = params[i + 1]
+ droot = params[i + 2]
+ i += 3
+ else
+ darg = params[i]
+ i += 1
+ end
+
+ if isboxed
+ darg = load!(builder, convert(LLVMType, NTuple{width, T′}), darg)
+ end
+ push!(realparms, darg)
+ if arg_roots != 0
+ push!(realparms, root)
+ push!(realparms, droot)
+ end
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
- # Enzyme expects, arg, [w x darg], root, droot
- # Julia expects arg, root, darg, droot
- # We already pushed arg
- # now params[i] referrs to root
- darg = nothing
- root = nothing
- droot = nothing
- if arg_rooting && arg_roots != 0
- root = params[i]
- darg = params[i+1]
- droot = params[i+2]
- i += 3
- else
- darg = params[i]
- i += 1
- end
+ # Enzyme expects, arg, [w x darg], root, droot
+ # Julia expects arg, root, darg, droot
+ # We already pushed arg
+ # now params[i] referrs to root
+ darg = nothing
+ root = nothing
+ droot = nothing
+ if arg_rooting && arg_roots != 0
+ root = params[i]
+ darg = params[i + 1]
+ droot = params[i + 2]
+ i += 3
+ else
+ darg = params[i]
+ i += 1
+ end
if T <: BatchMixedDuplicated
@assert Base.isconcretetype(T′)
@@ -3362,16 +3362,16 @@ function create_abi_wrapper(
end
push!(realparms, ival)
-
- if arg_rooting && arg_roots != 0
- push!(realparms, root)
- push!(realparms, droot)
- end
+
+ if arg_rooting && arg_roots != 0
+ push!(realparms, root)
+ push!(realparms, droot)
+ end
elseif T <: BatchDuplicatedFunc
- # TODO handle this
- if arg_rooting
- @assert arg_roots == 0
- end
+ # TODO handle this
+ if arg_rooting
+ @assert arg_roots == 0
+ end
Func = get_func(T)
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
llvmf = nested_codegen!(Mode, mod, funcspec, world)
@@ -3709,7 +3709,7 @@ function create_abi_wrapper(
end
if returnRoots
- move_sret_tofrom_roots!(builder, jltype, sret, root_ty, rootRet, SRetPointerToRootPointer)
+ move_sret_tofrom_roots!(builder, jltype, sret, root_ty, rootRet, SRetPointerToRootPointer)
end
if T_ret != T_void
ret!(builder, load!(builder, T_ret, sret))
@@ -3772,125 +3772,128 @@ function fixup_metadata!(f::LLVM.Function)
end
end
-@enum(SRetRootMovement,
+@enum(
+ SRetRootMovement,
SRetPointerToRootPointer = 0,
SRetValueToRootPointer = 1,
RootPointerToSRetValue = 2,
RootPointerToSRetPointer = 3
- )
+)
function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, sret::LLVM.Value, root_ty::LLVM.LLVMType, rootRet::LLVM.Value, direction::SRetRootMovement)
- count = 0
- todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
- Cuint[],
+ count = 0
+ todo = Tuple{Vector{Cuint}, LLVM.LLVMType}[
+ (
+ Cuint[],
jltype,
- )]
- function to_llvm(lst::Vector{Cuint})
- vals = LLVM.Value[]
- push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
- for i in lst
- push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
- end
- return vals
- end
+ ),
+ ]
+ function to_llvm(lst::Vector{Cuint})
+ vals = LLVM.Value[]
+ push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
+ for i in lst
+ push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
+ end
+ return vals
+ end
- extracted = LLVM.Value[]
-
- val = sret
- # TODO check that we perform this in the same order that extraction happens within julia
- # aka bfs/etc
- while length(todo) != 0
- path, ty = popfirst!(todo)
- if isa(ty, LLVM.PointerType)
- if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
- loc = inbounds_gep!(
- builder,
- root_ty,
- rootRet,
- to_llvm(Cuint[count]),
- )
- end
-
- if direction == SRetPointerToRootPointer
- outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
- outloc = load!(builder, ty, outloc)
- store!(builder, outloc, loc)
- elseif direction == SRetValueToRootPointer
- outloc = Enzyme.API.e_extract_value!(builder, sret, path)
- store!(builder, outloc, loc)
- elseif direction == RootPointerToSRetValue
- loc = load!(builder, ty, loc)
- sret = Enzyme.API.e_insert_value!(builder, sret, loc, path)
- elseif direction == RootPointerToSRetPointer
- outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
- loc = load!(builder, ty, loc)
- push!(extracted, loc)
- store!(builder, loc, outloc)
- else
- @assert false "Unhandled direction"
- end
-
- count += 1
- continue
+ extracted = LLVM.Value[]
+
+ val = sret
+ # TODO check that we perform this in the same order that extraction happens within julia
+ # aka bfs/etc
+ while length(todo) != 0
+ path, ty = popfirst!(todo)
+ if isa(ty, LLVM.PointerType)
+ if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
+ loc = inbounds_gep!(
+ builder,
+ root_ty,
+ rootRet,
+ to_llvm(Cuint[count]),
+ )
end
- if isa(ty, LLVM.ArrayType)
- if any_jltypes(ty)
- for i = 1:length(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
+
+ if direction == SRetPointerToRootPointer
+ outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+ outloc = load!(builder, ty, outloc)
+ store!(builder, outloc, loc)
+ elseif direction == SRetValueToRootPointer
+ outloc = Enzyme.API.e_extract_value!(builder, sret, path)
+ store!(builder, outloc, loc)
+ elseif direction == RootPointerToSRetValue
+ loc = load!(builder, ty, loc)
+ sret = Enzyme.API.e_insert_value!(builder, sret, loc, path)
+ elseif direction == RootPointerToSRetPointer
+ outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+ loc = load!(builder, ty, loc)
+ push!(extracted, loc)
+ store!(builder, loc, outloc)
+ else
+ @assert false "Unhandled direction"
+ end
+
+ count += 1
+ continue
+ end
+ if isa(ty, LLVM.ArrayType)
+ if any_jltypes(ty)
+ for i in 1:length(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
- continue
end
- if isa(ty, LLVM.VectorType)
- if any_jltypes(ty)
- for i = 1:size(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
- end
+ continue
+ end
+ if isa(ty, LLVM.VectorType)
+ if any_jltypes(ty)
+ for i in 1:size(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
- continue
end
- if isa(ty, LLVM.StructType)
- for (i, t) in enumerate(LLVM.elements(ty))
- if any_jltypes(t)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, t))
- end
+ continue
+ end
+ if isa(ty, LLVM.StructType)
+ for (i, t) in enumerate(LLVM.elements(ty))
+ if any_jltypes(t)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, t))
end
- continue
end
+ continue
end
+ end
- if direction == RootPointerToSRetPointer
- obj = get_base_and_offset(sret)[1]
- @assert length(extracted) > 0
- emit_writebarrier!(builder, LLVM.Value[obj, extracted...])
- end
- tracked = CountTrackedPointers(jltype)
- @assert count == tracked.count
- return val
+ if direction == RootPointerToSRetPointer
+ obj = get_base_and_offset(sret)[1]
+ @assert length(extracted) > 0
+ emit_writebarrier!(builder, LLVM.Value[obj, extracted...])
+ end
+ tracked = CountTrackedPointers(jltype)
+ @assert count == tracked.count
+ return val
end
function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
- jltype = value_type(sret)
- tracked = CountTrackedPointers(jltype)
- @assert tracked.count > 0
- @assert !tracked.all
- root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
- move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
+ jltype = value_type(sret)
+ tracked = CountTrackedPointers(jltype)
+ @assert tracked.count > 0
+ @assert !tracked.all
+ root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+ return move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
end
function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM.Value)
- jltype = value_type(sret)
- tracked = CountTrackedPointers(jltype)
- @assert tracked.count > 0
- @assert !tracked.all
- root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
- move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
+ jltype = value_type(sret)
+ tracked = CountTrackedPointers(jltype)
+ @assert tracked.count > 0
+ @assert !tracked.all
+ root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+ return move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
end
@@ -3975,54 +3978,54 @@ function lower_convention(
removedRoots = Set{Int}()
function is_mixed(idx::Int)
- if TT === nothing
- return false
- end
- if idx > length(TT.parameters)
- throw(AssertionError("TT=$TT, args=$args idx=$idx"))
- end
- return (
- TT.parameters[idx] <: MixedDuplicated ||
- TT.parameters[idx] <: BatchMixedDuplicated
- ) &&
- run_enzyme
+ if TT === nothing
+ return false
+ end
+ if idx > length(TT.parameters)
+ throw(AssertionError("TT=$TT, args=$args idx=$idx"))
+ end
+ return (
+ TT.parameters[idx] <: MixedDuplicated ||
+ TT.parameters[idx] <: BatchMixedDuplicated
+ ) &&
+ run_enzyme
end
for arg in args
typ = arg.codegen.typ
-
- if arg.rooted_typ !== nothing
-
- # There cannot exist a root arg if the original arg was boxed
- @assert !GPUCompiler.deserves_argbox(arg.rooted_typ)
-
- # There only can exist a rooting if the original argument was a bits_ref
- @assert arg.rooted_cc == GPUCompiler.BITS_REF
-
- # If the original arg exists and was lowered to be a bits_ref, we will destroy
- # the extra rooted arg and recombine with the bits_ref
- if (arg.arg_i - 1) in loweredArgs
- push!(removedRoots, arg.arg_i)
- continue
- end
-
- # If we are raising an argument to mixed, we will still destroy the extra rooted
- # arg and recombine with the bits ref
- if (arg.arg_i - 1) in boxedArgs
- @assert is_mixed(arg.arg_jl_i)
- push!(removedRoots, arg.arg_i)
- continue
- end
-
- @assert false "Unhandled rooted arg condition"
- end
- if GPUCompiler.deserves_argbox(arg.typ)
+ if arg.rooted_typ !== nothing
+
+ # There cannot exist a root arg if the original arg was boxed
+ @assert !GPUCompiler.deserves_argbox(arg.rooted_typ)
+
+ # There only can exist a rooting if the original argument was a bits_ref
+ @assert arg.rooted_cc == GPUCompiler.BITS_REF
+
+ # If the original arg exists and was lowered to be a bits_ref, we will destroy
+ # the extra rooted arg and recombine with the bits_ref
+ if (arg.arg_i - 1) in loweredArgs
+ push!(removedRoots, arg.arg_i)
+ continue
+ end
+
+ # If we are raising an argument to mixed, we will still destroy the extra rooted
+ # arg and recombine with the bits ref
+ if (arg.arg_i - 1) in boxedArgs
+ @assert is_mixed(arg.arg_jl_i)
+ push!(removedRoots, arg.arg_i)
+ continue
+ end
+
+ @assert false "Unhandled rooted arg condition"
+ end
+
+ if GPUCompiler.deserves_argbox(arg.typ)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
elseif arg.cc != GPUCompiler.BITS_REF
- if is_mixed(arg.arg_jl_i)
+ if is_mixed(arg.arg_jl_i)
push!(boxedArgs, arg.arg_i)
push!(raisedArgs, arg.arg_i)
push!(wrapper_types, LLVM.PointerType(typ, Derived))
@@ -4033,7 +4036,7 @@ function lower_convention(
end
else
# bits ref, and not boxed
- if is_mixed(arg.arg_jl_i)
+ if is_mixed(arg.arg_jl_i)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
@@ -4106,22 +4109,22 @@ function lower_convention(
end
for arg in args
parm = ops[arg.codegen.i]
- if arg.arg_i in removedRoots
- if arg.rooted_arg_i in loweredArgs
- nops[end] = recombine_value!(builder, nops[end], parm)
- elseif arg.rooted_arg_i in raisedArgs
- jltype = convert(LLVMType, arg.rooted_typ)
- tracked = CountTrackedPointers(jltype)
- @assert tracked.count > 0
- @assert !tracked.all
- root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
- move_sret_tofrom_roots!(builder, jltype, nops[end], root_ty, parm, RootPointerToSRetPointer)
- else
- @assert false
- end
- elseif (arg.arg_i) in removedRoots && (arg.rooted_arg_i in loweredArgs || arg)
- continue
- elseif arg.arg_i in loweredArgs
+ if arg.arg_i in removedRoots
+ if arg.rooted_arg_i in loweredArgs
+ nops[end] = recombine_value!(builder, nops[end], parm)
+ elseif arg.rooted_arg_i in raisedArgs
+ jltype = convert(LLVMType, arg.rooted_typ)
+ tracked = CountTrackedPointers(jltype)
+ @assert tracked.count > 0
+ @assert !tracked.all
+ root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
+ move_sret_tofrom_roots!(builder, jltype, nops[end], root_ty, parm, RootPointerToSRetPointer)
+ else
+ @assert false
+ end
+ elseif (arg.arg_i) in removedRoots && (arg.rooted_arg_i in loweredArgs || arg)
+ continue
+ elseif arg.arg_i in loweredArgs
push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
elseif arg.arg_i in raisedArgs
obj = emit_allocobj!(builder, arg.typ, "raisedArg")
@@ -4131,10 +4134,10 @@ function lower_convention(
LLVM.PointerType(value_type(parm), addrspace(value_type(obj))),
)
store!(builder, parm, bc)
- if !(arg.arg_i in removedRoots)
+ if !(arg.arg_i in removedRoots)
emit_writebarrier!(builder, get_julia_inner_types(builder, obj, parm))
- end
- addr = addrspacecast!(
+ end
+ addr = addrspacecast!(
builder,
bc,
LLVM.PointerType(value_type(parm), Derived),
@@ -4178,7 +4181,7 @@ function lower_convention(
wrapper_args = Vector{LLVM.Value}()
sretPtr = nothing
- retRootPtr = nothing
+ retRootPtr = nothing
dl = string(LLVM.datalayout(LLVM.parent(entry_f)))
if sret
if !in(0, parmsRemoved)
@@ -4213,39 +4216,39 @@ function lower_convention(
end
# perform argument conversions
- wrapper_idx = 1
+ wrapper_idx = 1
for arg in args
parm = parameters(entry_f)[arg.codegen.i]
- if arg.arg_i in removedRoots
- wrapparm = parameters(wrapper_f)[wrapper_idx - 1]
- root_ty = convert(LLVMType, arg.typ)
- ptr = alloca!(builder, root_ty, LLVM.name(parm)*".innerparm")
+ if arg.arg_i in removedRoots
+ wrapparm = parameters(wrapper_f)[wrapper_idx - 1]
+ root_ty = convert(LLVMType, arg.typ)
+ ptr = alloca!(builder, root_ty, LLVM.name(parm) * ".innerparm")
if TT !== nothing && TT.parameters[arg.arg_jl_i] <: Const
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
-
+
ctx = LLVM.context(entry_f)
- typeTree = copy(typetree(arg.typ, ctx, dl, seen))
+ typeTree = copy(typetree(arg.typ, ctx, dl, seen))
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
metadata(ptr)["enzyme_type"] = to_md(typeTree, ctx)
-
- if arg.arg_i-1 in loweredArgs
- extract_roots_from_value!(builder, wrapparm, ptr)
- else
- @assert (arg.arg_i - 1) in boxedArgs
- @assert is_mixed(arg.arg_jl_i)
- jltype = convert(LLVMType, arg.rooted_typ)
- move_sret_tofrom_roots!(builder, jltype, wrapparm, root_ty, ptr, SRetPointerToRootPointer)
- end
+
+ if arg.arg_i - 1 in loweredArgs
+ extract_roots_from_value!(builder, wrapparm, ptr)
+ else
+ @assert (arg.arg_i - 1) in boxedArgs
+ @assert is_mixed(arg.arg_jl_i)
+ jltype = convert(LLVMType, arg.rooted_typ)
+ move_sret_tofrom_roots!(builder, jltype, wrapparm, root_ty, ptr, SRetPointerToRootPointer)
+ end
push!(wrapper_args, ptr)
- continue
- end
+ continue
+ end
- wrapparm = parameters(wrapper_f)[wrapper_idx]
- wrapper_idx += 1
- if arg.arg_i in loweredArgs
+ wrapparm = parameters(wrapper_f)[wrapper_idx]
+ wrapper_idx += 1
+ if arg.arg_i in loweredArgs
# copy the argument value to a stack slot, and reference it.
ty = value_type(parm)
if !isa(ty, LLVM.PointerType)
@@ -4285,7 +4288,7 @@ function lower_convention(
),
)
push!(
- parameter_attributes(wrapper_f, wrapper_idx - 1),
+ parameter_attributes(wrapper_f, wrapper_idx - 1),
StringAttribute(
"enzymejl_parmtype",
string(convert(UInt, unsafe_to_pointer(arg.typ))),
@@ -4306,7 +4309,7 @@ function lower_convention(
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
push!(
- parameter_attributes(wrapper_f, wrapper_idx - 1),
+ parameter_attributes(wrapper_f, wrapper_idx - 1),
StringAttribute(
"enzyme_type",
string(typeTree),
@@ -4330,7 +4333,7 @@ function lower_convention(
push!(wrapper_args, wrapparm)
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
push!(
- parameter_attributes(wrapper_f, wrapper_idx - 1),
+ parameter_attributes(wrapper_f, wrapper_idx - 1),
attr,
)
end
@@ -4467,13 +4470,13 @@ function lower_convention(
string(UInt(GPUCompiler.BITS_REF)),
),
)
- res = load!(builder, RT, sretPtr)
- @static if VERSION >= v"1.12"
- if returnRoots
- res = recombine_value!(builder, res, retRootPtr)
- end
- end
- ret!(builder, res)
+ res = load!(builder, RT, sretPtr)
+ @static if VERSION >= v"1.12"
+ if returnRoots
+ res = recombine_value!(builder, res, retRootPtr)
+ end
+ end
+ ret!(builder, res)
end
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
ret!(builder)
@@ -4627,25 +4630,25 @@ function lower_convention(
println(io, string(wrapper_f))
println(
io,
- "TT=$TT\n",
+ "TT=$TT\n",
"parmsRemoved=",
parmsRemoved,
"\nretRemoved=",
retRemoved,
"\nprargs=",
prargs,
- "\nreturnRoots=",
- returnRoots,
- "\nboxedArgs=",
- boxedArgs,
- "\nloweredArgs=",
- loweredArgs,
- "\nraisedArgs=",
- raisedArgs,
- "\nremovedRoots=",
- removedRoots,
- "\nloweredReturn=",
- loweredReturn
+ "\nreturnRoots=",
+ returnRoots,
+ "\nboxedArgs=",
+ boxedArgs,
+ "\nloweredArgs=",
+ loweredArgs,
+ "\nraisedArgs=",
+ raisedArgs,
+ "\nremovedRoots=",
+ removedRoots,
+ "\nloweredReturn=",
+ loweredReturn
)
println(io, "Broken lower convention")
end
@@ -4736,8 +4739,8 @@ function lower_convention(
LLVM.@dispose pb = NewPMPassBuilder() begin
add!(pb, NewPMModulePassManager()) do mpm
# Kill the temporary staging function
- add!(mpm, GlobalDCEPass())
- add!(mpm, GlobalOptPass())
+ add!(mpm, GlobalDCEPass())
+ add!(mpm, GlobalOptPass())
end
LLVM.run!(pb, mod)
end
@@ -4963,7 +4966,7 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
end
end
GPUCompiler.@safe_warn "Using fallback BLAS replacements for ($found), performance may be degraded"
- run!(GlobalOptPass(), mod)
+ run!(GlobalOptPass(), mod)
end
custom, state = set_module_types!(interp, mod, primalf, job, edges, params.run_enzyme, mode)
@@ -5267,16 +5270,16 @@ end
)
)
- size = Compiler.datatype_layoutsize(jTy)
+ size = Compiler.datatype_layoutsize(jTy)
if offset < size && isa(sz, LLVM.ConstantInt) && size - offset >= convert(Int, sz)
lim = convert(Int, sz)
md = to_fullmd(jTy, offset, lim)
@assert byref == GPUCompiler.BITS_REF ||
byref == GPUCompiler.MUT_REF
metadata(inst)["enzyme_truetype"] = md
- elseif byref == GPUCompiler.BITS_VALUE && jTy <: Ptr && eltype(jTy) == Any
- # Todo generalize this
- md = to_fullmd(jTy, 0, sizeof(Ptr{Cvoid}))
+ elseif byref == GPUCompiler.BITS_VALUE && jTy <: Ptr && eltype(jTy) == Any
+ # Todo generalize this
+ md = to_fullmd(jTy, 0, sizeof(Ptr{Cvoid}))
metadata(inst)["enzyme_truetype"] = md
end
end
@@ -5394,9 +5397,9 @@ end
nm == "ijl_new_array" ||
nm == "jl_new_array" ||
nm == "jl_alloc_genericmemory" ||
- nm == "ijl_alloc_genericmemory" ||
- nm == "jl_alloc_genericmemory_unchecked" ||
- nm == "ijl_alloc_genericmemory_unchecked"
+ nm == "ijl_alloc_genericmemory" ||
+ nm == "jl_alloc_genericmemory_unchecked" ||
+ nm == "ijl_alloc_genericmemory_unchecked"
continue
end
if is_readonly(called)
@@ -5470,7 +5473,7 @@ end
expectedTapeType,
loweredArgs,
boxedArgs,
- removedRoots,
+ removedRoots,
)
toremove = String[]
# Inline the wrapper
@@ -5871,7 +5874,7 @@ const DumpLLVMCall = Ref(false)
error("Return type `$rrt` not marked Const, but is ghost or const type.")
end
- needs_rooting = false
+ needs_rooting = false
sret_types = Type[] # Julia types of all returned variables
# By ref values we create and need to preserve
@@ -6139,14 +6142,14 @@ const DumpLLVMCall = Ref(false)
end
# calls fptr
- llvmtys = LLVMType[]
- for x in types
- push!(llvmtys, convert(LLVMType, x; allow_boxed = true))
- arg_roots = inline_roots_type(x)
- if needs_rooting && arg_roots != 0
- push!(llvmtys, convert(LLVMType, AnyArray(3)))
- end
- end
+ llvmtys = LLVMType[]
+ for x in types
+ push!(llvmtys, convert(LLVMType, x; allow_boxed = true))
+ arg_roots = inline_roots_type(x)
+ if needs_rooting && arg_roots != 0
+ push!(llvmtys, convert(LLVMType, AnyArray(3)))
+ end
+ end
T_void = convert(LLVMType, Nothing)
@@ -6209,11 +6212,11 @@ const DumpLLVMCall = Ref(false)
tape = callparams[end]
if TapeType <: EnzymeTapeToLoad
llty = Compiler.from_tape_type(eltype(TapeType))
-
- arg_roots = inline_roots_type(llty)
- if needs_rooting && arg_roots != 0
- throw(AssertionError("Should check about rooted tape calling conv"))
- end
+
+ arg_roots = inline_roots_type(llty)
+ if needs_rooting && arg_roots != 0
+ throw(AssertionError("Should check about rooted tape calling conv"))
+ end
tape = bitcast!(
builder,
@@ -6226,13 +6229,13 @@ const DumpLLVMCall = Ref(false)
else
llty = Compiler.from_tape_type(TapeType)
- arg_roots = inline_roots_type(llty)
- if needs_rooting && arg_roots != 0
- tape = callparams[end-1]
- end
- if value_type(tape) != llty
- throw(AssertionError("MisMatched Tape type, expected $(string(value_type(tape))) found $(string(llty)) from $TapeType arg_roots=$arg_roots"))
- end
+ arg_roots = inline_roots_type(llty)
+ if needs_rooting && arg_roots != 0
+ tape = callparams[end - 1]
+ end
+ if value_type(tape) != llty
+ throw(AssertionError("MisMatched Tape type, expected $(string(value_type(tape))) found $(string(llty)) from $TapeType arg_roots=$arg_roots"))
+ end
end
end
@@ -6275,9 +6278,9 @@ const DumpLLVMCall = Ref(false)
end
reinsert_gcmarker!(llvm_f)
- if DumpLLVMCall[]
- API.EnzymeDumpModuleRef(mod.ref)
- end
+ if DumpLLVMCall[]
+ API.EnzymeDumpModuleRef(mod.ref)
+ end
ir = string(mod)
fn = LLVM.name(llvm_f)
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 13518d4a..744a2d10 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -158,16 +158,16 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
run!(pb, mod, tm)
end
end # middle_optimize!
-
- run!(GCInvariantVerifierPass(strong=false), mod)
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
middle_optimize!()
-
- run!(GCInvariantVerifierPass(strong=false), mod)
-
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
+
middle_optimize!(true)
-
- run!(GCInvariantVerifierPass(strong=false), mod)
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
# Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to
# known functions
@@ -185,20 +185,20 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
end
run!(pb, mod, tm)
end
-
- run!(GCInvariantVerifierPass(strong=false), mod)
-
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
+
removeDeadArgs!(mod, tm)
-
- run!(GCInvariantVerifierPass(strong=false), mod)
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
detect_writeonly!(mod)
-
- run!(GCInvariantVerifierPass(strong=false), mod)
-
+
+ run!(GCInvariantVerifierPass(strong = false), mod)
+
nodecayed_phis!(mod)
-
- run!(GCInvariantVerifierPass(strong=false), mod)
+
+ return run!(GCInvariantVerifierPass(strong = false), mod)
end
function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)
diff --git a/src/errors.jl b/src/errors.jl
index 81a39046..22f553cf 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -184,11 +184,11 @@ function Base.showerror(io::IO, ece::CallingConventionMismatchError)
println(io)
- if VERBOSE_ERRORS[]
+ return if VERBOSE_ERRORS[]
if ece.backtrace isa Cstring
- Base.println(io, Base.unsafe_string(ece.backtrace))
+ Base.println(io, Base.unsafe_string(ece.backtrace))
else
- Base.println(io, ece.backtrace)
+ Base.println(io, ece.backtrace)
end
else
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
@@ -1252,41 +1252,41 @@ else
end
end
end
-
- if isa(cur, LLVM.LoadInst)
- larg, off = get_base_and_offset(operands(cur)[1])
- if off == 0 && isa(larg, LLVM.AllocaInst)
- legal = true
- for u in LLVM.uses(larg)
- u = LLVM.user(u)
- if isa(u, LLVM.LoadInst)
- continue
- end
- if isa(u, LLVM.CallInst) && isa(called_operand(u), LLVM.Function)
- intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(u))
- if intr == LLVM.Intrinsic("llvm.lifetime.start").id || intr == LLVM.Intrinsic("llvm.lifetime.end").id || LLVM.name(called_operand(u)) == "llvm.enzyme.lifetime_end" || LLVM.name(called_operand(u)) ==
- "llvm.enzyme.lifetime_start"
- continue
- end
- end
- if isa(u, LLVM.StoreInst)
- v = operands(u)[1]
- if v == larg
- legal = false;
- break
- end
- if v isa ConstantInt && convert(Int, v) == -1
- continue
- end
- end
- legal = false
- break
- end
- if legal
- return make_batched(ncur, prevbb)
- end
- end
- end
+
+ if isa(cur, LLVM.LoadInst)
+ larg, off = get_base_and_offset(operands(cur)[1])
+ if off == 0 && isa(larg, LLVM.AllocaInst)
+ legal = true
+ for u in LLVM.uses(larg)
+ u = LLVM.user(u)
+ if isa(u, LLVM.LoadInst)
+ continue
+ end
+ if isa(u, LLVM.CallInst) && isa(called_operand(u), LLVM.Function)
+ intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(u))
+ if intr == LLVM.Intrinsic("llvm.lifetime.start").id || intr == LLVM.Intrinsic("llvm.lifetime.end").id || LLVM.name(called_operand(u)) == "llvm.enzyme.lifetime_end" || LLVM.name(called_operand(u)) ==
+ "llvm.enzyme.lifetime_start"
+ continue
+ end
+ end
+ if isa(u, LLVM.StoreInst)
+ v = operands(u)[1]
+ if v == larg
+ legal = false
+ break
+ end
+ if v isa ConstantInt && convert(Int, v) == -1
+ continue
+ end
+ end
+ legal = false
+ break
+ end
+ if legal
+ return make_batched(ncur, prevbb)
+ end
+ end
+ end
legal, TT, byref = abs_typeof(cur, true)
diff --git a/src/gradientutils.jl b/src/gradientutils.jl
index dec9f7d5..8a9d4fe5 100644
--- a/src/gradientutils.jl
+++ b/src/gradientutils.jl
@@ -317,7 +317,7 @@ function batch_call_same_with_inverted_arg_if_active!(
args::Vector{<:LLVM.Value},
valTys::Vector{API.CValueType},
lookup::Bool;
- need_result = true,
+ need_result = true,
kwargs...
)
@@ -340,7 +340,7 @@ function batch_call_same_with_inverted_arg_if_active!(
end
end
end
- res = call_same_with_inverted_arg_if_active!(B, gutils, orig, args2, valTys, lookup; need_result, kwargs..., movebefore=idx == 1)
+ res = call_same_with_inverted_arg_if_active!(B, gutils, orig, args2, valTys, lookup; need_result, kwargs..., movebefore = idx == 1)
if shadow === nothing
continue
end
diff --git a/src/jlrt.jl b/src/jlrt.jl
index b806ac50..ebf12393 100644
--- a/src/jlrt.jl
+++ b/src/jlrt.jl
@@ -889,9 +889,9 @@ end
function emit_type_layout_elsz!(B::LLVM.IRBuilder, @nospecialize(ty::LLVM.Value))
legal, JTy = absint(ty)
if legal
- @assert JTy isa Type
- res = Compiler.datatype_layoutsize(JTy)
- return LLVM.ConstantInt(res)
+ @assert JTy isa Type
+ res = Compiler.datatype_layoutsize(JTy)
+ return LLVM.ConstantInt(res)
end
ty = emit_layout_of_type!(B, ty)
@@ -957,20 +957,20 @@ function get_memory_len(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
end
if nm in (
- "jl_alloc_genericmemory",
- "ijl_alloc_genericmemory",
- )
- res = operands(array)[2]
+ "jl_alloc_genericmemory",
+ "ijl_alloc_genericmemory",
+ )
+ res = operands(array)[2]
return res
end
if nm in (
- "jl_alloc_genericmemory_unchecked",
- "ijl_alloc_genericmemory_unchecked",
- )
- # This is number of bytes not number of elements
- res = get_memory_size(B, array)
- es = get_memory_elsz(B, array)
- return udiv!(B, res, es)
+ "jl_alloc_genericmemory_unchecked",
+ "ijl_alloc_genericmemory_unchecked",
+ )
+ # This is number of bytes not number of elements
+ res = get_memory_size(B, array)
+ es = get_memory_elsz(B, array)
+ return udiv!(B, res, es)
end
end
ST = get_memory_struct()
@@ -992,23 +992,23 @@ end
# nel - number of elements
#
-@static if VERSION >= v"1.11"
-function get_memory_nbytes(B::LLVM.IRBuilder, memty::Type{<:Memory}, nel::LLVM.Value)
- elsz = LLVM.ConstantInt(Compiler.datatype_layoutsize(memty))
- isboxed = Base.datatype_arrayelem(memty) == 1
- isunion = Base.datatype_arrayelem(memty) == 2
-
- if isboxed
- elsz = LLVM.ConstantInt(sizeof(Ptr{Cvoid}))
- end
- nbytes = LLVM.mul!(B, nel, elsz)
+@static if VERSION >= v"1.11"
+ function get_memory_nbytes(B::LLVM.IRBuilder, memty::Type{<:Memory}, nel::LLVM.Value)
+ elsz = LLVM.ConstantInt(Compiler.datatype_layoutsize(memty))
+ isboxed = Base.datatype_arrayelem(memty) == 1
+ isunion = Base.datatype_arrayelem(memty) == 2
+
+ if isboxed
+ elsz = LLVM.ConstantInt(sizeof(Ptr{Cvoid}))
+ end
+ nbytes = LLVM.mul!(B, nel, elsz)
- if isunion
- # an extra byte for each isbits union memory element, stored at m->ptr + m->length
- nbytes = LLVM.add!(B, nbytes, nel)
+ if isunion
+ # an extra byte for each isbits union memory element, stored at m->ptr + m->length
+ nbytes = LLVM.add!(B, nbytes, nel)
+ end
+ return nbytes
end
- return nbytes
-end
end
function get_memory_nbytes(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
@@ -1019,12 +1019,12 @@ function get_memory_nbytes(B::LLVM.IRBuilder, @nospecialize(array::LLVM.Value))
nm = LLVM.name(fn)
end
if nm in (
- "jl_alloc_genericmemory_unchecked",
- "ijl_alloc_genericmemory_unchecked",
- )
- # This is number of bytes not number of elements
- res = operands(array)[2]
- return res
+ "jl_alloc_genericmemory_unchecked",
+ "ijl_alloc_genericmemory_unchecked",
+ )
+ # This is number of bytes not number of elements
+ res = operands(array)[2]
+ return res
end
end
nel = get_memory_len(B, array)
diff --git a/src/llvm/attributes.jl b/src/llvm/attributes.jl
index c384a2b6..b0eab84b 100644
--- a/src/llvm/attributes.jl
+++ b/src/llvm/attributes.jl
@@ -747,14 +747,14 @@ function annotate!(mod::LLVM.Module)
"ijl_gc_alloc_typed",
"jl_alloc_genericmemory",
"ijl_alloc_genericmemory",
- "jl_alloc_genericmemory_unchecked",
- "ijl_alloc_genericmemory_unchecked",
+ "jl_alloc_genericmemory_unchecked",
+ "ijl_alloc_genericmemory_unchecked",
"jl_alloc_array_1d",
"jl_alloc_array_2d",
"jl_alloc_array_3d",
"ijl_alloc_array_1d",
"ijl_alloc_array_2d",
- "ijl_alloc_array_3d",
+ "ijl_alloc_array_3d",
"ijl_new_array",
"jl_new_array"
)
@@ -808,8 +808,8 @@ function annotate!(mod::LLVM.Module)
"ijl_box_int64",
"jl_alloc_genericmemory",
"ijl_alloc_genericmemory",
- "jl_alloc_genericmemory_unchecked",
- "ijl_alloc_genericmemory_unchecked",
+ "jl_alloc_genericmemory_unchecked",
+ "ijl_alloc_genericmemory_unchecked",
"jl_alloc_array_1d",
"jl_alloc_array_2d",
"jl_alloc_array_3d",
@@ -821,7 +821,7 @@ function annotate!(mod::LLVM.Module)
"jl_genericmemory_slice",
"ijl_genericmemory_slice",
"jl_genericmemory_copy_slice",
- "ijl_genericmemory_copy_slice",
+ "ijl_genericmemory_copy_slice",
"jl_idtable_rehash",
"ijl_idtable_rehash",
"jl_f_tuple",
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 5bd0124a..bf7416b5 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2623,9 +2623,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
add!(fpm, AllocOptPass())
add!(fpm, SROAPass())
end
- if RunAttributor[]
+ if RunAttributor[]
add!(mpm, EnzymeAttributorPass())
- end
+ end
add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, EarlyCSEPass())
end
diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl
index b0094dc9..7a87f26b 100644
--- a/src/rules/activityrules.jl
+++ b/src/rules/activityrules.jl
@@ -79,7 +79,7 @@ function julia_activity_rule(f::LLVM.Function, method_table)
typ, _ = enzyme_extract_parm_type(f, arg.codegen.i)
@assert typ == arg.typ
- if (kwarg_inactive && arg.arg_i == 2) || guaranteed_const_nongen(arg.typ, world) || (arg.rooted_typ !== nothing && guaranteed_const_nongen(arg.rooted_typ, world))
+ if (kwarg_inactive && arg.arg_i == 2) || guaranteed_const_nongen(arg.typ, world) || (arg.rooted_typ !== nothing && guaranteed_const_nongen(arg.rooted_typ, world))
push!(
parameter_attributes(f, arg.codegen.i),
StringAttribute("enzyme_inactive"),
diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl
index c2220f99..3b6e2955 100644
--- a/src/rules/allocrules.jl
+++ b/src/rules/allocrules.jl
@@ -20,26 +20,26 @@ function array_shadow_handler(
),
)
end
-
+
b = LLVM.IRBuilder(B)
orig = LLVM.Value(OrigCI)::LLVM.CallInst
nm = LLVM.name(LLVM.called_operand(orig)::LLVM.Function)
-
+
if iszeroinit(typ)
- # If already zero init we should not need to perform the initial memset.
- # However as I have yet to actually see such a type exist in the wild, I want to see
- # what triggers it.
+ # If already zero init we should not need to perform the initial memset.
+ # However as I have yet to actually see such a type exist in the wild, I want to see
+ # what triggers it.
throw(
AssertionError(
- "THERE IS A TYPE WHICH IS ZERO INIT ($typ)",
+ "THERE IS A TYPE WHICH IS ZERO INIT ($typ)",
),
)
- # Only the regular, checked version does the zero.
- if nm == "jl_alloc_genericmemory" || nm == "ijl_alloc_genericmemory"
- return C_NULL
- end
+ # Only the regular, checked version does the zero.
+ if nm == "jl_alloc_genericmemory" || nm == "ijl_alloc_genericmemory"
+ return C_NULL
+ end
end
typ = eltype(typ)
@@ -54,7 +54,7 @@ function array_shadow_handler(
end
anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, false) #=lookup=#
-
+
isunboxed, elsz, al = Base.uniontype_layout(typ)
isunion = typ isa Union
@@ -73,17 +73,17 @@ function array_shadow_handler(
get_memory_nbytes(b, anti)
else
arlen = get_array_len(b, anti)
- tot = LLVM.mul!(b, arlen, LLVM.ConstantInt(LLVM.value_type(arlen), elsz, false))
-
- if elsz == 1 && !isunion
- # extra byte for all julia allocated byte arrays
- tot = LLVM.add!(b, tot, LLVM.ConstantInt(LLVM.value_type(tot), 1, false))
- end
- if isunion
- # an extra byte for each isbits union array element, stored after a->maxsize
- tot = LLVM.add!(b, tot, prod)
- end
- tot
+ tot = LLVM.mul!(b, arlen, LLVM.ConstantInt(LLVM.value_type(arlen), elsz, false))
+
+ if elsz == 1 && !isunion
+ # extra byte for all julia allocated byte arrays
+ tot = LLVM.add!(b, tot, LLVM.ConstantInt(LLVM.value_type(tot), 1, false))
+ end
+ if isunion
+ # an extra byte for each isbits union array element, stored after a->maxsize
+ tot = LLVM.add!(b, tot, prod)
+ end
+ tot
end
@@ -114,7 +114,7 @@ end
"jl_alloc_array_3d", "ijl_alloc_array_3d",
"jl_new_array", "ijl_new_array",
"jl_alloc_genericmemory", "ijl_alloc_genericmemory",
- "jl_alloc_genericmemory_unchecked", "ijl_alloc_genericmemory_unchecked"
+ "jl_alloc_genericmemory_unchecked", "ijl_alloc_genericmemory_unchecked",
),
@cfunction(
array_shadow_handler,
diff --git a/src/typeutils/jltypes.jl b/src/typeutils/jltypes.jl
index 72146f2e..3026ea65 100644
--- a/src/typeutils/jltypes.jl
+++ b/src/typeutils/jltypes.jl
@@ -2,14 +2,14 @@
iszeroinit(Base.@nospecialize t) = (Base.@_total_meta; isa(t, DataType) && (t.flags & 0x0004) == 0x0004)
@static if VERSION >= v"1.11"
-const datatype_layoutsize = Base.datatype_layoutsize
+ const datatype_layoutsize = Base.datatype_layoutsize
else
-function datatype_layoutsize(dt::Base.DataType)
- Base.@_foldable_meta
- dt.layout == C_NULL && throw(Base.UndefRefError())
- size = unsafe_load(convert(Ptr{Base.DataTypeLayout}, dt.layout)).size
- return size % Int
-end
+ function datatype_layoutsize(dt::Base.DataType)
+ Base.@_foldable_meta
+ dt.layout == C_NULL && throw(Base.UndefRefError())
+ size = unsafe_load(convert(Ptr{Base.DataTypeLayout}, dt.layout)).size
+ return size % Int
+ end
end
# On 1.12+, there was a change to the calling convention where
@@ -17,38 +17,38 @@ end
# return the number of roots in the corresponding convention, or
# 0 if it does not apply https://github.com/JuliaLang/julia/pull/55767/files#diff-62cfb2606c6a323a7f26a3eddfa0bf2b819fa33e094561fee09daeb328e3a1e7
function inline_roots_type(@nospecialize(LT::LLVM.LLVMType))::Int
- @static if VERSION <= v"1.12-"
- return 0
- else
- if !(LT isa LLVM.ArrayType || LT isa LLVM.StructType)
- return 0
- end
- tracked = CountTrackedPointers(LT)
- if tracked.count > 0 && !tracked.all
- return tracked.count
- end
- return 0
- end
+ @static if VERSION <= v"1.12-"
+ return 0
+ else
+ if !(LT isa LLVM.ArrayType || LT isa LLVM.StructType)
+ return 0
+ end
+ tracked = CountTrackedPointers(LT)
+ if tracked.count > 0 && !tracked.all
+ return tracked.count
+ end
+ return 0
+ end
end
function inline_roots_type(@nospecialize(T::Type))::Int
- @static if VERSION <= v"1.12-"
- return 0
- else
- if T === Union{}
- return 0
- end
- if GPUCompiler.deserves_argbox(T)
- return 0
- end
- if Base.isabstracttype(T)
- return 0
- end
- if isghostty(T) || Core.Compiler.isconstType(T)
- return 0
- end
- LT = convert(LLVM.LLVMType, T)
- return inline_roots_type(LT)
+ @static if VERSION <= v"1.12-"
+ return 0
+ else
+ if T === Union{}
+ return 0
+ end
+ if GPUCompiler.deserves_argbox(T)
+ return 0
+ end
+ if Base.isabstracttype(T)
+ return 0
+ end
+ if isghostty(T) || Core.Compiler.isconstType(T)
+ return 0
+ end
+ LT = convert(LLVM.LLVMType, T)
+ return inline_roots_type(LT)
end
end
@@ -56,35 +56,34 @@ end
# with the AnyArray's as requisite for the new roots for the calling convention
# on 1.12
function rooted_argument_list(iterable)
- results = Tuple{Type, Union{Nothing, Type}}[]
- for T in iterable
- roots = inline_roots_type(T)
- push!(results, (T, nothing))
- if roots != 0
- push!(results, (AnyArray(roots), T))
- end
- end
- return results
+ results = Tuple{Type, Union{Nothing, Type}}[]
+ for T in iterable
+ roots = inline_roots_type(T)
+ push!(results, (T, nothing))
+ if roots != 0
+ push!(results, (AnyArray(roots), T))
+ end
+ end
+ return results
end
function split_value_into(B::LLVM.IRBuilder, val::LLVM.Value)
- LT = value_type(val)
- tracked = CountTrackedPointers(LT)
- @assert tracked.count > 0
- @assert !tracked.all
- RT = convert(LLVM.LLVMType, AnyArray(tracked.count))
- al = alloca!(B, RT)
- fdsafdsa
- return (val, al)
+ LT = value_type(val)
+ tracked = CountTrackedPointers(LT)
+ @assert tracked.count > 0
+ @assert !tracked.all
+ RT = convert(LLVM.LLVMType, AnyArray(tracked.count))
+ al = alloca!(B, RT)
+ fdsafdsa
+ return (val, al)
end
function recombine_value(B::LLVM.IRBuilder, val::LLVM.Value, roots::LLVM.Value)
- TODO
+ TODO
return val
end
-
struct RemovedParam end
# Modified from GPUCompiler classify_arguments
@@ -123,30 +122,36 @@ function classify_arguments(
last_cc = nothing
arg_jl_i = 1
for (source_i, (source_typ, rooted_typ)) in enumerate(rooted_argument_list(source_sig.parameters))
- if rooted_typ !== nothing
- arg_jl_i -= 1
- end
+ if rooted_typ !== nothing
+ arg_jl_i -= 1
+ end
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
- push!(args, (cc = GPUCompiler.GHOST, typ = source_typ, arg_i = source_i,
- rooted_typ = rooted_typ,
- rooted_arg_i = rooted_typ === nothing ? nothing : (source_i - 1),
- rooted_cc = rooted_typ === nothing ? nothing : last_cc,
- arg_jl_i = arg_jl_i,
- ...*[Comment body truncated]* |
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19386158803/artifacts/4576044989. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2782 +/- ##
==========================================
- Coverage 69.01% 68.52% -0.50%
==========================================
Files 58 58
Lines 19996 20335 +339
==========================================
+ Hits 13800 13934 +134
- Misses 6196 6401 +205 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
didn't test yet if worked, just did a sweep through parts of the code.
custom derivatives also need a sweep