@@ -105,7 +105,9 @@ static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
105105 CeedVector_Hip * impl ;
106106
107107 CeedCallBackend (CeedVectorGetData (vec , & impl ));
108- CeedCheck (impl -> h_array && !impl -> d_array , CeedVectorReturnCeed (vec ), CEED_ERROR_BACKEND , "Unified shared memory should only use host pointers" );
108+ CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
109+ CeedCheck (impl -> d_array && !impl -> h_array , CeedVectorReturnCeed (vec ), CEED_ERROR_BACKEND ,
110+ "Unified shared memory should only use device pointers" );
109111 return CEED_ERROR_SUCCESS ;
110112 }
111113
@@ -155,8 +157,8 @@ static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType
155157 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
156158 CeedCallBackend (CeedVectorGetData (vec , & impl ));
157159
158- // Use host memory for unified memory
159- mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
160+ // Use device memory for unified memory
161+ mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
160162
161163 switch (mem_type ) {
162164 case CEED_MEM_HOST :
@@ -179,8 +181,8 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
179181 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
180182 CeedCallBackend (CeedVectorGetData (vec , & impl ));
181183
182- // Use host memory for unified memory
183- mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
184+ // Use device memory for unified memory
185+ mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
184186
185187 switch (mem_type ) {
186188 case CEED_MEM_HOST :
@@ -239,8 +241,8 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
239241 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
240242 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
241243
242- // Use host memory for unified memory
243- local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
244+ // Use device memory for unified memory
245+ local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
244246
245247 switch (local_mem_type ) {
246248 case CEED_MEM_HOST :
@@ -267,7 +269,6 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
267269 CeedVector_Hip * impl ;
268270 Ceed_Hip * hip_data ;
269271 hipblasHandle_t handle ;
270- CeedScalar * d_array ;
271272
272273 CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (vec ), & handle ));
273274 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
@@ -280,11 +281,8 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
280281 length = length_vec < length_copy ? length_vec : length_copy ;
281282 }
282283
283- // Use host memory for unified memory
284- d_array = hip_data -> has_unified_addressing ? impl -> h_array : impl -> d_array ;
285-
286284 // Set value for synced device/host array
287- if (d_array ) {
285+ if (impl -> d_array ) {
288286 CeedScalar * copy_array ;
289287
290288 // Number of values to copy
@@ -293,12 +291,11 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
293291 CeedCallBackend (CeedVectorGetArray (vec_copy , CEED_MEM_DEVICE , & copy_array ));
294292#if defined(CEED_SCALAR_IS_FP32 )
295293 CeedCallHipblas (CeedVectorReturnCeed (vec ),
296- hipblasScopy_64 (handle , (int64_t )length , d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
294+ hipblasScopy_64 (handle , (int64_t )length , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
297295#else
298296 CeedCallHipblas (CeedVectorReturnCeed (vec ),
299- hipblasDcopy_64 (handle , (int64_t )length , d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
297+ hipblasDcopy_64 (handle , (int64_t )length , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
300298#endif
301- CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
302299 CeedCallBackend (CeedVectorRestoreArray (vec_copy , & copy_array ));
303300 } else if (impl -> h_array ) {
304301 CeedScalar * copy_array ;
@@ -331,7 +328,6 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
331328static int CeedVectorSetValue_Hip (CeedVector vec , CeedScalar val ) {
332329 CeedSize length ;
333330 CeedVector_Hip * impl ;
334- CeedScalar * d_array ;
335331 Ceed_Hip * hip_data ;
336332
337333 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
@@ -352,15 +348,11 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
352348 }
353349 }
354350
355- // Use host memory for unified memory
356- d_array = hip_data -> has_unified_addressing ? impl -> h_array : impl -> d_array ;
357-
358- if (d_array ) {
359- CeedCallBackend (CeedDeviceSetValue_Hip (d_array , length , val ));
360- CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
361- if (!hip_data -> has_unified_addressing ) impl -> h_array = NULL ;
351+ if (impl -> d_array ) {
352+ CeedCallBackend (CeedDeviceSetValue_Hip (impl -> d_array , length , val ));
353+ impl -> h_array = NULL ;
362354 }
363- if (impl -> h_array && d_array != impl -> h_array ) {
355+ if (impl -> h_array ) {
364356 CeedCallBackend (CeedHostSetValue_Hip (impl -> h_array , length , val ));
365357 impl -> d_array = NULL ;
366358 }
@@ -387,20 +379,15 @@ static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSiz
387379 CeedSize length ;
388380 CeedVector_Hip * impl ;
389381 Ceed_Hip * hip_data ;
390- CeedScalar * d_array ;
391382
392383 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
393384 CeedCallBackend (CeedVectorGetData (vec , & impl ));
394385 CeedCallBackend (CeedVectorGetLength (vec , & length ));
395386
396- // Use host memory for unified memory
397- d_array = hip_data -> has_unified_addressing ? impl -> h_array : impl -> d_array ;
398-
399387 // Set value for synced device/host array
400- if (d_array ) {
401- CeedCallBackend (CeedDeviceSetValueStrided_Hip (d_array , start , step , length , val ));
402- CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
403- if (!hip_data -> has_unified_addressing ) impl -> h_array = NULL ;
388+ if (impl -> d_array ) {
389+ CeedCallBackend (CeedDeviceSetValueStrided_Hip (impl -> d_array , start , step , length , val ));
390+ impl -> h_array = NULL ;
404391 } else if (impl -> h_array ) {
405392 CeedCallBackend (CeedHostSetValueStrided_Hip (impl -> h_array , start , step , length , val ));
406393 impl -> d_array = NULL ;
@@ -420,8 +407,8 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
420407 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
421408 CeedCallBackend (CeedVectorGetData (vec , & impl ));
422409
423- // Use host memory for unified memory
424- mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
410+ // Use device memory for unified memory
411+ mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
425412
426413 // Sync array to requested mem_type
427414 CeedCallBackend (CeedVectorSyncArray (vec , mem_type ));
@@ -453,8 +440,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedMemType mem_type
453440 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
454441 CeedCallBackend (CeedVectorGetData (vec , & impl ));
455442
456- // Use host memory for unified memory
457- mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
443+ // Use device memory for unified memory
444+ mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
458445
459446 // Sync array to requested mem_type
460447 CeedCallBackend (CeedVectorSyncArray (vec , mem_type ));
@@ -489,8 +476,8 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
489476 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
490477 CeedCallBackend (CeedVectorGetData (vec , & impl ));
491478
492- // Use host memory for unified memory
493- local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
479+ // Use device memory for unified memory
480+ local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
494481
495482 CeedCallBackend (CeedVectorGetArrayCore_Hip (vec , local_mem_type , array ));
496483 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
@@ -517,8 +504,8 @@ static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType m
517504 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
518505 CeedCallBackend (CeedVectorGetData (vec , & impl ));
519506
520- // Use host memory for unified memory
521- local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : mem_type ;
507+ // Use device memory for unified memory
508+ local_mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
522509
523510 CeedCallBackend (CeedVectorHasArrayOfType_Hip (vec , local_mem_type , & has_array_of_type ));
524511 if (!has_array_of_type ) {
@@ -557,8 +544,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557544 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
558545
559546 // Compute norm
560- CeedMemType mem_type = hip_data -> has_unified_addressing ? CEED_MEM_HOST : CEED_MEM_DEVICE ;
561- CeedCallBackend (CeedVectorGetArrayRead (vec , mem_type , & d_array ));
547+ CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & d_array ));
562548
563549 * norm = 0.0 ;
564550 switch (type ) {
@@ -625,20 +611,14 @@ static int CeedVectorReciprocal_Hip(CeedVector vec) {
625611 CeedSize length ;
626612 CeedVector_Hip * impl ;
627613 Ceed_Hip * hip_data ;
628- CeedScalar * d_array ;
629614
630615 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
631616 CeedCallBackend (CeedVectorGetData (vec , & impl ));
632617 CeedCallBackend (CeedVectorGetLength (vec , & length ));
633618
634- d_array = hip_data -> has_unified_addressing ? impl -> h_array : impl -> d_array ;
635-
636619 // Set value for synced device/host array
637- if (d_array ) {
638- CeedCallBackend (CeedDeviceReciprocal_Hip (d_array , length ));
639- CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
640- }
641- if (impl -> h_array && d_array != impl -> h_array ) CeedCallBackend (CeedHostReciprocal_Hip (impl -> h_array , length ));
620+ if (impl -> d_array ) CeedCallBackend (CeedDeviceReciprocal_Hip (impl -> d_array , length ));
621+ if (impl -> h_array ) CeedCallBackend (CeedHostReciprocal_Hip (impl -> h_array , length ));
642622 return CEED_ERROR_SUCCESS ;
643623}
644624
@@ -658,25 +638,21 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
658638 CeedVector_Hip * impl ;
659639 Ceed_Hip * hip_data ;
660640 hipblasHandle_t handle ;
661- CeedScalar * d_array ;
662641
663642 CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (x ), & handle ));
664643 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (x ), & hip_data ));
665644 CeedCallBackend (CeedVectorGetData (x , & impl ));
666645 CeedCallBackend (CeedVectorGetLength (x , & length ));
667646
668- d_array = hip_data -> has_unified_addressing ? impl -> h_array : impl -> d_array ;
669-
670647 // Set value for synced device/host array
671- if (d_array ) {
648+ if (impl -> d_array ) {
672649#if defined(CEED_SCALAR_IS_FP32 )
673- CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasSscal_64 (handle , (int64_t )length , & alpha , d_array , 1 ));
650+ CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasSscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
674651#else
675- CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasDscal_64 (handle , (int64_t )length , & alpha , d_array , 1 ));
652+ CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasDscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
676653#endif
677- CeedCallHip (CeedVectorReturnCeed (x ), hipDeviceSynchronize ());
678654 }
679- if (impl -> h_array && d_array != impl -> h_array ) CeedCallBackend (CeedHostScale_Hip (impl -> h_array , alpha , length ));
655+ if (impl -> h_array ) CeedCallBackend (CeedHostScale_Hip (impl -> h_array , alpha , length ));
680656 return CEED_ERROR_SUCCESS ;
681657}
682658
@@ -696,7 +672,6 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
696672 CeedVector_Hip * y_impl , * x_impl ;
697673 Ceed_Hip * hip_data ;
698674 hipblasHandle_t handle ;
699- CeedScalar * x_d_array , * y_d_array ;
700675
701676 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (y ), & hip_data ));
702677 CeedCallBackend (CeedVectorGetData (y , & y_impl ));
@@ -705,20 +680,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
705680
706681 CeedCallBackend (CeedVectorGetLength (y , & length ));
707682
708- y_d_array = hip_data -> has_unified_addressing ? y_impl -> h_array : y_impl -> d_array ;
709-
710683 // Set value for synced device/host array
711- if (y_d_array ) {
684+ if (y_impl -> d_array ) {
712685 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_DEVICE ));
713- x_d_array = hip_data -> has_unified_addressing ? x_impl -> h_array : x_impl -> d_array ;
714686#if defined(CEED_SCALAR_IS_FP32 )
715- CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasSaxpy_64 (handle , (int64_t )length , & alpha , x_d_array , 1 , y_d_array , 1 ));
687+ CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasSaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
716688#else
717- CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasDaxpy_64 (handle , (int64_t )length , & alpha , x_d_array , 1 , y_d_array , 1 ));
689+ CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasDaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
718690#endif
719- CeedCallHip (CeedVectorReturnCeed (y ), hipDeviceSynchronize ());
720691 }
721- if (y_impl -> h_array && y_d_array != y_impl -> h_array ) {
692+ if (y_impl -> h_array ) {
722693 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_HOST ));
723694 CeedCallBackend (CeedHostAXPY_Hip (y_impl -> h_array , alpha , x_impl -> h_array , length ));
724695 }
@@ -746,7 +717,6 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,
746717 CeedVector_Hip * y_impl , * x_impl ;
747718 Ceed_Hip * hip_data ;
748719 hipblasHandle_t handle ;
749- CeedScalar * x_d_array , * y_d_array ;
750720
751721 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (y ), & hip_data ));
752722 CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (y ), & handle ));
@@ -755,16 +725,12 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,
755725
756726 CeedCallBackend (CeedVectorGetLength (y , & length ));
757727
758- y_d_array = hip_data -> has_unified_addressing ? y_impl -> h_array : y_impl -> d_array ;
759-
760728 // Set value for synced device/host array
761- if (y_d_array ) {
729+ if (y_impl -> d_array ) {
762730 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_DEVICE ));
763- x_d_array = hip_data -> has_unified_addressing ? x_impl -> h_array : x_impl -> d_array ;
764- CeedCallBackend (CeedDeviceAXPBY_Hip (y_d_array , alpha , beta , x_d_array , length ));
765- CeedCallHip (CeedVectorReturnCeed (y ), hipDeviceSynchronize ());
731+ CeedCallBackend (CeedDeviceAXPBY_Hip (y_impl -> d_array , alpha , beta , x_impl -> d_array , length ));
766732 }
767- if (y_impl -> h_array && y_d_array != y_impl -> h_array ) {
733+ if (y_impl -> h_array ) {
768734 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_HOST ));
769735 CeedCallBackend (CeedHostAXPBY_Hip (y_impl -> h_array , alpha , beta , x_impl -> h_array , length ));
770736 }
@@ -790,7 +756,6 @@ int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedSc
790756static int CeedVectorPointwiseMult_Hip (CeedVector w , CeedVector x , CeedVector y ) {
791757 CeedSize length ;
792758 CeedVector_Hip * w_impl , * x_impl , * y_impl ;
793- CeedScalar * w_d_array , * x_d_array , * y_d_array ;
794759 Ceed_Hip * hip_data ;
795760
796761 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (x ), & hip_data ));
@@ -804,17 +769,12 @@ static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y)
804769 CeedCallBackend (CeedVectorSetValue (w , 0.0 ));
805770 }
806771
807- w_d_array = hip_data -> has_unified_addressing ? w_impl -> h_array : w_impl -> d_array ;
808-
809- if (w_d_array ) {
772+ if (w_impl -> d_array ) {
810773 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_DEVICE ));
811774 CeedCallBackend (CeedVectorSyncArray (y , CEED_MEM_DEVICE ));
812- x_d_array = hip_data -> has_unified_addressing ? x_impl -> h_array : x_impl -> d_array ;
813- y_d_array = hip_data -> has_unified_addressing ? y_impl -> h_array : y_impl -> d_array ;
814- CeedCallBackend (CeedDevicePointwiseMult_Hip (w_d_array , x_d_array , y_d_array , length ));
815- CeedCallHip (CeedVectorReturnCeed (y ), hipDeviceSynchronize ());
775+ CeedCallBackend (CeedDevicePointwiseMult_Hip (w_impl -> d_array , x_impl -> d_array , y_impl -> d_array , length ));
816776 }
817- if (w_impl -> h_array && w_d_array != w_impl -> h_array ) {
777+ if (w_impl -> h_array ) {
818778 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_HOST ));
819779 CeedCallBackend (CeedVectorSyncArray (y , CEED_MEM_HOST ));
820780 CeedCallBackend (CeedHostPointwiseMult_Hip (w_impl -> h_array , x_impl -> h_array , y_impl -> h_array , length ));
@@ -832,9 +792,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
832792 CeedCallBackend (CeedVectorGetData (vec , & impl ));
833793 CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
834794
835- if (!hip_data -> has_unified_addressing ) {
836- CeedCallHip (CeedVectorReturnCeed (vec ), hipFree (impl -> d_array_owned ));
837- }
795+ CeedCallHip (CeedVectorReturnCeed (vec ), hipFree (impl -> d_array_owned ));
838796 CeedCallBackend (CeedFree (& impl -> h_array_owned ));
839797 CeedCallBackend (CeedFree (& impl ));
840798 return CEED_ERROR_SUCCESS ;
0 commit comments