Skip to content

Commit b39ad90

Browse files
Use sharedarrays on Mtl (#75)
* Change nthreads default to nothing in solver * Update AbstractFixedEffectSolver to use optional nthreads * Allow optional nthreads parameter in AbstractFixedEffectSolver * Change nthreads parameter to 'nothing' in solver functions * Bump version from 2.5.2 to 2.6.0 * Update method and double_precision arguments in functions * Fix capitalization of 'Metal' in method argument * Update benchmark_Metal.jl * Reformat function signatures for consistency * Decrease maxiter by 1 in lsmr! call * safer to use Int for big arrays and not more costly * Update MetalExt.jl * better to do chunkis of 100_000 even if it means more threads than Threads.nthreads * Update SolverCPU.jl * used shared arrays for Metal * Update Project.toml * Update AbstractFixedEffectSolver.jl * Update MetalExt.jl * rmv nthreads
1 parent d2c3e5c commit b39ad90

File tree

6 files changed

+89
-81
lines changed

6 files changed

+89
-81
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name = "FixedEffects"
22
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
3-
version = "2.6.0"
3+
version = "2.7.0"
4+
45

56
[deps]
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

ext/CUDAExt.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module CUDAExt
22
using FixedEffects, CUDA
3-
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap
3+
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
44
CUDA.allowscalar(false)
55

66
##############################################################################
@@ -36,17 +36,17 @@ mutable struct FixedEffectLinearMapCUDA{T} <: AbstractFixedEffectLinearMap{T}
3636
fes::Vector{<:FixedEffect}
3737
scales::Vector{<:AbstractVector}
3838
caches::Vector{<:AbstractVector}
39-
nthreads::Int
4039
end
4140

42-
function FixedEffectLinearMapCUDA{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
41+
function FixedEffectLinearMapCUDA{T}(fes::Vector{<:FixedEffect}) where {T}
4342
fes = [_cu(T, fe) for fe in fes]
4443
scales = [CUDA.zeros(T, fe.n) for fe in fes]
4544
caches = [CUDA.zeros(T, length(fes[1].interaction)) for fe in fes]
46-
return FixedEffectLinearMapCUDA{T}(fes, scales, caches, nthreads)
45+
return FixedEffectLinearMapCUDA{T}(fes, scales, caches)
4746
end
4847

49-
function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector, nthreads::Integer)
48+
function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector)
49+
nthreads = 256
5050
nblocks = cld(length(y), nthreads)
5151
@cuda threads=nthreads blocks=nblocks gather_kernel!(fecoef, refs, α, y, cache)
5252
end
@@ -61,7 +61,8 @@ function gather_kernel!(fecoef, refs, α, y, cache)
6161
end
6262
end
6363

64-
function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::CuVector, cache::CuVector, nthreads::Integer)
64+
function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::CuVector, cache::CuVector)
65+
nthreads = 256
6566
nblocks = cld(length(y), nthreads)
6667
@cuda threads=nthreads blocks=nblocks scatter_kernel!(y, α, fecoef, refs, cache)
6768
end
@@ -101,11 +102,7 @@ function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, w
101102
end
102103

103104
function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::AbstractWeights, ::Type{Val{:CUDA}}, nthreads = nothing) where {T}
104-
if nthreads === nothing
105-
nthreads = 256
106-
end
107-
nthreads = prevpow(2, nthreads)
108-
m = FixedEffectLinearMapCUDA{T}(fes, nthreads)
105+
m = FixedEffectLinearMapCUDA{T}(fes)
109106
b = CUDA.zeros(T, length(weights))
110107
r = CUDA.zeros(T, length(weights))
111108
x = FixedEffectCoefficients([CUDA.zeros(T, fe.n) for fe in fes])
@@ -120,15 +117,16 @@ end
120117
function FixedEffects.update_weights!(feM::FixedEffectSolverCUDA{T}, weights::AbstractWeights) where {T}
121118
copyto!(feM.weights, _cu(T, weights))
122119
for (scale, fe) in zip(feM.m.scales, feM.m.fes)
123-
scale!(scale, fe.refs, fe.interaction, feM.weights, feM.m.nthreads)
120+
scale!(scale, fe.refs, fe.interaction, feM.weights)
124121
end
125122
for (cache, scale, fe) in zip(feM.m.caches, feM.m.scales, feM.m.fes)
126-
cache!(cache, fe.refs, fe.interaction, feM.weights, scale, feM.m.nthreads)
123+
cache!(cache, fe.refs, fe.interaction, feM.weights, scale)
127124
end
128125
return feM
129126
end
130127

