Skip to content

Commit fb40051

Browse files
optimize bit unpack on arm64 using neon instructions
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
1 parent c17926b commit fb40051

14 files changed

+1198
-108
lines changed

masks_int32_arm64.s

Lines changed: 0 additions & 106 deletions
This file was deleted.

unpack_int32_1bit_arm64.s

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//go:build !purego
2+
3+
#include "textflag.h"
4+
#include "unpack_neon_macros_arm64.h"
5+
6+
// unpackInt32x1bitNEON implements NEON unpacking for bitWidth=1 using direct bit manipulation
7+
// Each byte contains 8 bits: [bit7][bit6][bit5][bit4][bit3][bit2][bit1][bit0]
8+
//
9+
// func unpackInt32x1bitNEON(dst []int32, src []byte, bitWidth uint)
10+
TEXT ·unpackInt32x1bitNEON(SB), NOSPLIT, $0-56
11+
MOVD dst_base+0(FP), R0 // R0 = dst pointer
12+
MOVD dst_len+8(FP), R1 // R1 = dst length
13+
MOVD src_base+24(FP), R2 // R2 = src pointer
14+
MOVD bitWidth+48(FP), R3 // R3 = bitWidth (should be 1)
15+
16+
MOVD $0, R5 // R5 = index (initialize early for tail path)
17+
18+
// Check if we have at least 64 values to process
19+
CMP $64, R1
20+
BLT neon1_tail
21+
22+
// Round down to multiple of 64 for NEON processing
23+
MOVD R1, R4
24+
LSR $6, R4, R4 // R4 = len / 64
25+
LSL $6, R4, R4 // R4 = aligned length (multiple of 64)
26+
27+
// Load mask for 1 bit (0x01010101...)
28+
MOVD $0x0101010101010101, R6
29+
VMOV R6, V31.D[0]
30+
VMOV R6, V31.D[1] // V31 = mask for single bits
31+
32+
neon1_loop:
33+
// Load 8 bytes (contains 64 x 1-bit values)
34+
VLD1 (R2), [V0.B8]
35+
36+
// Extract each bit position (8 separate streams)
37+
VAND V31.B16, V0.B16, V1.B16 // V1 = bit 0
38+
39+
VUSHR $1, V0.B16, V2.B16
40+
VAND V31.B16, V2.B16, V2.B16 // V2 = bit 1
41+
42+
VUSHR $2, V0.B16, V3.B16
43+
VAND V31.B16, V3.B16, V3.B16 // V3 = bit 2
44+
45+
VUSHR $3, V0.B16, V4.B16
46+
VAND V31.B16, V4.B16, V4.B16 // V4 = bit 3
47+
48+
VUSHR $4, V0.B16, V5.B16
49+
VAND V31.B16, V5.B16, V5.B16 // V5 = bit 4
50+
51+
VUSHR $5, V0.B16, V6.B16
52+
VAND V31.B16, V6.B16, V6.B16 // V6 = bit 5
53+
54+
VUSHR $6, V0.B16, V7.B16
55+
VAND V31.B16, V7.B16, V7.B16 // V7 = bit 6
56+
57+
VUSHR $7, V0.B16, V8.B16
58+
VAND V31.B16, V8.B16, V8.B16 // V8 = bit 7
59+
60+
// Stage 1: ZIP pairs (8 streams → 4 streams of pairs)
61+
VZIP1 V2.B8, V1.B8, V9.B8 // V9 = [bit0,bit1] interleaved
62+
VZIP1 V4.B8, V3.B8, V10.B8 // V10 = [bit2,bit3] interleaved
63+
VZIP1 V6.B8, V5.B8, V11.B8 // V11 = [bit4,bit5] interleaved
64+
VZIP1 V8.B8, V7.B8, V12.B8 // V12 = [bit6,bit7] interleaved
65+
66+
VZIP2 V2.B8, V1.B8, V13.B8 // V13 = [bit0,bit1] upper half
67+
VZIP2 V4.B8, V3.B8, V14.B8 // V14 = [bit2,bit3] upper half
68+
VZIP2 V6.B8, V5.B8, V15.B8 // V15 = [bit4,bit5] upper half
69+
VZIP2 V8.B8, V7.B8, V16.B8 // V16 = [bit6,bit7] upper half
70+
71+
// Stage 2: ZIP quads (4 streams → 2 streams of quads)
72+
VZIP1 V10.H4, V9.H4, V17.H4 // V17 = [0,1,2,3] interleaved
73+
VZIP1 V12.H4, V11.H4, V18.H4 // V18 = [4,5,6,7] interleaved
74+
VZIP2 V10.H4, V9.H4, V19.H4 // V19 = [0,1,2,3] next
75+
VZIP2 V12.H4, V11.H4, V20.H4 // V20 = [4,5,6,7] next
76+
77+
VZIP1 V14.H4, V13.H4, V21.H4 // V21 = upper [0,1,2,3]
78+
VZIP1 V16.H4, V15.H4, V22.H4 // V22 = upper [4,5,6,7]
79+
VZIP2 V14.H4, V13.H4, V23.H4 // V23 = upper [0,1,2,3] next
80+
VZIP2 V16.H4, V15.H4, V24.H4 // V24 = upper [4,5,6,7] next
81+
82+
// Stage 3: ZIP octets (2 streams → fully sequential)
83+
VZIP1 V18.S2, V17.S2, V25.S2 // V25 = values 0-7
84+
VZIP2 V18.S2, V17.S2, V26.S2 // V26 = values 8-15
85+
VZIP1 V20.S2, V19.S2, V27.S2 // V27 = values 16-23
86+
VZIP2 V20.S2, V19.S2, V28.S2 // V28 = values 24-31
87+
VZIP1 V22.S2, V21.S2, V1.S2 // V1 = values 32-39
88+
VZIP2 V22.S2, V21.S2, V2.S2 // V2 = values 40-47
89+
VZIP1 V24.S2, V23.S2, V3.S2 // V3 = values 48-55
90+
VZIP2 V24.S2, V23.S2, V4.S2 // V4 = values 56-63
91+
92+
// Widen to int32 and store - Process first 32 values
93+
USHLL_8H_8B(5, 25)
94+
USHLL_4S_4H(6, 5)
95+
USHLL2_4S_8H(7, 5)
96+
VST1 [V6.S4, V7.S4], (R0)
97+
ADD $32, R0, R0
98+
99+
USHLL_8H_8B(5, 26)
100+
USHLL_4S_4H(6, 5)
101+
USHLL2_4S_8H(7, 5)
102+
VST1 [V6.S4, V7.S4], (R0)
103+
ADD $32, R0, R0
104+
105+
USHLL_8H_8B(5, 27)
106+
USHLL_4S_4H(6, 5)
107+
USHLL2_4S_8H(7, 5)
108+
VST1 [V6.S4, V7.S4], (R0)
109+
ADD $32, R0, R0
110+
111+
USHLL_8H_8B(5, 28)
112+
USHLL_4S_4H(6, 5)
113+
USHLL2_4S_8H(7, 5)
114+
VST1 [V6.S4, V7.S4], (R0)
115+
ADD $32, R0, R0
116+
117+
// Process second 32 values
118+
USHLL_8H_8B(5, 1)
119+
USHLL_4S_4H(6, 5)
120+
USHLL2_4S_8H(7, 5)
121+
VST1 [V6.S4, V7.S4], (R0)
122+
ADD $32, R0, R0
123+
124+
USHLL_8H_8B(5, 2)
125+
USHLL_4S_4H(6, 5)
126+
USHLL2_4S_8H(7, 5)
127+
VST1 [V6.S4, V7.S4], (R0)
128+
ADD $32, R0, R0
129+
130+
USHLL_8H_8B(5, 3)
131+
USHLL_4S_4H(6, 5)
132+
USHLL2_4S_8H(7, 5)
133+
VST1 [V6.S4, V7.S4], (R0)
134+
ADD $32, R0, R0
135+
136+
USHLL_8H_8B(5, 4)
137+
USHLL_4S_4H(6, 5)
138+
USHLL2_4S_8H(7, 5)
139+
VST1 [V6.S4, V7.S4], (R0)
140+
ADD $32, R0, R0
141+
142+
// Advance pointers
143+
ADD $8, R2, R2 // src += 8 bytes
144+
ADD $64, R5, R5 // index += 64
145+
146+
CMP R4, R5
147+
BLT neon1_loop
148+
149+
neon1_tail:
150+
// Handle remaining elements with scalar fallback
151+
CMP R1, R5
152+
BEQ neon1_done
153+
154+
// Compute remaining elements
155+
SUB R5, R1, R1
156+
157+
// Fall back to scalar unpack for tail
158+
MOVD $1, R4 // bitMask = 1
159+
MOVD $0, R6 // bitOffset = 0
160+
MOVD $0, R7 // index = 0
161+
B neon1_scalar_test
162+
163+
neon1_scalar_loop:
164+
MOVD R6, R8
165+
LSR $3, R8, R8 // byte_index = bitOffset / 8
166+
MOVBU (R2)(R8), R9 // Load byte
167+
168+
MOVD R6, R10
169+
AND $7, R10, R10 // bit_offset = bitOffset % 8
170+
171+
LSR R10, R9, R9 // Shift right by bit offset
172+
AND $1, R9, R9 // Mask to get bit
173+
MOVW R9, (R0) // Store as int32
174+
175+
ADD $4, R0, R0 // dst++
176+
ADD $1, R6, R6 // bitOffset++
177+
ADD $1, R7, R7 // index++
178+
179+
neon1_scalar_test:
180+
CMP R1, R7
181+
BLT neon1_scalar_loop
182+
183+
neon1_done:
184+
RET

0 commit comments

Comments
 (0)