Skip to content

Commit 0e2ef84

Browse files
committed
Add a dispatch for LinearAlgebra.norm2
`norm(@view x[..], 2)` was previously leading to a call of `LinearAlgebra.generic_norm2` which led to a scalar indexing. This catches such cuda subarray norm2 calls earlier. Inf-norm and p-norm with cuda subarrays still lead to the following dispatches: ```julia LinearAlgebra.generic_normInf(x) = float(mapreduce(norm, max, x)) LinearAlgebra.generic_norm1(x) = mapreduce(float ∘ norm, +, x) ``` I am not sure if there is a better way to dispatch the above. should resolve #2280
1 parent f5100a1 commit 0e2ef84

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

Diff for: lib/cublas/linalg.jl

+4
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF
138138
end
139139
end
140140

141+
function LinearAlgebra.norm2(x::SubArray{T,N,P} where {T<:Union{Float16, ComplexF16, CublasFloat}, N, P<:DenseCuArray{<:T}})
142+
return nrm2(x)
143+
end
144+
141145
LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x)
142146

143147
function LinearAlgebra.axpy!(alpha::Number, x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, ComplexF16, CublasFloat}

Diff for: test/libraries/cublas.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,17 @@ end
17671767
@view(p[reshape(1:(out*inn),out,inn)]) * x
17681768
end
17691769
end
1770+
1771+
@testset "nrm2 with strided inputs" begin # JuliaGPU/CUDA.jl#2280
1772+
cudaTypes = (Float16, ComplexF16, CublasFloat)
1773+
for CT in cudaTypes
1774+
x = rand(CT, 10, 10, 10)
1775+
dx = CuArray(x)
1776+
dx_ = @view dx[3:6, 1:5, :]
1777+
x_ = @view x[3:6, 1:5, :]
1778+
@test norm(dx_, 2) norm(x_, 2)
1779+
end
1780+
end
17701781
end
17711782

17721783
############################################################################################

0 commit comments

Comments
 (0)