Skip to content

Improved Tiling for Matrix-Vector Products #131

@BrendonChau

Description

@BrendonChau

For stochastic trace estimation, we require being able to evaluate terms like $\mathbf{G}^\top \mathbf{x}$ and $\mathbf{G}\mathbf{x}$ efficiently, where $\mathbf{G}$ is a SnpArray. I was looking into SnpLinAlg and noticed that I was having alot of memory allocations, presumably due to use of Julia Base.@threads and pointer allocation from the tiling procedure. However, it turns out that only multithreading within a tile evaluation rather than multithreading over several tiles yields a substantial performance improvement.

First the section defining the functions Gt_mul_x!, Gt_mul_x_tiled!, G_mul_x!, and G_mul_x_tiled!

using StatsBase, Statistics, Plots, LinearAlgebra, BenchmarkTools, Random
using SnpArrays, DelimitedFiles, Glob, StaticArrays
using Plots, VectorizationBase, LoopVectorization, Polyester

function Gt_mul_x!(
    y::AbstractVector{T},
    G::SnpArray,
    x::AbstractVector{Tx},
    μ::AbstractVector{T},
    σinv::AbstractVector{T}
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    m, d = size(G)

    K, rem = divrem(m, 4)
    fill!(y, zero(T))
    # Parallelize over the columns of G
    @turbo thread=true for k in axes(G, 2)
        yk = zero(T)
        for j in 1:K
            for i in 1:4
                # index into the correct value in compressed SNP array
                Gijk = (G.data[j, k] >> ((i - 1) << 1)) & 0x03
                # index into the corresponding value in x
                xval = x[(j - 1) << 2 + i]
                # assume ADDITIVE_MODEL is being used
                # yk += xval * ifelse(isone(Gijk), zero(T), T((Gijk >= 0x02) * (Gijk - 0x01)) - μ[k])
                # Branch-free version
                yk += xval * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k])
            end
        end
        y[k] = yk
    end

    if rem > 0
        @turbo thread=true for k in axes(G, 2)
            yk = zero(T)
            for i in 1:rem
                Gijk = (G.data[K + 1, k] >> ((i - 1) << 1)) & 0x03
                xval = x[K << 2 + i]
                # yk += xval * ifelse(isone(Gijk), zero(T), T((Gijk >= 0x02) * (Gijk - 0x01)) - μ[k])
                # Branch-free version
                yk += xval * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k])
            end
            y[k] += yk
        end
    end

    @turbo for j in axes(G, 2)
        y[j] = y[j] * σinv[j]
    end
    return y
end

function Gt_mul_x_chunk!(
    y::AbstractVector{T},
    G::AbstractMatrix{UInt8},
    x::AbstractVector{Tx},
    μ::AbstractVector{T}
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    @turbo thread=true for k in axes(G, 2)
        yk = zero(T)
        for j in axes(G, 1)
            for i in 1:4
                # index into the correct value in compressed SNP array
                Gijk = (G[j, k] >> ((i - 1) << 1)) & 0x03
                # index into the corresponding value in x
                xval = x[(j - 1) << 2 + i]
                # assume ADDITIVE_MODEL is being used
                # yk += xval * ifelse(isone(Gijk), zero(T), T((Gijk >= 0x02) * (Gijk - 0x01)) - μ[k])
                # Branch-free version
                # * almost identical performance to ifelse
                # * (Gijk >= 0x02) * (Gijk - 0x01) should compile to 
                #   saturating arithmetic, which has a built-in LLVM function call
                yk += xval * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k])
            end
        end
        y[k] += yk
    end
end

