First pass at making the protocol reader properly handle OOM

This commit is contained in:
Pieter Noordhuis 2011-04-20 13:15:58 +02:00
parent 178024244d
commit 5d78214557

124
hiredis.c
View File

@ -41,6 +41,10 @@
#include "sds.h" #include "sds.h"
#include "util.h" #include "util.h"
#define REDIS_READER_OOM -2
#define REDIS_READER_NEED_MORE_DATA -1
#define REDIS_READER_OK 0
typedef struct redisReader { typedef struct redisReader {
struct redisReplyObjectFunctions *fn; struct redisReplyObjectFunctions *fn;
sds error; /* holds optional error */ sds error; /* holds optional error */
@ -62,7 +66,8 @@ static void *createIntegerObject(const redisReadTask *task, long long value);
static void *createNilObject(const redisReadTask *task); static void *createNilObject(const redisReadTask *task);
static void redisSetReplyReaderError(redisReader *r, sds err); static void redisSetReplyReaderError(redisReader *r, sds err);
/* Default set of functions to build the reply. */ /* Default set of functions to build the reply. Keep in mind that such a
* function returning NULL is interpreted as OOM. */
static redisReplyObjectFunctions defaultFunctions = { static redisReplyObjectFunctions defaultFunctions = {
createStringObject, createStringObject,
createArrayObject, createArrayObject,
@ -73,9 +78,11 @@ static redisReplyObjectFunctions defaultFunctions = {
/* Create a reply object */ /* Create a reply object */
static redisReply *createReplyObject(int type) { static redisReply *createReplyObject(int type) {
redisReply *r = malloc(sizeof(*r)); redisReply *r = calloc(1,sizeof(*r));
if (r == NULL)
return NULL;
if (!r) redisOOM();
r->type = type; r->type = type;
return r; return r;
} }
@ -89,35 +96,49 @@ void freeReplyObject(void *reply) {
case REDIS_REPLY_INTEGER: case REDIS_REPLY_INTEGER:
break; /* Nothing to free */ break; /* Nothing to free */
case REDIS_REPLY_ARRAY: case REDIS_REPLY_ARRAY:
for (j = 0; j < r->elements; j++) if (r->elements > 0 && r->element != NULL) {
if (r->element[j]) freeReplyObject(r->element[j]); for (j = 0; j < r->elements; j++)
free(r->element); if (r->element[j] != NULL)
freeReplyObject(r->element[j]);
free(r->element);
}
break; break;
case REDIS_REPLY_ERROR: case REDIS_REPLY_ERROR:
case REDIS_REPLY_STATUS: case REDIS_REPLY_STATUS:
case REDIS_REPLY_STRING: case REDIS_REPLY_STRING:
free(r->str); if (r->str != NULL)
free(r->str);
break; break;
} }
free(r); free(r);
} }
static void *createStringObject(const redisReadTask *task, char *str, size_t len) { static void *createStringObject(const redisReadTask *task, char *str, size_t len) {
redisReply *r = createReplyObject(task->type); redisReply *r, *parent;
char *value = malloc(len+1); char *buf;
if (!value) redisOOM();
assert(task->type == REDIS_REPLY_ERROR || r = createReplyObject(task->type);
if (r == NULL)
return NULL;
buf = malloc(len+1);
if (buf == NULL) {
freeReplyObject(r);
return NULL;
}
assert(task->type == REDIS_REPLY_ERROR ||
task->type == REDIS_REPLY_STATUS || task->type == REDIS_REPLY_STATUS ||
task->type == REDIS_REPLY_STRING); task->type == REDIS_REPLY_STRING);
/* Copy string value */ /* Copy string value */
memcpy(value,str,len); memcpy(buf,str,len);
value[len] = '\0'; buf[len] = '\0';
r->str = value; r->str = buf;
r->len = len; r->len = len;
if (task->parent) { if (task->parent) {
redisReply *parent = task->parent->obj; parent = task->parent->obj;
assert(parent->type == REDIS_REPLY_ARRAY); assert(parent->type == REDIS_REPLY_ARRAY);
parent->element[task->idx] = r; parent->element[task->idx] = r;
} }
@ -125,12 +146,22 @@ static void *createStringObject(const redisReadTask *task, char *str, size_t len
} }
static void *createArrayObject(const redisReadTask *task, int elements) { static void *createArrayObject(const redisReadTask *task, int elements) {
redisReply *r = createReplyObject(REDIS_REPLY_ARRAY); redisReply *r, *parent;
r = createReplyObject(REDIS_REPLY_ARRAY);
if (r == NULL)
return NULL;
r->element = calloc(elements,sizeof(redisReply*));
if (r->element == NULL) {
freeReplyObject(r);
return NULL;
}
r->elements = elements; r->elements = elements;
if ((r->element = calloc(sizeof(redisReply*),elements)) == NULL)
redisOOM();
if (task->parent) { if (task->parent) {
redisReply *parent = task->parent->obj; parent = task->parent->obj;
assert(parent->type == REDIS_REPLY_ARRAY); assert(parent->type == REDIS_REPLY_ARRAY);
parent->element[task->idx] = r; parent->element[task->idx] = r;
} }
@ -138,10 +169,16 @@ static void *createArrayObject(const redisReadTask *task, int elements) {
} }
static void *createIntegerObject(const redisReadTask *task, long long value) { static void *createIntegerObject(const redisReadTask *task, long long value) {
redisReply *r = createReplyObject(REDIS_REPLY_INTEGER); redisReply *r, *parent;
r = createReplyObject(REDIS_REPLY_INTEGER);
if (r == NULL)
return NULL;
r->integer = value; r->integer = value;
if (task->parent) { if (task->parent) {
redisReply *parent = task->parent->obj; parent = task->parent->obj;
assert(parent->type == REDIS_REPLY_ARRAY); assert(parent->type == REDIS_REPLY_ARRAY);
parent->element[task->idx] = r; parent->element[task->idx] = r;
} }
@ -149,9 +186,14 @@ static void *createIntegerObject(const redisReadTask *task, long long value) {
} }
static void *createNilObject(const redisReadTask *task) { static void *createNilObject(const redisReadTask *task) {
redisReply *r = createReplyObject(REDIS_REPLY_NIL); redisReply *r, *parent;
r = createReplyObject(REDIS_REPLY_NIL);
if (r == NULL)
return NULL;
if (task->parent) { if (task->parent) {
redisReply *parent = task->parent->obj; parent = task->parent->obj;
assert(parent->type == REDIS_REPLY_ARRAY); assert(parent->type == REDIS_REPLY_ARRAY);
parent->element[task->idx] = r; parent->element[task->idx] = r;
} }
@ -284,12 +326,16 @@ static int processLineItem(redisReader *r) {
obj = (void*)(size_t)(cur->type); obj = (void*)(size_t)(cur->type);
} }
if (obj == NULL)
return REDIS_READER_OOM;
/* Set reply if this is the root object. */ /* Set reply if this is the root object. */
if (r->ridx == 0) r->reply = obj; if (r->ridx == 0) r->reply = obj;
moveToNextTask(r); moveToNextTask(r);
return 0; return REDIS_READER_OK;
} }
return -1;
return REDIS_READER_NEED_MORE_DATA;
} }
static int processBulkItem(redisReader *r) { static int processBulkItem(redisReader *r) {
@ -328,15 +374,19 @@ static int processBulkItem(redisReader *r) {
/* Proceed when obj was created. */ /* Proceed when obj was created. */
if (success) { if (success) {
if (obj == NULL)
return REDIS_READER_OOM;
r->pos += bytelen; r->pos += bytelen;
/* Set reply if this is the root object. */ /* Set reply if this is the root object. */
if (r->ridx == 0) r->reply = obj; if (r->ridx == 0) r->reply = obj;
moveToNextTask(r); moveToNextTask(r);
return 0; return REDIS_READER_OK;
} }
} }
return -1;
return REDIS_READER_NEED_MORE_DATA;
} }
static int processMultiBulkItem(redisReader *r) { static int processMultiBulkItem(redisReader *r) {
@ -362,6 +412,10 @@ static int processMultiBulkItem(redisReader *r) {
obj = r->fn->createNil(cur); obj = r->fn->createNil(cur);
else else
obj = (void*)REDIS_REPLY_NIL; obj = (void*)REDIS_REPLY_NIL;
if (obj == NULL)
return REDIS_READER_OOM;
moveToNextTask(r); moveToNextTask(r);
} else { } else {
if (r->fn && r->fn->createArray) if (r->fn && r->fn->createArray)
@ -369,6 +423,9 @@ static int processMultiBulkItem(redisReader *r) {
else else
obj = (void*)REDIS_REPLY_ARRAY; obj = (void*)REDIS_REPLY_ARRAY;
if (obj == NULL)
return REDIS_READER_OOM;
/* Modify task stack when there are more than 0 elements. */ /* Modify task stack when there are more than 0 elements. */
if (elements > 0) { if (elements > 0) {
cur->elements = elements; cur->elements = elements;
@ -387,9 +444,10 @@ static int processMultiBulkItem(redisReader *r) {
/* Set reply if this is the root object. */ /* Set reply if this is the root object. */
if (root) r->reply = obj; if (root) r->reply = obj;
return 0; return REDIS_READER_OK;
} }
return -1;
return REDIS_READER_NEED_MORE_DATA;
} }
static int processItem(redisReader *r) { static int processItem(redisReader *r) {
@ -534,6 +592,8 @@ void redisReplyReaderFeed(void *reader, const char *buf, size_t len) {
int redisReplyReaderGetReply(void *reader, void **reply) { int redisReplyReaderGetReply(void *reader, void **reply) {
redisReader *r = reader; redisReader *r = reader;
int ret = REDIS_READER_OK;
if (reply != NULL) *reply = NULL; if (reply != NULL) *reply = NULL;
/* When the buffer is empty, there will never be a reply. */ /* When the buffer is empty, there will never be a reply. */
@ -553,9 +613,13 @@ int redisReplyReaderGetReply(void *reader, void **reply) {
/* Process items in reply. */ /* Process items in reply. */
while (r->ridx >= 0) while (r->ridx >= 0)
if (processItem(r) < 0) if ((ret = processItem(r)) != REDIS_READER_OK)
break; break;
/* Set errors on OOM. */
if (ret == REDIS_READER_OOM)
return REDIS_ERR;
/* Discard part of the buffer when we've consumed at least 1k, to avoid /* Discard part of the buffer when we've consumed at least 1k, to avoid
* doing unnecessary calls to memmove() in sds.c. */ * doing unnecessary calls to memmove() in sds.c. */
if (r->pos >= 1024) { if (r->pos >= 1024) {