Skip to content

Commit c040d5e

Browse files
authored
Merge pull request #5591 from quic/topic/ssyr2k_direct_sme1
Support for SME1 based ssyr2k_direct kernel for cblas_ssyr2k level 3 API
2 parents 20ae36b + 6939a43 commit c040d5e

File tree

9 files changed

+407
-0
lines changed

9 files changed

+407
-0
lines changed

common_level3.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,31 @@ void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
110110
float beta,
111111
float * C, BLASLONG strideC);
112112

113+
void ssyr2k_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
114+
float alpha,
115+
float * A, BLASLONG strideA,
116+
float * B, BLASLONG strideB,
117+
float beta,
118+
float * R, BLASLONG strideR);
119+
void ssyr2k_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
120+
float alpha,
121+
float * A, BLASLONG strideA,
122+
float * B, BLASLONG strideB,
123+
float beta,
124+
float * R, BLASLONG strideR);
125+
void ssyr2k_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
126+
float alpha,
127+
float * A, BLASLONG strideA,
128+
float * B, BLASLONG strideB,
129+
float beta,
130+
float * R, BLASLONG strideR);
131+
void ssyr2k_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
132+
float alpha,
133+
float * A, BLASLONG strideA,
134+
float * B, BLASLONG strideB,
135+
float beta,
136+
float * R, BLASLONG strideR);
137+
113138
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
114139

