Open
Description
Hi,
I am solving a two equation problem in Firedrake and I was a bit concerned about the disproportional amount of time it takes to compute the derivative of my reduced functional Jhat.
The 1D problem I'm trying to solve is as follows:
The equation for c and solving method is basically the same as in the DG advection equation with upwinding tutorial, k, k_2 are constants, and the control here is c_in.
Solving for c and q takes only couple of seconds, but computing Jhat.derivative() takes around 6 minutes.
Does this look normal, or is there a problem in my code and/or a way how to solve this faster?
Thank you for your help!
from firedrake import *
from firedrake_adjoint import *
# Set up the mesh
mesh = UnitIntervalMesh(40)
# Set up the function spaces
Vec = VectorFunctionSpace(mesh, "CG", 1)
V_c = FunctionSpace(mesh, "DG", 1)
V_q = FunctionSpace(mesh, "DG", 0)
W = V_c*V_q
# Get the spatial coordinate for x and set constant velocity with static boundary conditions
x, = SpatialCoordinate(mesh)
velocity = as_vector((1, ))
u = Function(Vec).interpolate(velocity)
c_in = Constant(1.0)
bcs = [DirichletBC(W.sub(0), c_in, 1)]
# Set the initial condition
f = Function(W)
with stop_annotating():
c, q = f.split()
q.assign(1.0)
# Set time T, step dt
T = 2
dt = T/600
dtc = Constant(dt)
# Set the left hand side of our equation
dc_trial, dq_trial = TrialFunctions(W)
phi, psi = TestFunctions(W)
a = phi*dc_trial*dx + psi*dq_trial*dx
# We define ``n`` to be the built-in ``FacetNormal`` object; a unit normal vector
# that can be used in integrals over exterior and interior facets. We next define
# ``un`` to be an object which is equal to :math:`\vec{u}\cdot\vec{n}` if this is
# positive, and zero if this is negative. This will be useful in the upwind terms.
n = FacetNormal(mesh)
un = 0.5*(dot(u, n) + abs(dot(u, n)))
k = 0.8
k2 = 0.1
# Right-hand side
L1 = dtc*(c*div(phi*u)*dx
- conditional(dot(u, n) < 0, phi*dot(u, n)*c_in, 0.0)*ds
- conditional(dot(u, n) > 0, phi*dot(u, n)*c, 0.0)*ds
- (phi('+') - phi('-'))*(un('+')*c('+') - un('-')*c('-'))*dS
- k*phi*q*c*dx
- k2*psi*q*c*dx)
# Runge-Kutta
f1 = Function(W); f2 = Function(W)
L2 = replace(L1, {c: split(f1)[0], q: split(f1)[1]}); L3 = replace(L1, {c: split(f2)[0], q: split(f2)[1]})
# We now declare a variable to hold the temporary increments at each stage.
df = Function(W)
# We make use of the ``LinearVariationalProblem`` and
# ``LinearVariationalSolver`` objects for each of our Runge-Kutta stages.
params = {'ksp_type': 'preonly', 'pc_type': 'bjacobi', 'sub_pc_type': 'ilu', 'mat_type': 'aij'}
prob1 = LinearVariationalProblem(a, L1, df, bcs=bcs)
solv1 = LinearVariationalSolver(prob1, solver_parameters=params)
prob2 = LinearVariationalProblem(a, L2, df, bcs=bcs)
solv2 = LinearVariationalSolver(prob2, solver_parameters=params)
prob3 = LinearVariationalProblem(a, L3, df, bcs=bcs)
solv3 = LinearVariationalSolver(prob3, solver_parameters=params)
# Run the time loop with three Runge-Kutta stages, and write the results
# into the results list
t = 0.0
step = 0
with stop_annotating():
c_, q_ = f.split()
results = [[Function(c_)],[Function(q)]]
while t < T - 0.5*dt:
solv1.solve()
f1.assign(f + df)
solv2.solve()
f2.assign(0.75*f + 0.25*(f1 + df))
solv3.solve()
f.assign((1.0/3.0)*f + (2.0/3.0)*(f2 + df))
with stop_annotating():
c_, q_ = f.split()
results[0].append(Function(c_))
results[1].append(Function(q_))
step += 1
t += dt
# Set up control and reduced functional Jhat
c, q = split(f)
J = assemble(c*ds(2))
m = Control(c_in)
Jhat = ReducedFunctional(J, m)
d = Jhat.derivative()
print(d.dat.data)