Skip to content

Commit 57fac03

Browse files
alankellyxnnpack-bot
authored andcommitted
Turn c2 asm f32 kernels
This is the first time that we exploit the broken dependency between gemm & igemm PiperOrigin-RevId: 724255293
1 parent 114acd2 commit 57fac03

File tree

89 files changed

+8527
-2608
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+8527
-2608
lines changed

bench/f32-gemm-minmax.cc

+176
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,182 @@
15831583
}
15841584

15851585
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast)
1586+
1587+
static void f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1588+
GEMMBenchmark(state,
1589+
xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast,
1590+
xnn_init_f32_minmax_scalar_params,
1591+
xnn_pack_f32_gemm_goi_w,
1592+
/*mr=*/1, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1593+
benchmark::utils::CheckAVX512F);
1594+
}
1595+
1596+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast)
1597+
1598+
static void f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1599+
GEMMBenchmark(state,
1600+
xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast,
1601+
xnn_init_f32_minmax_scalar_params,
1602+
xnn_pack_f32_gemm_goi_w,
1603+
/*mr=*/2, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1604+
benchmark::utils::CheckAVX512F);
1605+
}
1606+
1607+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast)
1608+
1609+
static void f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1610+
GEMMBenchmark(state,
1611+
xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast,
1612+
xnn_init_f32_minmax_scalar_params,
1613+
xnn_pack_f32_gemm_goi_w,
1614+
/*mr=*/3, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1615+
benchmark::utils::CheckAVX512F);
1616+
}
1617+
1618+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast)
1619+
1620+
static void f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1621+
GEMMBenchmark(state,
1622+
xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast,
1623+
xnn_init_f32_minmax_scalar_params,
1624+
xnn_pack_f32_gemm_goi_w,
1625+
/*mr=*/4, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1626+
benchmark::utils::CheckAVX512F);
1627+
}
1628+
1629+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast)
1630+
1631+
static void f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1632+
GEMMBenchmark(state,
1633+
xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast,
1634+
xnn_init_f32_minmax_scalar_params,
1635+
xnn_pack_f32_gemm_goi_w,
1636+
/*mr=*/5, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1637+
benchmark::utils::CheckAVX512F);
1638+
}
1639+
1640+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast)
1641+
1642+
static void f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1643+
GEMMBenchmark(state,
1644+
xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast,
1645+
xnn_init_f32_minmax_scalar_params,
1646+
xnn_pack_f32_gemm_goi_w,
1647+
/*mr=*/6, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1648+
benchmark::utils::CheckAVX512F);
1649+
}
1650+
1651+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast)
1652+
1653+
static void f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1654+
GEMMBenchmark(state,
1655+
xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast,
1656+
xnn_init_f32_minmax_scalar_params,
1657+
xnn_pack_f32_gemm_goi_w,
1658+
/*mr=*/7, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1659+
benchmark::utils::CheckAVX512F);
1660+
}
1661+
1662+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast)
1663+
1664+
static void f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1665+
GEMMBenchmark(state,
1666+
xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast,
1667+
xnn_init_f32_minmax_scalar_params,
1668+
xnn_pack_f32_gemm_goi_w,
1669+
/*mr=*/8, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1670+
benchmark::utils::CheckAVX512F);
1671+
}
1672+
1673+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast)
1674+
1675+
static void f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1676+
GEMMBenchmark(state,
1677+
xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast,
1678+
xnn_init_f32_minmax_scalar_params,
1679+
xnn_pack_f32_gemm_goi_w,
1680+
/*mr=*/9, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1681+
benchmark::utils::CheckAVX512F);
1682+
}
1683+
1684+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast)
1685+
1686+
static void f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1687+
GEMMBenchmark(state,
1688+
xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast,
1689+
xnn_init_f32_minmax_scalar_params,
1690+
xnn_pack_f32_gemm_goi_w,
1691+
/*mr=*/10, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1692+
benchmark::utils::CheckAVX512F);
1693+
}
1694+
1695+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast)
1696+
1697+
static void f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1698+
GEMMBenchmark(state,
1699+
xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast,
1700+
xnn_init_f32_minmax_scalar_params,
1701+
xnn_pack_f32_gemm_goi_w,
1702+
/*mr=*/11, /*nr=*/16, /*kr=*/2, /*sr=*/1,
1703+
benchmark::utils::CheckAVX512F);
1704+
}
1705+
1706+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast)
1707+
1708+
static void f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1709+
GEMMBenchmark(state,
1710+
xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast,
1711+
xnn_init_f32_minmax_scalar_params,
1712+
xnn_pack_f32_gemm_goi_w,
1713+
/*mr=*/1, /*nr=*/32, /*kr=*/2, /*sr=*/1,
1714+
benchmark::utils::CheckAVX512F);
1715+
}
1716+
1717+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast)
1718+
1719+
static void f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1720+
GEMMBenchmark(state,
1721+
xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast,
1722+
xnn_init_f32_minmax_scalar_params,
1723+
xnn_pack_f32_gemm_goi_w,
1724+
/*mr=*/2, /*nr=*/32, /*kr=*/2, /*sr=*/1,
1725+
benchmark::utils::CheckAVX512F);
1726+
}
1727+
1728+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast)
1729+
1730+
static void f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1731+
GEMMBenchmark(state,
1732+
xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast,
1733+
xnn_init_f32_minmax_scalar_params,
1734+
xnn_pack_f32_gemm_goi_w,
1735+
/*mr=*/3, /*nr=*/32, /*kr=*/2, /*sr=*/1,
1736+
benchmark::utils::CheckAVX512F);
1737+
}
1738+
1739+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast)
1740+
1741+
static void f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1742+
GEMMBenchmark(state,
1743+
xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast,
1744+
xnn_init_f32_minmax_scalar_params,
1745+
xnn_pack_f32_gemm_goi_w,
1746+
/*mr=*/4, /*nr=*/32, /*kr=*/2, /*sr=*/1,
1747+
benchmark::utils::CheckAVX512F);
1748+
}
1749+
1750+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast)
1751+
1752+
static void f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
1753+
GEMMBenchmark(state,
1754+
xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast,
1755+
xnn_init_f32_minmax_scalar_params,
1756+
xnn_pack_f32_gemm_goi_w,
1757+
/*mr=*/5, /*nr=*/32, /*kr=*/2, /*sr=*/1,
1758+
benchmark::utils::CheckAVX512F);
1759+
}
1760+
1761+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast)
15861762
#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY
15871763

