@@ -223,20 +223,20 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
223223//------------------------------------------------------------------------------
224224// Copy host array to value strided
225225//------------------------------------------------------------------------------
226- static int CeedHostCopyStrided_Hip (CeedScalar * h_array , CeedSize start , CeedSize step , CeedSize length , CeedScalar * h_copy_array ) {
227- for (CeedSize i = start ; i < length ; i += step ) h_copy_array [i ] = h_array [i ];
226+ static int CeedHostCopyStrided_Hip (CeedScalar * h_array , CeedSize start , CeedSize stop , CeedSize step , CeedScalar * h_copy_array ) {
227+ for (CeedSize i = start ; i < stop ; i += step ) h_copy_array [i ] = h_array [i ];
228228 return CEED_ERROR_SUCCESS ;
229229}
230230
231231//------------------------------------------------------------------------------
232232// Copy device array to value strided (impl in .hip.cpp file)
233233//------------------------------------------------------------------------------
234- int CeedDeviceCopyStrided_Hip (CeedScalar * d_array , CeedSize start , CeedSize step , CeedSize length , CeedScalar * d_copy_array );
234+ int CeedDeviceCopyStrided_Hip (CeedScalar * d_array , CeedSize start , CeedSize stop , CeedSize step , CeedScalar * d_copy_array );
235235
236236//------------------------------------------------------------------------------
237237// Copy a vector to a value strided
238238//------------------------------------------------------------------------------
239- static int CeedVectorCopyStrided_Hip (CeedVector vec , CeedSize start , CeedSize step , CeedVector vec_copy ) {
239+ static int CeedVectorCopyStrided_Hip (CeedVector vec , CeedSize start , CeedSize stop , CeedSize step , CeedVector vec_copy ) {
240240 CeedSize length ;
241241 CeedVector_Hip * impl ;
242242
@@ -248,6 +248,7 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
248248 CeedCallBackend (CeedVectorGetLength (vec_copy , & length_copy ));
249249 length = length_vec < length_copy ? length_vec : length_copy ;
250250 }
251+ if (stop == -1 ) stop = length ;
251252 // Set value for synced device/host array
252253 if (impl -> d_array ) {
253254 CeedScalar * copy_array ;
@@ -260,12 +261,12 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
260261 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
261262 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
262263#if defined(CEED_SCALAR_IS_FP32 )
263- CeedCallHipblas (ceed , hipblasScopy_64 (handle , (int64_t )length , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
264+ CeedCallHipblas (ceed , hipblasScopy_64 (handle , (int64_t )( stop - start ) , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
264265#else /* CEED_SCALAR */
265- CeedCallHipblas (ceed , hipblasDcopy_64 (handle , (int64_t )length , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
266+ CeedCallHipblas (ceed , hipblasDcopy_64 (handle , (int64_t )( stop - start ) , impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
266267#endif /* CEED_SCALAR */
267268#else /* HIP_VERSION */
268- CeedCallBackend (CeedDeviceCopyStrided_Hip (impl -> d_array , start , step , length , copy_array ));
269+ CeedCallBackend (CeedDeviceCopyStrided_Hip (impl -> d_array , start , stop , step , copy_array ));
269270#endif /* HIP_VERSION */
270271 CeedCallBackend (CeedVectorRestoreArray (vec_copy , & copy_array ));
271272 impl -> h_array = NULL ;
@@ -274,7 +275,7 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
274275 CeedScalar * copy_array ;
275276
276277 CeedCallBackend (CeedVectorGetArray (vec_copy , CEED_MEM_HOST , & copy_array ));
277- CeedCallBackend (CeedHostCopyStrided_Hip (impl -> h_array , start , step , length , copy_array ));
278+ CeedCallBackend (CeedHostCopyStrided_Hip (impl -> h_array , start , stop , step , copy_array ));
278279 CeedCallBackend (CeedVectorRestoreArray (vec_copy , & copy_array ));
279280 impl -> d_array = NULL ;
280281 } else {
@@ -336,31 +337,32 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
336337//------------------------------------------------------------------------------
337338// Set host array to value strided
338339//------------------------------------------------------------------------------
339- static int CeedHostSetValueStrided_Hip (CeedScalar * h_array , CeedSize start , CeedSize step , CeedSize length , CeedScalar val ) {
340- for (CeedSize i = start ; i < length ; i += step ) h_array [i ] = val ;
340+ static int CeedHostSetValueStrided_Hip (CeedScalar * h_array , CeedSize start , CeedSize stop , CeedSize step , CeedScalar val ) {
341+ for (CeedSize i = start ; i < stop ; i += step ) h_array [i ] = val ;
341342 return CEED_ERROR_SUCCESS ;
342343}
343344
344345//------------------------------------------------------------------------------
345346// Set device array to value strided (impl in .hip.cpp file)
346347//------------------------------------------------------------------------------
347- int CeedDeviceSetValueStrided_Hip (CeedScalar * d_array , CeedSize start , CeedSize step , CeedSize length , CeedScalar val );
348+ int CeedDeviceSetValueStrided_Hip (CeedScalar * d_array , CeedSize start , CeedSize stop , CeedSize step , CeedScalar val );
348349
349350//------------------------------------------------------------------------------
350351// Set a vector to a value strided
351352//------------------------------------------------------------------------------
352- static int CeedVectorSetValueStrided_Hip (CeedVector vec , CeedSize start , CeedSize step , CeedScalar val ) {
353+ static int CeedVectorSetValueStrided_Hip (CeedVector vec , CeedSize start , CeedSize stop , CeedSize step , CeedScalar val ) {
353354 CeedSize length ;
354355 CeedVector_Hip * impl ;
355356
356357 CeedCallBackend (CeedVectorGetData (vec , & impl ));
357358 CeedCallBackend (CeedVectorGetLength (vec , & length ));
358359 // Set value for synced device/host array
360+ if (stop == -1 ) stop = length ;
359361 if (impl -> d_array ) {
360- CeedCallBackend (CeedDeviceSetValueStrided_Hip (impl -> d_array , start , step , length , val ));
362+ CeedCallBackend (CeedDeviceSetValueStrided_Hip (impl -> d_array , start , stop , step , val ));
361363 impl -> h_array = NULL ;
362364 } else if (impl -> h_array ) {
363- CeedCallBackend (CeedHostSetValueStrided_Hip (impl -> h_array , start , step , length , val ));
365+ CeedCallBackend (CeedHostSetValueStrided_Hip (impl -> h_array , start , stop , step , val ));
364366 impl -> d_array = NULL ;
365367 } else {
366368 return CeedError (CeedVectorReturnCeed (vec ), CEED_ERROR_BACKEND , "CeedVector must have valid data set" );
0 commit comments