function Gt_mul_x_tiled!(
    y::AbstractVector{T},
    G::SnpArray,
    x::AbstractVector{Tx},
    μ::AbstractVector{T},
    σinv::AbstractVector{T};
    chunkwidth::Int = 1024
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    m, d = size(G)
    mpacked = size(G.data, 1)
    packedchunkwidth = chunkwidth >> 2

    # need K and rem to handle remainder rows
    K, rem = divrem(m, 4)

    # Slice up row iteration space
    nrowchunks, rowchunkrem = divrem(m, chunkwidth)
    # Slice up col iteration space
    ncolchunks, colchunkrem = divrem(d, chunkwidth)
    # Initialize output vector to zero, 
    # * could be modified to agree with 5-argument mul!
    fill!(y, zero(T))
    @inbounds for col in 1:ncolchunks
        colstart = (col - 1) * chunkwidth
        _y = view(y, (colstart + 1):(col * chunkwidth))
        _μ = view(μ, (colstart + 1):(col * chunkwidth))
        for l in 1:nrowchunks
            rowstart       = (l - 1) * chunkwidth
            packedrowstart = (l - 1) * packedchunkwidth
            _x      = view(x, (rowstart + 1):(l * chunkwidth))
            _G_data = view(G.data, (packedrowstart + 1):(packedrowstart + packedchunkwidth), (colstart + 1):(col * chunkwidth))
            Gt_mul_x_chunk!(_y, _G_data, _x, _μ)
        end
        if rowchunkrem > 0
            rowstart       = nrowchunks * chunkwidth
            packedrowstart = nrowchunks * packedchunkwidth
            # If `m` is a multiple of 4, can just load the entire remainder view
            if rem == 0
                _x      = view(x, (rowstart + 1):m)
                _G_data = view(G.data, (packedrowstart + 1):mpacked, (colstart + 1):(col * chunkwidth))
                Gt_mul_x_chunk!(_y, _G_data, _x, _μ)
            else
                # Taking a view excluding the remainer rows of y and G
                _x      = view(x, (rowstart + 1):(K << 2))
                _G_data = view(G.data, (packedrowstart + 1):K, (colstart + 1):(col * chunkwidth))
                Gt_mul_x_chunk!(_y, _G_data, _x, _μ)
                # Handling remainder rows
                @turbo thread=true for k in 1:chunkwidth
                    yk = zero(T)
                    for i in 1:rem
                        Gijk = (G.data[K + 1, colstart + k] >> ((i - 1) << 1)) & 0x03
                        xval = x[K << 2 + i]
                        yk += xval * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * _μ[k])
                    end
                    _y[k] += yk
                end
            end
        end
    end
    if colchunkrem > 0
        # TODO: Handle remainder columns
        nothing
    end
    @turbo for j in axes(G, 2)
        y[j] = y[j] * σinv[j]
    end
    return y
end

function G_mul_x!(
    y::AbstractVector{T},
    G::SnpArray,
    x::AbstractVector{Tx},
    μ::AbstractVector{T},
    σinv::AbstractVector{T}
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    m, d = size(G)

    K, rem = divrem(m, 4)
    fill!(y, zero(T))
    # Parallelize over the columns of G
    @turbo thread=true for k in axes(G, 2)
        for j in 1:K
            for i in 1:4
                # index into the correct value in compressed SNP array
                Gijk = (G.data[j, k] >> ((i - 1) << 1)) & 0x03
                # assume ADDITIVE_MODEL is being used
                # y[(j - 1) << 2 + i] += x[k] * ifelse(isone(Gijk), zero(T), T((Gijk >= 0x02) * (Gijk - 0x01)) - μ[k]) * σinv[k]
                y[(j - 1) << 2 + i] += x[k] * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k]) * σinv[k]
            end
        end
    end

    if rem > 0
        @turbo thread=true for k in axes(G, 2)
            for i in 1:rem
                Gijk = (G.data[K + 1, k] >> ((i - 1) << 1)) & 0x03
                # y[K << 2 + i] += x[k] * ifelse(isone(Gijk), zero(T), T((Gijk >= 0x02) * (Gijk - 0x01)) - μ[k]) * σinv[k]
                y[K << 2 + i] += x[k] * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k]) * σinv[k]
            end
        end
    end
    return y
end

function G_mul_x_chunk!(
    y::AbstractVector{T},
    G::AbstractMatrix{UInt8},
    x::AbstractVector{Tx},
    μ::AbstractVector{T},
    σinv::AbstractVector{T}
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    @turbo thread=true for k in axes(G, 2)
        for j in axes(G, 1)
            for i in 1:4
                # index into the correct value in compressed SNP array
                Gijk = (G[j, k] >> ((i - 1) << 1)) & 0x03
                y[((j - 1) << 2) + i] += x[k] * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * μ[k]) * σinv[k]
            end
        end
    end
end