131-
function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, nthreads::Integer)
128+
function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector)
129+
nthreads = 256
132130
nblocks = cld(length(refs), nthreads)
133131
fill!(scale, 0)
134132
@cuda threads=nthreads blocks=nblocks scale_kernel!(scale, refs, interaction, weights)
@@ -145,7 +143,8 @@ function scale_kernel!(scale, refs, interaction, weights)
145143
end
146144
end
147145

148-
function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, scale::CuVector, nthreads::Integer)
146+
function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights::CuVector, scale::CuVector)
147+
nthreads = 256
149148
nblocks = cld(length(cache), nthreads)
150149
@cuda threads=nthreads blocks=nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
151150
end
@@ -160,6 +159,15 @@ function cache!_kernel!(cache, refs, interaction, weights, scale)
160159
end
161160
end
162161

162+
function FixedEffects.copy_internal!(feM::FixedEffectSolverCUDA, field::Symbol, r::AbstractVector)
163+
copyto!(feM.tmp, r)
164+
copyto!(getfield(feM, field), feM.tmp)
165+
end
166+
167+
function FixedEffects.copy_internal!(r::AbstractVector, feM::FixedEffectSolverCUDA, field::Symbol)
168+
copyto!(feM.tmp, getfield(feM, field))
169+
copyto!(r, feM.tmp)
170+
end
163171

164172

165173
end

ext/MetalExt.jl

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module MetalExt
22
using FixedEffects, Metal
3-
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap
3+
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
44
Metal.allowscalar(false)
55

66
##############################################################################
@@ -35,50 +35,53 @@ mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
3535
fes::Vector{<:FixedEffect}
3636
scales::Vector{<:AbstractVector}
3737
caches::Vector
38-
nthreads::Int
3938
end
4039

4140
function bucketize_refs(refs::Vector, n::Int)
4241
# count the number of obs per group
43-
counts = zeros(Int, n)
44-
@inbounds for r in refs
45-
counts[r] += 1
46-
end
42+
counts = zeros(Int, n)
43+
@inbounds for r in refs
44+
counts[r] += 1
45+
end
4746
# offsets is vcat(1, cumsum(counts))
48-
offsets = Vector{Int}(undef, n + 1)
47+
offsets_mtl = Metal.@sync Metal.zeros(Int, n + 1; storage = Metal.SharedStorage)
48+
offsets = unsafe_wrap(Array{Int}, offsets_mtl, size(offsets_mtl))
4949
offsets[1] = 1
5050
@inbounds for k in 1:n
5151
offsets[k+1] = offsets[k] + counts[k]
5252
end
53+
54+
perm_mtl = Metal.@sync Metal.zeros(Int, length(refs); storage = Metal.SharedStorage)
55+
perm = unsafe_wrap(Array{Int}, perm_mtl, size(perm_mtl))
5356
next = offsets[1:n]
54-
perm = Vector{Int}(undef, length(refs))
5557
@inbounds for i in eachindex(refs)
5658
r = refs[i]
5759
p = next[r]
5860
perm[p] = i
5961
next[r] = p + 1
6062
end
61-
return perm, offsets
63+
return perm_mtl, offsets_mtl
6264
end
6365

64-
function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
66+
function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}) where {T}
6567
fes2 = [_mtl(T, fe) for fe in fes]
6668
scales = [Metal.zeros(T, fe.n) for fe in fes]
67-
caches = [[Metal.zeros(T, length(fe.refs)), Metal.zeros(Int, 1), Metal.zeros(Int, 1)] for fe in fes]
69+
caches = [Any[Metal.zeros(T, length(fe.refs)), Metal.zeros(Int, 1), Metal.zeros(Int, 1)] for fe in fes]
6870
Threads.@threads for i in 1:length(fes)
6971
refs = fes[i].refs
7072
n = fes[i].n
7173
if n < min(100_000, div(length(refs), 16))
7274
out = bucketize_refs(refs, n)
73-
caches[i][2] = MtlArray(out[1])
74-
caches[i][3] = MtlArray(out[2])
75+
caches[i][2] = out[1]
76+
caches[i][3] = out[2]
7577
end
7678
end
77-
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, nthreads)
79+
return FixedEffectLinearMapMetal{T}(fes2, scales, caches)
7880
end
7981

