Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/Enzyme.jl b/src/Enzyme.jl
index d93bc0b6..17ec9f1a 100644
--- a/src/Enzyme.jl
+++ b/src/Enzyme.jl
@@ -144,7 +144,7 @@ function Base.showerror(io::IO, ece::OpaquePointerError)
Base.Experimental.show_error_hints(io, ece)
end
print(io, "OpaquePointerError: Enzyme execution failed to handle opaque pointers, with the following information:\n")
- print(io, ece.msg, '\n')
+ return print(io, ece.msg, '\n')
end
diff --git a/src/compiler.jl b/src/compiler.jl
index 432a6c5c..66d46360 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3937,7 +3937,7 @@ function lower_convention(
if returnRoots && !in(1, parmsRemoved)
retRootPtr = alloca!(
builder,
- sret_ty(entry_f, 1+sret),
+ sret_ty(entry_f, 1 + sret),
"innerreturnroots",
)
# retRootPtr = alloca!(builder, parameters(wrapper_f)[1])
diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl
index 9b120731..c4d8832b 100644
--- a/src/compiler/validation.jl
+++ b/src/compiler/validation.jl
@@ -616,9 +616,9 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
method_table = Core.Compiler.method_table(interp)
bt = backtrace(inst)
dest = called_operand(inst)
-
+
if isa(dest, LLVM.PHIInst) && all(Base.Fix1(==, operands(dest)[1]), operands(dest))
- dest = operands(dest)[1]
+ dest = operands(dest)[1]
end
if isa(dest, LLVM.ConstantExpr) && opcode(dest) == LLVM.API.LLVMIntToPtr && isa(operands(dest)[1], LLVM.ConstantExpr) && opcode(operands(dest)[1]) == LLVM.API.LLVMPtrToInt
dest = operands(operands(dest)[1])[1]
@@ -1148,7 +1148,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
else
false, nothing
end
-
+
lfn = nothing
if found
lfn = replaceWith
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index a41033c5..69799cf1 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -753,11 +753,11 @@ function nodecayed_phis!(mod::LLVM.Module)
v2 = operands(v)[1]
if addrspace(value_type(v2)) == 0
if addr == 13 && isa(v, LLVM.ConstantExpr)
- PT = if LLVM.is_opaque(value_type(v))
- LLVM.PointerType(10)
- else
- LLVM.PointerType(eltype(value_type(v)), 10)
- end
+ PT = if LLVM.is_opaque(value_type(v))
+ LLVM.PointerType(10)
+ else
+ LLVM.PointerType(eltype(value_type(v)), 10)
+ end
v2 = const_addrspacecast(
operands(v)[1],
PT
@@ -917,12 +917,12 @@ function nodecayed_phis!(mod::LLVM.Module)
undeforpoison |= isa(v, LLVM.PoisonValue)
end
if undeforpoison
- PT = if LLVM.is_opaque(value_type(v))
- LLVM.PointerType(10)
- else
- LLVM.PointerType(eltype(value_type(v)), 10)
- end
- return LLVM.UndefValue(PT), offset, addr == 13
+ PT = if LLVM.is_opaque(value_type(v))
+ LLVM.PointerType(10)
+ else
+ LLVM.PointerType(eltype(value_type(v)), 10)
+ end
+ return LLVM.UndefValue(PT), offset, addr == 13
end
if isa(v, LLVM.PHIInst) && !hasload && haskey(goffsets, v)
@@ -1241,7 +1241,7 @@ function fix_decayaddr!(mod::LLVM.Module)
mayread = false
maywrite = false
sret = true
- sret_elty = nothing
+ sret_elty = nothing
sretkind = kind(if LLVM.version().major >= 12
TypeAttribute("sret", LLVM.Int32Type())
else
@@ -1255,11 +1255,11 @@ function fix_decayaddr!(mod::LLVM.Module)
t_sret = false
for a in collect(parameter_attributes(fop, i))
if kind(a) == sretkind
- sret_elty = sret_ty(fop, i)
+ sret_elty = sret_ty(fop, i)
t_sret = true
end
if kind(a) == kind(StringAttribute("enzyme_sret"))
- sret_elty = sret_ty(fop, i)
+ sret_elty = sret_ty(fop, i)
t_sret = true
end
# if kind(a) == kind(StringAttribute("enzyme_sret_v"))
@@ -1300,7 +1300,7 @@ function fix_decayaddr!(mod::LLVM.Module)
throw(AssertionError(msg))
end
- @assert sret_elty !== nothing
+ @assert sret_elty !== nothing
if temp === nothing
nb = IRBuilder()
position!(nb, first(instructions(first(blocks(f)))))
@@ -1447,11 +1447,11 @@ function prop_global!(g::LLVM.GlobalVariable)
end
end
end
- if value_type(var) != value_type(res)
- al = alloca!(B, value_type(res))
- store!(B, res, al)
- res = load!(B, value_type(var), al)
- end
+ if value_type(var) != value_type(res)
+ al = alloca!(B, value_type(res))
+ store!(B, res, al)
+ res = load!(B, value_type(var), al)
+ end
replace_uses!(var, res)
eraseInst(LLVM.parent(var), var)
continue
@@ -1667,10 +1667,10 @@ function propagate_returned!(mod::LLVM.Module)
changed = true
end
has_user = false
- for u in LLVM.uses(fn)
- has_user = true
- break
- end
+ for u in LLVM.uses(fn)
+ has_user = true
+ break
+ end
attrs = collect(function_attributes(fn))
prevent = any(
kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for
@@ -1681,8 +1681,8 @@ function propagate_returned!(mod::LLVM.Module)
# end
argn = nothing
toremove = Int64[]
- # Don't bother with functions we're about to delete anyways
- if has_user
+ # Don't bother with functions we're about to delete anyways
+ if has_user
for (i, arg) in enumerate(parameters(fn))
if any(
kind(attr) == kind(EnumAttribute("returned")) for
@@ -1726,7 +1726,7 @@ function propagate_returned!(mod::LLVM.Module)
if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue)
illegalUse = true
break
- end
+ end
seenfn = false
todo = LLVM.Instruction[]
if isa(ops[i], LLVM.AllocaInst)
@@ -1793,20 +1793,20 @@ function propagate_returned!(mod::LLVM.Module)
position!(B, first(instructions(first(blocks(fn)))))
- has_use = false
- for _ in LLVM.uses(arg)
- has_use = true
- break
- end
+ has_use = false
+ for _ in LLVM.uses(arg)
+ has_use = true
+ break
+ end
- if has_use
- argeltype = sret_ty(fn, i)
- al = alloca!(B, argeltype)
- if value_type(al) != value_type(arg)
- al = addrspacecast!(B, al, value_type(arg))
+ if has_use
+ argeltype = sret_ty(fn, i)
+ al = alloca!(B, argeltype)
+ if value_type(al) != value_type(arg)
+ al = addrspacecast!(B, al, value_type(arg))
+ end
+ LLVM.replace_uses!(arg, al)
end
- LLVM.replace_uses!(arg, al)
- end
end
end
@@ -1907,7 +1907,7 @@ function propagate_returned!(mod::LLVM.Module)
end
end
end
- end
+ end
illegalUse = !(
linkage(fn) == LLVM.API.LLVMInternalLinkage ||
linkage(fn) == LLVM.API.LLVMPrivateLinkage
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index d568fa2a..8a9541a0 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -301,7 +301,7 @@ function enzyme_custom_setup_args(
LLVM.ConstantInt(LLVM.IntType(32), 0),
],
)
-
+
if !is_opaque(value_type(ptr))
@assert eltype(value_type(ptr)) == arty
end
@@ -1452,7 +1452,7 @@ function enzyme_custom_common_rev(
end
if sret !== nothing
- sty = sret_ty(llvmf, 1+swiftself)
+ sty = sret_ty(llvmf, 1 + swiftself)
if LLVM.version().major >= 12
attr = TypeAttribute("sret", sty)
else
diff --git a/src/typeutils/conversion.jl b/src/typeutils/conversion.jl
index 1dbcb893..b090667a 100644
--- a/src/typeutils/conversion.jl
+++ b/src/typeutils/conversion.jl
@@ -26,7 +26,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool}
if 10 <= addrspace <= 12
return Any, true
elseif LLVM.is_opaque(LLVM.PointerType(Type))
- return Core.LLVMPtr{Cvoid,Int(addrspace)}, false
+ return Core.LLVMPtr{Cvoid, Int(addrspace)}, false
else
e = LLVM.API.LLVMGetElementType(Type)
tkind2 = LLVM.API.LLVMGetTypeKind(e)
diff --git a/src/utils.jl b/src/utils.jl
index 137d7460..206fafd9 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -503,11 +503,13 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
vt = LLVM.value_type(LLVM.parameters(fn)[idx])
- sretkind = LLVM.kind(if LLVM.version().major >= 12
- LLVM.TypeAttribute("sret", LLVM.Int32Type())
- else
- LLVM.EnumAttribute("sret")
- end)
+ sretkind = LLVM.kind(
+ if LLVM.version().major >= 12
+ LLVM.TypeAttribute("sret", LLVM.Int32Type())
+ else
+ LLVM.EnumAttribute("sret")
+ end
+ )
enzymejl_parmtype_ref = nothing
@@ -537,7 +539,7 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
if ekind == "enzymejl_returnRoots"
nroots = parse(Int, LLVM.value(attr))
-
+
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
@@ -549,13 +551,13 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
end
if ekind == "enzyme_sret"
- ety = parse(UInt, LLVM.value(attr))
- ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
- ety = LLVM.LLVMType(ety)
+ ety = parse(UInt, LLVM.value(attr))
+ ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
+ ety = LLVM.LLVMType(ety)
if !LLVM.is_opaque(vt)
- @assert ety == eltype(vt)
+ @assert ety == eltype(vt)
end
-
+
return ety
end
diff --git a/test/rules/internal_rules.jl b/test/rules/internal_rules.jl
index 93c2fc53..51d7ec0f 100644
--- a/test/rules/internal_rules.jl
+++ b/test/rules/internal_rules.jl
@@ -108,7 +108,7 @@ end
res = Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0)))
@test res[1][1] ≈ 375.0
@test res[1][2] ≈ 750.0
-
+
@test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
((var"1" = 25.0, var"2" = 50.0),)
@test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == |
Contributor
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19282144045/artifacts/4537588880. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2764 +/- ##
==========================================
+ Coverage 68.91% 68.95% +0.04%
==========================================
Files 58 58
Lines 19861 19961 +100
==========================================
+ Hits 13688 13765 +77
- Misses 6173 6196 +23 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This was referenced Nov 12, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.