Skip to content

Commit f978a17

Browse files
Patch 10 (#70)
* Update CUDAExt.jl * Update CUDAExt.jl * Update Project.toml * Update ci.yml * improve CUDA, MetalExt, and fix predict bug * Bump version from 2.4.1 to 2.5.0 * Update AbstractFixedEffectSolver.jl * avoid 0 in scale * Update Project.toml * Update runtests.jl * requires 1.8 * try 1.9
1 parent 5f7ba6b commit f978a17

File tree

8 files changed

+126
-56
lines changed

8 files changed

+126
-56
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.6'
20+
- '1.9'
2121
- '1'
2222
# - 'nightly'
2323
os:

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FixedEffects"
22
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
3-
version = "2.4.1"
3+
version = "2.5.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -11,7 +11,7 @@ GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"
1111
[compat]
1212
StatsBase = "0.33, 0.34"
1313
GroupedArrays = "0.3"
14-
julia = "1.6"
14+
julia = "1.9"
1515

1616
[extensions]
1717
CUDAExt = "CUDA"
@@ -24,10 +24,11 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2424
[extras]
2525
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
2626
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
27+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2728
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2829
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3031

3132
[targets]
32-
test = ["CategoricalArrays", "CUDA", "Pkg", "PooledArrays", "Test"]
33+
test = ["CategoricalArrays", "CUDA", "Metal", "Pkg", "PooledArrays", "Test"]
3334

benchmarks/benchmark_Metal.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@ id2 = rand(1:K, N)
77
fes = [FixedEffect(id1), FixedEffect(id2)]
88
x = Float32.(rand(N))
99

10+
11+
# here what takes time if the seocnd fixede ffects where K is very small and so there is a lot of trheads that want to write on the same thing. In that case, it would probably be good to actually pre-compute permutation for each fixed effects once, and then do as manu groups as permutations etc
12+
13+
1014
# simple problem
11-
@time solve_residuals!(deepcopy(x), fes)
15+
@time solve_residuals!(deepcopy(x), fes; double_precision = false)
1216
# 0.654833 seconds (1.99 k allocations: 390.841 MiB, 3.71% gc time)
13-
@time solve_residuals!(deepcopy(x), fes; method = :Metal)
17+
@time solve_residuals!(deepcopy(x), fes; double_precision = false, method = :Metal)
1418
# 0.298326 seconds (129.08 k allocations: 79.208 MiB)
1519
@time solve_residuals!([x x x x], fes)
1620
# 1.604061 seconds (1.25 M allocations: 416.364 MiB, 4.21% gc time, 30.57% compilation time)

ext/CUDAExt.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::C
5252
end
5353

5454
function gather_kernel!(fecoef, refs, α, y, cache)
55-
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
55+
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
5656
stride = blockDim().x * gridDim().x
57-
@inbounds for i = index:stride:length(y)
57+
i = index
58+
@inbounds while i <= length(y)
5859
CUDA.@atomic fecoef[refs[i]] += α * y[i] * cache[i]
60+
i += stride
5961
end
6062
end
6163

@@ -65,10 +67,12 @@ function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::
6567
end
6668

6769
function scatter_kernel!(y, α, fecoef, refs, cache)
68-
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
70+
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
6971
stride = blockDim().x * gridDim().x
70-
@inbounds for i = index:stride:length(y)
72+
i = index
73+
@inbounds while i <= length(y)
7174
y[i] += α * fecoef[refs[i]] * cache[i]
75+
i += stride
7276
end
7377
end
7478

@@ -124,14 +128,16 @@ function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights:
124128
nblocks = cld(length(refs), nthreads)
125129
fill!(scale, 0)
126130
@cuda threads=nthreads blocks=nblocks scale_kernel!(scale, refs, interaction, weights)
127-
map!(x -> x > 0 ? 1 / sqrt(x) : 0, scale, scale)
131+
map!(x -> x > 0 ? 1 / sqrt(x) : zero(eltype(scale)), scale, scale)
128132
end
129133

130134
function scale_kernel!(scale, refs, interaction, weights)
131-
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
135+
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
132136
stride = blockDim().x * gridDim().x
133-
@inbounds for i = index:stride:length(interaction)
137+
i = index
138+
@inbounds while i <= length(interaction)
134139
CUDA.@atomic scale[refs[i]] += abs2(interaction[i]) * weights[i]
140+
i += stride
135141
end
136142
end
137143

@@ -141,10 +147,12 @@ function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights:
141147
end
142148

143149
function cache!_kernel!(cache, refs, interaction, weights, scale)
144-
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
150+
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
145151
stride = blockDim().x * gridDim().x
146-
@inbounds for i = index:stride:length(cache)
152+
i = index
153+
@inbounds while i <= length(cache)
147154
cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
155+
i += stride
148156
end
149157
end
150158

ext/MetalExt.jl

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,113 @@ _mtl(T::Type, w::AbstractVector) = MtlVector{T}(convert(Vector{T}, w))
3434
mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
3535
fes::Vector{<:FixedEffect}
3636
scales::Vector{<:AbstractVector}
37-
caches::Vector{<:AbstractVector}
37+
caches::Vector
3838
nthreads::Int
3939
end
4040

41+
function bucketize_refs(refs::Vector, K::Int, T)
42+
if K < 100_000 && (length(refs) ÷ K >= 16)
43+
N = length(refs)
44+
counts = zeros(Int, K)
45+
@inbounds for r in refs
46+
counts[r] += 1
47+
end
48+
offsets = Vector{Int}(undef, K+1)
49+
offsets[1] = 1
50+
@inbounds for k in 1:K
51+
offsets[k+1] = offsets[k] + counts[k]
52+
end
53+
next = copy(offsets[1:K]) # write heads
54+
perm = Vector{UInt32}(undef, N)
55+
@inbounds for i in 1:N
56+
r = refs[i]
57+
p = next[r]
58+
perm[p] = i
59+
next[r] = p + 1
60+
end
61+
return Metal.zeros(T, length(refs)), MtlArray(Int32.(perm)), MtlArray(Int32.(offsets))
62+
else
63+
return Metal.zeros(T, length(refs)), Metal.zeros(Int32, 1), Metal.zeros(Int32, 1)
64+
end
65+
end
66+
4167
function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
42-
fes = [_mtl(T, fe) for fe in fes]
68+
fes2 = [_mtl(T, fe) for fe in fes]
4369
scales = [Metal.zeros(T, fe.n) for fe in fes]
44-
caches = [Metal.zeros(T, length(fes[1].interaction)) for fe in fes]
45-
return FixedEffectLinearMapMetal{T}(fes, scales, caches, nthreads)
70+
caches = [bucketize_refs(fe.refs, fe.n, T) for fe in fes]
71+
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, nthreads)
4672
end
4773

