Skip to content

Commit dc350c7

Browse files
committed
Try to reuse mask vectors when unmasking websocket frames
1 parent 1d376b5 commit dc350c7

File tree

1 file changed

+58
-31
lines changed

1 file changed

+58
-31
lines changed

src/lib/lwan-websocket.c

+58-31
Original file line numberDiff line numberDiff line change
@@ -150,63 +150,90 @@ static size_t get_frame_length(struct lwan_request *request, uint16_t header)
150150

151151
static void unmask(char *msg, size_t msg_len, char mask[static 4])
152152
{
153-
const int32_t mask32 = (int32_t)string_as_uint32(mask);
154-
const char *msg_end = msg + msg_len;
153+
/* TODO: handle alignment of `msg` to use (at least) NT loads
154+
* as we're rewriting msg anyway. (NT writes aren't that
155+
* useful as the unmasked value will be used right after.) */
155156

156157
#if defined(__AVX2__)
157-
const size_t len256 = msg_len / 32;
158-
if (len256) {
159-
const __m256i mask256 = _mm256_setr_epi32(
160-
mask32, mask32, mask32, mask32, mask32, mask32, mask32, mask32);
161-
for (size_t i = 0; i < len256; i++) {
162-
__m256i v = _mm256_loadu_si256((__m256i *)msg);
158+
const __m256i mask256 =
159+
_mm256_castps_si256(_mm256_broadcast_ss((const float *)mask));
160+
if (msg_len >= 32) {
161+
do {
162+
__m256i v = _mm256_lddqu_si256((const __m256i *)msg);
163163
_mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256));
164-
msg += 32;
165-
}
166164

167-
msg_len = (size_t)(msg_end - msg);
165+
msg += 32;
166+
msg_len -= 32;
167+
} while (msg_len >= 32);
168168
}
169169
#endif
170170

171171
#if defined(__SSE2__)
172-
const size_t len128 = msg_len / 16;
173-
if (len128) {
174-
const __m128i mask128 = _mm_setr_epi32(mask32, mask32, mask32, mask32);
175-
for (size_t i = 0; i < len128; i++) {
176-
__m128i v = _mm_loadu_si128((__m128i *)msg);
172+
#if defined(__AVX2__)
173+
const __m128i mask128 = _mm256_extracti128_si256(mask256, 0);
174+
#elif defined(__SSE3__)
175+
const __m128i mask128 = _mm_lddqu_si128((const float *)mask);
176+
#else
177+
const __m128i mask128 = _mm_loadu_si128((const __m128i *)mask);
178+
#endif
179+
if (msg_len >= 16) {
180+
do {
181+
#if defined(__SSE3__)
182+
__m128i v = _mm_lddqu_si128((const __m128i *)msg);
183+
#else
184+
__m128i v = _mm_loadu_si128((const __m128i *)msg);
185+
#endif
186+
177187
_mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128));
178-
msg += 16;
179-
}
180188

181-
msg_len = (size_t)(msg_end - msg);
189+
msg += 16;
190+
msg_len -= 16;
191+
} while (msg_len >= 16);
182192
}
183193
#endif
184194

185195
if (sizeof(void *) == 8) {
186-
const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32;
187-
const size_t len64 = msg_len / 8;
188-
for (size_t i = 0; i < len64; i++) {
189-
uint64_t v = string_as_uint64(msg);
190-
v ^= mask64;
191-
msg = mempcpy(msg, &v, sizeof(v));
196+
if (msg_len >= 8) {
197+
#if defined(__SSE_4_1__)
198+
/* We're far away enough from the AVX2 path that it's
199+
* probably better to use mask128 instead of mask256
200+
* here. */
201+
const __int64 mask64 = _mm_extract_epi64(mask128, 0);
202+
#else
203+
const uint32_t mask32 = string_as_uint32(mask);
204+
const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32;
205+
#endif
206+
do {
207+
uint64_t v = string_as_uint64(msg);
208+
v ^= (uint64_t)mask64;
209+
msg = mempcpy(msg, &v, sizeof(v));
210+
msg_len -= 8;
211+
} while (msg_len >= 8);
192212
}
193213
}
194214

195-
const size_t len32 = (size_t)((msg_end - msg) / 4);
196-
for (size_t i = 0; i < len32; i++) {
197-
uint32_t v = string_as_uint32(msg);
198-
v ^= (uint32_t)mask32;
199-
msg = mempcpy(msg, &v, sizeof(v));
215+
if (msg_len >= 4) {
216+
const uint32_t mask32 = string_as_uint32(mask);
217+
do {
218+
uint32_t v = string_as_uint32(msg);
219+
v ^= (uint32_t)mask32;
220+
msg = mempcpy(msg, &v, sizeof(v));
221+
msg_len -= 4;
222+
} while (msg_len >= 4);
200223
}
201224

202-
switch (msg_end - msg) {
225+
switch (msg_len) {
203226
case 3:
204227
msg[2] ^= mask[2]; /* fallthrough */
205228
case 2:
206229
msg[1] ^= mask[1]; /* fallthrough */
207230
case 1:
208231
msg[0] ^= mask[0];
232+
break;
233+
default:
234+
__builtin_unreachable();
209235
}
236+
#undef MASK32_SET
210237
}
211238

212239
static void send_websocket_pong(struct lwan_request *request, uint16_t header)

0 commit comments

Comments
 (0)