80-
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector, nthreads::Integer)
82+
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector)
8183
n = length(fecoef)
84+
nthreads = Int(device().maxThreadsPerThreadgroup.width)
8285
if n < min(100_000, div(length(refs), 16))
8386
Metal.@sync @metal threads=nthreads groups=n gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
8487
else
@@ -138,7 +141,8 @@ function gather_kernel!(fecoef, refs, α, y, cache)
138141
return nothing
139142
end
140143

141-
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector, nthreads::Integer)
144+
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector)
145+
nthreads = Int(device().maxThreadsPerThreadgroup.width)
142146
nblocks = cld(length(y), nthreads)
143147
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
144148
end
@@ -168,40 +172,36 @@ mutable struct FixedEffectSolverMetal{T} <: FixedEffects.AbstractFixedEffectSolv
168172
v::FixedEffectCoefficients{<: AbstractVector{T}}
169173
h::FixedEffectCoefficients{<: AbstractVector{T}}
170174
hbar::FixedEffectCoefficients{<: AbstractVector{T}}
171-
tmp::Vector{T} # used to convert AbstractVector to Vector{T}
172175
fes::Vector{<:FixedEffect}
173176
end
177+
174178

175179
function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::AbstractWeights, ::Type{Val{:Metal}}, nthreads = nothing) where {T}
176-
if nthreads === nothing
177-
nthreads = Int(device().maxThreadsPerThreadgroup.width)
178-
end
179-
nthreads = prevpow(2, nthreads)
180-
m = FixedEffectLinearMapMetal{T}(fes, nthreads)
181-
b = Metal.zeros(T, length(weights))
182-
r = Metal.zeros(T, length(weights))
180+
m = FixedEffectLinearMapMetal{T}(fes)
181+
b = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
182+
r = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
183183
x = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
184184
v = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
185185
h = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
186186
hbar = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
187-
tmp = zeros(T, length(weights))
188-
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, tmp, fes)
187+
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, fes)
189188
FixedEffects.update_weights!(feM, weights)
190189
end
191190

192191

193192
function FixedEffects.update_weights!(feM::FixedEffectSolverMetal{T}, weights::AbstractWeights) where {T}
194193
copyto!(feM.weights, _mtl(T, weights))
195194
for (scale, fe) in zip(feM.m.scales, feM.m.fes)
196-
scale!(scale, fe.refs, fe.interaction, feM.weights, feM.m.nthreads)
195+
scale!(scale, fe.refs, fe.interaction, feM.weights)
197196
end
198197
for (cache, scale, fe) in zip(feM.m.caches, feM.m.scales, feM.m.fes)
199-
cache!(cache, fe.refs, fe.interaction, feM.weights, scale, feM.m.nthreads)
198+
cache!(cache, fe.refs, fe.interaction, feM.weights, scale)
200199
end
201200
return feM
202201
end
203202

204-
function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector, nthreads::Integer)
203+
function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector)
204+
nthreads = Int(device().maxThreadsPerThreadgroup.width)
205205
nblocks = cld(length(refs), nthreads)
206206
fill!(scale, 0)
207207
Metal.@sync @metal threads=nthreads groups=nblocks scale_kernel!(scale, refs, interaction, weights)
@@ -224,7 +224,8 @@ function inv_kernel!(scale, T)
224224
return nothing
225225
end
226226

227-
function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
227+
function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector)
228+
nthreads = Int(device().maxThreadsPerThreadgroup.width)
228229
nblocks = cld(length(cache[1]), nthreads)
229230
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache[1], refs, interaction, weights, scale)
230231
end
@@ -237,5 +238,17 @@ function cache!_kernel!(cache, refs, interaction, weights, scale)
237238
return nothing
238239
end
239240

241+
function FixedEffects.copy_internal!(feM::FixedEffectSolverMetal{T}, field::Symbol, r::AbstractVector) where {T}
242+
synchronize()
243+
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
244+
copyto!(feM_r, r)
245+
end
246+
247+
function FixedEffects.copy_internal!(r::AbstractVector, feM::FixedEffectSolverMetal{T}, field::Symbol) where {T}
248+
synchronize()
249+
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
250+
copyto!(r, feM_r)
251+
end
252+
240253