48-
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::MtlVector, nthreads::Integer)
49-
nblocks = cld(length(y), nthreads)
50-
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache)
74+
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache, nthreads::Integer)
75+
K = length(fecoef)
76+
if K < 100_000 && (length(refs) ÷ K >= 16)
77+
Metal.@sync @metal threads=nthreads groups=K gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
78+
else
79+
nblocks = cld(length(y), nthreads)
80+
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache[1])
81+
end
82+
end
83+
84+
function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}) where {NT}
85+
k = threadgroup_position_in_grid().x # 1..K (Julia-style indexing) :contentReference[oaicite:2]{index=2}
86+
tid = thread_position_in_threadgroup().x # 1..nthreads :contentReference[oaicite:3]{index=3}
87+
nt = threads_per_threadgroup().x # nthreads :contentReference[oaicite:4]{index=4}
88+
89+
# threadgroup scratch
90+
T = eltype(fecoef)
91+
shared = Metal.MtlThreadGroupArray(T, NT) # threadgroup-local array :contentReference[oaicite:5]{index=5}
92+
93+
start = @inbounds offsets[k]
94+
stop = @inbounds offsets[k+1] - Int32(1)
95+
96+
acc = zero(T)
97+
98+
# each thread walks its portion of the bucket
99+
j = start + Int32(tid - 1)
100+
while j <= stop
101+
i = @inbounds perm[j]
102+
@inbounds acc +=* y[i] * cache[i])
103+
j += Int32(nt)
104+
end
105+
106+
@inbounds shared[tid] = acc
107+
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup) # sync + tg fence :contentReference[oaicite:6]{index=6}
108+
109+
# tree reduction in shared memory
110+
offset = Int32(nt ÷ UInt32(2))
111+
while offset > 0
112+
if tid <= offset
113+
@inbounds shared[tid] += shared[tid + offset]
114+
end
115+
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup)
116+
offset ÷= Int32(2)
117+
end
118+
119+
# one write per coefficient (no atomics needed if groups == K and 1 group per k)
120+
if tid == UInt32(1)
121+
@inbounds fecoef[k] += shared[1]
122+
end
123+
124+
return nothing
51125
end
52126

53127
function gather_kernel!(fecoef, refs, α, y, cache)
54128
i = thread_position_in_grid_1d()
55129
if i <= length(refs)
56-
Metal.atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
130+
@inbounds Metal.atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
57131
end
58132
return nothing
59133
end
60134

