diff --git a/demos/helmholtz/helmholtz.txt b/demos/helmholtz/helmholtz.txt new file mode 100644 index 0000000000..95334ccb6a --- /dev/null +++ b/demos/helmholtz/helmholtz.txt @@ -0,0 +1,57 @@ +Main Stage 366614 +Main Stage;firedrake 44369 +Main Stage;firedrake;firedrake.solving.solve 86 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve 196 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve 140 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval 736 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute 212 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral 112 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral;pyop2.global_kernel.GlobalKernel.compile 415552 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;firedrake.tsfc_interface.compile_form 42597 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval 866 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute 149 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral 136 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral;pyop2.global_kernel.GlobalKernel.compile 407506 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__ 1771 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form 56423 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form;firedrake.formmanipulation.split_form 1907 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.solving_utils._SNESContext.__init__ 618 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__ 145 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.ufl_expr.action 4387 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__ 332 +Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__;firedrake.ufl_expr.adjoint 2798 +Main Stage;firedrake;firedrake.function.Function.interpolate 342 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble 5644 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 29 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 298 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 204 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 682292 +Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 40658 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write 2473 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate 303 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble 1080 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 23 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 328 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 165 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 663410 +Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 55147 +Main Stage;firedrake;firedrake.__init__ 495196 +Main Stage;firedrake;firedrake.assemble.assemble 949 +Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute 310 +Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral 95 +Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral;pyop2.global_kernel.GlobalKernel.compile 355507 +Main Stage;firedrake;firedrake.assemble.assemble;firedrake.tsfc_interface.compile_form 20219 +Main Stage;firedrake;CreateFunctionSpace 919 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace 79 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__ 165 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data 13 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__ 825 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;FunctionSpaceData: CreateElement 1274 +Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;firedrake.mesh.MeshTopology._facets 789 +Main Stage;firedrake;CreateFunctionSpace;CreateMesh 147 +Main Stage;firedrake;CreateFunctionSpace;CreateMesh;Mesh: numbering 376 +Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh 12 +Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh 11 +Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh 834 +Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;CreateMesh 676 +Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;DMPlexInterp 382 diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index c1dfbcc07e..c8e58c0ec6 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -55,6 +55,8 @@ def __init__(self, A, *, P=None, solver_parameters=None, solver_parameters = solving_utils.set_defaults(solver_parameters, A.arguments(), ksp_defaults=self.DEFAULT_KSP_PARAMETERS) + # todo: add offload to solver parameters - how? prefix? + self.A = A self.comm = A.comm self._comm = internal_comm(self.comm, self) @@ -163,6 +165,18 @@ def solve(self, x, b): else: acc = x.dat.vec_wo + # if "cu" in self.A.petscmat.type: # todo: cuda or cu? + # with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self): + # b_cu = PETSc.Vec() + # b_cu.createCUDAWithArrays(rhs) + # u = PETSc.Vec() + # u.createCUDAWithArrays(solution) + # self.ksp.solve(b_cu, u) + # u.getArray() + + # else: + # instead: preconditioner + with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self): self.ksp.solve(rhs, solution) diff --git a/firedrake/preconditioners/__init__.py b/firedrake/preconditioners/__init__.py index cd75ae7380..ca04bd9cbd 100644 --- a/firedrake/preconditioners/__init__.py +++ b/firedrake/preconditioners/__init__.py @@ -12,3 +12,4 @@ from firedrake.preconditioners.fdm import * # noqa: F401 from firedrake.preconditioners.hiptmair import * # noqa: F401 from firedrake.preconditioners.facet_split import * # noqa: F401 +from firedrake.preconditioners.offload import * # noqa: F401 diff --git a/firedrake/preconditioners/offload.py b/firedrake/preconditioners/offload.py new file mode 100644 index 0000000000..7a306ae24e --- /dev/null +++ b/firedrake/preconditioners/offload.py @@ -0,0 +1,117 @@ +from firedrake.preconditioners.base import PCBase +from firedrake.functionspace import FunctionSpace, MixedFunctionSpace +from firedrake.petsc import PETSc +from firedrake.ufl_expr import TestFunction, TrialFunction +import firedrake.dmhooks as dmhooks +from firedrake.dmhooks import get_function_space + +__all__ = ("OffloadPC",) + + +class OffloadPC(PCBase): + """Offload PC from CPU to GPU and back. + + Internally this makes a PETSc PC object that can be controlled by + options using the extra options prefix ``offload_``. + """ + + _prefix = "offload_" + + def initialize(self, pc): + with PETSc.Log.Event("Event: initialize offload"): # + A, P = pc.getOperators() + + outer_pc = pc + appctx = self.get_appctx(pc) + fcp = appctx.get("form_compiler_parameters") + + V = get_function_space(pc.getDM()) + if len(V) == 1: + V = FunctionSpace(V.mesh(), V.ufl_element()) + else: + V = MixedFunctionSpace([V_ for V_ in V]) + test = TestFunction(V) + trial = TrialFunction(V) + + (a, bcs) = self.form(pc, test, trial) + + if P.type == "assembled": + context = P.getPythonContext() + # It only makes sense to preconditioner/invert a diagonal + # block in general. That's all we're going to allow. + if not context.on_diag: + raise ValueError("Only makes sense to invert diagonal block") + + prefix = pc.getOptionsPrefix() + options_prefix = prefix + self._prefix + + mat_type = PETSc.Options().getString(options_prefix + "mat_type", "cusparse") + + # Convert matrix to ajicusparse + with PETSc.Log.Event("Event: matrix offload"): + P_cu = P.convert(mat_type='aijcusparse') # todo + + # Transfer nullspace + P_cu.setNullSpace(P.getNullSpace()) + tnullsp = P.getTransposeNullSpace() + if tnullsp.handle != 0: + P_cu.setTransposeNullSpace(tnullsp) + P_cu.setNearNullSpace(P.getNearNullSpace()) + + # PC object set-up + pc = PETSc.PC().create(comm=outer_pc.comm) + pc.incrementTabLevel(1, parent=outer_pc) + + # We set a DM and an appropriate SNESContext on the constructed PC + # so one can do e.g. multigrid or patch solves. + dm = outer_pc.getDM() + self._ctx_ref = self.new_snes_ctx( + outer_pc, a, bcs, mat_type, + fcp=fcp, options_prefix=options_prefix + ) + + pc.setDM(dm) + pc.setOptionsPrefix(options_prefix) + pc.setOperators(A, P_cu) + self.pc = pc + with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref, save=False): + pc.setFromOptions() + + def update(self, pc): + _, P = pc.getOperators() + _, P_cu = self.pc.getOperators() + P.copy(P_cu) + + def form(self, pc, test, trial): + _, P = pc.getOperators() + if P.getType() == "python": + context = P.getPythonContext() + return (context.a, context.row_bcs) + else: + context = dmhooks.get_appctx(pc.getDM()) + return (context.Jp or context.J, context._problem.bcs) + + # Convert vectors to CUDA, solve and get solution on CPU back + def apply(self, pc, x, y): + with PETSc.Log.Event("Event: apply offload"): # + dm = pc.getDM() + with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref): + with PETSc.Log.Event("Event: vectors offload"): + y_cu = PETSc.Vec() # begin + y_cu.createCUDAWithArrays(y) + x_cu = PETSc.Vec() + x_cu.createCUDAWithArrays(x) # end + with PETSc.Log.Event("Event: solve"): + self.pc.apply(x_cu, y_cu) # + with PETSc.Log.Event("Event: vectors copy back"): + y.copy(y_cu) # + + def applyTranspose(self, pc, X, Y): + raise NotImplementedError + + def view(self, pc, viewer=None): + super().view(pc, viewer) + print("viewing PC") + if hasattr(self, "pc"): + viewer.printfASCII("PC to solve on GPU\n") + self.pc.view(viewer) diff --git a/firedrake/solving.py b/firedrake/solving.py index de55e52048..46754a2a17 100644 --- a/firedrake/solving.py +++ b/firedrake/solving.py @@ -252,15 +252,18 @@ def _la_solve(A, x, b, **kwargs): options_prefix=options_prefix) if isinstance(x, firedrake.Vector): x = x.function - # linear MG doesn't need RHS, supply zero. - lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs) - mat_type = A.mat_type - appctx = solver_parameters.get("appctx", {}) - ctx = solving_utils._SNESContext(lvp, - mat_type=mat_type, - pmat_type=mat_type, - appctx=appctx, - options_prefix=options_prefix) + if not isinstance(A, firedrake.matrix.AssembledMatrix): + # linear MG doesn't need RHS, supply zero. + lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs) + mat_type = A.mat_type + appctx = solver_parameters.get("appctx", {}) + ctx = solving_utils._SNESContext(lvp, + mat_type=mat_type, + pmat_type=mat_type, + appctx=appctx, + options_prefix=options_prefix) + else: + ctx = None dm = solver.ksp.dm with dmhooks.add_hooks(dm, solver, appctx=ctx):