Skip to content

Fix hessian sret type#2824

Open
wsmoses wants to merge 8 commits intomainfrom
hconv
Open

Fix hessian sret type#2824
wsmoses wants to merge 8 commits intomainfrom
hconv

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Dec 1, 2025

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/compiler.jl b/src/compiler.jl
index d25ddc22..ce7f9969 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1080,15 +1080,15 @@ end
                     EnumAttribute("willreturn"),
                     EnumAttribute("nosync"),
                     EnumAttribute("nofree"),
-           	    StringAttribute("enzyme_preserve_primal", "*"),
+            StringAttribute("enzyme_preserve_primal", "*"),
 		      ]
     else
         LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"),
                     EnumAttribute("willreturn"),
                     EnumAttribute("nosync"),
-		    EnumAttribute("nofree"),
-           	    StringAttribute("enzyme_preserve_primal", "*"),
-		    ]
+            EnumAttribute("nofree"),
+            StringAttribute("enzyme_preserve_primal", "*"),
+        ]
     end
     handleCustom(state, custom, k_name, llvmfn, name, attrs)
     return
@@ -6765,11 +6765,11 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
             for f in functions(mod)
                 for i in 1:length(parameters(f))
                     for a in collect(parameter_attributes(f, i))
-                       if kind(a) == "enzyme_sret"
-                           API.EnzymeDumpValueRef(f)
-                       end
-                       @assert kind(a) != "enzyme_sret"
-                       @assert kind(a) != "enzyme_sret_v"
+                        if kind(a) == "enzyme_sret"
+                            API.EnzymeDumpValueRef(f)
+                        end
+                        @assert kind(a) != "enzyme_sret"
+                        @assert kind(a) != "enzyme_sret_v"
                     end
                 end
             end
@@ -6779,7 +6779,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
             if DumpPrePostOpt[]
                 API.EnzymeDumpModuleRef(mod.ref)
             end
-            post_optimize!(mod, JIT.get_tm(); callconv=false)
+            post_optimize!(mod, JIT.get_tm(); callconv = false)
             if DumpPostOpt[]
                 API.EnzymeDumpModuleRef(mod.ref)
             end
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index ec2e8d34..6fbe8b6b 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -392,7 +392,7 @@ const DumpPostCallConv = Ref(false)
 function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
     addr13NoAlias(mod)
     