241254
end

src/AbstractFixedEffectLinearMap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function LinearAlgebra.mul!(fecoefs::FixedEffectCoefficients,
2929
fem = adjoint(Cfem)
3030
rmul!(fecoefs, β)
3131
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
32-
gather!(fecoef, fe.refs, α, y, cache, fem.nthreads)
32+
gather!(fecoef, fe.refs, α, y, cache)
3333
end
3434
return fecoefs
3535
end
@@ -38,7 +38,7 @@ function LinearAlgebra.mul!(y::AbstractVector, fem::AbstractFixedEffectLinearMap
3838
fecoefs::FixedEffectCoefficients, α::Number, β::Number)
3939
rmul!(y, β)
4040
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
41-
scatter!(y, α, fecoef, fe.refs, cache, fem.nthreads)
41+
scatter!(y, α, fecoef, fe.refs, cache)
4242
end
4343
return y
4444
end

src/AbstractFixedEffectSolver.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
##
77
##############################################################################
88
abstract type AbstractFixedEffectSolver{T} end
9-
works_with_view(::AbstractFixedEffectSolver) = false
109

1110
"""
1211
`solve_residuals!(y, fes, w; method = :cpu, double_precision = method == :cpu, tol = 1e-8, maxiter = 10000)`
@@ -43,22 +42,17 @@ function solve_residuals!(y::Union{AbstractVector{<: Real}, AbstractMatrix{<: Re
4342
nthreads = nothing)
4443
any((length(fe) != size(y, 1) for fe in fes)) && throw("FixedEffects must have the same length as y")
4544
any(ismissing.(fes)) && throw("FixedEffects must not have missing values")
46-
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method}, nthreads)
45+
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method})
4746
solve_residuals!(y, feM; maxiter = maxiter, tol = tol)
4847
end
4948

5049

5150

5251
function solve_residuals!(r::AbstractVector{<:Real}, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
5352
# One cannot copy view of Vector (r) on GPU, so first collect the vector
54-
if works_with_view(feM)
55-
copyto!(feM.r, r)
56-
else
57-
copyto!(feM.tmp, r)
58-
copyto!(feM.r, feM.tmp)
59-
end
53+
copy_internal!(feM, :r, r)
6054
if !(feM.weights isa UnitWeights)
61-
feM.r .*= sqrt.(feM.weights)
55+
feM.r .*= sqrt.(feM.weights)
6256
end
6357
copyto!(feM.b, feM.r)
6458
mul!(feM.x, feM.m', feM.b, 1, 0)
@@ -71,12 +65,7 @@ function solve_residuals!(r::AbstractVector{<:Real}, feM::AbstractFixedEffectSol
7165
if !(feM.weights isa UnitWeights)
7266
feM.r ./= sqrt.(feM.weights)
7367
end
74-
if works_with_view(feM)
75-
copyto!(r, feM.r)
76-
else
77-
copyto!(feM.tmp, feM.r)
78-
copyto!(r, feM.tmp)
79-
end
68+
copy_internal!(r, feM, :r)
8069
return r, iter, converged
8170
end
8271
@@ -160,18 +149,13 @@ function solve_coefficients!(y::AbstractVector{<: Number}, fes::AbstractVector{<
160149
nthreads = nothing)
161150
any(ismissing.(fes)) && throw("Some FixedEffect has a missing value for reference or interaction")
162151
any((length(fe) != length(y) for fe in fes)) && throw("FixedEffects must have the same length as y")
163-
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method}, nthreads)
152+
feM = AbstractFixedEffectSolver{double_precision ? Float64 : Float32}(fes, w, Val{method})
164153
solve_coefficients!(y, feM; maxiter = maxiter, tol = tol)
165154
end
166155
167156
function FixedEffects.solve_coefficients!(r::AbstractVector, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
168157
# One cannot copy view of Vector (r) on GPU, so first collect the vector
169-
if works_with_view(feM)
170-
copyto!(feM.b, r)
171-
else
172-
copyto!(feM.tmp, r)
173-
copyto!(feM.b, feM.tmp)
174-
end
158+
copy_internal!(feM, :b, r)
175159
if !(feM.weights isa UnitWeights)
176160
feM.b .*= sqrt.(feM.weights)
177161
end

0 commit comments

Comments
 (0)