Skip to content

Commit bdb8f9d

Browse files
committed
memcheck - emulate device vector methods too
1 parent 03f9705 commit bdb8f9d

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

backends/memcheck/ceed-memcheck-vector.c

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <ceed.h>
99
#include <ceed/backend.h>
10+
#include <assert.h>
1011
#include <math.h>
1112
#include <stdbool.h>
1213
#include <string.h>
@@ -94,6 +95,38 @@ static int CeedVectorSetArray_Memcheck(CeedVector vec, CeedMemType mem_type, Cee
9495
return CEED_ERROR_SUCCESS;
9596
}
9697

98+
//------------------------------------------------------------------------------
99+
// Set internal array to value
100+
//------------------------------------------------------------------------------
101+
static int CeedVectorSetValue_Memcheck(CeedVector vec, CeedScalar value) {
102+
CeedSize length;
103+
CeedVector_Memcheck *impl;
104+
105+
CeedCallBackend(CeedVectorGetData(vec, &impl));
106+
CeedCallBackend(CeedVectorGetLength(vec, &length));
107+
108+
if (!impl->array_allocated) CeedCallBackend(CeedVectorSetArray_Memcheck(vec, CEED_MEM_HOST, CEED_COPY_VALUES, NULL));
109+
assert(impl->array_allocated);
110+
for (CeedSize i = 0; i < length; i++) impl->array_allocated[i] = value;
111+
return CEED_ERROR_SUCCESS;
112+
}
113+
114+
//------------------------------------------------------------------------------
115+
// Set internal array to value strided
116+
//------------------------------------------------------------------------------
117+
static int CeedVectorSetValueStrided_Memcheck(CeedVector vec, CeedSize start, CeedSize step, CeedScalar val) {
118+
CeedSize length;
119+
CeedVector_Memcheck *impl;
120+
121+
CeedCallBackend(CeedVectorGetData(vec, &impl));
122+
CeedCallBackend(CeedVectorGetLength(vec, &length));
123+
124+
if (!impl->array_allocated) CeedCallBackend(CeedVectorSetArray_Memcheck(vec, CEED_MEM_HOST, CEED_COPY_VALUES, NULL));
125+
assert(impl->array_allocated);
126+
for (CeedSize i = start; i < length; i += step) impl->array_allocated[i] = val;
127+
return CEED_ERROR_SUCCESS;
128+
}
129+
97130
//------------------------------------------------------------------------------
98131
// Sync arrays
99132
//------------------------------------------------------------------------------
@@ -267,6 +300,84 @@ static int CeedVectorRestoreArrayRead_Memcheck(CeedVector vec) {
267300
return CEED_ERROR_SUCCESS;
268301
}
269302

