@@ -34,39 +34,113 @@ _mtl(T::Type, w::AbstractVector) = MtlVector{T}(convert(Vector{T}, w))
3434mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
3535 fes:: Vector{<:FixedEffect}
3636 scales:: Vector{<:AbstractVector}
37- caches:: Vector{<:AbstractVector}
37+ caches:: Vector
3838 nthreads:: Int
3939end
4040
41+ function bucketize_refs(refs:: Vector , K:: Int , T)
42+ if K < 100_000 && (length(refs) ÷ K >= 16 )
43+ N = length(refs)
44+ counts = zeros(Int, K)
45+ @inbounds for r in refs
46+ counts[r] += 1
47+ end
48+ offsets = Vector{Int}(undef, K+ 1 )
49+ offsets[1 ] = 1
50+ @inbounds for k in 1 : K
51+ offsets[k+ 1 ] = offsets[k] + counts[k]
52+ end
53+ next = copy(offsets[1 : K]) # write heads
54+ perm = Vector{UInt32}(undef, N)
55+ @inbounds for i in 1 : N
56+ r = refs[i]
57+ p = next[r]
58+ perm[p] = i
59+ next[r] = p + 1
60+ end
61+ return Metal. zeros(T, length(refs)), MtlArray(Int32.(perm)), MtlArray(Int32.(offsets))
62+ else
63+ return Metal. zeros(T, length(refs)), Metal. zeros(Int32, 1 ), Metal. zeros(Int32, 1 )
64+ end
65+ end
66+
4167function FixedEffectLinearMapMetal{T}(fes:: Vector{<:FixedEffect} , nthreads) where {T}
42- fes = [_mtl(T, fe) for fe in fes]
68+ fes2 = [_mtl(T, fe) for fe in fes]
4369 scales = [Metal. zeros(T, fe. n) for fe in fes]
44- caches = [Metal . zeros(T, length(fes[ 1 ] . interaction) ) for fe in fes]
45- return FixedEffectLinearMapMetal{T}(fes , scales, caches, nthreads)
70+ caches = [bucketize_refs(fe . refs, fe . n, T ) for fe in fes]
71+ return FixedEffectLinearMapMetal{T}(fes2 , scales, caches, nthreads)
4672end
4773
48- function FixedEffects. gather!(fecoef:: MtlVector , refs:: MtlVector , α:: Number , y:: MtlVector , cache:: MtlVector , nthreads:: Integer )
49- nblocks = cld(length(y), nthreads)
50- Metal. @sync @metal threads= nthreads groups= nblocks gather_kernel!(fecoef, refs, α, y, cache)
74+ function FixedEffects. gather!(fecoef:: MtlVector , refs:: MtlVector , α:: Number , y:: MtlVector , cache, nthreads:: Integer )
75+ K = length(fecoef)
76+ if K < 100_000 && (length(refs) ÷ K >= 16 )
77+ Metal. @sync @metal threads= nthreads groups= K gather_kernel_bin!(fecoef, refs, α, y, cache[1 ], cache[2 ], cache[3 ], Val(nthreads))
78+ else
79+ nblocks = cld(length(y), nthreads)
80+ Metal. @sync @metal threads= nthreads groups= nblocks gather_kernel!(fecoef, refs, α, y, cache[1 ])
81+ end
82+ end
83+
84+ function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, :: Val{NT} ) where {NT}
85+ k = threadgroup_position_in_grid(). x # 1..K (Julia-style indexing) :contentReference[oaicite:2]{index=2}
86+ tid = thread_position_in_threadgroup(). x # 1..nthreads :contentReference[oaicite:3]{index=3}
87+ nt = threads_per_threadgroup(). x # nthreads :contentReference[oaicite:4]{index=4}
88+
89+ # threadgroup scratch
90+ T = eltype(fecoef)
91+ shared = Metal. MtlThreadGroupArray(T, NT) # threadgroup-local array :contentReference[oaicite:5]{index=5}
92+
93+ start = @inbounds offsets[k]
94+ stop = @inbounds offsets[k+ 1 ] - Int32(1 )
95+
96+ acc = zero(T)
97+
98+ # each thread walks its portion of the bucket
99+ j = start + Int32(tid - 1 )
100+ while j <= stop
101+ i = @inbounds perm[j]
102+ @inbounds acc += (α * y[i] * cache[i])
103+ j += Int32(nt)
104+ end
105+
106+ @inbounds shared[tid] = acc
107+ Metal. threadgroup_barrier(Metal. MemoryFlagThreadGroup) # sync + tg fence :contentReference[oaicite:6]{index=6}
108+
109+ # tree reduction in shared memory
110+ offset = Int32(nt ÷ UInt32(2 ))
111+ while offset > 0
112+ if tid <= offset
113+ @inbounds shared[tid] += shared[tid + offset]
114+ end
115+ Metal. threadgroup_barrier(Metal. MemoryFlagThreadGroup)
116+ offset ÷= Int32(2 )
117+ end
118+
119+ # one write per coefficient (no atomics needed if groups == K and 1 group per k)
120+ if tid == UInt32(1 )
121+ @inbounds fecoef[k] += shared[1 ]
122+ end
123+
124+ return nothing
51125end
52126
53127function gather_kernel!(fecoef, refs, α, y, cache)
54128 i = thread_position_in_grid_1d()
55129 if i <= length(refs)
56- Metal. atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
130+ @inbounds Metal. atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
57131 end
58132 return nothing
59133end
60134
61- function FixedEffects. scatter!(y:: MtlVector , α:: Number , fecoef:: MtlVector , refs:: MtlVector , cache:: MtlVector , nthreads:: Integer )
135+ function FixedEffects. scatter!(y:: MtlVector , α:: Number , fecoef:: MtlVector , refs:: MtlVector , cache, nthreads:: Integer )
62136 nblocks = cld(length(y), nthreads)
63- Metal. @sync @metal threads= nthreads groups= nblocks scatter_kernel!(y, α, fecoef, refs, cache)
137+ Metal. @sync @metal threads= nthreads groups= nblocks scatter_kernel!(y, α, fecoef, refs, cache[ 1 ] )
64138end
65139
66140function scatter_kernel!(y, α, fecoef, refs, cache)
67141 i = thread_position_in_grid_1d()
68142 if i <= length(y)
69- y[i] += α * fecoef[refs[i]] * cache[i]
143+ @inbounds y[i] += α * fecoef[refs[i]] * cache[i]
70144 end
71145 return nothing
72146end
@@ -121,34 +195,34 @@ function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weigh
121195 nblocks = cld(length(refs), nthreads)
122196 fill!(scale, 0 )
123197 Metal. @sync @metal threads= nthreads groups= nblocks scale_kernel!(scale, refs, interaction, weights)
124- Metal. @sync @metal threads= nthreads groups= nblocks inv_kernel!(scale)
198+ Metal. @sync @metal threads= nthreads groups= nblocks inv_kernel!(scale, eltype(scale) )
125199end
126200
127201function scale_kernel!(scale, refs, interaction, weights)
128202 i = thread_position_in_grid_1d()
129203 if i <= length(refs)
130- Metal. atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^ 2 * weights[i])
204+ @inbounds Metal. atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^ 2 * weights[i])
131205 end
132206 return nothing
133207end
134208
135- function inv_kernel!(scale)
209+ function inv_kernel!(scale, T )
136210 i = thread_position_in_grid_1d()
137211 if i <= length(scale)
138- scale[i] = (scale[i] > 0 ) ? (1 / sqrt(scale[i])) : 0.0
212+ @inbounds scale[i] = (scale[i] > 0 ) ? (1 / sqrt(scale[i])) : zero(T)
139213 end
140214 return nothing
141215end
142216
143- function cache!(cache:: MtlVector , refs:: MtlVector , interaction:: MtlVector , weights:: MtlVector , scale:: MtlVector , nthreads:: Integer )
144- nblocks = cld(length(cache), nthreads)
145- Metal. @sync @metal threads= nthreads groups= nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
217+ function cache!(cache, refs:: MtlVector , interaction:: MtlVector , weights:: MtlVector , scale:: MtlVector , nthreads:: Integer )
218+ nblocks = cld(length(cache[ 1 ] ), nthreads)
219+ Metal. @sync @metal threads= nthreads groups= nblocks cache!_kernel!(cache[ 1 ] , refs, interaction, weights, scale)
146220end
147221
148222function cache!_kernel!(cache, refs, interaction, weights, scale)
149223 i = thread_position_in_grid_1d()
150224 if i <= length(cache)
151- cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
225+ @inbounds cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
152226 end
153227 return nothing
154228end
0 commit comments