@@ -89,6 +89,19 @@ macro inside(ex)
8989 end |> esc
9090end
9191
92+ # Could also use ScopedValues in Julia 1.11+
93+ using Preferences
94+ const backend = @load_preference (" backend" , " KernelAbstractions" )
95+ function set_backend (new_backend:: String )
96+ if ! (new_backend in (" SIMD" , " KernelAbstractions" ))
97+ throw (ArgumentError (" Invalid backend: \" $(new_backend) \" " ))
98+ end
99+
100+ # Set it in our runtime values, as well as saving it to disk
101+ @set_preferences! (" backend" => new_backend)
102+ @info (" New backend set; restart your Julia session for this change to take effect!" )
103+ end
104+
92105"""
93106 @loop <expr> over <I ∈ R>
94107
@@ -118,26 +131,34 @@ Note that `get_backend` is used on the _first_ variable in `expr` (`a` in this e
118131"""
119132macro loop (args... )
120133 ex,_,itr = args
121- _,I,R = itr. args; sym = []
134+ _,I,R = itr. args
135+ sym = []
122136 grab! (sym,ex) # get arguments and replace composites in `ex`
123137 setdiff! (sym,[I]) # don't want to pass I as an argument
138+ symT = symtypes (sym) # generate a list of types for each symbol
124139 @gensym (kern, kern_) # generate unique kernel function names for serial and KA execution
125- return quote
126- function $kern ($ (rep .(sym)... ),:: Val{1} )
127- @simd for $ I ∈ $ R
140+ @static if backend == " KernelAbstractions"
141+ return quote
142+ @kernel function $kern_ ($ (rep .(sym)... ),@Const (I0)) # replace composite arguments
143+ $ I = @index (Global,Cartesian)
144+ $ I += I0
128145 @fastmath @inbounds $ ex
129146 end
130- end
131- @kernel function $kern_ ($ (rep .(sym)... ),@Const (I0)) # replace composite arguments
132- $ I = @index (Global,Cartesian)
133- $ I += I0
134- @fastmath @inbounds $ ex
135- end
136- function $kern ($ (rep .(sym)... ),_)
137- $ kern_ (get_backend ($ (sym[1 ])),64 )($ (sym... ),$ R[1 ]- oneunit ($ R[1 ]),ndrange= size ($ R))
138- end
139- $ kern ($ (sym... ),Val {Threads.nthreads()} ()) # dispatch to SIMD for -t 1, or KA otherwise
140- end |> esc
147+ function $kern ($ (joinsymtype (rep .(sym),symT)... )) where {$ (symT... )}
148+ $ kern_ (get_backend ($ (sym[1 ])),64 )($ (sym... ),$ R[1 ]- oneunit ($ R[1 ]),ndrange= size ($ R))
149+ end
150+ $ kern ($ (sym... ))
151+ end |> esc
152+ else # backend == "SIMD"
153+ return quote
154+ function $kern ($ (joinsymtype (rep .(sym),symT)... )) where {$ (symT... )}
155+ @simd for $ I ∈ $ R
156+ @fastmath @inbounds $ ex
157+ end
158+ end
159+ $ kern ($ (sym... ))
160+ end |> esc
161+ end
141162end
142163function grab! (sym,ex:: Expr )
143164 ex. head == :. && return union! (sym,[ex]) # grab composite name and return
@@ -149,6 +170,10 @@ grab!(sym,ex::Symbol) = union!(sym,[ex]) # grab symbol name
149170grab! (sym,ex) = nothing
150171rep (ex) = ex
151172rep (ex:: Expr ) = ex. head == :. ? Symbol (ex. args[2 ]. value) : ex
173+ using Random
174+ symtypes (sym) = [Symbol .(Random. randstring (' A' :' Z' ,4 )) for _ in 1 : length (sym)]
175+ joinsymtype (sym:: Symbol ,symT:: Symbol ) = Expr (:(:: ), sym, symT)
176+ joinsymtype (sym,symT) = zip (sym,symT) .| > x-> joinsymtype (x... )
152177
153178using StaticArrays
154179"""
0 commit comments