Skip to content

Commit a84b3f8

Browse files
amontoisonglwagnerxkykaitomchor
authored
Interface to Krylov.jl solvers (#4041)
* Update dot * Add Krylov support * Interface Krylov.jl * Relax the types * Apply suggestions from code review Co-authored-by: Gregory L. Wagner <[email protected]> * Update src/Solvers/krylov.jl Co-authored-by: Gregory L. Wagner <[email protected]> * Apply suggestions from code review Co-authored-by: Gregory L. Wagner <[email protected]> * Update the version of Krylov.jl in the Project.toml * Working version with Krylov.jl * Add an import * Update field.jl and Krylov.jl * Import LinearAlgebra.dot and LinearAlgebra.norm * Add a using LinearAlgebra * Remove what is not needed for the submodule Solvers * Fix the warning related to copyto! * Add a KrylovSolver * Update krylov_solver.jl * Fix an import for the workspace * Add tests for KrylovSolver * Update src/Solvers/krylov_solver.jl * Update src/Solvers/krylov_solver.jl Co-authored-by: Xin Kai Lee <[email protected]> * Add a structure KrylovPreconditioner * Provide additional arguments to KrylovSolver.op and KrylovSolver.preconditioner * Make KrylovOperator and KrylovPreconditioner mutable * Update the preconditioner to make it SPD * Rename the argument krylov_solver into method for KrylovSolver * Fix everything! * Restore the old norm * Apply suggestions from code review Co-authored-by: Tomás Chor <[email protected]> Co-authored-by: Gregory L. Wagner <[email protected]> * Add a docstring for KrylovSolver * Update test/test_krylov_solver.jl Co-authored-by: Gregory L. Wagner <[email protected]> * Update src/Solvers/krylov_solver.jl Co-authored-by: Tomás Chor <[email protected]> --------- Co-authored-by: Gregory Wagner <[email protected]> Co-authored-by: Gregory L. Wagner <[email protected]> Co-authored-by: Xin Kai Lee <[email protected]> Co-authored-by: Tomás Chor <[email protected]>
1 parent a4d98f8 commit a84b3f8

12 files changed

+282
-59
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1818
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1919
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
2020
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
21+
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
2122
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
2223
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2324
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -44,8 +45,8 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
4445
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
4546
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4647
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
47-
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
4848
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
49+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
4950

5051
[extensions]
5152
OceananigansAMDGPUExt = "AMDGPU"
@@ -74,6 +75,7 @@ InteractiveUtils = "1.9"
7475
IterativeSolvers = "0.9"
7576
JLD2 = "0.4, 0.5"
7677
KernelAbstractions = "0.9.21"
78+
Krylov = "0.9.10"
7779
KrylovPreconditioners = "0.3.0"
7880
LinearAlgebra = "1.9"
7981
Logging = "1.9"
@@ -105,11 +107,11 @@ DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
105107
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
106108
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
107109
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
108-
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
109110
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
110111
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
111112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
112113
TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a"
114+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
113115

114116
[targets]
115117
test = ["AMDGPU", "oneAPI", "DataDeps", "SafeTestsets", "Test", "Enzyme", "Reactant", "Metal", "CUDA_Runtime_jll", "MPIPreferences", "TimesDates", "NCDatasets"]

src/Fields/field.jl

+12-11
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ using Oceananigans.Grids: parent_index_range, index_range_offset, default_indice
33
using Oceananigans.Grids: index_range_contains
44

55
using Adapt
6+
using LinearAlgebra
67
using KernelAbstractions: @kernel, @index
78
using Base: @propagate_inbounds
89

910
import Oceananigans: boundary_conditions
1011
import Oceananigans.Architectures: on_architecture
1112
import Oceananigans.BoundaryConditions: fill_halo_regions!, getbc
12-
import Statistics: norm, mean, mean!
13+
import Statistics: mean, mean!
14+
import LinearAlgebra: dot, norm
1315
import Base: ==
1416

1517
#####
@@ -587,7 +589,12 @@ const ReducedAbstractField = Union{XReducedAbstractField,
587589
XYZReducedAbstractField}
588590

589591
# TODO: needs test
590-
Statistics.dot(a::Field, b::Field) = mapreduce((x, y) -> x * y, +, interior(a), interior(b))
592+
LinearAlgebra.dot(a::AbstractField, b::AbstractField) = mapreduce((x, y) -> x * y, +, interior(a), interior(b))
593+
function LinearAlgebra.norm(a::AbstractField; condition = nothing)
594+
r = zeros(a.grid, 1)
595+
Base.mapreducedim!(x -> x * x, +, r, condition_operand(a, condition, 0))
596+
return CUDA.@allowscalar sqrt(r[1])
597+
end
591598

592599
# TODO: in-place allocations with function mappings need to be fixed in Julia Base...
593600
const SumReduction = typeof(Base.sum!)
@@ -736,17 +743,11 @@ end
736743

737744
Statistics.mean!(r::ReducedAbstractField, a::AbstractArray; kwargs...) = Statistics.mean!(identity, r, a; kwargs...)
738745

739-
function Statistics.norm(a::AbstractField; condition = nothing)
740-
r = zeros(a.grid, 1)
741-
Base.mapreducedim!(x -> x * x, +, r, condition_operand(a, condition, 0))
742-
return CUDA.@allowscalar sqrt(r[1])
743-
end
744-
745746
function Base.isapprox(a::AbstractField, b::AbstractField; kw...)
746-
conditioned_a = condition_operand(a, nothing, one(eltype(a)))
747-
conditioned_b = condition_operand(b, nothing, one(eltype(b)))
747+
conditional_a = condition_operand(a, nothing, one(eltype(a)))
748+
conditional_b = condition_operand(b, nothing, one(eltype(b)))
748749
# TODO: Make this non-allocating?
749-
return all(isapprox.(conditioned_a, conditioned_b; kw...))
750+
return all(isapprox.(conditional_a, conditional_b; kw...))
750751
end
751752

752753
#####

src/Models/HydrostaticFreeSurfaceModels/pcg_implicit_free_surface_solver.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ end
183183
"""
184184
Add `- H⁻¹ ∇H ⋅ ∇ηⁿ` to the right-hand-side.
185185
"""
186-
@inline function precondition!(P_r, preconditioner::FFTImplicitFreeSurfaceSolver, r, η, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
186+
@inline function precondition!(P_r, preconditioner::FFTImplicitFreeSurfaceSolver, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
187187
poisson_solver = preconditioner.fft_poisson_solver
188188
arch = architecture(poisson_solver)
189189
grid = preconditioner.three_dimensional_grid
@@ -192,12 +192,12 @@ Add `- H⁻¹ ∇H ⋅ ∇ηⁿ` to the right-hand-side.
192192

193193
launch!(arch, grid, :xy,
194194
fft_preconditioner_right_hand_side!,
195-
poisson_solver.storage, r, η, grid, Az, Lz)
195+
poisson_solver.storage, r, grid, Az, Lz)
196196

197197
return solve!(P_r, preconditioner, poisson_solver.storage, g, Δt)
198198
end
199199

200-
@kernel function fft_preconditioner_right_hand_side!(fft_rhs, pcg_rhs, η, grid, Az, Lz)
200+
@kernel function fft_preconditioner_right_hand_side!(fft_rhs, pcg_rhs, grid, Az, Lz)
201201
i, j = @index(Global, NTuple)
202202
@inbounds fft_rhs[i, j, 1] = pcg_rhs[i, j, grid.Nz+1] / (Lz * Az)
203203
end
@@ -233,11 +233,11 @@ end
233233

234234
struct DiagonallyDominantInversePreconditioner end
235235

236-
@inline precondition!(P_r, ::DiagonallyDominantInversePreconditioner, r, η, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt) =
237-
diagonally_dominant_precondition!(P_r, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
236+
@inline precondition!(P_r, ::DiagonallyDominantInversePreconditioner, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt) =
237+
diagonally_dominant_inverse_precondition!(P_r, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
238238

239239
"""
240-
_diagonally_dominant_precondition!(P_r, grid, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
240+
_diagonally_dominant_inverse_precondition!(P_r, grid, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
241241
242242
Return the diagonally dominant inverse preconditioner applied to the residuals consistently
243243
with `M = D⁻¹(I - (A - D)D⁻¹) ≈ A⁻¹` where `I` is the identity matrix, `A` is the linear
@@ -256,13 +256,13 @@ P_rᵢⱼ = rᵢⱼ / Acᵢⱼ - 1 / Acᵢⱼ ( Ax⁻ / Acᵢ₋₁ rᵢ₋₁
256256
where `Ac`, `Ax⁻`, `Ax⁺`, `Ay⁻` and `Ay⁺` are the coefficients of `ηᵢⱼ`, `ηᵢ₋₁ⱼ`, `ηᵢ₊₁ⱼ`, `ηᵢⱼ₋₁`,
257257
and `ηᵢⱼ₊₁` in `_implicit_free_surface_linear_operation!`
258258
"""
259-
function diagonally_dominant_precondition!(P_r, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
259+
function diagonally_dominant_inverse_precondition!(P_r, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
260260
grid = ∫ᶻ_Axᶠᶜᶜ.grid
261261
arch = architecture(P_r)
262262

263263
fill_halo_regions!(r)
264264

265-
launch!(arch, grid, :xy, _diagonally_dominant_precondition!,
265+
launch!(arch, grid, :xy, _diagonally_dominant_inverse_precondition!,
266266
P_r, grid, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
267267

268268
return nothing
@@ -286,7 +286,7 @@ end
286286
2 * Ay⁻(i, j, grid, ay) / (Ac(i, j-1, grid, g, Δt, ax, ay) + Ac(i, j, grid, g, Δt, ax, ay)) * r[i, j-1, grid.Nz+1] -
287287
2 * Ay⁺(i, j, grid, ay) / (Ac(i, j+1, grid, g, Δt, ax, ay) + Ac(i, j, grid, g, Δt, ax, ay)) * r[i, j+1, grid.Nz+1])
288288

289-
@kernel function _diagonally_dominant_precondition!(P_r, grid, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
289+
@kernel function _diagonally_dominant_inverse_precondition!(P_r, grid, r, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ, g, Δt)
290290
i, j = @index(Global, NTuple)
291291
@inbounds P_r[i, j, grid.Nz+1] = heuristic_inverse_times_residuals(i, j, r, grid, g, Δt, ∫ᶻ_Axᶠᶜᶜ, ∫ᶻ_Ayᶜᶠᶜ)
292292
end

src/Solvers/Solvers.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ export
55
FFTBasedPoissonSolver,
66
FourierTridiagonalPoissonSolver,
77
ConjugateGradientSolver,
8-
HeptadiagonalIterativeSolver
8+
HeptadiagonalIterativeSolver,
9+
KrylovSolver
910

1011
using Statistics
1112
using FFTW
@@ -43,6 +44,7 @@ include("plan_transforms.jl")
4344
include("fft_based_poisson_solver.jl")
4445
include("fourier_tridiagonal_poisson_solver.jl")
4546
include("conjugate_gradient_poisson_solver.jl")
47+
include("krylov_solver.jl")
4648
include("sparse_approximate_inverse.jl")
4749
include("matrix_solver_utils.jl")
4850
include("sparse_preconditioners.jl")

src/Solvers/conjugate_gradient_poisson_solver.jl

+9-14
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,9 @@ const FFTBasedPreconditioner = Union{FFTBasedPoissonSolver, FourierTridiagonalPo
127127

128128
function precondition!(p, preconditioner::FFTBasedPreconditioner, r, args...)
129129
compute_preconditioner_rhs!(preconditioner, r)
130-
solve!(p, preconditioner)
131-
132-
mean_p = mean(p)
133-
grid = p.grid
134-
arch = architecture(grid)
135-
launch!(arch, grid, :xyz, subtract_and_mask!, p, grid, mean_p)
136-
130+
shift = - sqrt(eps(eltype(r))) # to make the operator strictly negative definite
131+
solve!(p, preconditioner, preconditioner.storage, shift)
132+
p .*= -1
137133
return p
138134
end
139135

@@ -175,16 +171,15 @@ end
175171
Az⁻(i, j, k, grid) - Az⁺(i, j, k, grid)
176172

177173
@inline heuristic_residual(i, j, k, grid, r) =
178-
@inbounds 1 / Ac(i, j, k, grid) * (r[i, j, k] - 2 * Ax⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i-1, j, k, grid)) * r[i-1, j, k] -
179-
2 * Ax⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i+1, j, k, grid)) * r[i+1, j, k] -
180-
2 * Ay⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j-1, k, grid)) * r[i, j-1, k] -
181-
2 * Ay⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j+1, k, grid)) * r[i, j+1, k] -
182-
2 * Az⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j, k-1, grid)) * r[i, j, k-1] -
183-
2 * Az⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j, k+1, grid)) * r[i, j, k+1])
174+
@inbounds 1 / abs(Ac(i, j, k, grid)) * (r[i, j, k] - 2 * Ax⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i-1, j, k, grid)) * r[i-1, j, k] -
175+
2 * Ax⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i+1, j, k, grid)) * r[i+1, j, k] -
176+
2 * Ay⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j-1, k, grid)) * r[i, j-1, k] -
177+
2 * Ay⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j+1, k, grid)) * r[i, j+1, k] -
178+
2 * Az⁻(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j, k-1, grid)) * r[i, j, k-1] -
179+
2 * Az⁺(i, j, k, grid) / (Ac(i, j, k, grid) + Ac(i, j, k+1, grid)) * r[i, j, k+1])
184180

