Skip to content

Commit fde09a2

Browse files
authored
Merge pull request #1890 from CEED/zach/at-points-update-reference-coords
at-points - ensure reference point coordinates are always up to date
2 parents d80bc30 + 68c01f3 commit fde09a2

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -848,13 +848,19 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
848848
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec));
849849

850850
// Get point coordinates
851-
if (!impl->point_coords_elem) {
851+
{
852852
CeedVector point_coords = NULL;
853853
CeedElemRestriction rstr_points = NULL;
854854

855855
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
856-
CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
857-
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
856+
if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
857+
{
858+
uint64_t state;
859+
CeedCallBackend(CeedVectorGetState(point_coords, &state));
860+
if (impl->points_state != state) {
861+
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
862+
}
863+
}
858864
CeedCallBackend(CeedVectorDestroy(&point_coords));
859865
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
860866
}
@@ -1855,13 +1861,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18551861
}
18561862

18571863
// Get point coordinates
1858-
if (!impl->point_coords_elem) {
1864+
{
18591865
CeedVector point_coords = NULL;
18601866
CeedElemRestriction rstr_points = NULL;
18611867

18621868
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
1863-
CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
1864-
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
1869+
if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
1870+
{
1871+
uint64_t state;
1872+
CeedCallBackend(CeedVectorGetState(point_coords, &state));
1873+
if (impl->points_state != state) {
1874+
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
1875+
}
1876+
}
18651877
CeedCallBackend(CeedVectorDestroy(&point_coords));
18661878
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
18671879
}

backends/cuda-ref/ceed-cuda-ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ typedef struct {
133133

134134
typedef struct {
135135
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136-
uint64_t *input_states; // State tracking for passive inputs
136+
uint64_t *input_states, points_state; // State tracking for passive inputs
137137
CeedVector *e_vecs_in, *e_vecs_out;
138138
CeedVector *q_vecs_in, *q_vecs_out;
139139
CeedInt num_inputs, num_outputs;

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -846,13 +846,19 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
846846
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec));
847847

848848
// Get point coordinates
849-
if (!impl->point_coords_elem) {
849+
{
850850
CeedVector point_coords = NULL;
851851
CeedElemRestriction rstr_points = NULL;
852852

853853
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
854-
CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
855-
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
854+
if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
855+
{
856+
uint64_t state;
857+
CeedCallBackend(CeedVectorGetState(point_coords, &state));
858+
if (impl->points_state != state) {
859+
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
860+
}
861+
}
856862
CeedCallBackend(CeedVectorDestroy(&point_coords));
857863
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
858864
}
@@ -1852,13 +1858,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
18521858
}
18531859

18541860
// Get point coordinates
1855-
if (!impl->point_coords_elem) {
1861+
{
18561862
CeedVector point_coords = NULL;
18571863
CeedElemRestriction rstr_points = NULL;
18581864

18591865
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
1860-
CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
1861-
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
1866+
if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
1867+
{
1868+
uint64_t state;
1869+
CeedCallBackend(CeedVectorGetState(point_coords, &state));
1870+
if (impl->points_state != state) {
1871+
CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
1872+
}
1873+
}
18621874
CeedCallBackend(CeedVectorDestroy(&point_coords));
18631875
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
18641876
}

backends/hip-ref/ceed-hip-ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ typedef struct {
138138

139139
typedef struct {
140140
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
141-
uint64_t *input_states; // State tracking for passive inputs
141+
uint64_t *input_states, points_state; // State tracking for passive inputs
142142
CeedVector *e_vecs_in, *e_vecs_out;
143143
CeedVector *q_vecs_in, *q_vecs_out;
144144
CeedInt num_inputs, num_outputs;

0 commit comments

Comments
 (0)