15881764

cmake/gen/amd64_microkernels.cmake

+17-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
SET(PROD_AMD64_ASM_MICROKERNEL_SRCS
1313
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
14-
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S)
14+
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
15+
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
16+
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S)
1517

1618
SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
1719
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
@@ -40,31 +42,45 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
4042
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
4143
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
4244
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S
45+
src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S
4346
src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S
4447
src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S
4548
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S
49+
src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S
4650
src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S
51+
src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S
4752
src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S
4853
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S
54+
src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S
4955
src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S
56+
src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S
5057
src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S
5158
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S
59+
src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S
5260
src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S
61+
src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S
5362
src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S
5463
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S
64+
src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S
5565
src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S
5666
src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S
5767
src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S
68+
src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S
5869
src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S
5970
src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S
71+
src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S
6072
src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S
6173
src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S
74+
src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S
6275
src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S
6376
src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S
77+
src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S
6478
src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S
6579
src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S
80+
src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S
6681
src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S
6782
src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S
83+
src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S
6884
src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S
6985
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S
7086
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S

gemm_compiler/avx512bf16_template.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Avx512Bf16(isa.Avx512F):
1313

1414
def __init__(self):
15-
pass # Empty constructor
15+
self.c = 2
1616

1717
def isa(self):
1818
return 'avx512bf16'
@@ -33,27 +33,28 @@ def compute_asm(self):
3333
def function_name(self, M, N, isa):
3434
return f'xnn_bf16_f32_gemm_minmax_ukernel_{M}x{N}c2__asm_amd64_{isa}_broadcast'
3535

36+
def init_accumulators(self, M, N):
37+
asm_string = super().init_accumulators(M, N)
38+
asm_string += """
39+
# Are there at least 4 bytes?
40+
cmp rdx, 4
41+
js inner_loop_tail\n"""
42+
43+
return asm_string
44+
3645
def outer_loop_prepare(self, M, N):
3746
k_register = self.k_register()
3847
kc_register = self.kc_register()
3948
offset = M * 16 + self.c_ptr_stack_offset()
49+
kmask = self.k_mask()
4050
asm_string = f"""
4151
# Copy k and flip bit.
4252
mov {k_register}, rdx
4353
and {k_register}, 0x2
44-
and {kc_register}, 0xFFFFFFFFFFFFFFFD
54+
and {kc_register}, {kmask}
4555
mov [rsp + {offset}], {k_register}\n"""
4656
return asm_string
4757

48-
def init_accumulators(self, M, N):
49-
asm_string = super().init_accumulators(M, N)
50-
asm_string += """
51-
# Are there at least 4 bytes?
52-
cmp rdx, 4
53-
js inner_loop_tail\n"""
54-
55-
return asm_string
56-
5758
def inner_loop_tail(self, M, N):
5859
k_register = self.k_register()
5960
nc_register = self.nc_register()
@@ -75,3 +76,9 @@ def inner_loop_tail(self, M, N):
7576
else:
7677
asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True)
7778
return asm_string
79+
80+
def element_size(self):
81+
return 2
82+
83+
def k_mask(self):
84+
return "0xFFFFFFFFFFFFFFFD"

0 commit comments

Comments
 (0)