61-
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::MtlVector, nthreads::Integer)
135+
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache, nthreads::Integer)
62136
nblocks = cld(length(y), nthreads)
63-
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache)
137+
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
64138
end
65139

66140
function scatter_kernel!(y, α, fecoef, refs, cache)
67141
i = thread_position_in_grid_1d()
68142
if i <= length(y)
69-
y[i] += α * fecoef[refs[i]] * cache[i]
143+
@inbounds y[i] += α * fecoef[refs[i]] * cache[i]
70144
end
71145
return nothing
72146
end
@@ -121,34 +195,34 @@ function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weigh
121195
nblocks = cld(length(refs), nthreads)
122196
fill!(scale, 0)
123197
Metal.@sync @metal threads=nthreads groups=nblocks scale_kernel!(scale, refs, interaction, weights)
124-
Metal.@sync @metal threads=nthreads groups=nblocks inv_kernel!(scale)
198+
Metal.@sync @metal threads=nthreads groups=nblocks inv_kernel!(scale, eltype(scale))
125199
end
126200

127201
function scale_kernel!(scale, refs, interaction, weights)
128202
i = thread_position_in_grid_1d()
129203
if i <= length(refs)
130-
Metal.atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^2 * weights[i])
204+
@inbounds Metal.atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^2 * weights[i])
131205
end
132206
return nothing
133207
end
134208

135-
function inv_kernel!(scale)
209+
function inv_kernel!(scale, T)
136210
i = thread_position_in_grid_1d()
137211
if i <= length(scale)
138-
scale[i] = (scale[i] > 0) ? (1 / sqrt(scale[i])) : 0.0
212+
@inbounds scale[i] = (scale[i] > 0) ? (1 / sqrt(scale[i])) : zero(T)
139213
end
140214
return nothing
141215
end
142216

143-
function cache!(cache::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
144-
nblocks = cld(length(cache), nthreads)
145-
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
217+
function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
218+
nblocks = cld(length(cache[1]), nthreads)
219+
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache[1], refs, interaction, weights, scale)
146220
end
147221

148222
function cache!_kernel!(cache, refs, interaction, weights, scale)
149223
i = thread_position_in_grid_1d()
150224
if i <= length(cache)
151-
cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
225+
@inbounds cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
152226
end
153227
return nothing
154228
end

src/SolverCPU.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,6 @@ function FixedEffectLinearMapCPU{T}(fes::Vector{<:FixedEffect}, ::Type{Val{:cpu}
1717
return FixedEffectLinearMapCPU{T}(fes, scales, caches, nthreads)
1818
end
1919

20-
function LinearAlgebra.mul!(fecoefs::FixedEffectCoefficients,
21-
Cfem::Adjoint{T, FixedEffectLinearMapCPU{T}},
22-
y::AbstractVector, α::Number, β::Number) where {T}
23-
fem = adjoint(Cfem)
24-
rmul!(fecoefs, β)
25-
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
26-
gather!(fecoef, fe.refs, α, y, cache, fem.nthreads)
27-
end
28-
return fecoefs
29-
end
3020

3121
# multithreaded gather seemds to be slower
3222
function gather!(fecoef::AbstractVector, refs::AbstractVector, α::Number,

test/runtests.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
tests = ["types.jl", "solve.jl"]
22
println("Running tests:")
33

4-
# A work around for tests to run on older versions of Julia
5-
using Pkg
6-
if VERSION >= v"1.8"
7-
Pkg.add("Metal")
8-
using Metal
9-
end
10-
11-
using Test, StatsBase, CUDA, FixedEffects, PooledArrays, CategoricalArrays
4+
using Test, StatsBase, CUDA, Metal, FixedEffects, PooledArrays, CategoricalArrays
125

136
for test in tests
147
try

test/solve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ method_s = [:cpu]
2525
if CUDA.functional()
2626
push!(method_s, :CUDA)
2727
end
28-
#if Metal.functional()
29-
# push!(method_s, :Metal)
30-
#end
28+
if Metal.functional()
29+
push!(method_s, :Metal)
30+
end
3131
for method in method_s
3232
println("$method Float32")
33-
local (r, iter, conv) = solve_residuals!(deepcopy(x),fes, method=method, double_precision = false)
33+
local (r, iter, conv) = solve_residuals!(deepcopy(x), fes, method=method, double_precision = false)
3434
@test Float32.(r) Float32.(r_ols)
3535
end
3636

0 commit comments

Comments
 (0)