Skip to content

Commit cf7113f

Browse files
committed
gpu - gen ApplyAdd functions
1 parent cad1658 commit cf7113f

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

backends/cuda-gen/ceed-cuda-gen-operator.c

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,53 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,
293293
return CEED_ERROR_SUCCESS;
294294
}
295295

296+
static int CeedOperatorApplyAddComposite_Cuda_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
297+
bool is_run_good[CEED_COMPOSITE_MAX] = {false};
298+
CeedInt num_suboperators;
299+
const CeedScalar *input_arr = NULL;
300+
CeedScalar *output_arr = NULL;
301+
CeedOperator *sub_operators;
302+
303+
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
304+
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
305+
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
306+
if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
307+
for (CeedInt i = 0; i < num_suboperators; i++) {
308+
CeedCallBackend(CeedOperatorApplyAddCore_Cuda_gen(sub_operators[i], NULL, input_arr, output_arr, &is_run_good[i], request));
309+
}
310+
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
311+
if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
312+
313+
// Fallback on unsuccessful run
314+
for (CeedInt i = 0; i < num_suboperators; i++) {
315+
if (!is_run_good[i]) {
316+
CeedOperator op_fallback;
317+
318+
CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/cuda/ref CeedOperator");
319+
CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
320+
CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
321+
}
322+
}
323+
return CEED_ERROR_SUCCESS;
324+
}
325+
296326
//------------------------------------------------------------------------------
297327
// Create operator
298328
//------------------------------------------------------------------------------
299329
int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
330+
bool is_composite;
300331
Ceed ceed;
301332
CeedOperator_Cuda_gen *impl;
302333

303334
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
304335
CeedCallBackend(CeedCalloc(1, &impl));
305336
CeedCallBackend(CeedOperatorSetData(op, impl));
306-
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda_gen));
337+
CeedCall(CeedOperatorIsComposite(op, &is_composite));
338+
if (is_composite) {
339+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Cuda_gen));
340+
} else {
341+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda_gen));
342+
}
307343
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda_gen));
308344
CeedCallBackend(CeedDestroy(&ceed));
309345
return CEED_ERROR_SUCCESS;

backends/cuda-gen/ceed-cuda-gen.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ static int CeedInit_Cuda_gen(const char *resource, Ceed ceed) {
3939

4040
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "QFunctionCreate", CeedQFunctionCreate_Cuda_gen));
4141
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreate", CeedOperatorCreate_Cuda_gen));
42+
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "CompositeOperatorCreate", CeedOperatorCreate_Cuda_gen));
4243
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreateAtPoints", CeedOperatorCreate_Cuda_gen));
4344
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "Destroy", CeedDestroy_Cuda));
4445
return CEED_ERROR_SUCCESS;

backends/hip-gen/ceed-hip-gen-operator.c

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,53 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C
240240
return CEED_ERROR_SUCCESS;
241241
}
242242

243+
static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
244+
bool is_run_good[CEED_COMPOSITE_MAX] = {false};
245+
CeedInt num_suboperators;
246+
const CeedScalar *input_arr = NULL;
247+
CeedScalar *output_arr = NULL;
248+
CeedOperator *sub_operators;
249+
250+
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
251+
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
252+
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
253+
if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
254+
for (CeedInt i = 0; i < num_suboperators; i++) {
255+
CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], NULL, input_arr, output_arr, &is_run_good[i], request));
256+
}
257+
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
258+
if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
259+
260+
// Fallback on unsuccessful run
261+
for (CeedInt i = 0; i < num_suboperators; i++) {
262+
if (!is_run_good[i]) {
263+
CeedOperator op_fallback;
264+
265+
CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
266+
CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
267+
CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
268+
}
269+
}
270+
return CEED_ERROR_SUCCESS;
271+
}
272+
243273
//------------------------------------------------------------------------------
244274
// Create operator
245275
//------------------------------------------------------------------------------
246276
int CeedOperatorCreate_Hip_gen(CeedOperator op) {
277+
bool is_composite;
247278
Ceed ceed;
248279
CeedOperator_Hip_gen *impl;
249280

250281
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
251282
CeedCallBackend(CeedCalloc(1, &impl));
252283
CeedCallBackend(CeedOperatorSetData(op, impl));
253-
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
284+
CeedCall(CeedOperatorIsComposite(op, &is_composite));
285+
if (is_composite) {
286+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen));
287+
} else {
288+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
289+
}
254290
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
255291
CeedCallBackend(CeedDestroy(&ceed));
256292
return CEED_ERROR_SUCCESS;

backends/hip-gen/ceed-hip-gen.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ static int CeedInit_Hip_gen(const char *resource, Ceed ceed) {
3939

4040
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "QFunctionCreate", CeedQFunctionCreate_Hip_gen));
4141
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreate", CeedOperatorCreate_Hip_gen));
42+
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "CompositeOperatorCreate", CeedOperatorCreate_Hip_gen));
4243
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreateAtPoints", CeedOperatorCreate_Hip_gen));
4344
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "Destroy", CeedDestroy_Hip));
4445
return CEED_ERROR_SUCCESS;

0 commit comments

Comments
 (0)