Skip to content

Commit c54385e

Browse files
committed
some clean-up in Hdiv-mixed folder
1 parent 038dc82 commit c54385e

14 files changed

+277
-253
lines changed

examples/Hdiv-mixed/include/post-processing.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
#include "../include/setup-libceed.h"
88
#include "structs.h"
9-
PetscErrorCode PrintOutput(DM dm, Ceed ceed, AppCtx app_ctx, PetscBool has_ts, CeedMemType mem_type_backend, TS ts, SNES snes, KSP ksp, Vec U,
10-
CeedScalar l2_error_u, CeedScalar l2_error_p);
9+
PetscErrorCode PrintOutput(DM dm, Ceed ceed, AppCtx app_ctx, PetscBool has_ts, TS ts, SNES snes, KSP ksp, Vec U, CeedScalar l2_error_u,
10+
CeedScalar l2_error_p);
1111
PetscErrorCode SetupProjectVelocityCtx_Hdiv(MPI_Comm comm, DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_Hdiv);
12-
PetscErrorCode SetupProjectVelocityCtx_H1(MPI_Comm comm, DM dm_H1, Ceed ceed, CeedData ceed_data, VecType vec_type, OperatorApplyContext ctx_H1);
12+
PetscErrorCode SetupProjectVelocityCtx_H1(MPI_Comm comm, DM dm_H1, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_H1);
1313
PetscErrorCode ProjectVelocity(AppCtx app_ctx, Vec U, Vec *U_H1);
1414
PetscErrorCode CtxVecDestroy(AppCtx app_ctx);
1515
#endif // post_processing_h

examples/Hdiv-mixed/include/setup-dm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// ---------------------------------------------------------------------------
1212
// Setup DM
1313
// ---------------------------------------------------------------------------
14-
PetscErrorCode CreateDM(MPI_Comm comm, MatType mat_type, VecType vec_type, DM *dm);
14+
PetscErrorCode CreateDM(MPI_Comm comm, Ceed ceed, DM *dm);
1515
PetscErrorCode PerturbVerticesSmooth(DM dm);
1616
PetscErrorCode PerturbVerticesRandom(DM dm);
1717

examples/Hdiv-mixed/include/setup-fe.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
// ---------------------------------------------------------------------------
1212
// Setup FE
1313
// ---------------------------------------------------------------------------
14+
CeedMemType MemTypeP2C(PetscMemType mtype);
1415
PetscErrorCode SetupFEHdiv(MPI_Comm comm, DM dm, DM dm_u0, DM dm_p0);
1516
PetscErrorCode SetupFEH1(ProblemData problem_data, AppCtx app_ctx, DM dm_H1);
17+
PetscInt Involute(PetscInt i);
18+
PetscErrorCode CreateRestrictionFromPlex(Ceed ceed, DM dm, CeedInt height, DMLabel domain_label, CeedInt value, CeedElemRestriction *elem_restr);
19+
// Utility function to create local CEED Oriented restriction from DMPlex
20+
PetscErrorCode CreateRestrictionFromPlexOriented(Ceed ceed, DM dm, DM dm_u0, DM dm_p0, CeedInt P, CeedElemRestriction *elem_restr_u,
21+
CeedElemRestriction *elem_restr_p, CeedElemRestriction *elem_restr_u0,
22+
CeedElemRestriction *elem_restr_p0);
1623
#endif // setupfe_h
Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
#ifndef setuplibceed_h
22
#define setuplibceed_h
33

4+
#include "setup-fe.h"
45
#include "structs.h"
56

6-
// Convert PETSc MemType to libCEED MemType
7-
CeedMemType MemTypeP2C(PetscMemType mtype);
87
// Destroy libCEED objects
98
PetscErrorCode CeedDataDestroy(CeedData ceed_data, ProblemData problem_data);
10-
// Utility function - essential BC dofs are encoded in closure indices as -(i+1)
11-
PetscInt Involute(PetscInt i);
12-
// Utility function to create local CEED restriction from DMPlex
13-
PetscErrorCode CreateRestrictionFromPlex(Ceed ceed, DM dm, CeedInt height, DMLabel domain_label, CeedInt value, CeedElemRestriction *elem_restr);
14-
// Utility function to create local CEED Oriented restriction from DMPlex
15-
PetscErrorCode CreateRestrictionFromPlexOriented(Ceed ceed, DM dm, DM dm_u0, DM dm_p0, CeedInt P, CeedElemRestriction *elem_restr_u,
16-
CeedElemRestriction *elem_restr_p, CeedElemRestriction *elem_restr_u0,
17-
CeedElemRestriction *elem_restr_p0);
18-
// Set up libCEED for a given degree
199
PetscErrorCode SetupLibceed(DM dm, DM dm_u0, DM dm_p0, DM dm_H1, Ceed ceed, AppCtx app_ctx, ProblemData problem_data, CeedData ceed_data);
2010
#endif // setuplibceed_h

examples/Hdiv-mixed/include/setup-solvers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "petscvec.h"
88
#include "structs.h"
99

10-
PetscErrorCode SetupJacobianOperatorCtx(DM dm, Ceed ceed, CeedData ceed_data, VecType vec_type, OperatorApplyContext ctx_jacobian);
10+
PetscErrorCode SetupJacobianOperatorCtx(DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_jacobian);
1111
PetscErrorCode SetupResidualOperatorCtx(DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_residual);
1212
PetscErrorCode SetupErrorOperatorCtx(DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_error);
1313
PetscErrorCode ApplyMatOp(Mat A, Vec X, Vec Y);

examples/Hdiv-mixed/include/setup-ts.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include "structs.h"
88

9-
PetscErrorCode CreateInitialConditions(CeedData ceed_data, AppCtx app_ctx, VecType vec_type, Vec U);
9+
PetscErrorCode CreateInitialConditions(CeedData ceed_data, AppCtx app_ctx, Vec U);
1010
PetscErrorCode SetupResidualOperatorCtx_Ut(MPI_Comm comm, DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_residual_ut);
1111
PetscErrorCode SetupResidualOperatorCtx_U0(MPI_Comm comm, DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_initial_u0);
1212
PetscErrorCode SetupResidualOperatorCtx_P0(MPI_Comm comm, DM dm, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_initial_p0);

examples/Hdiv-mixed/main.c

