diff --git a/src/Models/NonhydrostaticModels/solve_for_pressure.jl b/src/Models/NonhydrostaticModels/solve_for_pressure.jl index 580bad5762..869903ab2e 100644 --- a/src/Models/NonhydrostaticModels/solve_for_pressure.jl +++ b/src/Models/NonhydrostaticModels/solve_for_pressure.jl @@ -3,6 +3,7 @@ using Oceananigans.DistributedComputations: DistributedFFTBasedPoissonSolver using Oceananigans.Grids: XDirection, YDirection, ZDirection, inactive_cell using Oceananigans.Solvers: FFTBasedPoissonSolver, FourierTridiagonalPoissonSolver using Oceananigans.Solvers: ConjugateGradientPoissonSolver +using Oceananigans.Solvers: KrylovPoissonSolver using Oceananigans.Solvers: solve! ##### @@ -89,3 +90,11 @@ function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver, return solve!(pressure, solver.conjugate_gradient_solver, rhs) end +function solve_for_pressure!(pressure, solver::KrylovPoissonSolver, Δt, Ũ) + rhs = solver.right_hand_side + grid = solver.grid + arch = architecture(grid) + launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, Ũ) + return solve!(pressure, solver.krylov_solver, rhs) +end + diff --git a/src/Solvers/Solvers.jl b/src/Solvers/Solvers.jl index f410322592..17d65da9e7 100644 --- a/src/Solvers/Solvers.jl +++ b/src/Solvers/Solvers.jl @@ -5,8 +5,10 @@ export FFTBasedPoissonSolver, FourierTridiagonalPoissonSolver, ConjugateGradientSolver, + ConjugateGradientPoissonSolver, HeptadiagonalIterativeSolver, - KrylovSolver + KrylovSolver, + KrylovPoissonSolver using Statistics using FFTW @@ -45,6 +47,7 @@ include("fft_based_poisson_solver.jl") include("fourier_tridiagonal_poisson_solver.jl") include("conjugate_gradient_poisson_solver.jl") include("krylov_solver.jl") +include("krylov_poisson_solver.jl") include("sparse_approximate_inverse.jl") include("matrix_solver_utils.jl") include("sparse_preconditioners.jl") diff --git a/src/Solvers/conjugate_gradient_poisson_solver.jl b/src/Solvers/conjugate_gradient_poisson_solver.jl index f809b492f9..ce3cb783ac 100644 --- a/src/Solvers/conjugate_gradient_poisson_solver.jl +++ b/src/Solvers/conjugate_gradient_poisson_solver.jl @@ -12,7 +12,7 @@ struct ConjugateGradientPoissonSolver{G, R, S} conjugate_gradient_solver :: S end -architecture(solver::ConjugateGradientPoissonSolver) = architecture(cgps.grid) +architecture(cgps::ConjugateGradientPoissonSolver) = architecture(cgps.grid) iteration(cgps::ConjugateGradientPoissonSolver) = iteration(cgps.conjugate_gradient_solver) Base.summary(ips::ConjugateGradientPoissonSolver) = diff --git a/src/Solvers/krylov_poisson_solver.jl b/src/Solvers/krylov_poisson_solver.jl new file mode 100644 index 0000000000..5a5c6db9e5 --- /dev/null +++ b/src/Solvers/krylov_poisson_solver.jl @@ -0,0 +1,77 @@ +using Oceananigans.Operators +using Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid +using Statistics: mean +using Oceananigans.Solvers: compute_laplacian!, DefaultPreconditioner + +using KernelAbstractions: @kernel, @index + +import Oceananigans.Architectures: architecture + +struct KrylovPoissonSolver{G, R, S} + grid :: G + right_hand_side :: R + krylov_solver :: S +end + +architecture(kps::KrylovPoissonSolver) = architecture(kps.grid) +iteration(kps::KrylovPoissonSolver) = iteration(kps.krylov_solver) + +Base.summary(kps::KrylovPoissonSolver) = + "KrylovPoissonSolver with $(kps.krylov_solver) method, $(summary(kps.krylov_solver.preconditioner)) preconditioner on $(summary(kps.grid))" + +function Base.show(io::IO, kps::KrylovPoissonSolver) + A = architecture(kps.grid) + print(io, "KrylovPoissonSolver:", '\n', + "├── grid: ", summary(kps.grid), '\n', + "└── krylov_solver: ", summary(kps.krylov_solver), '\n', + " ├── maxiter: ", prettysummary(kps.krylov_solver.maxiter), '\n', + " ├── reltol: ", prettysummary(kps.krylov_solver.reltol), '\n', + " ├── abstol: ", prettysummary(kps.krylov_solver.abstol), '\n', + " ├── preconditioner: ", prettysummary(kps.krylov_solver.preconditioner), '\n', + " └── iteration: ", prettysummary(kps.krylov_solver.workspace.stats.niter)) +end + +""" + KrylovPoissonSolver(grid; + method = :cg, + preconditioner = DefaultPreconditioner(), + reltol = sqrt(eps(grid)), + abstol = sqrt(eps(grid)), + kw...) + +Creates a `KrylovPoissonSolver` with `method` on `grid` using a `preconditioner`. +`KrylovPoissonSolver` is iterative, and will stop when both the relative error in the +pressure solution is smaller than `reltol` and the absolute error is smaller than `abstol`. Other +keyword arguments are passed to `KrylovSolver`. +""" +function KrylovPoissonSolver(grid; + method = :cg, + preconditioner = DefaultPreconditioner(), + reltol = sqrt(eps(grid)), + abstol = sqrt(eps(grid)), + kw...) + + # if method ∉ [:cg, :bicgstab] + # @warn "Currently, KrylovPoissonSolver only supports :cg and :bicgstab methods. Support for other methods will be added soon!" + # end + + if preconditioner isa DefaultPreconditioner # try to make a useful default + if grid isa ImmersedBoundaryGrid && grid.underlying_grid isa GridWithFFTSolver + preconditioner = fft_poisson_solver(grid.underlying_grid) + else + preconditioner = DiagonallyDominantPreconditioner() + end + end + + rhs = CenterField(grid) + + krylov_solver = KrylovSolver(compute_laplacian!; + method, + reltol, + abstol, + preconditioner, + template_field = rhs, + kw...) + + return KrylovPoissonSolver(grid, rhs, krylov_solver) +end \ No newline at end of file diff --git a/src/Solvers/krylov_solver.jl b/src/Solvers/krylov_solver.jl index 2c16b1bd30..2c0920a48a 100644 --- a/src/Solvers/krylov_solver.jl +++ b/src/Solvers/krylov_solver.jl @@ -180,3 +180,15 @@ function solve!(x, solver::KrylovSolver, b, args...; kwargs...) copyto!(x, solver.workspace.x.field) return x end + +function Base.show(io::IO, solver::KrylovSolver) + print(io, "KrylovSolver on ", summary(solver.architecture), "\n", + "├── method: ", solver.method, "\n", + "├── grid: ", summary(solver.grid), "\n", + "├── preconditioner: ", prettysummary(solver.preconditioner), "\n", + "├── reltol: ", prettysummary(solver.reltol), "\n", + "├── abstol: ", prettysummary(solver.abstol), "\n", + "├── maxiter: ", solver.maxiter, "\n", + "├── work field: ", summary(solver.workspace.r), "\n", + "└── linear operation: ", prettysummary(solver.op.fun)) +end \ No newline at end of file