-    removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
+    removeDeadArgs!(mod, tm, #=post_gc_fixup=# false)
 
     memcpy_sret_split!(mod)
     # if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
@@ -448,17 +448,17 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
     if callconv
         fixup_callconv!(mod, tm)
     end
-    
+
     for f in functions(mod)
-	if isempty(blocks(f))
-		continue
-	end
-	if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
-	     delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
-	end
+        if isempty(blocks(f))
+            continue
+        end
+        if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
+            delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
+        end
     end
 
-    removeDeadArgs!(mod, tm, #=post_gc_fixup=#true)
+    removeDeadArgs!(mod, tm, #=post_gc_fixup=# true)
 
     @dispose pb = NewPMPassBuilder() begin
         registerEnzymeAndPassPipeline!(pb)
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 5d488bb9..ec7c12db 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2619,16 +2619,16 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
             )
                 for u in LLVM.uses(fn)
                     u = LLVM.user(u)
-		    if !isa(u, LLVM.CallInst)
-                    	# TODO investigate if the inttoptr store that comes from reference caller poses an issue.
-			continue
-			msg = sprint() do io
-			   println(io, "Unknown user of fn: ", string(u))
-			   println(io, "fn: ", string(fn))
-			   println(io, "mod: ", string(LLVM.parent(fn)))
-			end
-			throw(AssertionError(msg))
-		    end
+                        if !isa(u, LLVM.CallInst)
+                            # TODO investigate if the inttoptr store that comes from reference caller poses an issue.
+                            continue
+                            msg = sprint() do io
+                                println(io, "Unknown user of fn: ", string(u))
+                                println(io, "fn: ", string(fn))
+                                println(io, "mod: ", string(LLVM.parent(fn)))
+                            end
+                            throw(AssertionError(msg))
+                        end
                     B = IRBuilder()
                     nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
                     position!(B, nextInst)
@@ -2665,26 +2665,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
                 for u in LLVM.uses(fn)
                     u = LLVM.user(u)
                     if isa(u, LLVM.ConstantExpr)
-			for u in LLVM.uses(u)
-			   u = LLVM.user(u)
-			    if !isa(u, LLVM.CallInst)
-				continue
-			    end
-			    @assert isa(u, LLVM.CallInst)
-			    B = IRBuilder()
-			    nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
-			    position!(B, nextInst)
-			    inp = operands(u)[idx]
-			    cl = call!(B, funcT, sfunc, LLVM.Value[inp])
-			    if isa(value_type(inp), LLVM.PointerType)
-				LLVM.API.LLVMAddCallSiteAttribute(
-				    cl,
-				    LLVM.API.LLVMAttributeIndex(1),
-				    EnumAttribute("nocapture"),
-				)
-			    end
-			end
-			continue
+                        for u in LLVM.uses(u)
+                            u = LLVM.user(u)
+                            if !isa(u, LLVM.CallInst)
+                                continue
+                            end
+                            @assert isa(u, LLVM.CallInst)
+                            B = IRBuilder()
+                            nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
+                            position!(B, nextInst)
+                            inp = operands(u)[idx]
+                            cl = call!(B, funcT, sfunc, LLVM.Value[inp])
+                            if isa(value_type(inp), LLVM.PointerType)
+                                LLVM.API.LLVMAddCallSiteAttribute(
+                                    cl,
+                                    LLVM.API.LLVMAttributeIndex(1),
+                                    EnumAttribute("nocapture"),
+                                )
+                            end
+                        end
+                        continue
                     end
                     if !isa(u, LLVM.CallInst)
                         continue

@codecov
Copy link

codecov bot commented Dec 1, 2025

Codecov Report

❌ Patch coverage is 72.00000% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 66.78%. Comparing base (91c3f4f) to head (5efc980).

Files with missing lines Patch % Lines
src/llvm/transforms.jl 38.09% 13 Missing ⚠️
src/compiler.jl 92.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2824      +/-   ##
==========================================
- Coverage   66.88%   66.78%   -0.11%     
==========================================
  Files          58       58              
  Lines       21315    21356      +41     
==========================================
+ Hits        14257    14262       +5     
- Misses       7058     7094      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Benchmark Results

main 5efc980... main / 5efc980...
basics/make_zero/namedtuple 0.0591 ± 0.0052 μs 0.0603 ± 0.0043 μs 0.98 ± 0.11
basics/make_zero/struct 0.257 ± 0.013 μs 0.258 ± 0.01 μs 0.995 ± 0.063
basics/overhead 3.46 ± 0.002 ns 4.03 ± 0.001 ns 0.858 ± 0.00054
basics/remake_zero!/namedtuple 0.229 ± 0.01 μs 0.245 ± 0.012 μs 0.933 ± 0.061
basics/remake_zero!/struct 0.232 ± 0.013 μs 0.234 ± 0.014 μs 0.988 ± 0.081
fold_broadcast/multidim_sum_bcast/1D 10.9 ± 0.68 μs 10.9 ± 0.32 μs 0.998 ± 0.069
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.37 μs 12.3 ± 0.38 μs 0.99 ± 0.042
time_to_load 1.03 ± 0.016 s 1.04 ± 0.012 s 0.994 ± 0.019

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/21129099509/artifacts/5173734216.

@wsmoses
Copy link
Member Author

wsmoses commented Dec 13, 2025

@copilot isolate a MWE of whatever test is timing out

Copy link
Contributor

Copilot AI commented Dec 13, 2025

@wsmoses I've opened a new pull request, #2848, to work on those changes. Once the pull request is ready, I'll request review from you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants