Skip to content

Commit d0fa0f5

Browse files
committed
Change MI300A to use hipMalloc per LC tips
1 parent ff4acf1 commit d0fa0f5

File tree

2 files changed

+45
-90
lines changed

2 files changed

+45
-90
lines changed

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 45 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
105105
CeedVector_Hip *impl;
106106

107107
CeedCallBackend(CeedVectorGetData(vec, &impl));
108-
CeedCheck(impl->h_array && !impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "Unified shared memory should only use host pointers");
108+
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
109+
CeedCheck(impl->d_array && !impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND,
110+
"Unified shared memory should only use device pointers");
109111
return CEED_ERROR_SUCCESS;
110112
}
111113

@@ -155,8 +157,8 @@ static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType
155157
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
156158
CeedCallBackend(CeedVectorGetData(vec, &impl));
157159

158-
// Use host memory for unified memory
159-
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
160+
// Use device memory for unified memory
161+
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
160162

161163
switch (mem_type) {
162164
case CEED_MEM_HOST:
@@ -179,8 +181,8 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
179181
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
180182
CeedCallBackend(CeedVectorGetData(vec, &impl));
181183

182-
// Use host memory for unified memory
183-
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
184+
// Use device memory for unified memory
185+
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
184186

185187
switch (mem_type) {
186188
case CEED_MEM_HOST:
@@ -239,8 +241,8 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
239241
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
240242
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
241243

242-
// Use host memory for unified memory
243-
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
244+
// Use device memory for unified memory
245+
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
244246

245247
switch (local_mem_type) {
246248
case CEED_MEM_HOST:
@@ -267,7 +269,6 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
267269
CeedVector_Hip *impl;
268270
Ceed_Hip *hip_data;
269271
hipblasHandle_t handle;
270-
CeedScalar *d_array;
271272

272273
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(vec), &handle));
273274
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
@@ -280,11 +281,8 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
280281
length = length_vec < length_copy ? length_vec : length_copy;
281282
}
282283

283-
// Use host memory for unified memory
284-
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;
285-
286284
// Set value for synced device/host array
287-
if (d_array) {
285+
if (impl->d_array) {
288286
CeedScalar *copy_array;
289287

290288
// Number of values to copy
@@ -293,12 +291,11 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
293291
CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
294292
#if defined(CEED_SCALAR_IS_FP32)
295293
CeedCallHipblas(CeedVectorReturnCeed(vec),
296-
hipblasScopy_64(handle, (int64_t)length, d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
294+
hipblasScopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
297295
#else
298296
CeedCallHipblas(CeedVectorReturnCeed(vec),
299-
hipblasDcopy_64(handle, (int64_t)length, d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
297+
hipblasDcopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
300298
#endif
301-
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
302299
CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
303300
} else if (impl->h_array) {
304301
CeedScalar *copy_array;
@@ -331,7 +328,6 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
331328
static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
332329
CeedSize length;
333330
CeedVector_Hip *impl;
334-
CeedScalar *d_array;
335331
Ceed_Hip *hip_data;
336332

337333
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
@@ -352,15 +348,11 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
352348
}
353349
}
354350

355-
// Use host memory for unified memory
356-
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;
357-
358-
if (d_array) {
359-
CeedCallBackend(CeedDeviceSetValue_Hip(d_array, length, val));
360-
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
361-
if (!hip_data->has_unified_addressing) impl->h_array = NULL;
351+
if (impl->d_array) {
352+
CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
353+
impl->h_array = NULL;
362354
}
363-
if (impl->h_array && d_array != impl->h_array) {
355+
if (impl->h_array) {
364356
CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
365357
impl->d_array = NULL;
366358
}
@@ -387,20 +379,15 @@ static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSiz
387379
CeedSize length;
388380
CeedVector_Hip *impl;
389381
Ceed_Hip *hip_data;
390-
CeedScalar *d_array;
391382

392383
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
393384
CeedCallBackend(CeedVectorGetData(vec, &impl));
394385
CeedCallBackend(CeedVectorGetLength(vec, &length));
395386

396-
// Use host memory for unified memory
397-
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;
398-
399387
// Set value for synced device/host array
400-
if (d_array) {
401-
CeedCallBackend(CeedDeviceSetValueStrided_Hip(d_array, start, step, length, val));
402-
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
403-
if (!hip_data->has_unified_addressing) impl->h_array = NULL;
388+
if (impl->d_array) {
389+
CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, step, length, val));
390+
impl->h_array = NULL;
404391
} else if (impl->h_array) {
405392
CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, step, length, val));
406393
impl->d_array = NULL;
@@ -420,8 +407,8 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
420407
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
421408
CeedCallBackend(CeedVectorGetData(vec, &impl));
422409

423-
// Use host memory for unified memory
424-
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
410+
// Use device memory for unified memory
411+
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
425412

426413
// Sync array to requested mem_type
427414
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
@@ -453,8 +440,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedMemType mem_type
453440
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
454441
CeedCallBackend(CeedVectorGetData(vec, &impl));
455442

456-
// Use host memory for unified memory
457-
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
443+
// Use device memory for unified memory
444+
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
458445

459446
// Sync array to requested mem_type
460447
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
@@ -489,8 +476,8 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
489476
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
490477
CeedCallBackend(CeedVectorGetData(vec, &impl));
491478

492-
// Use host memory for unified memory
493-
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
479+
// Use device memory for unified memory
480+
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
494481

495482
CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, local_mem_type, array));
496483
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
@@ -517,8 +504,8 @@ static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType m
517504
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
518505
CeedCallBackend(CeedVectorGetData(vec, &impl));
519506

520-
// Use host memory for unified memory
521-
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
507+
// Use device memory for unified memory
508+
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
522509

523510
CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, local_mem_type, &has_array_of_type));
524511
if (!has_array_of_type) {
@@ -557,8 +544,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557544
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
558545

559546
// Compute norm
560-
CeedMemType mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : CEED_MEM_DEVICE;
561-
CeedCallBackend(CeedVectorGetArrayRead(vec, mem_type, &d_array));
547+
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
562548

563549
*norm = 0.0;
564550
switch (type) {
@@ -625,20 +611,14 @@ static int CeedVectorReciprocal_Hip(CeedVector vec) {
625611
CeedSize length;
626612
CeedVector_Hip *impl;
627613
Ceed_Hip *hip_data;
628-
CeedScalar *d_array;
629614

630615
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
631616
CeedCallBackend(CeedVectorGetData(vec, &impl));
632617
CeedCallBackend(CeedVectorGetLength(vec, &length));
633618

634-
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;
635-
636619
// Set value for synced device/host array
637-
if (d_array) {
638-
CeedCallBackend(CeedDeviceReciprocal_Hip(d_array, length));
639-
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
640-
}
641-
if (impl->h_array && d_array != impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
620+
if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
621+
if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
642622
return CEED_ERROR_SUCCESS;
643623
}
644624

@@ -658,25 +638,21 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
658638
CeedVector_Hip *impl;
659639
Ceed_Hip *hip_data;
660640
hipblasHandle_t handle;
661-
CeedScalar *d_array;
662641

663642
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
664643
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(x), &hip_data));
665644
CeedCallBackend(CeedVectorGetData(x, &impl));
666645
CeedCallBackend(CeedVectorGetLength(x, &length));
667646

668-
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;
669-
670647
// Set value for synced device/host array
671-
if (d_array) {
648+
if (impl->d_array) {
672649
#if defined(CEED_SCALAR_IS_FP32)
673-
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, d_array, 1));
650+
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
674651
#else
675-
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, d_array, 1));
652+
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
676653
#endif
677-
CeedCallHip(CeedVectorReturnCeed(x), hipDeviceSynchronize());
678654
}
679-
if (impl->h_array && d_array != impl->h_array) CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
655+
if (impl->h_array) CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
680656
return CEED_ERROR_SUCCESS;
681657
}
682658

@@ -696,7 +672,6 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
696672
CeedVector_Hip *y_impl, *x_impl;
697673
Ceed_Hip *hip_data;
698674
hipblasHandle_t handle;
699-
CeedScalar *x_d_array, *y_d_array;
700675