Lines changed: 47 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ int main(int argc, char **argv) {
4141
// Initialize PETSc
4242
// ---------------------------------------------------------------------------
4343
PetscCall(PetscInitialize(&argc, &argv, NULL, help));
44+
MPI_Comm comm = PETSC_COMM_WORLD;
4445

4546
// ---------------------------------------------------------------------------
4647
// Create structs
@@ -77,6 +78,12 @@ int main(int argc, char **argv) {
7778
// Context for post-processing
7879
app_ctx->ctx_Hdiv = ctx_Hdiv;
7980
app_ctx->ctx_H1 = ctx_H1;
81+
app_ctx->comm = comm;
82+
83+
// ---------------------------------------------------------------------------
84+
// Process command line options
85+
// ---------------------------------------------------------------------------
86+
PetscCall(ProcessCommandLineOptions(app_ctx));
8087

8188
// ---------------------------------------------------------------------------
8289
// Initialize libCEED
@@ -85,41 +92,20 @@ int main(int argc, char **argv) {
8592
Ceed ceed;
8693
CeedInit("/cpu/self/ref/serial", &ceed);
8794
// CeedInit(app_ctx->ceed_resource, &ceed);
88-
CeedMemType mem_type_backend;
89-
CeedGetPreferredMemType(ceed, &mem_type_backend);
90-
91-
VecType vec_type = NULL;
92-
MatType mat_type = NULL;
93-
switch (mem_type_backend) {
94-
case CEED_MEM_HOST:
95-
vec_type = VECSTANDARD;
96-
break;
97-
case CEED_MEM_DEVICE: {
98-
const char *resolved;
99-
CeedGetResource(ceed, &resolved);
100-
if (strstr(resolved, "/gpu/cuda")) vec_type = VECCUDA;
101-
else if (strstr(resolved, "/gpu/hip")) vec_type = VECKOKKOS;
102-
else vec_type = VECSTANDARD;
103-
}
104-
}
105-
if (strstr(vec_type, VECCUDA)) mat_type = MATAIJCUSPARSE;
106-
else if (strstr(vec_type, VECKOKKOS)) mat_type = MATAIJKOKKOS;
107-
else mat_type = MATAIJ;
10895

10996
// -- Process general command line options
110-
MPI_Comm comm = PETSC_COMM_WORLD;
11197
// ---------------------------------------------------------------------------
11298
// Create DM
11399
// ---------------------------------------------------------------------------
114100
DM dm, dm_u0, dm_p0, dm_H1;
115101
// DM for mixed problem
116-
PetscCall(CreateDM(comm, mat_type, vec_type, &dm));
102+
PetscCall(CreateDM(app_ctx->comm, ceed, &dm));
117103
// DM for projecting initial velocity to Hdiv space
118-
PetscCall(CreateDM(comm, mat_type, vec_type, &dm_u0));
104+
PetscCall(CreateDM(app_ctx->comm, ceed, &dm_u0));
119105
// DM for projecting initial pressure in L2
120-
PetscCall(CreateDM(comm, mat_type, vec_type, &dm_p0));
106+
PetscCall(CreateDM(app_ctx->comm, ceed, &dm_p0));
121107
// DM for projecting solution U into H1 space for PetscViewer
122-
PetscCall(CreateDM(comm, mat_type, vec_type, &dm_H1));
108+
PetscCall(CreateDM(app_ctx->comm, ceed, &dm_H1));
123109
// TODO: add mesh option
124110
// perturb dm to have smooth random mesh
125111
// PetscCall( PerturbVerticesSmooth(dm) );
@@ -129,18 +115,10 @@ int main(int argc, char **argv) {
129115
// PetscCall(PerturbVerticesRandom(dm) );
130116
// PetscCall(PerturbVerticesRandom(dm_H1) );
131117

132-
// ---------------------------------------------------------------------------
133-
// Process command line options
134-
// ---------------------------------------------------------------------------
135-
// -- Register problems to be available on the command line
136-
PetscCall(RegisterProblems_Hdiv(app_ctx));
137-
138-
app_ctx->comm = comm;
139-
PetscCall(ProcessCommandLineOptions(app_ctx));
140-
141118
// ---------------------------------------------------------------------------
142119
// Choose the problem from the list of registered problems
143120
// ---------------------------------------------------------------------------
121+
PetscCall(RegisterProblems_Hdiv(app_ctx));
144122
{
145123
PetscErrorCode (*p)(Ceed, ProblemData, DM, void *);
146124
PetscCall(PetscFunctionListFind(app_ctx->problems, app_ctx->problem_name, &p));
@@ -151,7 +129,7 @@ int main(int argc, char **argv) {
151129
// ---------------------------------------------------------------------------
152130
// Setup FE for H(div) mixed-problem and H1 projection in post-processing.c
153131
// ---------------------------------------------------------------------------
154-
PetscCall(SetupFEHdiv(comm, dm, dm_u0, dm_p0));
132+
PetscCall(SetupFEHdiv(app_ctx->comm, dm, dm_u0, dm_p0));
155133
PetscCall(SetupFEH1(problem_data, app_ctx, dm_H1));
156134

157135
// ---------------------------------------------------------------------------
@@ -167,36 +145,36 @@ int main(int argc, char **argv) {
167145
PetscCall(SetupLibceed(dm, dm_u0, dm_p0, dm_H1, ceed, app_ctx, problem_data, ceed_data));
168146

169147
// ---------------------------------------------------------------------------
170-
// Setup pressure boundary conditions
148+
// Setup pressure boundary conditions (not working)
171149
// ---------------------------------------------------------------------------
172150
// --Create empty local vector for libCEED
173-
Vec P_loc;
174-
PetscInt P_loc_size;
175-
CeedScalar *p0;
176-
CeedVector P_ceed;
177-
PetscMemType pressure_mem_type;
178-
PetscCall(DMCreateLocalVector(dm, &P_loc));
179-
PetscCall(VecGetSize(P_loc, &P_loc_size));
180-
PetscCall(VecZeroEntries(P_loc));
181-
PetscCall(VecGetArrayAndMemType(P_loc, &p0, &pressure_mem_type));
182-
CeedVectorCreate(ceed, P_loc_size, &P_ceed);
183-
CeedVectorSetArray(P_ceed, MemTypeP2C(pressure_mem_type), CEED_USE_POINTER, p0);
184-
// -- Apply operator to create local pressure vector on boundary
185-
PetscCall(DMAddBoundariesPressure(ceed, ceed_data, app_ctx, problem_data, dm, P_ceed));
186-
// CeedVectorView(P_ceed, "%12.8f", stdout);
187-
// -- Map local to global
188-
Vec P;
189-
CeedVectorTakeArray(P_ceed, MemTypeP2C(pressure_mem_type), NULL);
190-
PetscCall(VecRestoreArrayAndMemType(P_loc, &p0));
191-
PetscCall(DMCreateGlobalVector(dm, &P));
192-
PetscCall(VecZeroEntries(P));
193-
PetscCall(DMLocalToGlobal(dm, P_loc, ADD_VALUES, P));
151+
// Vec P_loc;
152+
// PetscInt P_loc_size;
153+
// CeedScalar *p0;
154+
// CeedVector P_ceed;
155+
// PetscMemType pressure_mem_type;
156+
// PetscCall(DMCreateLocalVector(dm, &P_loc));
157+
// PetscCall(VecGetSize(P_loc, &P_loc_size));
158+
// PetscCall(VecZeroEntries(P_loc));
159+
// PetscCall(VecGetArrayAndMemType(P_loc, &p0, &pressure_mem_type));
160+
// CeedVectorCreate(ceed, P_loc_size, &P_ceed);
161+
// CeedVectorSetArray(P_ceed, MemTypeP2C(pressure_mem_type), CEED_USE_POINTER, p0);
162+
//// -- Apply operator to create local pressure vector on boundary
163+
// PetscCall(DMAddBoundariesPressure(ceed, ceed_data, app_ctx, problem_data, dm, P_ceed));
164+
//// CeedVectorView(P_ceed, "%12.8f", stdout);
165+
//// -- Map local to global
166+
// Vec P;
167+
// CeedVectorTakeArray(P_ceed, MemTypeP2C(pressure_mem_type), NULL);
168+
// PetscCall(VecRestoreArrayAndMemType(P_loc, &p0));
169+
// PetscCall(DMCreateGlobalVector(dm, &P));
170+
// PetscCall(VecZeroEntries(P));
171+
// PetscCall(DMLocalToGlobal(dm, P_loc, ADD_VALUES, P));
194172

195173
// ---------------------------------------------------------------------------
196174
// Setup context for projection problem; post-processing.c
197175
// ---------------------------------------------------------------------------
198-
PetscCall(SetupProjectVelocityCtx_Hdiv(comm, dm, ceed, ceed_data, app_ctx->ctx_Hdiv));
199-
PetscCall(SetupProjectVelocityCtx_H1(comm, dm_H1, ceed, ceed_data, vec_type, app_ctx->ctx_H1));
176+
PetscCall(SetupProjectVelocityCtx_Hdiv(app_ctx->comm, dm, ceed, ceed_data, app_ctx->ctx_Hdiv));
177+
PetscCall(SetupProjectVelocityCtx_H1(app_ctx->comm, dm_H1, ceed, ceed_data, app_ctx->ctx_H1));
200178

201179
// ---------------------------------------------------------------------------
202180
// Setup TSSolve for Richard problem
@@ -206,13 +184,13 @@ int main(int argc, char **argv) {
206184
// ---------------------------------------------------------------------------
207185
// Setup context for initial conditions
208186
// ---------------------------------------------------------------------------
209-
PetscCall(SetupResidualOperatorCtx_U0(comm, dm_u0, ceed, ceed_data, app_ctx->ctx_initial_u0));
210-
PetscCall(SetupResidualOperatorCtx_P0(comm, dm_p0, ceed, ceed_data, app_ctx->ctx_initial_p0));
211-
PetscCall(SetupResidualOperatorCtx_Ut(comm, dm, ceed, ceed_data, app_ctx->ctx_residual_ut));
212-
PetscCall(CreateInitialConditions(ceed_data, app_ctx, vec_type, U));
187+
PetscCall(SetupResidualOperatorCtx_U0(app_ctx->comm, dm_u0, ceed, ceed_data, app_ctx->ctx_initial_u0));
188+
PetscCall(SetupResidualOperatorCtx_P0(app_ctx->comm, dm_p0, ceed, ceed_data, app_ctx->ctx_initial_p0));
189+
PetscCall(SetupResidualOperatorCtx_Ut(app_ctx->comm, dm, ceed, ceed_data, app_ctx->ctx_residual_ut));
190+
PetscCall(CreateInitialConditions(ceed_data, app_ctx, U));
213191
// VecView(U, PETSC_VIEWER_STDOUT_WORLD);
214192
// Solve Richards problem
215-
PetscCall(TSCreate(comm, &ts));
193+
PetscCall(TSCreate(app_ctx->comm, &ts));
216194
PetscCall(VecZeroEntries(app_ctx->ctx_residual_ut->X_loc));
217195
PetscCall(VecZeroEntries(app_ctx->ctx_residual_ut->X_t_loc));
218196
PetscCall(TSSolveRichard(ceed_data, app_ctx, ts, &U));
@@ -225,10 +203,10 @@ int main(int argc, char **argv) {
225203
SNES snes;
226204
KSP ksp;
227205
if (!problem_data->has_ts) {
228-
PetscCall(SetupJacobianOperatorCtx(dm, ceed, ceed_data, vec_type, app_ctx->ctx_jacobian));
206+
PetscCall(SetupJacobianOperatorCtx(dm, ceed, ceed_data, app_ctx->ctx_jacobian));
229207
PetscCall(SetupResidualOperatorCtx(dm, ceed, ceed_data, app_ctx->ctx_residual));
230208
// Create SNES
231-
PetscCall(SNESCreate(comm, &snes));
209+
PetscCall(SNESCreate(app_ctx->comm, &snes));
232210
PetscCall(SNESGetKSP(snes, &ksp));
233211
PetscCall(PDESolver(ceed_data, app_ctx, snes, ksp, &U));
234212
// VecView(U, PETSC_VIEWER_STDOUT_WORLD);
@@ -244,14 +222,14 @@ int main(int argc, char **argv) {
244222
// ---------------------------------------------------------------------------
245223
// Print solver iterations and final norms
246224
// ---------------------------------------------------------------------------
247-
PetscCall(PrintOutput(dm, ceed, app_ctx, problem_data->has_ts, mem_type_backend, ts, snes, ksp, U, l2_error_u, l2_error_p));
225+
PetscCall(PrintOutput(dm, ceed, app_ctx, problem_data->has_ts, ts, snes, ksp, U, l2_error_u, l2_error_p));
248226

249227
// ---------------------------------------------------------------------------
250228
// Save solution (paraview)
251229
// ---------------------------------------------------------------------------
252230
if (app_ctx->view_solution) {
253231
PetscViewer viewer_p;
254-
PetscCall(PetscViewerVTKOpen(comm, "darcy_pressure.vtu", FILE_MODE_WRITE, &viewer_p));
232+
PetscCall(PetscViewerVTKOpen(app_ctx->comm, "darcy_pressure.vtu", FILE_MODE_WRITE, &viewer_p));
255233
PetscCall(VecView(U, viewer_p));
256234
PetscCall(PetscViewerDestroy(&viewer_p));
257235

@@ -260,7 +238,7 @@ int main(int argc, char **argv) {
260238
PetscCall(ProjectVelocity(app_ctx, U, &U_H1));
261239

262240
PetscViewer viewer_u;
263-
PetscCall(PetscViewerVTKOpen(comm, "darcy_velocity.vtu", FILE_MODE_WRITE, &viewer_u));
241+
PetscCall(PetscViewerVTKOpen(app_ctx->comm, "darcy_velocity.vtu", FILE_MODE_WRITE, &viewer_u));
264242
PetscCall(VecView(U_H1, viewer_u));
265243
PetscCall(PetscViewerDestroy(&viewer_u));
266244
PetscCall(VecDestroy(&U_H1));

examples/Hdiv-mixed/src/cl-options.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ PetscErrorCode ProcessCommandLineOptions(AppCtx app_ctx) {
2525
PetscBool ceed_flag = PETSC_FALSE;
2626
PetscFunctionBeginUser;
2727

28-
PetscOptionsBegin(app_ctx->comm, NULL, "H(div) examples in PETSc with libCEED", NULL);
28+
PetscOptionsBegin(app_ctx->comm, NULL, "H(div) mixed-problem in PETSc with libCEED", NULL);
2929

3030
PetscCall(PetscOptionsString("-ceed", "CEED resource specifier", NULL, app_ctx->ceed_resource, app_ctx->ceed_resource,
3131
sizeof(app_ctx->ceed_resource), &ceed_flag));

examples/Hdiv-mixed/src/post-processing.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
// -----------------------------------------------------------------------------
66
// This function print the output
77
// -----------------------------------------------------------------------------
8-
PetscErrorCode PrintOutput(DM dm, Ceed ceed, AppCtx app_ctx, PetscBool has_ts, CeedMemType mem_type_backend, TS ts, SNES snes, KSP ksp, Vec U,
9-
CeedScalar l2_error_u, CeedScalar l2_error_p) {
8+
PetscErrorCode PrintOutput(DM dm, Ceed ceed, AppCtx app_ctx, PetscBool has_ts, TS ts, SNES snes, KSP ksp, Vec U, CeedScalar l2_error_u,
9+
CeedScalar l2_error_p) {
1010
PetscFunctionBeginUser;
1111

1212
const char *used_resource;
13+
CeedMemType mem_type_backend;
1314
CeedGetResource(ceed, &used_resource);
15+
CeedGetPreferredMemType(ceed, &mem_type_backend);
1416
char hostname[PETSC_MAX_PATH_LEN];
1517
PetscCall(PetscGetHostName(hostname, sizeof hostname));
1618
PetscInt comm_size;
@@ -128,9 +130,11 @@ PetscErrorCode SetupProjectVelocityCtx_Hdiv(MPI_Comm comm, DM dm, Ceed ceed, Cee
128130
PetscFunctionReturn(0);
129131
}
130132

131-
PetscErrorCode SetupProjectVelocityCtx_H1(MPI_Comm comm, DM dm_H1, Ceed ceed, CeedData ceed_data, VecType vec_type, OperatorApplyContext ctx_H1) {
133+
PetscErrorCode SetupProjectVelocityCtx_H1(MPI_Comm comm, DM dm_H1, Ceed ceed, CeedData ceed_data, OperatorApplyContext ctx_H1) {
132134
PetscFunctionBeginUser;
133135

136+
VecType vec_type;
137+
PetscCall(DMGetVecType(dm_H1, &vec_type));
134138
ctx_H1->comm = comm;
135139
ctx_H1->dm = dm_H1;
136140
PetscCall(DMCreateLocalVector(dm_H1, &ctx_H1->X_loc));

examples/Hdiv-mixed/src/setup-dm.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,33 @@
33
#include "petscerror.h"
44

55
// ---------------------------------------------------------------------------
6-
// Setup DM
6+
// Create DM
77
// ---------------------------------------------------------------------------
8-
PetscErrorCode CreateDM(MPI_Comm comm, MatType mat_type, VecType vec_type, DM *dm) {
8+
PetscErrorCode CreateDM(MPI_Comm comm, Ceed ceed, DM *dm) {
99
PetscFunctionBeginUser;
1010

11+
CeedMemType mem_type_backend;
12+
CeedGetPreferredMemType(ceed, &mem_type_backend);
13+
14+
VecType vec_type = NULL;
15+
MatType mat_type = NULL;
16+
switch (mem_type_backend) {
17+
case CEED_MEM_HOST:
18+
vec_type = VECSTANDARD;
19+
break;
20+
case CEED_MEM_DEVICE: {
21+
const char *resolved;
22+
CeedGetResource(ceed, &resolved);
23+
if (strstr(resolved, "/gpu/cuda")) vec_type = VECCUDA;
24+
else if (strstr(resolved, "/gpu/hip/occa")) vec_type = VECSTANDARD; // https://github.com/CEED/libCEED/issues/678
25+
else if (strstr(resolved, "/gpu/hip")) vec_type = VECHIP;
26+
else vec_type = VECSTANDARD;
27+
}
28+
}
29+
if (strstr(vec_type, VECCUDA)) mat_type = MATAIJCUSPARSE;
30+
else if (strstr(vec_type, VECKOKKOS)) mat_type = MATAIJKOKKOS;
31+
else mat_type = MATAIJ;
32+
1133
// Create DMPLEX
1234
PetscCall(DMCreate(comm, dm));
1335
PetscCall(DMSetType(*dm, DMPLEX));

0 commit comments

Comments
 (0)