diff --git a/src/async.c b/src/async.c index 522fb55d..c930a43e 100644 --- a/src/async.c +++ b/src/async.c @@ -46,6 +46,7 @@ #include "async_private.h" #include "net.h" #include "valkey_private.h" +#include "vkutil.h" #include #include @@ -61,6 +62,7 @@ typedef struct { sds command; + size_t len; valkeyCallbackFn *user_callback; void *user_priv_data; } ssubscribeCallbackData; @@ -808,21 +810,78 @@ void valkeyAsyncHandleTimeout(valkeyAsyncContext *ac) { valkeyAsyncDisconnectInternal(ac); } -/* Sets a pointer to the first argument and its length starting at p. Returns - * the number of bytes to skip to get to the following argument. */ -static const char *nextArgument(const char *start, const char **str, size_t *len) { - const char *p = start; - if (p[0] != '$') { - p = strchr(p, '$'); - if (p == NULL) +static inline int vk_isdigit_ascii(char c) { + return (unsigned)(c - '0') < 10; +} + +#define MAX_BULK_LEN (512ULL * 1024ULL * 1024ULL) +vk_static_assert(MAX_BULK_LEN < (UINT64_MAX - 9U) / 10); +static const char *parseBulkLen(const char *p, const char *end, uint64_t *len) { + uint64_t acc = 0; + + assert(p != NULL && end != NULL && end - p >= 0 && len != NULL); + + if (end == p || !vk_isdigit_ascii(*p)) + return NULL; + + while (p < end && vk_isdigit_ascii(*p)) { + unsigned d = *p - '0'; + + acc = acc * 10 + d; + if (acc > (uint64_t)MAX_BULK_LEN) return NULL; + + p++; } - *len = (int)strtol(p + 1, NULL, 10); - p = strchr(p, '\r'); - assert(p); + *len = acc; + return p; +} + +/* Find the next argument in a command buffer, i.e. find the next bulkstring + * in an array of bulkstrings. + * Returns a pointer to the end of a found argument, which can be used when + * finding following arguments, or NULL when an argument is not found. + * The found string is returned by pointer via `str` and length in `strlen`. */ +static const char *nextArgument(const char *buf, size_t buflen, const char **str, size_t *strlen) { + if (buf == NULL || buflen == 0) + goto error; + + const char *p = buf; + + /* Find a bulkstring identifier. */ + if (p[0] != '$') { + if ((p = memchr(p, '$', buflen)) == NULL) + goto error; + } + p++; /* Skip found '$' */ + + uint64_t len; + + p = parseBulkLen(p, buf + buflen, &len); + if (p == NULL) + goto error; + + /* Calculate end pointer for \r\n\r\n */ + const char *end = p + 2 + len + 2; + + /* Validate the parsed length and field separators. */ + if ((size_t)(end - buf) > buflen || p[0] != '\r' || p[len + 2] != '\r') + goto error; + + /* Return pointer to the string, length, and pointer to next element. */ *str = p + 2; - return p + 2 + (*len) + 2; + *strlen = len; + + if ((size_t)(end - buf) == buflen) /* No more data in buffer? */ + return NULL; + + return end; + +error: + *str = NULL; + *strlen = 0; + return NULL; } void valkeySsubscribeCallback(struct valkeyAsyncContext *ac, void *reply, void *privdata) { @@ -846,8 +905,8 @@ void valkeySsubscribeCallback(struct valkeyAsyncContext *ac, void *reply, void * assert(r != NULL); if (r->type == VALKEY_REPLY_ERROR) { /*/ On CROSSSLOT, MOVED and other errors */ - p = nextArgument(data->command, &cstr, &clen); - while ((p = nextArgument(p, &astr, &alen)) != NULL) { + p = nextArgument(data->command, data->len, &cstr, &clen); + while ((p = nextArgument(p, data->len - (p - data->command), &astr, &alen)) != NULL || astr != NULL) { sname = sdsnewlen(astr, alen); if (sname == NULL) goto oom; @@ -865,8 +924,8 @@ void valkeySsubscribeCallback(struct valkeyAsyncContext *ac, void *reply, void * } } else { if ((r->type == VALKEY_REPLY_ARRAY || r->type == VALKEY_REPLY_PUSH) && strncasecmp(r->element[0]->str, "ssubscribe", 10) == 0) { - p = nextArgument(data->command, &cstr, &clen); - while ((p = nextArgument(p, &astr, &alen)) != NULL) { + p = nextArgument(data->command, data->len, &cstr, &clen); + while ((p = nextArgument(p, data->len - (p - data->command), &astr, &alen)) != NULL || astr != NULL) { sname = sdsnewlen(astr, alen); if (sname == NULL) goto oom; @@ -921,29 +980,31 @@ static int valkeyAsyncAppendCmdLen(valkeyAsyncContext *ac, valkeyCallbackFn *fn, if (c->flags & (VALKEY_DISCONNECTING | VALKEY_FREEING)) return VALKEY_ERR; - /* Setup callback */ - cb.fn = fn; - cb.privdata = privdata; - cb.pending_subs = 1; - cb.unsubscribe_sent = 0; - cb.subscribed = 0; + /* Get the first string in the command, and don't accept empty commands. */ + p = nextArgument(cmd, len, &cstr, &clen); + if (cstr == NULL) + return VALKEY_ERR; - /* Find out which command will be appended. */ - p = nextArgument(cmd, &cstr, &clen); - assert(p != NULL); - hasnext = (p[0] == '$'); + hasnext = (p && (p[0] == '$')); pvariant = (tolower(cstr[0]) == 'p') ? 1 : 0; svariant = valkeyIsShardedVariant(cstr); hasprefix = svariant || pvariant; cstr += hasprefix; clen -= hasprefix; + /* Setup callback */ + cb.fn = fn; + cb.privdata = privdata; + cb.pending_subs = 1; + cb.unsubscribe_sent = 0; + cb.subscribed = 0; + if (hasnext && strncasecmp(cstr, "subscribe\r\n", 11) == 0) { int was_subscribed = c->flags & VALKEY_SUBSCRIBED; c->flags |= VALKEY_SUBSCRIBED; /* Add every channel/pattern to the list of subscription callbacks. */ - while ((p = nextArgument(p, &astr, &alen)) != NULL) { + while ((p = nextArgument(p, len - (p - cmd), &astr, &alen)) != NULL || astr != NULL) { sname = sdsnewlen(astr, alen); if (sname == NULL) goto oom; @@ -977,15 +1038,12 @@ static int valkeyAsyncAppendCmdLen(valkeyAsyncContext *ac, valkeyCallbackFn *fn, if (ssubscribe_data == NULL) goto oom; - /* copy command to iterate over all channels. - * actual length of cmd is actually len + 1 (see valkeyvFormatCommand). - * last byte important in nextArgument function. - */ - ssubscribe_data->command = vk_malloc(len + 1); + /* Copy command to iterate over all channels. */ + ssubscribe_data->command = vk_malloc(len); if (ssubscribe_data->command == NULL) goto oom; - memcpy(ssubscribe_data->command, cmd, len + 1); - + memcpy(ssubscribe_data->command, cmd, len); + ssubscribe_data->len = len; ssubscribe_data->user_callback = fn; ssubscribe_data->user_priv_data = privdata; @@ -1015,7 +1073,7 @@ static int valkeyAsyncAppendCmdLen(valkeyAsyncContext *ac, valkeyCallbackFn *fn, if (hasnext) { /* Send an unsubscribe with specific channels/patterns. * Bookkeeping the number of expected replies */ - while ((p = nextArgument(p, &astr, &alen)) != NULL) { + while ((p = nextArgument(p, len - (p - cmd), &astr, &alen)) != NULL || astr != NULL) { sname = sdsnewlen(astr, alen); if (sname == NULL) goto oom; diff --git a/tests/client_test.c b/tests/client_test.c index 6f14885c..4950fb40 100644 --- a/tests/client_test.c +++ b/tests/client_test.c @@ -1884,6 +1884,53 @@ void null_cb(valkeyAsyncContext *ac, void *r, void *privdata) { state->checkpoint++; } +/* Test the command parsing, required for pub/sub in the async API. */ +void test_async_command_parsing(struct config config) { + test("Async command parsing: "); + valkeyOptions options = get_server_tcp_options(config); + valkeyAsyncContext *ac = valkeyAsyncConnectWithOptions(&options); + assert(ac); + + /* Null ptr. */ + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, NULL, 45)); + /* Empty command. */ + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "", 0)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, " $", 2)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*0\r\n", 4)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$-1\r\n", 9)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$-1\r\nUNSUBSCRIBE\r\n", 22)); + /* Protocol error: erroneous bulkstring length and data. */ + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$100000", 11)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$100000\r", 12)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$100000\r\n", 13)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$10HELP\r\n\r\n", 15)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$100000\r\nTO-SHORT\r\n", 23)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$1\r\nTO-LONG\r\n", 17)); + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$123456789\r\n", 11)); + + /* Faulty length given to function. */ + for (int i = 0; i < 19; i++) { + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*2\r\n$9\r\nSUBSCRIBE\r\n$7\r\nCHANNEL\r\n", i)); + } + for (int i = 0; i < 21; i++) { + assert(VALKEY_ERR == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$11\r\nUNSUBSCRIBE\r\n", i)); + } + + /* Complete command. */ + assert(VALKEY_OK == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*2\r\n$9\r\nSUBSCRIBE\r\n$7\r\nCHANNEL\r\n", 32)); + assert(VALKEY_OK == valkeyAsyncFormattedCommand(ac, NULL, NULL, "*1\r\n$11\r\nUNSUBSCRIBE\r\n", 22)); + + // Heap allocate command without NULL terminator. + const char ping[] = "*1\r\n$4\r\nPING\r\n"; + size_t len = sizeof(ping) - 1; + char *buf = vk_malloc_safe(len); + memcpy(buf, ping, len); + assert(VALKEY_OK == valkeyAsyncFormattedCommand(ac, NULL, NULL, buf, len)); + free(buf); + + valkeyAsyncFree(ac); +} + static void test_pubsub_handling(struct config config) { test("Subscribe, handle published message and unsubscribe: "); /* Setup event dispatcher with a testcase timeout */ @@ -2817,6 +2864,7 @@ int main(int argc, char **argv) { get_server_version(c, &major, NULL); disconnect(c, 0); + test_async_command_parsing(cfg); test_pubsub_handling(cfg); test_pubsub_multiple_channels(cfg); test_monitor(cfg);