Skip to content

Commit 0d1781c

Browse files
committed
Merge 'Websocket server fixes' from Povilas Kanapickas
This PR fixes several issues in websocket frame parser. The most important one is that websocket parser didn't properly support partial packets. That is, the parser would fail if it received part of the websocket frame in one chunk of data and then received the rest of the frame separately. The rest of fixes are general cleanups and fixes for bugs that can only be reproduced when using websocket parser in client mode (there will be more PRs in this area). Unit tests have been added to better cover this area of the websocket parser. The PR is best reviewed commit by commit. Closes #2535 * https://github.com/scylladb/seastar: websocket: Remove unnecessary condition in frame parsing websocket: Fix logic when parsing header websocket: Avoid memory copy when full websocket frames are received websocket: Fix websocket frame parsing on partial packets
2 parents 6104aee + 0065fd2 commit 0d1781c

File tree

3 files changed

+93
-49
lines changed

3 files changed

+93
-49
lines changed

include/seastar/websocket/server.hh

+4
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class websocket_parser {
135135
sstring _buffer;
136136
std::unique_ptr<frame_header> _header;
137137
uint64_t _payload_length;
138+
uint64_t _consumed_payload_length = 0;
138139
uint32_t _masking_key;
139140
buff_t _result;
140141

@@ -144,6 +145,9 @@ class websocket_parser {
144145
static future<consumption_result_t> stop(buff_t data) {
145146
return make_ready_future<consumption_result_t>(stop_consuming(std::move(data)));
146147
}
148+
uint64_t remaining_payload_length() const {
149+
return _payload_length - _consumed_payload_length;
150+
}
147151

148152
// Removes mask from payload given in p.
149153
void remove_mask(buff_t& p, size_t n) {

src/websocket/server.cc

+59-36
Original file line numberDiff line numberDiff line change
@@ -207,23 +207,27 @@ future<websocket_parser::consumption_result_t> websocket_parser::operator()(
207207
}
208208
if (_state == parsing_state::flags_and_payload_data) {
209209
if (_buffer.length() + data.size() >= 2) {
210-
if (_buffer.length() < 2) {
211-
size_t hlen = _buffer.length();
212-
_buffer.append(data.get(), 2 - hlen);
213-
data.trim_front(2 - hlen);
214-
_header = std::make_unique<frame_header>(_buffer.data());
215-
_buffer = {};
216-
217-
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
218-
// We must close the connection if data isn't masked.
219-
if ((!_header->masked) ||
220-
// RSVX must be 0
221-
(_header->rsv1 | _header->rsv2 | _header->rsv3) ||
222-
// Opcode must be known.
223-
(!_header->is_opcode_known())) {
224-
_cstate = connection_state::error;
225-
return websocket_parser::stop(std::move(data));
226-
}
210+
// _buffer.length() is less than 2 when entering this if body due to how
211+
// the rest of code is structured. The else branch will never increase
212+
// _buffer.length() to >=2 and other paths to this condition will always
213+
// have buffer cleared.
214+
assert(_buffer.length() < 2);
215+
216+
size_t hlen = _buffer.length();
217+
_buffer.append(data.get(), 2 - hlen);
218+
data.trim_front(2 - hlen);
219+
_header = std::make_unique<frame_header>(_buffer.data());
220+
_buffer = {};
221+
222+
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
223+
// We must close the connection if data isn't masked.
224+
if ((!_header->masked) ||
225+
// RSVX must be 0
226+
(_header->rsv1 | _header->rsv2 | _header->rsv3) ||
227+
// Opcode must be known.
228+
(!_header->is_opcode_known())) {
229+
_cstate = connection_state::error;
230+
return websocket_parser::stop(std::move(data));
227231
}
228232
_state = parsing_state::payload_length_and_mask;
229233
} else {
@@ -238,35 +242,54 @@ future<websocket_parser::consumption_result_t> websocket_parser::operator()(
238242
size_t hlen = _buffer.length();
239243
_buffer.append(data.get(), required_bytes - hlen);
240244
data.trim_front(required_bytes - hlen);
241-
242-
_payload_length = _header->length;
243-
char const *input = _buffer.data();
244-
if (_header->length == 126) {
245-
_payload_length = consume_be<uint16_t>(input);
246-
} else if (_header->length == 127) {
247-
_payload_length = consume_be<uint64_t>(input);
248-
}
249-
250-
_masking_key = consume_be<uint32_t>(input);
251-
_buffer = {};
252245
}
246+
_payload_length = _header->length;
247+
char const *input = _buffer.data();
248+
if (_header->length == 126) {
249+
_payload_length = consume_be<uint16_t>(input);
250+
} else if (_header->length == 127) {
251+
_payload_length = consume_be<uint64_t>(input);
252+
}
253+
254+
_masking_key = consume_be<uint32_t>(input);
255+
_buffer = {};
253256
_state = parsing_state::payload;
254257
} else {
255258
_buffer.append(data.get(), data.size());
256259
return websocket_parser::dont_stop();
257260
}
258261
}
259262
if (_state == parsing_state::payload) {
260-
if (_payload_length > data.size()) {
261-
_payload_length -= data.size();
262-
remove_mask(data, data.size());
263-
_result = std::move(data);
264-
return websocket_parser::stop(buff_t(0));
263+
if (data.size() < remaining_payload_length()) {
264+
// data has insufficient data to complete the frame - consume data.size() bytes
265+
if (_result.empty()) {
266+
_result = temporary_buffer<char>(remaining_payload_length());
267+
_consumed_payload_length = 0;
268+
}
269+
std::copy(data.begin(), data.end(), _result.get_write() + _consumed_payload_length);
270+
_consumed_payload_length += data.size();
271+
return websocket_parser::dont_stop();
265272
} else {
266-
_result = data.clone();
273+
// data has sufficient data to complete the frame - consume remaining_payload_length()
274+
auto consumed_bytes = remaining_payload_length();
275+
if (_result.empty()) {
276+
// Try to avoid memory copies in case when network packets contain one or more full
277+
// websocket frames.
278+
if (consumed_bytes == data.size()) {
279+
_result = std::move(data);
280+
data = temporary_buffer<char>(0);
281+
} else {
282+
_result = data.share();
283+
_result.trim(consumed_bytes);
284+
data.trim_front(consumed_bytes);
285+
}
286+
} else {
287+
std::copy(data.begin(), data.begin() + consumed_bytes,
288+
_result.get_write() + _consumed_payload_length);
289+
data.trim_front(consumed_bytes);
290+
}
267291
remove_mask(_result, _payload_length);
268-
data.trim_front(_payload_length);
269-
_payload_length = 0;
292+
_consumed_payload_length = 0;
270293
_state = parsing_state::flags_and_payload_data;
271294
return websocket_parser::stop(std::move(data));
272295
}

tests/unit/websocket_test.cc

+30-13
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,36 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) {
136136
output.flush().get();
137137
input.read_exactly(186).get();
138138

139-
// Sending and receiving a websocket frame
140-
const auto ws_frame = std::string(
141-
"\202\204" // 1000 0002 1000 0100
142-
"TEST" // Masking Key
143-
"\0\0\0\0", 10); // Masked Message - TEST
144-
const auto rs_frame = std::string(
145-
"\202\004" // 1000 0002 0000 0100
146-
"TEST", 6); // Message - TEST
147-
output.write(ws_frame).get();
148-
output.flush().get();
139+
unsigned ws_frame_len = 10;
140+
for (unsigned split_i = 0; split_i < ws_frame_len - 1; ++split_i) {
141+
// The loop tests various combinations of partial websocket frame coming in
142+
143+
// Sending and receiving a websocket frame
144+
const std::string ws_frame = std::string(
145+
"\202\204" // 1000 0002 1000 0100
146+
"TEST" // Masking Key
147+
"\0\0\0\0", ws_frame_len); // Masked Message - TEST
148+
const auto rs_frame = std::string(
149+
"\202\004" // 1000 0002 0000 0100
150+
"TEST", 6); // Message - TEST
151+
152+
if (split_i == 0) {
153+
output.write(ws_frame).get();
154+
output.flush().get();
155+
} else {
156+
output.write(ws_frame.substr(0, split_i)).get();
157+
output.flush().get();
149158

150-
auto response = input.read_exactly(6).get();
151-
auto response_str = std::string(response.begin(), response.end());
152-
BOOST_REQUIRE_EQUAL(rs_frame, response_str);
159+
// ensure that server attempts to read before the second part of the frame lands
160+
sleep(std::chrono::milliseconds(100)).get();
161+
162+
output.write(ws_frame.substr(split_i)).get();
163+
output.flush().get();
164+
}
165+
166+
auto response = input.read_exactly(6).get();
167+
auto response_str = std::string(response.begin(), response.end());
168+
BOOST_REQUIRE_EQUAL(rs_frame, response_str);
169+
}
153170
});
154171
}

0 commit comments

Comments
 (0)