function G_mul_x_tiled!(
    y::AbstractVector{T},
    G::SnpArray,
    x::AbstractVector{Tx},
    μ::AbstractVector{T},
    σinv::AbstractVector{T};
    chunkwidth::Int = 1024
    ) where {T<:AbstractFloat, Tx<:Union{Signed, AbstractFloat}}
    m, d = size(G)
    mpacked = size(G.data, 1)
    packedchunkwidth = chunkwidth >> 2

    # need K and rem to handle remainder rows
    K, rem = divrem(m, 4)

    # Slice up row iteration space
    nrowchunks, rowchunkrem = divrem(m, chunkwidth)
    # Slice up col iteration space
    ncolchunks, colchunkrem = divrem(d, chunkwidth)
    # Initialize output vector to zero, 
    # * could be modified to agree with 5-argument mul!
    fill!(y, zero(T))
    # multithreading only on inner tile loop avoids pointer allocations using @turbo thread=true
    # a more sophisticated way might use buffers and multi-level parallelism
    @inbounds for col in 1:ncolchunks
        colstart = (col - 1) * chunkwidth
        _x    = view(x,    (colstart + 1):(col * chunkwidth))
        _μ    = view(μ,    (colstart + 1):(col * chunkwidth))
        _σinv = view(σinv, (colstart + 1):(col * chunkwidth))
        for l in 1:nrowchunks
            rowstart       = (l - 1) * chunkwidth
            packedrowstart = (l - 1) * packedchunkwidth
            _y      = view(y, (rowstart + 1):(l * chunkwidth))
            _G_data = view(G.data, (packedrowstart + 1):(packedrowstart + packedchunkwidth), (colstart + 1):(col * chunkwidth))
            G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
        end
        if rowchunkrem > 0
            rowstart       = nrowchunks * chunkwidth
            packedrowstart = nrowchunks * packedchunkwidth
            # If `m` is a multiple of 4, can just load the entire remainder view
            if rem == 0
                _y      = view(y, (rowstart + 1):m)
                _G_data = view(G.data, (packedrowstart + 1):mpacked, (colstart + 1):(col * chunkwidth))
                G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
            else
                # TODO: pretty ugly, is there a better/less verbose way of handling remainders?
                # Taking a view excluding the remainer rows of y and G
                _y      = view(y, (rowstart + 1):(K << 2))
                _G_data = view(G.data, (packedrowstart + 1):K, (colstart + 1):(col * chunkwidth))
                G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
                # Handling remainder rows
                @turbo thread=true for k in 1:chunkwidth
                    for i in 1:rem
                        # index into the correct value in compressed SNP array
                        Gijk = (G.data[K + 1, colstart + k] >> ((i - 1) << 1)) & 0x03
                        y[K << 2 + i] += _x[k] * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * _μ[k]) * _σinv[k]
                    end
                end
            end
        end
    end
    # Handling remainder columns when `d` is not a multiple of `chunkwidth`
    @inbounds if colchunkrem > 0
        colstart = ncolchunks * chunkwidth
        _x    = view(x,    (colstart + 1):d)
        _μ    = view(μ,    (colstart + 1):d)
        _σinv = view(σinv, (colstart + 1):d)
        for l in 1:nrowchunks
            packedrowstart = (l - 1) * packedchunkwidth
            _y      = view(y, ((l - 1) * chunkwidth + 1):(l * chunkwidth))
            _G_data = view(G.data, (packedrowstart + 1):(packedrowstart + packedchunkwidth), (colstart + 1):d)
            G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
        end
        if rowchunkrem > 0
            packedrowstart = nrowchunks * packedchunkwidth
            rowstart       = nrowchunks * chunkwidth
            if rem == 0
                _y      = view(y, (rowstart + 1):m)
                _G_data = view(G.data, (packedrowstart + 1):mpacked, (colstart + 1):d)
                G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
            else
                _y      = view(y, (rowstart + 1):(K << 2))
                _G_data = view(G.data, (packedrowstart + 1):K, (colstart + 1):d)
                G_mul_x_chunk!(_y, _G_data, _x, _μ, _σinv)
                # Handling remainder rows
                @turbo thread=true for k in eachindex(_x)
                    for i in 1:rem
                        # index into the correct value in compressed SNP array
                        Gijk = (G.data[K + 1, colstart + k] >> ((i - 1) << 1)) & 0x03
                        y[K << 2 + i] += _x[k] * (T((Gijk >= 0x02) * (Gijk - 0x01)) - !isone(Gijk) * _μ[k]) * _σinv[k]
                    end
                end
            end
        end
    end
    return y
end

And setting up data for simulation:

EUR = SnpArray(SnpArrays.datadir("EUR_subset.bed"))
EUR_5 = [EUR; EUR; EUR; EUR; EUR]
EUR_5_5 = [EUR_5 EUR_5 EUR_5 EUR_5 EUR_5]

G = EUR_5_5;
n, q = size(G)

# Second row of `columncounts` indicates number of missing
G_counts     = counts(G, dims = 1)
n_nonmissing = n .- G_counts[2, :]

# Allele Frequencies
ρ = (G_counts[3, :] + 2 * G_counts[4, :]) ./ (2 * n_nonmissing)
# Expected Values
μ = 2 * ρ
# Standard Deviations
σ = sqrt.(2 .* (1 .- ρ) .* ρ)
σinv = inv.(σ)

# Truncating the columns to some multiple of 2
m = n # 1895
d = 1 << 18 # 262_144

G_ = SnpArray(undef, m, d);
copyto!(G_.data, view(G.data, :, 1:d));

ρ_ = ρ[1:d]
μ_ = μ[1:d]
σinv_ = σinv[1:d]

First, we can look at $\mathbf{G}^\top \mathbf{x}$,

rng = MersenneTwister(1234)

xvec = rand(rng, (-1.0, 1.0), m);
resq = Vector{Float64}(undef, d);
@benchmark Gt_mul_x!($resq, $G_, $xvec, $μ_, $σinv_) evals=1