303+
//------------------------------------------------------------------------------
304+
// Take reciprocal of a vector
305+
//------------------------------------------------------------------------------
306+
static int CeedVectorReciprocal_Memcheck(CeedVector vec) {
307+
CeedSize length;
308+
CeedVector_Memcheck *impl;
309+
310+
CeedCallBackend(CeedVectorGetData(vec, &impl));
311+
CeedCallBackend(CeedVectorGetLength(vec, &length));
312+
313+
for (CeedSize i = 0; i < length; i++) {
314+
if (fabs(impl->array_allocated[i]) > CEED_EPSILON) impl->array_allocated[i] = 1. / impl->array_allocated[i];
315+
}
316+
return CEED_ERROR_SUCCESS;
317+
}
318+
319+
//------------------------------------------------------------------------------
320+
// Compute x = alpha x
321+
//------------------------------------------------------------------------------
322+
static int CeedVectorScale_Memcheck(CeedVector x, CeedScalar alpha) {
323+
CeedSize length;
324+
CeedVector_Memcheck *impl;
325+
326+
CeedCallBackend(CeedVectorGetData(x, &impl));
327+
CeedCallBackend(CeedVectorGetLength(x, &length));
328+
329+
for (CeedSize i = 0; i < length; i++) impl->array_allocated[i] *= alpha;
330+
return CEED_ERROR_SUCCESS;
331+
}
332+
333+
//------------------------------------------------------------------------------
334+
// Compute y = alpha x + y
335+
//------------------------------------------------------------------------------
336+
static int CeedVectorAXPY_Memcheck(CeedVector y, CeedScalar alpha, CeedVector x) {
337+
CeedSize length;
338+
CeedVector_Memcheck *impl_x, *impl_y;
339+
340+
CeedCallBackend(CeedVectorGetData(x, &impl_x));
341+
CeedCallBackend(CeedVectorGetData(y, &impl_y));
342+
CeedCallBackend(CeedVectorGetLength(y, &length));
343+
344+
for (CeedSize i = 0; i < length; i++) impl_y->array_allocated[i] += alpha * impl_x->array_allocated[i];
345+
return CEED_ERROR_SUCCESS;
346+
}
347+
348+
//------------------------------------------------------------------------------
349+
// Compute y = alpha x + beta y
350+
//------------------------------------------------------------------------------
351+
static int CeedVectorAXPBY_Memcheck(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
352+
CeedSize length;
353+
CeedVector_Memcheck *impl_x, *impl_y;
354+
355+
CeedCallBackend(CeedVectorGetData(x, &impl_x));
356+
CeedCallBackend(CeedVectorGetData(y, &impl_y));
357+
CeedCallBackend(CeedVectorGetLength(y, &length));
358+
359+
for (CeedSize i = 0; i < length; i++) impl_y->array_allocated[i] = alpha * impl_x->array_allocated[i] + beta * impl_y->array_allocated[i];
360+
return CEED_ERROR_SUCCESS;
361+
}
362+
363+
//------------------------------------------------------------------------------
364+
// Compute the pointwise multiplication w = x .* y
365+
//------------------------------------------------------------------------------
366+
static int CeedVectorPointwiseMult_Memcheck(CeedVector w, CeedVector x, CeedVector y) {
367+
CeedSize length;
368+
CeedVector_Memcheck *impl_x, *impl_y, *impl_w;
369+
370+
CeedCallBackend(CeedVectorGetData(x, &impl_x));
371+
CeedCallBackend(CeedVectorGetData(y, &impl_y));
372+
CeedCallBackend(CeedVectorGetData(w, &impl_w));
373+
CeedCallBackend(CeedVectorGetLength(w, &length));
374+
375+
if (!impl_w->array_allocated) CeedCallBackend(CeedVectorSetArray_Memcheck(w, CEED_MEM_HOST, CEED_COPY_VALUES, NULL));
376+
assert(impl_w->array_allocated);
377+
for (CeedSize i = 0; i < length; i++) impl_w->array_allocated[i] = impl_x->array_allocated[i] * impl_y->array_allocated[i];
378+
return CEED_ERROR_SUCCESS;
379+
}
380+
270381
//------------------------------------------------------------------------------
271382
// Vector Destroy
272383
//------------------------------------------------------------------------------
@@ -304,13 +415,20 @@ int CeedVectorCreate_Memcheck(CeedSize n, CeedVector vec) {
304415
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Memcheck));
305416
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Memcheck));
306417
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Memcheck));
418+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Memcheck));
419+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Memcheck));
307420
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Memcheck));
308421
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Memcheck));
309422
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Memcheck));
310423
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Memcheck));
311424
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Memcheck));
312425
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "RestoreArray", CeedVectorRestoreArray_Memcheck));
313426
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "RestoreArrayRead", CeedVectorRestoreArrayRead_Memcheck));
427+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Memcheck));
428+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Memcheck));
429+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Memcheck));
430+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Memcheck));
431+
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Memcheck));
314432
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Memcheck));
315433
return CEED_ERROR_SUCCESS;
316434
}

0 commit comments

Comments
 (0)