115140
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
268268
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
269269
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
270270
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
271+
void (*ssyr2k_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
272+
void (*ssyr2k_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
273+
void (*ssyr2k_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
274+
void (*ssyr2k_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float *, BLASLONG);
271275
#endif
272276

273277

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
6161
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
6262
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT
63+
#define SSYR2K_DIRECT_ALPHA_BETA_UN ssyr2k_direct_alpha_betaUN
64+
#define SSYR2K_DIRECT_ALPHA_BETA_UT ssyr2k_direct_alpha_betaUT
65+
#define SSYR2K_DIRECT_ALPHA_BETA_LN ssyr2k_direct_alpha_betaLN
66+
#define SSYR2K_DIRECT_ALPHA_BETA_LT ssyr2k_direct_alpha_betaLT
6367

6468
#define SGEMM_ONCOPY sgemm_oncopy
6569
#define SGEMM_OTCOPY sgemm_otcopy
@@ -240,6 +244,10 @@
240244
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
241245
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
242246
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
247+
#define SSYR2K_DIRECT_ALPHA_BETA_UN gotoblas -> ssyr2k_direct_alpha_betaUN
248+
#define SSYR2K_DIRECT_ALPHA_BETA_UT gotoblas -> ssyr2k_direct_alpha_betaUT
249+
#define SSYR2K_DIRECT_ALPHA_BETA_LN gotoblas -> ssyr2k_direct_alpha_betaLN
250+
#define SSYR2K_DIRECT_ALPHA_BETA_LT gotoblas -> ssyr2k_direct_alpha_betaLT
243251
#endif
244252

245253
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/syr2k.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Tr
345345
return;
346346
}
347347

348+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
349+
#if defined(ARCH_ARM64) && (defined(USE_SSYR2K_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
350+
#if defined(DYNAMIC_ARCH)
351+
if (support_sme1())
352+
#endif
353+
if (args.n == 0) return;
354+
if (order == CblasRowMajor && n == ldc) {
355+
if (Trans == CblasNoTrans && k == lda && k == ldb) {
356+
if (Uplo == CblasUpper) {
357+
SSYR2K_DIRECT_ALPHA_BETA_UN(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
358+
return;
359+
}else if (Uplo == CblasLower) {
360+
SSYR2K_DIRECT_ALPHA_BETA_LN(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
361+
return;
362+
}
363+
}
364+
else if (Trans == CblasTrans && n == lda && n ==ldb) {
365+
if (Uplo == CblasUpper) {
366+
SSYR2K_DIRECT_ALPHA_BETA_UT(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
367+
return;
368+
}else if (Uplo == CblasLower) {
369+
SSYR2K_DIRECT_ALPHA_BETA_LT(n, k, alpha, a, lda, b, ldb, beta, c, ldc);
370+
return;
371+
}
372+
}
373+
}
374+
#endif
375+
#endif
376+
348377
#endif
349378

350379
if (args.n == 0) return;

kernel/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
249249
if (ARM64)
250250
set(USE_DIRECT_SSYRK true)
251251
endif()
252+
set(USE_DIRECT_SSYR2K false)
253+
if (ARM64)
254+
set(USE_DIRECT_SSYR2K true)
255+
endif()
252256
set(USE_DIRECT_SGEMM false)
253257
if (X86_64 OR ARM64)
254258
set(USE_DIRECT_SGEMM true)
@@ -311,6 +315,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
311315
endif ()
312316
endif()
313317

318+
if (USE_DIRECT_SSYR2K)
319+
if (ARM64)
320+
set (SSYR2KDIRECTKERNEL_ALPHA_BETA ssyr2k_direct_alpha_beta_arm64_sme1.c)
321+
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUN" false "" "" false SINGLE)
322+
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaUT" false "" "" false SINGLE)
323+
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLN" false "" "" false SINGLE)
324+
GenerateNamedObjects("${KERNELDIR}/${SSYR2KDIRECTKERNEL_ALPHA_BETA}" "" "syr2k_direct_alpha_betaLT" false "" "" false SINGLE)
325+
endif ()
326+
endif()
327+
314328
foreach (float_type SINGLE DOUBLE)
315329
string(SUBSTRING ${float_type} 0 1 float_char)
316330
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})

kernel/Makefile.L3

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
5656
USE_DIRECT_STRMM = 1
5757
USE_DIRECT_SSYRK = 1
58+
USE_DIRECT_SSYR2K = 1
5859
endif
5960

6061
ifeq ($(ARCH), riscv64)
@@ -173,6 +174,17 @@ endif
173174
endif
174175
endif
175176

177+
ifdef USE_DIRECT_SSYR2K
178+
ifndef SSYR2KDIRECTKERNEL_ALPHA_BETA
179+
ifeq ($(ARCH), arm64)
180+
ifeq ($(TARGET_CORE), ARMV9SME)
181+
HAVE_SME = 1
182+
endif
183+
SSYR2KDIRECTKERNEL_ALPHA_BETA = ssyr2k_direct_alpha_beta_arm64_sme1.c
184+
endif
185+
endif
186+
endif
187+
176188
ifeq ($(BUILD_BFLOAT16), 1)
177189
ifndef BGEMMKERNEL
178190
BGEMM_BETA = ../generic/gemm_beta.c
@@ -280,6 +292,16 @@ SKERNELOBJS += \
280292
endif
281293
endif
282294

295+
ifdef USE_DIRECT_SSYR2K
296+
ifeq ($(ARCH), arm64)
297+
SKERNELOBJS += \
298+
ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) \
299+
ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
300+
ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) \
301+
ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
302+
endif
303+
endif
304+
283305
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
284306
DKERNELOBJS += \
285307
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1193,6 +1215,20 @@ $(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIREC
11931215
endif
11941216
endif
11951217

1218+
ifdef USE_DIRECT_SSYR2K
1219+
ifeq ($(ARCH), arm64)
1220+
$(KDIR)ssyr2k_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
1221+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@
1222+
$(KDIR)ssyr2k_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
1223+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@
1224+
$(KDIR)ssyr2k_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
1225+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@
1226+
$(KDIR)ssyr2k_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYR2KDIRECTKERNEL_ALPHA_BETA)
1227+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@
1228+
1229+
endif
1230+
endif
1231+
11961232
ifdef USE_TRMM
11971233
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)
11981234
ifeq ($(OS), AIX)

0 commit comments

Comments
 (0)