Skip to content

Commit b337d9d

Browse files
committed
LinearSolver: fix zero initial guess and update after error
1 parent 22ff4d1 commit b337d9d

File tree

3 files changed

+89
-35
lines changed

3 files changed

+89
-35
lines changed

firedrake/linear_solver.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ def solve(self, x, b):
8383
if b.function_space() != self.b.function_space():
8484
raise ValueError(f"b must be a Cofunction in {self.b.function_space()}.")
8585

86-
self.x.assign(x)
8786
self.b.assign(b)
88-
super().solve()
89-
x.assign(self.x)
87+
if self.ksp.getInitialGuessNonzero():
88+
self.x.assign(x)
89+
else:
90+
self.x.zero()
91+
try:
92+
super().solve()
93+
finally:
94+
# Update x even when ConvergenceError is raised
95+
x.assign(self.x)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from firedrake import *
2+
from firedrake.petsc import PETSc
3+
import numpy
4+
5+
6+
def test_linear_solver_zero_initial_guess():
7+
mesh = UnitIntervalMesh(10)
8+
space = FunctionSpace(mesh, "CG", 1)
9+
test = TestFunction(space)
10+
trial = TrialFunction(space)
11+
12+
solver = LinearSolver(assemble(inner(trial, test) * dx),
13+
solver_parameters={"ksp_type": "preonly",
14+
"pc_type": "jacobi",
15+
"ksp_max_it": 1,
16+
"ksp_initial_guess_nonzero": False})
17+
b = assemble(inner(Constant(1), test) * dx)
18+
19+
u1 = Function(space, name="u1")
20+
u1.assign(0)
21+
solver.solve(u1, b)
22+
23+
u2 = Function(space, name="u2")
24+
u2.assign(1)
25+
solver.solve(u2, b)
26+
assert numpy.allclose(u1.dat.data_ro, u2.dat.data_ro)
27+
28+
29+
def test_linear_solver_update_after_error():
30+
mesh = UnitIntervalMesh(10)
31+
space = FunctionSpace(mesh, "CG", 1)
32+
test = TestFunction(space)
33+
trial = TrialFunction(space)
34+
35+
solver = LinearSolver(assemble(inner(trial, test) * dx),
36+
solver_parameters={"ksp_type": "cg",
37+
"pc_type": "none",
38+
"ksp_max_it": 1,
39+
"ksp_atol": 1.0e-2})
40+
b = assemble(inner(Constant(1), test) * dx)
41+
42+
u = Function(space, name="u")
43+
u.assign(-1)
44+
uinit = Function(u, name="uinit")
45+
try:
46+
solver.solve(u, b)
47+
except firedrake.exceptions.ConvergenceError:
48+
assert solver.ksp.getConvergedReason() == PETSc.KSP.ConvergedReason.DIVERGED_MAX_IT
49+
50+
assert not numpy.allclose(u.dat.data_ro, uinit.dat.data_ro)
51+
52+
53+
def test_linear_solver_change_bc():
54+
mesh = UnitSquareMesh(4, 4, quadrilateral=False)
55+
V = FunctionSpace(mesh, "P", 1)
56+
u = TrialFunction(V)
57+
v = TestFunction(V)
58+
59+
a = inner(grad(u), grad(v))*dx
60+
61+
bcval = Function(V)
62+
x, y = SpatialCoordinate(mesh)
63+
bcval.interpolate(1 + 2*y)
64+
bc = DirichletBC(V, bcval, "on_boundary")
65+
66+
A = assemble(a, bcs=bc)
67+
b = Cofunction(V.dual())
68+
69+
solver = LinearSolver(A)
70+
71+
uh = Function(V)
72+
73+
solver.solve(uh, b)
74+
75+
assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro)
76+
77+
bcval.interpolate(-(1 + 2*y))
78+
79+
solver.solve(uh, b)
80+
assert numpy.allclose(uh.dat.data_ro, bc.function_arg.dat.data_ro)

tests/firedrake/regression/test_linear_solver_change_bc.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)