Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions components/tcp_transport/include/esp_transport_ws.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef enum ws_transport_opcodes {
WS_TRANSPORT_OPCODES_CLOSE = 0x08,
WS_TRANSPORT_OPCODES_PING = 0x09,
WS_TRANSPORT_OPCODES_PONG = 0x0a,
WS_TRANSPORT_OPCODES_COMPRESSED = 0x40,
WS_TRANSPORT_OPCODES_FIN = 0x80,
WS_TRANSPORT_OPCODES_NONE = 0x100, /*!< not a valid opcode to indicate no message previously received
* from the API esp_transport_ws_get_read_opcode() */
Expand All @@ -48,6 +49,13 @@ typedef struct {
* If false, only user frames are propagated, control frames are handled
* automatically during read operations
*/
bool per_msg_compress; /*!< Hint the server to enable per-message compression (RFC7692) */
int per_msg_client_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
int per_msg_server_deflate_window_bit; /*!< Hint the server Per-message deflate window bit 8 to 15; or leave 0 to let server decide */
bool per_msg_server_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on server side
* True for a safer transfer, false for better performance */
bool per_msg_client_no_ctx_takeover; /*!< Hint the server to reset the compression stream on every WS frame on client side
* True for a safer transfer, false for better performance */
} esp_transport_ws_config_t;

/**
Expand Down Expand Up @@ -184,6 +192,78 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o
*/
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);

/**
* @brief Returns the RSV1 flag (permessage-deflate) of the last read frame
*
* @param[in] t The transport handle
*
* @return
* - true if the last read frame was compressed
* - false otherwise
*/
bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t);

/**
* @brief Get per-message compression flag
*
* @param[in] t The transport handle
*
* @return
* - true if per-message compression is enabled
* - false if per-message compression is disabled
*/
bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t);

/**
* @brief Get client deflate window bit for per-message compression
*
* @param[in] t The transport handle
*
* @return
* - client deflate window bit
*/
int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t);

/**
* @brief Get server deflate window bit for per-message compression
*
* @param[in] t The transport handle
*
* @return
* - server deflate window bit
*/
int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t);

/**
* @brief Get server no context takeover flag for per-message compression
*
* If this is returned to be true, then the server-to-client's compression handle should be reset
* on every frame transfer. If this is false, then the server-to-client's compression handle
* should not be reset over the lifespan of this esp_transport_handle_t.
*
* @param[in] t The transport handle
*
* @return
* - true if server no context takeover is enabled
* - false if server no context takeover is disabled
*/
bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t);

/**
* @brief Get client no context takeover flag for per-message compression
*
* If this is returned to be true, then the client-to-server's compression handle should be reset
* on every frame transfer. If this is false, then the client-to-server's compression handle
* should not be reset over the lifespan of this esp_transport_handle_t.
*
* @param[in] t The transport handle
*
* @return
* - true if client no context takeover is enabled
* - false if client no context takeover is disabled
*/
bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t);

/**
* @brief Returns the HTTP status code of the websocket handshake
*
Expand Down
189 changes: 185 additions & 4 deletions components/tcp_transport/transport_ws.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ static const char *TAG = "transport_ws";

#define WS_BUFFER_SIZE CONFIG_WS_BUFFER_SIZE
#define WS_FIN 0x80
#define WS_COMPRESSED 0x40
#define WS_OPCODE_CONT 0x00
#define WS_OPCODE_TEXT 0x01
#define WS_OPCODE_BINARY 0x02
Expand Down Expand Up @@ -56,6 +57,7 @@ typedef struct {
int payload_len; /*!< Total length of the payload */
int bytes_remaining; /*!< Bytes left to read of the payload */
bool header_received; /*!< Flag to indicate that a new message header was received */
bool compressed; /*!< Per-message deflate compress flag (RSV1) */
} ws_transport_frame_state_t;

