Skip to content

Commit 79df7ad

Browse files
committed
hip - skip duplicate output rstr
1 parent d8d280d commit 79df7ad

File tree

2 files changed

+51
-14
lines changed

2 files changed

+51
-14
lines changed

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
2727

2828
// Apply data
2929
CeedCallBackend(CeedFree(&impl->skip_rstr_in));
30+
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
31+
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
3032
for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
3133
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i]));
3234
}
@@ -96,8 +98,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
9698
//------------------------------------------------------------------------------
9799
// Setup infields or outfields
98100
//------------------------------------------------------------------------------
99-
static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs,
100-
CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
101+
static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis,
102+
CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
101103
Ceed ceed;
102104
CeedQFunctionField *qf_fields;
103105
CeedOperatorField *op_fields;
@@ -183,7 +185,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
183185
break;
184186
}
185187
}
186-
// Drop duplicate input restrictions
188+
// Drop duplicate restrictions
187189
if (is_input) {
188190
for (CeedInt i = 0; i < num_fields; i++) {
189191
CeedVector vec_i;
@@ -198,11 +200,31 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
198200
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
199201
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
200202
if (vec_i == vec_j && rstr_i == rstr_j) {
201-
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
203+
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
202204
skip_rstr[j] = true;
203205
}
204206
}
205207
}
208+
} else {
209+
for (CeedInt i = num_fields - 1; i >= 0; i--) {
210+
CeedVector vec_i;
211+
CeedElemRestriction rstr_i;
212+
213+
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
214+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
215+
for (CeedInt j = i - 1; j >= 0; j--) {
216+
CeedVector vec_j;
217+
CeedElemRestriction rstr_j;
218+
219+
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
220+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
221+
if (vec_i == vec_j && rstr_i == rstr_j) {
222+
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
223+
skip_rstr[j] = true;
224+
apply_add_basis[i] = true;
225+
}
226+
}
227+
}
206228
}
207229
return CEED_ERROR_SUCCESS;
208230
}
@@ -233,6 +255,8 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
233255
// Allocate
234256
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
235257
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
258+
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
259+
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
236260
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
237261
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
238262
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
@@ -242,10 +266,10 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
242266
// Set up infield and outfield e_vecs and q_vecs
243267
// Infields
244268
CeedCallBackend(
245-
CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
269+
CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
246270
// Outfields
247-
CeedCallBackend(
248-
CeedOperatorSetupFields_Hip(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem));
271+
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
272+
num_input_fields, num_output_fields, Q, num_elem));
249273

250274
CeedCallBackend(CeedOperatorSetSetupDone(op));
251275
return CEED_ERROR_SUCCESS;
@@ -430,7 +454,11 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
430454
case CEED_EVAL_DIV:
431455
case CEED_EVAL_CURL:
432456
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
433-
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
457+
if (impl->apply_add_basis_out[i]) {
458+
CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
459+
} else {
460+
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
461+
}
434462
break;
435463
// LCOV_EXCL_START
436464
case CEED_EVAL_WEIGHT: {
@@ -451,6 +479,7 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
451479
if (eval_mode == CEED_EVAL_NONE) {
452480
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
453481
}
482+
if (impl->skip_rstr_out[i]) continue;
454483
// Get output vector
455484
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
456485
// Restrict
@@ -498,6 +527,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
498527
// Allocate
499528
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
500529
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
530+
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
531+
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
501532
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
502533
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
503534
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
@@ -506,11 +537,11 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
506537

507538
// Set up infield and outfield e_vecs and q_vecs
508539
// Infields
509-
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
540+
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
510541
max_num_points, num_elem));
511542
// Outfields
512-
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields,
513-
max_num_points, num_elem));
543+
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
544+
num_input_fields, num_output_fields, max_num_points, num_elem));
514545

515546
CeedCallBackend(CeedOperatorSetSetupDone(op));
516547
return CEED_ERROR_SUCCESS;
@@ -634,8 +665,13 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
634665
case CEED_EVAL_DIV:
635666
case CEED_EVAL_CURL:
636667
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
637-
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
638-
impl->e_vecs[i + impl->num_inputs]));
668+
if (impl->apply_add_basis_out[i]) {
669+
CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem,
670+
impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
671+
} else {
672+
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
673+
impl->e_vecs[i + impl->num_inputs]));
674+
}
639675
break;
640676
// LCOV_EXCL_START
641677
case CEED_EVAL_WEIGHT: {
@@ -656,6 +692,7 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
656692
if (eval_mode == CEED_EVAL_NONE) {
657693
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
658694
}
695+
if (impl->skip_rstr_out[i]) continue;
659696
// Get output vector
660697
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
661698
// Restrict

backends/hip-ref/ceed-hip-ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ typedef struct {
132132
} CeedOperatorAssemble_Hip;
133133

134134
typedef struct {
135-
bool *skip_rstr_in;
135+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136136
uint64_t *input_states; // State tracking for passive inputs
137137
CeedVector *e_vecs; // E-vectors, inputs followed by outputs
138138
CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator

0 commit comments

Comments
 (0)