185181
@kernel function _diagonally_dominant_precondition!(p, grid, r)
186182
i, j, k = @index(Global, NTuple)
187183
active = !inactive_cell(i, j, k, grid)
188184
@inbounds p[i, j, k] = heuristic_residual(i, j, k, grid, r) * active
189185
end
190-

src/Solvers/conjugate_gradient_solver.jl

+18-19
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ Arguments
7474
See [`solve!`](@ref) for more information about the preconditioned conjugate-gradient algorithm.
7575
"""
7676
function ConjugateGradientSolver(linear_operation;
77-
template_field::AbstractField,
78-
maxiter = prod(size(template_field)),
79-
reltol = sqrt(eps(eltype(template_field.grid))),
80-
abstol = 0,
81-
preconditioner = nothing)
77+
template_field::AbstractField,
78+
maxiter = prod(size(template_field)),
79+
reltol = sqrt(eps(eltype(template_field.grid))),
80+
abstol = 0,
81+
preconditioner = nothing)
8282

8383
arch = architecture(template_field)
8484
grid = template_field.grid
@@ -94,18 +94,18 @@ function ConjugateGradientSolver(linear_operation;
9494
FT = eltype(grid)
9595

9696
return ConjugateGradientSolver(arch,
97-
grid,
98-
linear_operation,
99-
FT(reltol),
100-
FT(abstol),
101-
maxiter,
102-
0,
103-
zero(FT),
104-
linear_operator_product,
105-
search_direction,
106-
residual,
107-
preconditioner,
108-
precondition_product)
97+
grid,
98+
linear_operation,
99+
FT(reltol),
100+
FT(abstol),
101+
maxiter,
102+
0,
103+
zero(FT),
104+
linear_operator_product,
105+
search_direction,
106+
residual,
107+
preconditioner,
108+
precondition_product)
109109
end
110110

111111
"""
@@ -158,7 +158,6 @@ Loop:
158158
```
159159
"""
160160
function solve!(x, solver::ConjugateGradientSolver, b, args...)
161-
162161
# Initialize
163162
solver.iteration = 0
164163

@@ -189,7 +188,7 @@ function iterate!(x, solver, b, args...)
189188

190189
# Preconditioned: z = P * r
191190
# Unpreconditioned: z = r
192-
@apply_regionally z = precondition!(solver.preconditioner_product, solver.preconditioner, r, x, args...)
191+
@apply_regionally z = precondition!(solver.preconditioner_product, solver.preconditioner, r, args...)
193192

194193
ρ = dot(z, r)
195194

src/Solvers/fft_based_poisson_solver.jl

-1
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,3 @@ end
135135

136136
@inbounds ϕ[i′, j′, k′] = real(ϕc[i, j, k])
137137
end
138-

0 commit comments

Comments
 (0)