Skip to content

Commit e704478

Browse files
authored
Merge pull request #1678 from CEED/jeremy/vec-fix
vec - fix poinwisemult length check
2 parents bdd4742 + 54404f0 commit e704478

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
17211721
// Work vector
17221722
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in));
17231723
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out));
1724+
{
1725+
CeedSize length_in, length_out;
1726+
1727+
CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in));
1728+
CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out));
1729+
// Need input e_vec to be longer
1730+
if (length_in < length_out) {
1731+
CeedVector temp = active_e_vec_in;
1732+
1733+
active_e_vec_in = active_e_vec_out;
1734+
active_e_vec_out = temp;
1735+
}
1736+
}
17241737

17251738
// Get point coordinates
17261739
if (!impl->point_coords_elem) {
@@ -1804,7 +1817,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18041817
const CeedScalar *e_vec_array;
18051818

18061819
CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
1807-
CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
1820+
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
18081821
break;
18091822
}
18101823
case CEED_EVAL_INTERP:

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
17181718
// Work vector
17191719
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in));
17201720
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out));
1721+
{
1722+
CeedSize length_in, length_out;
1723+
1724+
CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in));
1725+
CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out));
1726+
// Need input e_vec to be longer
1727+
if (length_in < length_out) {
1728+
CeedVector temp = active_e_vec_in;
1729+
1730+
active_e_vec_in = active_e_vec_out;
1731+
active_e_vec_out = temp;
1732+
}
1733+
}
17211734

17221735
// Get point coordinates
17231736
if (!impl->point_coords_elem) {
@@ -1801,7 +1814,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
18011814
const CeedScalar *e_vec_array;
18021815

18031816
CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
1804-
CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
1817+
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
18051818
break;
18061819
}
18071820
case CEED_EVAL_INTERP:

interface/ceed-vector.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,10 +862,10 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
862862
CeedCall(CeedVectorGetLength(w, &length_w));
863863
CeedCall(CeedVectorGetLength(x, &length_x));
864864
CeedCall(CeedVectorGetLength(y, &length_y));
865-
CeedCheck(length_x >= length_x && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
866-
"Cannot multiply vectors of different lengths."
867-
" x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
868-
length_x, length_y);
865+
CeedCheck(length_x >= length_w && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
866+
"Cannot pointwise multiply vectors of incompatible lengths."
867+
" w length: %" CeedSize_FMT " x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
868+
length_w, length_x, length_y);
869869

870870
CeedCall(CeedGetParent(w->ceed, &ceed_parent_w));
871871
CeedCall(CeedGetParent(x->ceed, &ceed_parent_x));

0 commit comments

Comments
 (0)