9696
9797function set_reverse_block! (gutils:: GradientUtils , block:: LLVM.BasicBlock )
9898 return LLVM. BasicBlock (API. EnzymeGradientUtilsSetReverseBlock (gutils, block))
99+ end
100+
101+ function get_or_insert_conditional_execute! (fn:: LLVM.Function ; postprocess= nothing , cmpidx:: Int = 1 )
102+ FT0 = LLVM. function_type (fn)
103+ ptys = LLVM. parameters (FT0)
104+ insert! (ptys, 0 , ptys[cmpidx])
105+
106+ void_rt = LLVM. return_type (FT0) == LLVM. VoidType ()
107+ if ! void_rt
108+ insert! (ptys, 0 , LLVM. return_type (FT0))
109+ end
110+ FT = LLVM. FunctionType (LLVM. return_type (FT0), ptys, LLVM. isvararg (FT0))
111+ mod = LLVM. parent (fn)
112+ fn, _ = get_function! (mod, " julia.enzyme.conditionally_execute." * LLVM. name (FT), FT)
113+ if isempty (blocks (fn))
114+ let builder = IRBuilder ()
115+ entry = BasicBlock (fn, " entry" )
116+ good = BasicBlock (fn, " good" )
117+ bad = BasicBlock (fn, " bad" )
118+ position! (builder, entry)
119+ parms = collect (parameters (fn))
120+
121+ cmp = icmp_eq! (builder, LLVM. API. LLVMIntEQ, parms[1 + ! void_rt], parms[1 + cmpidx + ! void_rt])
122+
123+ br! (builder, cmp, good, bad)
124+
125+ position! (builder, good)
126+ rparms = parms[(2 + ! void_rt): end ]
127+ res = call! (builder, FT0, fn, rparms)
128+ callconv! (res, callconv (fn))
129+ if postprocess != = nothing
130+ postprocess (builder, res, rparms)
131+ end
132+ if void_rt
133+ ret! (builder)
134+ else
135+ ret! (builder, res)
136+ end
137+
138+ position! (builder, bad)
139+ if void_rt
140+ ret! (builder)
141+ else
142+ ret! (builder, parms[1 ])
143+ end
144+ end
145+ push! (function_attributes (fn), EnumAttribute (" alwaysinline" ))
146+ end
147+ return fn
148+ end
149+
150+ """
151+ Helper function for llvm-level rule generation. Will call the same function with inverted bundles,
152+ if arg1 isn't active
153+ """
154+ function call_same_with_inverted_arg_if_active! (
155+ B:: LLVM.IRBuilder ,
156+ gutils:: GradientUtils ,
157+ orig:: LLVM.CallInst ,
158+ args:: Vector{<:LLVM.Value} ,
159+ valTys:: Vector{API.CValueType} ,
160+ lookup:: Bool ;
161+ postprocess= nothing ,
162+ cmpidx:: Int = 1
163+ )
164+ @assert length (args) == length (valTys)
165+
166+ origops = collect (operands (orig))
167+ if is_constant_value (gutils, origops[cmpidx])
168+ return nothing
169+ end
170+
171+ if ! get_runtime_activity (gutils)
172+ res = call_samefunc_with_inverted_bundles! (
173+ B,
174+ gutils,
175+ orig,
176+ args,
177+ valTys,
178+ lookup
179+ )
180+ callconv! (res, callconv (orig))
181+ debug_from_orig! (gutils, res, orig)
182+ if postprocess != = nothing
183+ postprocess (B, res, args)
184+ end
185+
186+ return res
187+ end
188+
189+ valTys = copy (valTys)
190+ @assert valTys[cmpidx] == API. VT_Shadow
191+ valTys[cmpidx] = API. VT_Both
192+ args = copy (args)
193+ insert! (args, 1 , new_from_original (gutils, origops[cmpidx]))
194+ if value_type (orig) != LLVM. VoidType ()
195+ insert! (args, 1 , new_from_original (gutils, orig))
196+ end
197+ condfn = get_or_insert_conditional_execute (LLVM. called_operand (orig):: LLVM.Function ; postprocess, cmpidx)
198+
199+ res = LLVM. Value (
200+ API. EnzymeGradientUtilsCallWithInvertedBundles (
201+ gutils,
202+ LLVM. called_operand (condfn),
203+ LLVM. function_type (condfn),
204+ args,
205+ length (args),
206+ orig,
207+ valTys,
208+ length (valTys),
209+ B,
210+ false ,
211+ ),
212+ ) #= lookup=#
213+ callconv! (res, callconv (orig))
214+ debug_from_orig! (gutils, res, orig)
215+ return res
216+ end
217+
218+
219+ """
220+ Helper function for llvm-level rule generation. Will call the same function with inverted bundles,
221+ if arg1 isn't active
222+ """
223+ function batch_call_same_with_inverted_arg_if_active! (
224+ B:: LLVM.IRBuilder ,
225+ gutils:: GradientUtils ,
226+ orig:: LLVM.CallInst ,
227+ args:: Vector{<:LLVM.Value} ,
228+ valTys:: Vector{API.CValueType} ,
229+ lookup:: Bool ;
230+ kwargs...
231+ )
232+
233+ void_rt = value_type (orig) == LLVM. VoidType ()
234+ shadow = if ! void_rt
235+ ST = LLVM. LLVMType (API. EnzymeGetShadowType (width, value_type (orig)))
236+ LLVM. UndefValue (ST)
237+ end
238+
239+ width = get_width (gutils)
240+
241+ for idx in 1 : width
242+ args2 = args
243+ if width > 1
244+ args2 = copy (args)
245+ for i in 1 : length (valTys)
246+ if valTys[i] == API. VT_Shadow
247+ args2[i] = extract_value! (B, args2[i], idx - 1 )
248+ end
249+ end
250+ end
251+ res = call_same_with_inverted_arg_if_active! (B, gutils, orig, args2, valTys, lookup; kwargs... )
252+ if width == 1
253+ shadow = res
254+ elseif res === nothing || shadow == nothing
255+ shadow = nothing
256+ else
257+ shadow = insert_value! (B, shadow, res, idx - 1 )
258+ end
259+ end
260+
261+ return shadow
99262end
0 commit comments