Skip to content

Commit 32b362f

Browse files
committed
Cleanup, support, and simplify runtime activity in llvmrules
1 parent 3e48bb3 commit 32b362f

File tree

3 files changed

+366
-395
lines changed

3 files changed

+366
-395
lines changed

src/gradientutils.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,167 @@ end
9696

9797
function set_reverse_block!(gutils::GradientUtils, block::LLVM.BasicBlock)
9898
return LLVM.BasicBlock(API.EnzymeGradientUtilsSetReverseBlock(gutils, block))
99+
end
100+
101+
function get_or_insert_conditional_execute!(fn::LLVM.Function; postprocess=nothing, cmpidx::Int = 1)
102+
FT0 = LLVM.function_type(fn)
103+
ptys = LLVM.parameters(FT0)
104+
insert!(ptys, 0, ptys[cmpidx])
105+
106+
void_rt = LLVM.return_type(FT0) == LLVM.VoidType()
107+
if !void_rt
108+
insert!(ptys, 0, LLVM.return_type(FT0))
109+
end
110+
FT = LLVM.FunctionType(LLVM.return_type(FT0), ptys, LLVM.isvararg(FT0))
111+
mod = LLVM.parent(fn)
112+
fn, _ = get_function!(mod, "julia.enzyme.conditionally_execute." * LLVM.name(FT), FT)
113+
if isempty(blocks(fn))
114+
let builder = IRBuilder()
115+
entry = BasicBlock(fn, "entry")
116+
good = BasicBlock(fn, "good")
117+
bad = BasicBlock(fn, "bad")
118+
position!(builder, entry)
119+
parms = collect(parameters(fn))
120+
121+
cmp = icmp_eq!(builder, LLVM.API.LLVMIntEQ, parms[1 + !void_rt], parms[1 + cmpidx + !void_rt])
122+
123+
br!(builder, cmp, good, bad)
124+
125+
position!(builder, good)
126+
rparms = parms[(2+!void_rt):end]
127+
res = call!(builder, FT0, fn, rparms)
128+
callconv!(res, callconv(fn))
129+
if postprocess !== nothing
130+
postprocess(builder, res, rparms)
131+
end
132+
if void_rt
133+
ret!(builder)
134+
else
135+
ret!(builder, res)
136+
end
137+
138+
position!(builder, bad)
139+
if void_rt
140+
ret!(builder)
141+
else
142+
ret!(builder, parms[1])
143+
end
144+
end
145+
push!(function_attributes(fn), EnumAttribute("alwaysinline"))
146+
end
147+
return fn
148+
end
149+
150+
"""
151+
Helper function for llvm-level rule generation. Will call the same function with inverted bundles,
152+
if arg1 isn't active
153+
"""
154+
function call_same_with_inverted_arg_if_active!(
155+
B::LLVM.IRBuilder,
156+
gutils::GradientUtils,
157+
orig::LLVM.CallInst,
158+
args::Vector{<:LLVM.Value},
159+
valTys::Vector{API.CValueType},
160+
lookup::Bool;
161+
postprocess=nothing,
162+
cmpidx::Int = 1
163+
)
164+
@assert length(args) == length(valTys)
165+
166+
origops = collect(operands(orig))
167+
if is_constant_value(gutils, origops[cmpidx])
168+
return nothing
169+
end
170+
171+
if !get_runtime_activity(gutils)
172+
res = call_samefunc_with_inverted_bundles!(
173+
B,
174+
gutils,
175+
orig,
176+
args,
177+
valTys,
178+
lookup
179+
)
180+
callconv!(res, callconv(orig))
181+
debug_from_orig!(gutils, res, orig)
182+
if postprocess !== nothing
183+
postprocess(B, res, args)
184+
end
185+
186+
return res
187+
end
188+
189+
valTys = copy(valTys)
190+
@assert valTys[cmpidx] == API.VT_Shadow
191+
valTys[cmpidx] = API.VT_Both
192+
args = copy(args)
193+
insert!(args, 1, new_from_original(gutils, origops[cmpidx]))
194+
if value_type(orig) != LLVM.VoidType()
195+
insert!(args, 1, new_from_original(gutils, orig))
196+
end
197+
condfn = get_or_insert_conditional_execute(LLVM.called_operand(orig)::LLVM.Function; postprocess, cmpidx)
198+
199+
res = LLVM.Value(
200+
API.EnzymeGradientUtilsCallWithInvertedBundles(
201+
gutils,
202+
LLVM.called_operand(condfn),
203+
LLVM.function_type(condfn),
204+
args,
205+
length(args),
206+
orig,
207+
valTys,
208+
length(valTys),
209+
B,
210+
false,
211+
),
212+
) #=lookup=#
213+
callconv!(res, callconv(orig))
214+
debug_from_orig!(gutils, res, orig)
215+
return res
216+
end
217+
218+
219+
"""
220+
Helper function for llvm-level rule generation. Will call the same function with inverted bundles,
221+
if arg1 isn't active
222+
"""
223+
function batch_call_same_with_inverted_arg_if_active!(
224+
B::LLVM.IRBuilder,
225+
gutils::GradientUtils,
226+
orig::LLVM.CallInst,
227+
args::Vector{<:LLVM.Value},
228+
valTys::Vector{API.CValueType},
229+
lookup::Bool;
230+
kwargs...
231+
)
232+
233+
void_rt = value_type(orig) ==LLVM.VoidType()
234+
shadow = if !void_rt
235+
ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))
236+
LLVM.UndefValue(ST)
237+
end
238+
239+
width = get_width(gutils)
240+
241+
for idx in 1:width
242+
args2 = args
243+
if width > 1
244+
args2 = copy(args)
245+
for i in 1:length(valTys)
246+
if valTys[i] == API.VT_Shadow
247+
args2[i] = extract_value!(B, args2[i], idx - 1)
248+
end
249+
end
250+
end
251+
res = call_same_with_inverted_arg_if_active!(B, gutils, orig, args2, valTys, lookup; kwargs...)
252+
if width == 1
253+
shadow = res
254+
elseif res === nothing || shadow == nothing
255+
shadow = nothing
256+
else
257+
shadow = insert_value!(B, shadow, res, idx - 1)
258+
end
259+
end
260+
261+
return shadow
99262
end

0 commit comments

Comments
 (0)