From 5d78214557f043ee135ef898f4738a93bbcbe525 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Wed, 20 Apr 2011 13:15:58 +0200 Subject: [PATCH] First pass at making the protocol reader properly handle OOM --- hiredis.c | 124 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 94 insertions(+), 30 deletions(-) diff --git a/hiredis.c b/hiredis.c index f2135ba..f9e8488 100644 --- a/hiredis.c +++ b/hiredis.c @@ -41,6 +41,10 @@ #include "sds.h" #include "util.h" +#define REDIS_READER_OOM -2 +#define REDIS_READER_NEED_MORE_DATA -1 +#define REDIS_READER_OK 0 + typedef struct redisReader { struct redisReplyObjectFunctions *fn; sds error; /* holds optional error */ @@ -62,7 +66,8 @@ static void *createIntegerObject(const redisReadTask *task, long long value); static void *createNilObject(const redisReadTask *task); static void redisSetReplyReaderError(redisReader *r, sds err); -/* Default set of functions to build the reply. */ +/* Default set of functions to build the reply. Keep in mind that such a + * function returning NULL is interpreted as OOM. */ static redisReplyObjectFunctions defaultFunctions = { createStringObject, createArrayObject, @@ -73,9 +78,11 @@ static redisReplyObjectFunctions defaultFunctions = { /* Create a reply object */ static redisReply *createReplyObject(int type) { - redisReply *r = malloc(sizeof(*r)); + redisReply *r = calloc(1,sizeof(*r)); + + if (r == NULL) + return NULL; - if (!r) redisOOM(); r->type = type; return r; } @@ -89,35 +96,49 @@ void freeReplyObject(void *reply) { case REDIS_REPLY_INTEGER: break; /* Nothing to free */ case REDIS_REPLY_ARRAY: - for (j = 0; j < r->elements; j++) - if (r->element[j]) freeReplyObject(r->element[j]); - free(r->element); + if (r->elements > 0 && r->element != NULL) { + for (j = 0; j < r->elements; j++) + if (r->element[j] != NULL) + freeReplyObject(r->element[j]); + free(r->element); + } break; case REDIS_REPLY_ERROR: case REDIS_REPLY_STATUS: case REDIS_REPLY_STRING: - free(r->str); + if (r->str != NULL) + free(r->str); break; } free(r); } static void *createStringObject(const redisReadTask *task, char *str, size_t len) { - redisReply *r = createReplyObject(task->type); - char *value = malloc(len+1); - if (!value) redisOOM(); - assert(task->type == REDIS_REPLY_ERROR || + redisReply *r, *parent; + char *buf; + + r = createReplyObject(task->type); + if (r == NULL) + return NULL; + + buf = malloc(len+1); + if (buf == NULL) { + freeReplyObject(r); + return NULL; + } + + assert(task->type == REDIS_REPLY_ERROR || task->type == REDIS_REPLY_STATUS || task->type == REDIS_REPLY_STRING); /* Copy string value */ - memcpy(value,str,len); - value[len] = '\0'; - r->str = value; + memcpy(buf,str,len); + buf[len] = '\0'; + r->str = buf; r->len = len; if (task->parent) { - redisReply *parent = task->parent->obj; + parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY); parent->element[task->idx] = r; } @@ -125,12 +146,22 @@ static void *createStringObject(const redisReadTask *task, char *str, size_t len } static void *createArrayObject(const redisReadTask *task, int elements) { - redisReply *r = createReplyObject(REDIS_REPLY_ARRAY); + redisReply *r, *parent; + + r = createReplyObject(REDIS_REPLY_ARRAY); + if (r == NULL) + return NULL; + + r->element = calloc(elements,sizeof(redisReply*)); + if (r->element == NULL) { + freeReplyObject(r); + return NULL; + } + r->elements = elements; - if ((r->element = calloc(sizeof(redisReply*),elements)) == NULL) - redisOOM(); + if (task->parent) { - redisReply *parent = task->parent->obj; + parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY); parent->element[task->idx] = r; } @@ -138,10 +169,16 @@ static void *createArrayObject(const redisReadTask *task, int elements) { } static void *createIntegerObject(const redisReadTask *task, long long value) { - redisReply *r = createReplyObject(REDIS_REPLY_INTEGER); + redisReply *r, *parent; + + r = createReplyObject(REDIS_REPLY_INTEGER); + if (r == NULL) + return NULL; + r->integer = value; + if (task->parent) { - redisReply *parent = task->parent->obj; + parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY); parent->element[task->idx] = r; } @@ -149,9 +186,14 @@ static void *createIntegerObject(const redisReadTask *task, long long value) { } static void *createNilObject(const redisReadTask *task) { - redisReply *r = createReplyObject(REDIS_REPLY_NIL); + redisReply *r, *parent; + + r = createReplyObject(REDIS_REPLY_NIL); + if (r == NULL) + return NULL; + if (task->parent) { - redisReply *parent = task->parent->obj; + parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY); parent->element[task->idx] = r; } @@ -284,12 +326,16 @@ static int processLineItem(redisReader *r) { obj = (void*)(size_t)(cur->type); } + if (obj == NULL) + return REDIS_READER_OOM; + /* Set reply if this is the root object. */ if (r->ridx == 0) r->reply = obj; moveToNextTask(r); - return 0; + return REDIS_READER_OK; } - return -1; + + return REDIS_READER_NEED_MORE_DATA; } static int processBulkItem(redisReader *r) { @@ -328,15 +374,19 @@ static int processBulkItem(redisReader *r) { /* Proceed when obj was created. */ if (success) { + if (obj == NULL) + return REDIS_READER_OOM; + r->pos += bytelen; /* Set reply if this is the root object. */ if (r->ridx == 0) r->reply = obj; moveToNextTask(r); - return 0; + return REDIS_READER_OK; } } - return -1; + + return REDIS_READER_NEED_MORE_DATA; } static int processMultiBulkItem(redisReader *r) { @@ -362,6 +412,10 @@ static int processMultiBulkItem(redisReader *r) { obj = r->fn->createNil(cur); else obj = (void*)REDIS_REPLY_NIL; + + if (obj == NULL) + return REDIS_READER_OOM; + moveToNextTask(r); } else { if (r->fn && r->fn->createArray) @@ -369,6 +423,9 @@ static int processMultiBulkItem(redisReader *r) { else obj = (void*)REDIS_REPLY_ARRAY; + if (obj == NULL) + return REDIS_READER_OOM; + /* Modify task stack when there are more than 0 elements. */ if (elements > 0) { cur->elements = elements; @@ -387,9 +444,10 @@ static int processMultiBulkItem(redisReader *r) { /* Set reply if this is the root object. */ if (root) r->reply = obj; - return 0; + return REDIS_READER_OK; } - return -1; + + return REDIS_READER_NEED_MORE_DATA; } static int processItem(redisReader *r) { @@ -534,6 +592,8 @@ void redisReplyReaderFeed(void *reader, const char *buf, size_t len) { int redisReplyReaderGetReply(void *reader, void **reply) { redisReader *r = reader; + int ret = REDIS_READER_OK; + if (reply != NULL) *reply = NULL; /* When the buffer is empty, there will never be a reply. */ @@ -553,9 +613,13 @@ int redisReplyReaderGetReply(void *reader, void **reply) { /* Process items in reply. */ while (r->ridx >= 0) - if (processItem(r) < 0) + if ((ret = processItem(r)) != REDIS_READER_OK) break; + /* Set errors on OOM. */ + if (ret == REDIS_READER_OOM) + return REDIS_ERR; + /* Discard part of the buffer when we've consumed at least 1k, to avoid * doing unnecessary calls to memmove() in sds.c. */ if (r->pos >= 1024) {