typedef struct {
Expand All @@ -75,6 +77,11 @@ typedef struct {
char *redir_host;
char *response_header;
size_t response_header_len;
bool per_msg_compress;
int per_msg_client_deflate_window_bit;
int per_msg_server_deflate_window_bit;
bool per_msg_server_no_ctx_takeover;
bool per_msg_client_no_ctx_takeover;
} transport_ws_t;

/**
Expand Down Expand Up @@ -201,6 +208,72 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
#endif

size_t outlen = 0;
char extension_header[168] = { 0 };
if (ws->per_msg_compress) {
int offset = 0;
int ext_ret = snprintf(extension_header, sizeof(extension_header), "Sec-WebSocket-Extensions: permessage-deflate");
if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to write header to permessage-deflate");
return -1;
}

offset += ext_ret;

if (ws->per_msg_client_no_ctx_takeover) {
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_no_context_takeover");
if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_no_context_takeover");
return -1;
}

offset += ext_ret;
}

if (ws->per_msg_server_no_ctx_takeover) {
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_no_context_takeover");
if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_no_context_takeover");
return -1;
}

offset += ext_ret;
}

// If this is 0 then it means to let server decide the client window bit
if (ws->per_msg_client_deflate_window_bit != 0) {
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits=%d", ws->per_msg_client_deflate_window_bit);
} else {
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; client_max_window_bits");
}

if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to write header to permessage-deflate client_max_window_bits");
return -1;
}

offset += ext_ret;

// If this is 0 then it means to let server decide the server window bit
if (ws->per_msg_server_deflate_window_bit != 0) {
ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "; server_max_window_bits=%d", ws->per_msg_server_deflate_window_bit);

if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to write header to permessage-deflate server_max_window_bits");
return -1;
}

offset += ext_ret;
}

ext_ret = snprintf(extension_header + offset, sizeof(extension_header) - offset, "\r\n");
if (ext_ret <= 0) {
ESP_LOGE(TAG, "Failed to concat permessage-deflate header");
return -1;
}

extension_header[sizeof(extension_header) - 1] = '\0';
}

esp_crypto_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key));
int len = snprintf(ws->buffer, WS_BUFFER_SIZE,
"GET %s HTTP/1.1\r\n"
Expand All @@ -209,10 +282,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
"User-Agent: %s\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Key: %s\r\n",
"Sec-WebSocket-Key: %s\r\n"
"%s", // For "Sec-WebSocket-Extensions"
ws->path,
host, port, user_agent_ptr,
client_key);
client_key,
extension_header);
if (len <= 0 || len >= WS_BUFFER_SIZE) {
ESP_LOGE(TAG, "Error in request generation, desired request len: %d, buffer size: %d", len, WS_BUFFER_SIZE);
return -1;
Expand Down Expand Up @@ -306,6 +381,9 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
}
header_cursor += strlen("\r\n");

// If compression was requested, we need to check server response
bool pmd_negotiated = false;

while(header_cursor < delim_ptr){
const char * end_of_line = strnstr(header_cursor, "\r\n", header_len - (header_cursor - ws->buffer));
if(!end_of_line){
Expand All @@ -332,6 +410,53 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
server_key = header_cursor + header_sec_websocket_accept_len;
server_key_len = line_len - header_sec_websocket_accept_len;
}
// Check for Sec-WebSocket-Extensions header
else if (ws->per_msg_compress && line_len >= strlen("Sec-WebSocket-Extensions: ") && !strncasecmp(header_cursor, "Sec-WebSocket-Extensions: ", strlen("Sec-WebSocket-Extensions: "))) {
const char* ext_params = header_cursor + strlen("Sec-WebSocket-Extensions: ");
int ext_params_len = line_len - strlen("Sec-WebSocket-Extensions: ");
ESP_LOGD(TAG, "Found Sec-WebSocket-Extensions: %.*s", ext_params_len, ext_params);

if (strcasestr(ext_params, "permessage-deflate")) {
pmd_negotiated = true;

// Server must agree to context takeover settings
if (!strcasestr(ext_params, "server_no_context_takeover")) {
ws->per_msg_server_no_ctx_takeover = false;
}
if (!strcasestr(ext_params, "client_no_context_takeover")) {
ws->per_msg_client_no_ctx_takeover = false;
}

const char *smwb_str = "server_max_window_bits=";
const char *found = strcasestr(ext_params, smwb_str);
if (found) {
char *endptr;
long smwb = strtol(found + strlen(smwb_str), &endptr, 10);
if (smwb < 8 || smwb > 15) {
ESP_LOGE(TAG, "compression: Server Max Window Bits is invalid: %ld", smwb);
return -1;
}

ws->per_msg_server_deflate_window_bit = (int)smwb;
} else {
ws->per_msg_server_deflate_window_bit = 15;
}

const char *cmwb_str = "client_max_window_bits=";
found = strcasestr(ext_params, cmwb_str);
if (found) {
char *endptr;
long cmwb = strtol(found + strlen(cmwb_str), &endptr, 10);

if (cmwb < 8 || cmwb > 15) {
ESP_LOGE(TAG, "compression: Client Max Window Bits is invalid: %ld", cmwb);
return -1;
}

ws->per_msg_client_deflate_window_bit = (int)cmwb;
}
}
}
else if (ws->header_hook) {
ws->header_hook(ws->header_user_context, header_cursor, line_len);
}
Expand All @@ -349,6 +474,10 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
header_cursor += strlen("\r\n");
}

if (ws->per_msg_compress && !pmd_negotiated) {
ws->per_msg_compress = false;
}

if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
if (location == NULL || location_len <= 0) {
ESP_LOGE(TAG, "Location header not found");
Expand Down Expand Up @@ -575,6 +704,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
ws->frame_state.header_received = true;
ws->frame_state.fin = (*data_ptr & 0x80) != 0;
ws->frame_state.opcode = (*data_ptr & 0x0F);
ws->frame_state.compressed = (*data_ptr & 0x40) != 0; // RSV1 bit in the header
data_ptr ++;
mask = ((*data_ptr >> 7) & 0x01);
payload_len = (*data_ptr & 0x7F);
Expand Down Expand Up @@ -979,14 +1109,65 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
}

ws->propagate_control_frames = config->propagate_control_frames;
ws->per_msg_compress = config->per_msg_compress;
ws->per_msg_client_no_ctx_takeover = config->per_msg_client_no_ctx_takeover;
ws->per_msg_server_no_ctx_takeover = config->per_msg_server_no_ctx_takeover;

if (config->per_msg_client_deflate_window_bit < 8 || config->per_msg_client_deflate_window_bit > 15) {
ws->per_msg_client_deflate_window_bit = 0;
} else {
ws->per_msg_client_deflate_window_bit = config->per_msg_client_deflate_window_bit;
}

if (config->per_msg_server_deflate_window_bit < 8 || config->per_msg_server_deflate_window_bit > 15) {
ws->per_msg_server_deflate_window_bit = 0;
} else {
ws->per_msg_server_deflate_window_bit = config->per_msg_server_deflate_window_bit;
}

return err;
}

bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->frame_state.fin;
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->frame_state.fin;
}

bool esp_transport_ws_get_rsv1_flag(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->frame_state.compressed;
}

bool esp_transport_ws_get_per_msg_compress(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->per_msg_compress;
}

int esp_transport_ws_get_per_msg_client_deflate_window_bit(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->per_msg_client_deflate_window_bit;
}

int esp_transport_ws_get_per_msg_server_deflate_window_bit(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->per_msg_server_deflate_window_bit;
}

bool esp_transport_ws_get_per_msg_server_no_ctx_takeover(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->per_msg_server_no_ctx_takeover && ws->per_msg_compress;
}

bool esp_transport_ws_get_per_msg_client_no_ctx_takeover(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->per_msg_client_no_ctx_takeover && ws->per_msg_compress;
}

int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)
Expand Down