Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ function _solve_dixon(a::ZZMatrix, b::ZZMatrix)
return z, d
end

#XU = B. only the upper triangular part of U is used
# Solve XU = B for X given U & B. Only the upper triangular part of U is used.
function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix; unipotent::Bool = false)
n = ncols(U)
m = nrows(b)
Expand Down Expand Up @@ -1702,9 +1702,9 @@ function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix; unipotent::B
return X
end

#UX = B, U has to be upper triangular
#I think due to the Strassen calling path, where Strasse.solve(side = :left)
#call directly AA.solve_left, this has to be in AA and cannot be independent.
# Solve UX = B for X given U & B: U has to be upper triangular.
# I think due to the Strassen calling path, where Strasse.solve(side = :left)
# call directly AA.solve_left, this has to be in AA and cannot be independent.
function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:left, unipotent::Bool = false)
if side == :left
return AbstractAlgebra._solve_triu_left(U, b; unipotent)
Expand All @@ -1715,20 +1715,22 @@ function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:lef
X = zero(b)
tmp = zero_matrix(ZZ, 1, n)
s = ZZ()
# We build up the solution column by column
GC.@preserve U b X tmp begin
for i = 1:m
for i = 1:m # i indexes the columns
tmp_ptr = mat_entry_ptr(tmp, 1, 1)
for j = 1:n
X_ptr = mat_entry_ptr(X, j, i)
set!(tmp_ptr, X_ptr)
tmp_ptr += sizeof(ZZRingElem)
end
for j = n:-1:1
# At this point tmp is full of zeroes
for j = n:-1:1 # j indexes the rows (in i-th column)
zero!(s)
tmp_ptr = mat_entry_ptr(tmp, 1, j+1)
for k = j + 1:n
U_ptr = mat_entry_ptr(U, j, k)
mul!(s, U_ptr, tmp_ptr)
addmul!(s, U_ptr, tmp_ptr)
tmp_ptr += sizeof(ZZRingElem)
# s = addmul!(s, U[j, k], tmp[k])
end
Expand All @@ -1755,35 +1757,35 @@ function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:lef
return X
end

#solves Ax = B for A lower triagular. if f != 0 (f is true), the diagonal
#is assumed to be 1 and not actually used.
#the upper part of A is not used/ touched.
#one cannot assert is_lower_triangular as this is used for the inplace
#lu decomposition where the matrix is full, encoding an upper triangular
#using the diagonal and a lower triangular with trivial diagonal
function AbstractAlgebra._solve_tril!(A::ZZMatrix, B::ZZMatrix, C::ZZMatrix, f::Int = 0)
# Solves Lx = B for L lower triangular. If unipotent is true, the diagonal
# is assumed to be 1 and not actually used.
# The upper part of L is not used/ touched.
# One cannot assert is_lower_triangular as this is used for the inplace
# lu decomposition where the matrix is full, encoding an upper triangular
# using the diagonal and a lower triangular with trivial diagonal
function AbstractAlgebra._solve_tril!(X::ZZMatrix, L::ZZMatrix, B::ZZMatrix; unipotent::Bool = false)

# a x u ax = u
# b c * y = v bx + cy = v
# d e f z w ....

@assert ncols(A) == ncols(C)
@assert ncols(X) == ncols(B)
s = ZZ(0)
GC.@preserve A B C begin
for i=1:ncols(A)
for j = 1:nrows(A)
t = C[j, i]
B_ptr = mat_entry_ptr(B, j, 1)
GC.@preserve X L B begin
for i=1:ncols(X)
for j = 1:nrows(X)
t = B[j, i]
L_ptr = mat_entry_ptr(L, j, 1)
for k = 1:j-1
A_ptr = mat_entry_ptr(B, k, i)
mul!(s, A_ptr, B_ptr)
B_ptr += sizeof(ZZRingElem)
X_ptr = mat_entry_ptr(X, k, i)
mul!(s, X_ptr, L_ptr)
L_ptr += sizeof(ZZRingElem)
sub!(t, t, s)
end
if f == 1
A[j,i] = t
if unipotent
X[j,i] = t
else
A[j,i] = divexact(t, B[j, j])
X[j,i] = divexact(t, L[j, j])
end
end
end
Expand Down
26 changes: 15 additions & 11 deletions test/flint/fmpz_mat-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,19 +737,23 @@ end
end

@testset "ZZMatrix.solve" begin
A = matrix(ZZ, 2, 2, [1,2,3,4])
# Test matrices have size (at least) 3x3 -- smaller failed to detect a bug.
U_triang = matrix(ZZ, 3, 3, [1,2,3, 0,4,5, 0,0,6])
L_triang = matrix(ZZ, 3, 3, [1,0,0, 2,3,0, 4,5,6])

@test AbstractAlgebra.Solve.matrix_normal_form_type(ZZ) === AbstractAlgebra.Solve.HermiteFormTrait()
@test AbstractAlgebra.Solve.matrix_normal_form_type(A) === AbstractAlgebra.Solve.HermiteFormTrait()

b = matrix(ZZ, 1, 2, [1, 6])
@test AbstractAlgebra._solve_triu_left(A, b) == matrix(ZZ, 1, 2, [1, 1])
b = matrix(ZZ, 2, 1, [3, 4])
@test AbstractAlgebra._solve_triu(A, b; side = :right) == matrix(ZZ, 2, 1, [1, 1])
b = matrix(ZZ, 2, 1, [1, 7])
c = similar(b)
AbstractAlgebra._solve_tril!(c, A, b)
@test c == matrix(ZZ, 2, 1, [1, 1])
@test AbstractAlgebra.Solve.matrix_normal_form_type(U_triang) === AbstractAlgebra.Solve.HermiteFormTrait()
@test AbstractAlgebra.Solve.matrix_normal_form_type(L_triang) === AbstractAlgebra.Solve.HermiteFormTrait()

X = matrix(ZZ, 3, 2, [3,1, 4,1, 5,9])
trX = transpose(X)
@test AbstractAlgebra._solve_triu_left(U_triang, trX*U_triang) == trX
@test AbstractAlgebra._solve_triu(U_triang, trX*U_triang; side = :left) == trX
@test AbstractAlgebra._solve_triu(U_triang, U_triang*X; side = :right) == X

c = similar(X)
AbstractAlgebra._solve_tril!(c, L_triang, L_triang*X)
@test c == X

S = matrix_space(ZZ, 3, 3)

Expand Down
Loading