701676
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(y), &hip_data));
702677
CeedCallBackend(CeedVectorGetData(y, &y_impl));
@@ -705,20 +680,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
705680

706681
CeedCallBackend(CeedVectorGetLength(y, &length));
707682

708-
y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;
709-
710683
// Set value for synced device/host array
711-
if (y_d_array) {
684+
if (y_impl->d_array) {
712685
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
713-
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
714686
#if defined(CEED_SCALAR_IS_FP32)
715-
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_d_array, 1, y_d_array, 1));
687+
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
716688
#else
717-
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_d_array, 1, y_d_array, 1));
689+
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
718690
#endif
719-
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
720691
}
721-
if (y_impl->h_array && y_d_array != y_impl->h_array) {
692+
if (y_impl->h_array) {
722693
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
723694
CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
724695
}
@@ -746,7 +717,6 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,
746717
CeedVector_Hip *y_impl, *x_impl;
747718
Ceed_Hip *hip_data;
748719
hipblasHandle_t handle;
749-
CeedScalar *x_d_array, *y_d_array;
750720

751721
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(y), &hip_data));
752722
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
@@ -755,16 +725,12 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,
755725

756726
CeedCallBackend(CeedVectorGetLength(y, &length));
757727

758-
y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;
759-
760728
// Set value for synced device/host array
761-
if (y_d_array) {
729+
if (y_impl->d_array) {
762730
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
763-
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
764-
CeedCallBackend(CeedDeviceAXPBY_Hip(y_d_array, alpha, beta, x_d_array, length));
765-
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
731+
CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
766732
}
767-
if (y_impl->h_array && y_d_array != y_impl->h_array) {
733+
if (y_impl->h_array) {
768734
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
769735
CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
770736
}
@@ -790,7 +756,6 @@ int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedSc
790756
static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
791757
CeedSize length;
792758
CeedVector_Hip *w_impl, *x_impl, *y_impl;
793-
CeedScalar *w_d_array, *x_d_array, *y_d_array;
794759
Ceed_Hip *hip_data;
795760

796761
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(x), &hip_data));
@@ -804,17 +769,12 @@ static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y)
804769
CeedCallBackend(CeedVectorSetValue(w, 0.0));
805770
}
806771

807-
w_d_array = hip_data->has_unified_addressing ? w_impl->h_array : w_impl->d_array;
808-
809-
if (w_d_array) {
772+
if (w_impl->d_array) {
810773
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
811774
CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
812-
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
813-
y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;
814-
CeedCallBackend(CeedDevicePointwiseMult_Hip(w_d_array, x_d_array, y_d_array, length));
815-
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
775+
CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
816776
}
817-
if (w_impl->h_array && w_d_array != w_impl->h_array) {
777+
if (w_impl->h_array) {
818778
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
819779
CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
820780
CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
@@ -832,9 +792,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
832792
CeedCallBackend(CeedVectorGetData(vec, &impl));
833793
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
834794

835-
if (!hip_data->has_unified_addressing) {
836-
CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
837-
}
795+
CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
838796
CeedCallBackend(CeedFree(&impl->h_array_owned));
839797
CeedCallBackend(CeedFree(&impl));
840798
return CEED_ERROR_SUCCESS;

backends/hip/ceed-hip-compile.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ int CeedRunKernel_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, cons
171171

172172
CeedCallBackend(CeedGetData(ceed, &data));
173173
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size, 1, 1, 0, NULL, args, NULL));
174-
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
175174
return CEED_ERROR_SUCCESS;
176175
}
177176

@@ -184,7 +183,6 @@ int CeedRunKernelDim_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, c
184183

185184
CeedCallBackend(CeedGetData(ceed, &data));
186185
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, 0, NULL, args, NULL));
187-
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
188186
return CEED_ERROR_SUCCESS;
189187
}
190188

@@ -197,7 +195,6 @@ int CeedRunKernelDimShared_Hip(Ceed ceed, hipFunction_t kernel, const int grid_s
197195

198196
CeedCallBackend(CeedGetData(ceed, &data));
199197
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, shared_mem_size, NULL, args, NULL));
200-
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
201198
return CEED_ERROR_SUCCESS;
202199
}
203200

0 commit comments

Comments
 (0)