Skip to content

Commit e733275

Browse files
authored
Add diag argument to copytri! (#679)
1 parent efdc7cf commit e733275

File tree

4 files changed

+46
-42
lines changed

4 files changed

+46
-42
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ steps:
4747
julia:
4848
- "1.10"
4949
- "1.12"
50+
- "1.13"
5051
soft_fail:
5152
- exit_status: 3
5253

@@ -99,6 +100,7 @@ steps:
99100
julia:
100101
- "1.10"
101102
- "1.12"
103+
- "1.13"
102104
soft_fail:
103105
- exit_status: 3
104106

@@ -151,6 +153,7 @@ steps:
151153
julia:
152154
- "1.10"
153155
- "1.12"
156+
- "1.13"
154157
soft_fail:
155158
- exit_status: 3
156159

@@ -203,6 +206,7 @@ steps:
203206
julia:
204207
- "1.10"
205208
- "1.12"
209+
- "1.13"
206210
soft_fail:
207211
- exit_status: 3
208212

.github/workflows/Test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
version: ['1.10', '1.11', '1.12', 'nightly']
23+
version: ['1.10', '1.11', '1.12', '1.13-nightly', 'nightly']
2424
os: [ubuntu-latest, macOS-latest, windows-latest]
2525
steps:
2626
- uses: actions/checkout@v6
@@ -97,7 +97,7 @@ jobs:
9797
- uses: actions/checkout@v6
9898
- uses: julia-actions/setup-julia@v2
9999
with:
100-
version: '1.11'
100+
version: '1.12'
101101
- uses: julia-actions/cache@v2
102102
- name: Run tests
103103
run: |

src/host/linalg.jl

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -107,48 +107,48 @@ end
107107

108108
## copy upper triangle to lower and vice versa
109109

110-
function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)
111-
n = LinearAlgebra.checksquare(A)
112-
if uplo == 'U' && conjugate
113-
@kernel function U_conj!(_A)
114-
I = @index(Global, Cartesian)
115-
i, j = Tuple(I)
116-
if j > i
117-
@inbounds _A[j,i] = conj(_A[i,j])
110+
function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool = false, diag::Bool = false)
111+
n = LinearAlgebra.checksquare(A)
112+
if uplo == 'U' && conjugate
113+
@kernel function U_conj!(_A)
114+
I = @index(Global, Cartesian)
115+
i, j = Tuple(I)
116+
if j + diag > i
117+
@inbounds _A[j,i] = conj(_A[i,j])
118+
end
118119
end
119-
end
120-
U_conj!(get_backend(A))(A; ndrange = size(A))
121-
elseif uplo == 'U' && !conjugate
122-
@kernel function U_noconj!(_A)
123-
I = @index(Global, Cartesian)
124-
i, j = Tuple(I)
125-
if j > i
126-
@inbounds _A[j,i] = _A[i,j]
120+
U_conj!(get_backend(A))(A; ndrange = size(A))
121+
elseif uplo == 'U' && !conjugate
122+
@kernel function U_noconj!(_A)
123+
I = @index(Global, Cartesian)
124+
i, j = Tuple(I)
125+
if j + diag > i
126+
@inbounds _A[j,i] = _A[i,j]
127+
end
127128
end
128-
end
129-
U_noconj!(get_backend(A))(A; ndrange = size(A))
130-
elseif uplo == 'L' && conjugate
131-
@kernel function L_conj!(_A)
132-
I = @index(Global, Cartesian)
133-
i, j = Tuple(I)
134-
if j > i
135-
@inbounds _A[i,j] = conj(_A[j,i])
129+
U_noconj!(get_backend(A))(A; ndrange = size(A))
130+
elseif uplo == 'L' && conjugate
131+
@kernel function L_conj!(_A)
132+
I = @index(Global, Cartesian)
133+
i, j = Tuple(I)
134+
if j + diag > i
135+
@inbounds _A[i,j] = conj(_A[j,i])
136+
end
136137
end
137-
end
138-
L_conj!(get_backend(A))(A; ndrange = size(A))
139-
elseif uplo == 'L' && !conjugate
140-
@kernel function L_noconj!(_A)
141-
I = @index(Global, Cartesian)
142-
i, j = Tuple(I)
143-
if j > i
144-
@inbounds _A[i,j] = _A[j,i]
138+
L_conj!(get_backend(A))(A; ndrange = size(A))
139+
elseif uplo == 'L' && !conjugate
140+
@kernel function L_noconj!(_A)
141+
I = @index(Global, Cartesian)
142+
i, j = Tuple(I)
143+
if j + diag > i
144+
@inbounds _A[i,j] = _A[j,i]
145+
end
145146
end
146-
end
147-
L_noconj!(get_backend(A))(A; ndrange = size(A))
148-
else
149-
throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
150-
end
151-
A
147+
L_noconj!(get_backend(A))(A; ndrange = size(A))
148+
else
149+
throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
150+
end
151+
A
152152
end
153153

154154
## copy a triangular part of a matrix to another matrix

test/testsuite/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,15 @@
7575

7676
@testset "triangular" begin
7777
@testset "copytri!" begin
78-
@testset for eltya in (Float32, Float64, ComplexF32, ComplexF64), uplo in ('U', 'L'), conjugate in (true, false)
78+
@testset for eltya in (Float32, Float64, ComplexF32, ComplexF64), uplo in ('U', 'L'), conjugate in (true, false), diag in (true, false)
7979
if !(eltya in eltypes)
8080
continue
8181
end
8282
n = 128
8383
areal = randn(n,n)/2
8484
aimg = randn(n,n)/2
8585
a = convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
86-
@test compare(x -> LinearAlgebra.copytri!(x, uplo, conjugate), AT, a)
86+
@test compare(x -> LinearAlgebra.copytri!(x, uplo, conjugate, diag), AT, a)
8787
end
8888
end
8989

0 commit comments

Comments
 (0)