From 030ae1fd97f04c0ff4536e4e35567147409fb985 Mon Sep 17 00:00:00 2001 From: Harishmcw Date: Tue, 25 Feb 2025 15:40:39 +0530 Subject: [PATCH] Redefined threading logic for WoA --- interface/gemv.c | 5 +++++ interface/lapack/gesv.c | 10 ++++++---- interface/zgemv.c | 15 ++++++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/interface/gemv.c b/interface/gemv.c index f91f364eed..0f8fe66782 100644 --- a/interface/gemv.c +++ b/interface/gemv.c @@ -79,6 +79,11 @@ static inline int get_gemv_optimal_nthreads_neoversev1(BLASLONG MN, int ncpu) { static inline int get_gemv_optimal_nthreads(BLASLONG MN) { int ncpu = num_cpu_avail(3); +#if defined(_WIN64) && defined(_M_ARM64) + if (MN > 100000000L) + return num_cpu_avail(4); + return 1; +#endif #if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) return get_gemv_optimal_nthreads_neoversev1(MN, ncpu); #elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) diff --git a/interface/lapack/gesv.c b/interface/lapack/gesv.c index 51a38de60d..21fcc20970 100644 --- a/interface/lapack/gesv.c +++ b/interface/lapack/gesv.c @@ -117,13 +117,15 @@ int NAME(blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, blasint *ipiv, #if defined(_WIN64) && defined(_M_ARM64) #ifdef COMPLEX - if (args.m * args.n > 600) + if (args.m * args.n <= 300) #else - if (args.m * args.n > 1000) + if (args.m * args.n <= 500) #endif - args.nthreads = num_cpu_avail(4); - else args.nthreads = 1; + else if (args.m * args.n <= 1000) + args.nthreads = 4; + else + args.nthreads = num_cpu_avail(4); #else #ifndef DOUBLE if (args.m * args.n < 40000) diff --git a/interface/zgemv.c b/interface/zgemv.c index 3e98dba7ff..3438575b90 100644 --- a/interface/zgemv.c +++ b/interface/zgemv.c @@ -252,25 +252,30 @@ void CNAME(enum CBLAS_ORDER order, #ifdef SMP - if ( 1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD ) +#if defined(_WIN64) && defined(_M_ARM64) + if (m*n > 25000000L) + nthreads = num_cpu_avail(4); + else + nthreads = 1; +#else + if (1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD) nthreads = 1; else nthreads = num_cpu_avail(2); +#endif if (nthreads == 1) { -#endif +#endif (gemv[(int)trans])(m, n, 0, alpha_r, alpha_i, a, lda, x, incx, y, incy, buffer); #ifdef SMP - } else { - (gemv_thread[(int)trans])(m, n, ALPHA, a, lda, x, incx, y, incy, buffer, nthreads); - } #endif + STACK_FREE(buffer); FUNCTION_PROFILE_END(4, m * n + m + n, 2 * m * n);