BenchmarkTools.Trial: 120 samples with 1 evaluation.
 Range (min  max):  38.016 ms  117.112 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     39.044 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   41.954 ms ±  10.720 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ██▂                                                           
  ████▇▅▁▅▁▅▅▆▁▅▁▁▁▁▁▅▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▅ ▅
  38 ms         Histogram: log(frequency) by time      91.6 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

resq2 = similar(resq);
@benchmark Gt_mul_x_tiled!($resq2, $G_, $xvec, $μ_, $σinv_) evals=1

BenchmarkTools.Trial: 109 samples with 1 evaluation.
 Range (min  max):  40.673 ms  58.873 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     45.255 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   45.881 ms ±  4.235 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▂█                                                           
  ██▆▇▅▁▃▃▇▅▆██▇▆█▆▇▆█▇▆▃▇▃▅▃▃▃▁▁▃▃▆▁▃▃▁▃▁▁▁▁▃▃▃▃▁▃▁▁▁▁▁▃▃▁▁▃ ▃
  40.7 ms         Histogram: frequency by time        58.6 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

then $\mathbf{Gx}$,

xvecq = rand(rng, (-1.0, 1.0), d);
resn = Vector{Float64}(undef, m);
@benchmark G_mul_x!($resn, $G_, $xvecq, $μ_, $σinv_) evals=1

BenchmarkTools.Trial: 16 samples with 1 evaluation.
 Range (min  max):  322.233 ms  391.568 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     327.874 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   332.096 ms ±  16.457 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▃ █▃   ▃█                                                      
  █▇██▁▁▇██▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  322 ms           Histogram: frequency by time          392 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

resn2 = similar(resn);
@benchmark G_mul_x_tiled!($resn2, $G_, $xvecq, $μ_, $σinv_; chunkwidth=1024) evals=1

BenchmarkTools.Trial: 85 samples with 1 evaluation.
 Range (min  max):  54.705 ms  72.373 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     58.710 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   59.256 ms ±  4.076 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █ ▂▄                    ▃                                    
  █▃██▇▅▁▃▃▁▃▅▁▅▅▅▃▅▃▆▆▁▅▃█▃▅▅▃▃▇▃▃▁▁▃▁▁▃▃▁▁▁▁▁▃▁▁▁▁▃▁▃▁▁▁▁▁▅ ▁
  54.7 ms         Histogram: frequency by time        70.2 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

And finally, looking at the same operations using SnpLinAlg

Gsla_ = SnpLinAlg{Float64}(G_; model=ADDITIVE_MODEL, center=true, scale=true, impute=true);

resq_sla = similar(resq);
@benchmark mul!($resq_sla, transpose($Gsla_), $xvec) evals=1

BenchmarkTools.Trial: 119 samples with 1 evaluation.
 Range (min  max):  40.871 ms   44.785 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     42.233 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   42.141 ms ± 589.852 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                           ▃▅█▅                                 
  ▃▁▁▃▃▁▃▄▇▄▄▄▄▄▅▄▅▇▅▄▁▅▅▅▆████▇▅▃▃▁▄▄▁▃▁▁▃▁▁▁▃▁▁▁▁▁▁▃▁▁▁▃▁▁▃▃ ▃
  40.9 ms         Histogram: frequency by time           44 ms <

 Memory estimate: 103.50 KiB, allocs estimate: 1428.

resn_sla = similar(resn);
@benchmark mul!($resn_sla, $Gsla_, $xvecq) evals=1

BenchmarkTools.Trial: 16 samples with 1 evaluation.
 Range (min  max):  321.255 ms  325.775 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     321.770 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   322.199 ms ±   1.182 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▁ ▁█▁▁▁▁▁ ▁      ▁        ▁     ▁                          ▁  
  ██▁███████▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  321 ms           Histogram: frequency by time          326 ms <

 Memory estimate: 221.59 KiB, allocs estimate: 3268.

For $\mathbf{G}^\top \mathbf{x}$, tiling provides no improvement over just using @turbo thread=true, when using SnpLinAlg, there is substantial overhead from using Base threads and tiled views. Whereas for $\mathbf{Gx}$ there is an over 5x performance improvement from using sequential tiling with limiting multithreading to a single tile computation. Even better, by not parallelizing over tiles, we avoid all heap allocations.

I suspect we might see a similar improvement with $\mathbf{G}^\top \mathbf{X}_1$ and $\mathbf{GX}_2$ where $\mathbf{X}_1$ and $\mathbf{X}_2$ are matrices, but implementation of that might be very tedious dealing with tile remainder chunks.

I did these simulations on my M1 Max Macbook, please let me know if the improvement is comparable for other hardware. It is possible that for something like a 32-core Xeon processor on a server we would see different behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions