diff --git a/CMakeLists.txt b/CMakeLists.txt index 7844575..87a4ac6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,14 +24,14 @@ PROJECT(hiredis VERSION "${VERSION}") SET(ENABLE_EXAMPLES OFF CACHE BOOL "Enable building hiredis examples") SET(hiredis_sources + alloc.c async.c dict.c hiredis.c net.c read.c sds.c - sockcompat.c - alloc.c) + sockcompat.c) SET(hiredis_sources ${hiredis_sources}) diff --git a/Makefile b/Makefile index 438ab8f..f19e760 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # Copyright (C) 2010-2011 Pieter Noordhuis # This file is released under the BSD license, see the COPYING file -OBJ=net.o hiredis.o sds.o async.o read.o sockcompat.o alloc.o +OBJ=alloc.o net.o hiredis.o sds.o async.o read.o sockcompat.o SSL_OBJ=ssl.o EXAMPLES=hiredis-example hiredis-example-libevent hiredis-example-libev hiredis-example-glib ifeq ($(USE_SSL),1) @@ -109,16 +109,16 @@ all: $(SSL_DYLIBNAME) $(SSL_STLIBNAME) $(SSL_PKGCONFNAME) endif # Deps (use make dep to generate this) +alloc.o: alloc.c fmacros.h alloc.h async.o: async.c fmacros.h alloc.h async.h hiredis.h read.h sds.h net.h dict.c dict.h win32.h async_private.h dict.o: dict.c fmacros.h alloc.h dict.h hiredis.o: hiredis.c fmacros.h hiredis.h read.h sds.h alloc.h net.h async.h win32.h -alloc.o: alloc.c alloc.h net.o: net.c fmacros.h net.h hiredis.h read.h sds.h alloc.h sockcompat.h win32.h -read.o: read.c fmacros.h read.h sds.h win32.h -sds.o: sds.c sds.h sdsalloc.h +read.o: read.c fmacros.h alloc.h read.h sds.h win32.h +sds.o: sds.c sds.h sdsalloc.h alloc.h sockcompat.o: sockcompat.c sockcompat.h -ssl.o: ssl.c hiredis.h read.h sds.h alloc.h async.h async_private.h -test.o: test.c fmacros.h hiredis.h read.h sds.h alloc.h net.h +ssl.o: ssl.c hiredis.h read.h sds.h alloc.h async.h win32.h async_private.h +test.o: test.c fmacros.h hiredis.h read.h sds.h alloc.h net.h sockcompat.h win32.h $(DYLIBNAME): $(OBJ) $(DYLIB_MAKE_CMD) -o $(DYLIBNAME) $(OBJ) $(REAL_LDFLAGS) diff --git a/README.md b/README.md index a979110..4e7b228 100644 --- a/README.md +++ b/README.md @@ -404,6 +404,16 @@ This should be done only in order to maximize performances when working with large payloads. The context should be set back to `REDIS_READER_MAX_BUF` again as soon as possible in order to prevent allocation of useless memory. +### Reader max array elements + +By default the hiredis reply parser sets the maximum number of multi-bulk elements +to 2^32 - 1 or 4,294,967,295 entries. If you need to process multi-bulk replies +with more than this many elements you can set the value higher or to zero, meaning +unlimited with: +```c +context->reader->maxelements = 0; +``` + ## SSL/TLS Support ### Building diff --git a/adapters/ae.h b/adapters/ae.h index 0393992..660d82e 100644 --- a/adapters/ae.h +++ b/adapters/ae.h @@ -96,7 +96,7 @@ static void redisAeCleanup(void *privdata) { redisAeEvents *e = (redisAeEvents*)privdata; redisAeDelRead(privdata); redisAeDelWrite(privdata); - free(e); + hi_free(e); } static int redisAeAttach(aeEventLoop *loop, redisAsyncContext *ac) { @@ -109,6 +109,9 @@ static int redisAeAttach(aeEventLoop *loop, redisAsyncContext *ac) { /* Create container for context and r/w events */ e = (redisAeEvents*)hi_malloc(sizeof(*e)); + if (e == NULL) + return REDIS_ERR; + e->context = ac; e->loop = loop; e->fd = c->fd; diff --git a/adapters/glib.h b/adapters/glib.h index e0a6411..ad59dd1 100644 --- a/adapters/glib.h +++ b/adapters/glib.h @@ -134,6 +134,9 @@ redis_source_new (redisAsyncContext *ac) g_return_val_if_fail(ac != NULL, NULL); source = (RedisSource *)g_source_new(&source_funcs, sizeof *source); + if (source == NULL) + return NULL; + source->ac = ac; source->poll_fd.fd = c->fd; source->poll_fd.events = 0; diff --git a/adapters/ivykis.h b/adapters/ivykis.h index 75616ee..179f6ab 100644 --- a/adapters/ivykis.h +++ b/adapters/ivykis.h @@ -43,7 +43,7 @@ static void redisIvykisCleanup(void *privdata) { redisIvykisEvents *e = (redisIvykisEvents*)privdata; iv_fd_unregister(&e->fd); - free(e); + hi_free(e); } static int redisIvykisAttach(redisAsyncContext *ac) { @@ -56,6 +56,9 @@ static int redisIvykisAttach(redisAsyncContext *ac) { /* Create container for context and r/w events */ e = (redisIvykisEvents*)hi_malloc(sizeof(*e)); + if (e == NULL) + return REDIS_ERR; + e->context = ac; /* Register functions to start/stop listening for events */ diff --git a/adapters/libev.h b/adapters/libev.h index 1520923..7057dbd 100644 --- a/adapters/libev.h +++ b/adapters/libev.h @@ -116,7 +116,7 @@ static void redisLibevCleanup(void *privdata) { redisLibevDelRead(privdata); redisLibevDelWrite(privdata); redisLibevStopTimer(privdata); - free(e); + hi_free(e); } static void redisLibevTimeout(EV_P_ ev_timer *timer, int revents) { @@ -149,6 +149,9 @@ static int redisLibevAttach(EV_P_ redisAsyncContext *ac) { /* Create container for context and r/w events */ e = (redisLibevEvents*)hi_calloc(1, sizeof(*e)); + if (e == NULL) + return REDIS_ERR; + e->context = ac; #if EV_MULTIPLICITY e->loop = loop; diff --git a/adapters/libevent.h b/adapters/libevent.h index 0674ca6..9150979 100644 --- a/adapters/libevent.h +++ b/adapters/libevent.h @@ -47,7 +47,7 @@ typedef struct redisLibeventEvents { } redisLibeventEvents; static void redisLibeventDestroy(redisLibeventEvents *e) { - free(e); + hi_free(e); } static void redisLibeventHandler(int fd, short event, void *arg) { @@ -153,6 +153,9 @@ static int redisLibeventAttach(redisAsyncContext *ac, struct event_base *base) { /* Create container for context and r/w events */ e = (redisLibeventEvents*)hi_calloc(1, sizeof(*e)); + if (e == NULL) + return REDIS_ERR; + e->context = ac; /* Register functions to start/stop listening for events */ diff --git a/adapters/libuv.h b/adapters/libuv.h index 7aac127..c120b1b 100644 --- a/adapters/libuv.h +++ b/adapters/libuv.h @@ -73,7 +73,7 @@ static void redisLibuvDelWrite(void *privdata) { static void on_close(uv_handle_t* handle) { redisLibuvEvents* p = (redisLibuvEvents*)handle->data; - free(p); + hi_free(p); } @@ -98,11 +98,9 @@ static int redisLibuvAttach(redisAsyncContext* ac, uv_loop_t* loop) { ac->ev.delWrite = redisLibuvDelWrite; ac->ev.cleanup = redisLibuvCleanup; - redisLibuvEvents* p = (redisLibuvEvents*)malloc(sizeof(*p)); - - if (!p) { - return REDIS_ERR; - } + redisLibuvEvents* p = (redisLibuvEvents*)hi_malloc(sizeof(*p)); + if (p == NULL) + return REDIS_ERR; memset(p, 0, sizeof(*p)); diff --git a/adapters/macosx.h b/adapters/macosx.h index 72121f6..3c87f1b 100644 --- a/adapters/macosx.h +++ b/adapters/macosx.h @@ -27,7 +27,7 @@ static int freeRedisRunLoop(RedisRunLoop* redisRunLoop) { CFSocketInvalidate(redisRunLoop->socketRef); CFRelease(redisRunLoop->socketRef); } - free(redisRunLoop); + hi_free(redisRunLoop); } return REDIS_ERR; } @@ -80,8 +80,9 @@ static int redisMacOSAttach(redisAsyncContext *redisAsyncCtx, CFRunLoopRef runLo /* Nothing should be attached when something is already attached */ if( redisAsyncCtx->ev.data != NULL ) return REDIS_ERR; - RedisRunLoop* redisRunLoop = (RedisRunLoop*) calloc(1, sizeof(RedisRunLoop)); - if( !redisRunLoop ) return REDIS_ERR; + RedisRunLoop* redisRunLoop = (RedisRunLoop*) hi_calloc(1, sizeof(RedisRunLoop)); + if (redisRunLoop == NULL) + return REDIS_ERR; /* Setup redis stuff */ redisRunLoop->context = redisAsyncCtx; diff --git a/alloc.c b/alloc.c index 55c3020..2b19c48 100644 --- a/alloc.c +++ b/alloc.c @@ -31,35 +31,56 @@ #include "fmacros.h" #include "alloc.h" #include +#include + +hiredisAllocFuncs hiredisAllocFns = { + .malloc = malloc, + .calloc = calloc, + .realloc = realloc, + .strdup = strdup, + .free = free, +}; + +/* Override hiredis' allocators with ones supplied by the user */ +hiredisAllocFuncs hiredisSetAllocators(hiredisAllocFuncs *override) { + hiredisAllocFuncs orig = hiredisAllocFns; + + hiredisAllocFns = *override; + + return orig; +} + +/* Reset allocators to use libc defaults */ +void hiredisResetAllocators(void) { + hiredisAllocFns = (hiredisAllocFuncs) { + .malloc = malloc, + .calloc = calloc, + .realloc = realloc, + .strdup = strdup, + .free = free, + }; +} + +#ifdef _WIN32 void *hi_malloc(size_t size) { - void *ptr = malloc(size); - if (ptr == NULL) - HIREDIS_OOM_HANDLER; - - return ptr; + return hiredisAllocFns.malloc(size); } void *hi_calloc(size_t nmemb, size_t size) { - void *ptr = calloc(nmemb, size); - if (ptr == NULL) - HIREDIS_OOM_HANDLER; - - return ptr; + return hiredisAllocFns.calloc(nmemb, size); } void *hi_realloc(void *ptr, size_t size) { - void *newptr = realloc(ptr, size); - if (newptr == NULL) - HIREDIS_OOM_HANDLER; - - return newptr; + return hiredisAllocFns.realloc(ptr, size); } char *hi_strdup(const char *str) { - char *newstr = strdup(str); - if (newstr == NULL) - HIREDIS_OOM_HANDLER; - - return newstr; + return hiredisAllocFns.strdup(str); } + +void hi_free(void *ptr) { + hiredisAllocFns.free(ptr); +} + +#endif diff --git a/alloc.h b/alloc.h index 2c9b04e..b581e56 100644 --- a/alloc.h +++ b/alloc.h @@ -31,23 +31,61 @@ #ifndef HIREDIS_ALLOC_H #define HIREDIS_ALLOC_H -#include /* for size_t */ - -#ifndef HIREDIS_OOM_HANDLER -#define HIREDIS_OOM_HANDLER abort() -#endif +#include /* for size_t */ #ifdef __cplusplus extern "C" { #endif +/* Structure pointing to our actually configured allocators */ +typedef struct hiredisAllocFuncs { + void *(*malloc)(size_t); + void *(*calloc)(size_t,size_t); + void *(*realloc)(void*,size_t); + char *(*strdup)(const char*); + void (*free)(void*); +} hiredisAllocFuncs; + +hiredisAllocFuncs hiredisSetAllocators(hiredisAllocFuncs *ha); +void hiredisResetAllocators(void); + +#ifndef _WIN32 + +/* Hiredis' configured allocator function pointer struct */ +extern hiredisAllocFuncs hiredisAllocFns; + +static inline void *hi_malloc(size_t size) { + return hiredisAllocFns.malloc(size); +} + +static inline void *hi_calloc(size_t nmemb, size_t size) { + return hiredisAllocFns.calloc(nmemb, size); +} + +static inline void *hi_realloc(void *ptr, size_t size) { + return hiredisAllocFns.realloc(ptr, size); +} + +static inline char *hi_strdup(const char *str) { + return hiredisAllocFns.strdup(str); +} + +static inline void hi_free(void *ptr) { + hiredisAllocFns.free(ptr); +} + +#else + void *hi_malloc(size_t size); void *hi_calloc(size_t nmemb, size_t size); void *hi_realloc(void *ptr, size_t size); char *hi_strdup(const char *str); +void hi_free(void *ptr); + +#endif #ifdef __cplusplus } #endif -#endif /* HIREDIS_ALLOC_H */ +#endif /* HIREDIS_ALLOC_H */ diff --git a/async.c b/async.c index 68a656f..261fe24 100644 --- a/async.c +++ b/async.c @@ -47,8 +47,9 @@ #include "async_private.h" -/* Forward declaration of function in hiredis.c */ +/* Forward declarations of hiredis.c functions */ int __redisAppendCommand(redisContext *c, const char *cmd, size_t len); +void __redisSetError(redisContext *c, int type, const char *str); /* Functions managing dictionary of callbacks for pub/sub. */ static unsigned int callbackHash(const void *key) { @@ -58,7 +59,12 @@ static unsigned int callbackHash(const void *key) { static void *callbackValDup(void *privdata, const void *src) { ((void) privdata); - redisCallback *dup = hi_malloc(sizeof(*dup)); + redisCallback *dup; + + dup = hi_malloc(sizeof(*dup)); + if (dup == NULL) + return NULL; + memcpy(dup,src,sizeof(*dup)); return dup; } @@ -80,7 +86,7 @@ static void callbackKeyDestructor(void *privdata, void *key) { static void callbackValDestructor(void *privdata, void *val) { ((void) privdata); - free(val); + hi_free(val); } static dictType callbackDict = { @@ -94,10 +100,19 @@ static dictType callbackDict = { static redisAsyncContext *redisAsyncInitialize(redisContext *c) { redisAsyncContext *ac; + dict *channels = NULL, *patterns = NULL; - ac = realloc(c,sizeof(redisAsyncContext)); + channels = dictCreate(&callbackDict,NULL); + if (channels == NULL) + goto oom; + + patterns = dictCreate(&callbackDict,NULL); + if (patterns == NULL) + goto oom; + + ac = hi_realloc(c,sizeof(redisAsyncContext)); if (ac == NULL) - return NULL; + goto oom; c = &(ac->c); @@ -126,9 +141,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->replies.tail = NULL; ac->sub.invalid.head = NULL; ac->sub.invalid.tail = NULL; - ac->sub.channels = dictCreate(&callbackDict,NULL); - ac->sub.patterns = dictCreate(&callbackDict,NULL); + ac->sub.channels = channels; + ac->sub.patterns = patterns; + return ac; +oom: + if (channels) dictRelease(channels); + if (patterns) dictRelease(patterns); + return NULL; } /* We want the error field to be accessible directly instead of requiring @@ -216,7 +236,7 @@ static int __redisPushCallback(redisCallbackList *list, redisCallback *source) { redisCallback *cb; /* Copy callback from stack to heap */ - cb = malloc(sizeof(*cb)); + cb = hi_malloc(sizeof(*cb)); if (cb == NULL) return REDIS_ERR_OOM; @@ -244,7 +264,7 @@ static int __redisShiftCallback(redisCallbackList *list, redisCallback *target) /* Copy callback from heap to stack */ if (target != NULL) memcpy(target,cb,sizeof(*cb)); - free(cb); + hi_free(cb); return REDIS_OK; } return REDIS_ERR; @@ -275,17 +295,27 @@ static void __redisAsyncFree(redisAsyncContext *ac) { __redisRunCallback(ac,&cb,NULL); /* Run subscription callbacks callbacks with NULL reply */ - it = dictGetIterator(ac->sub.channels); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.channels); + if (ac->sub.channels) { + it = dictGetIterator(ac->sub.channels); + if (it != NULL) { + while ((de = dictNext(it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + dictReleaseIterator(it); + } - it = dictGetIterator(ac->sub.patterns); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.patterns); + dictRelease(ac->sub.channels); + } + + if (ac->sub.patterns) { + it = dictGetIterator(ac->sub.patterns); + if (it != NULL) { + while ((de = dictNext(it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + dictReleaseIterator(it); + } + + dictRelease(ac->sub.patterns); + } /* Signal event lib to clean up */ _EL_CLEANUP(ac); @@ -388,6 +418,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, /* Locate the right callback */ assert(reply->element[1]->type == REDIS_REPLY_STRING); sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); + if (sname == NULL) + goto oom; + de = dictFind(callbacks,sname); if (de != NULL) { cb = dictGetEntryVal(de); @@ -421,6 +454,9 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, __redisShiftCallback(&ac->sub.invalid,dstcb); } return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } void redisProcessCallbacks(redisAsyncContext *ac) { @@ -588,8 +624,6 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) { c->funcs->async_write(ac); } -void __redisSetError(redisContext *c, int type, const char *str); - void redisAsyncHandleTimeout(redisAsyncContext *ac) { redisContext *c = &(ac->c); redisCallback cb; @@ -672,6 +706,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void /* Add every channel/pattern to the list of subscription callbacks. */ while ((p = nextArgument(p,&astr,&alen)) != NULL) { sname = sdsnewlen(astr,alen); + if (sname == NULL) + goto oom; + if (pvariant) cbdict = ac->sub.patterns; else @@ -715,6 +752,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void _EL_ADD_WRITE(ac); return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) { @@ -728,7 +768,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat return REDIS_ERR; status = __redisAsyncCommand(ac,fn,privdata,cmd,len); - free(cmd); + hi_free(cmd); return status; } @@ -758,15 +798,21 @@ int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void return status; } -void redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { +int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { if (!ac->c.timeout) { ac->c.timeout = hi_calloc(1, sizeof(tv)); + if (ac->c.timeout == NULL) { + __redisSetError(&ac->c, REDIS_ERR_OOM, "Out of memory"); + __redisAsyncCopyError(ac); + return REDIS_ERR; + } } - if (tv.tv_sec == ac->c.timeout->tv_sec && - tv.tv_usec == ac->c.timeout->tv_usec) { - return; + if (tv.tv_sec != ac->c.timeout->tv_sec || + tv.tv_usec != ac->c.timeout->tv_usec) + { + *ac->c.timeout = tv; } - *ac->c.timeout = tv; + return REDIS_OK; } diff --git a/async.h b/async.h index 781f1d1..ba301da 100644 --- a/async.h +++ b/async.h @@ -118,7 +118,7 @@ redisAsyncContext *redisAsyncConnectUnix(const char *path); int redisAsyncSetConnectCallback(redisAsyncContext *ac, redisConnectCallback *fn); int redisAsyncSetDisconnectCallback(redisAsyncContext *ac, redisDisconnectCallback *fn); -void redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv); +int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv); void redisAsyncDisconnect(redisAsyncContext *ac); void redisAsyncFree(redisAsyncContext *ac); diff --git a/dict.c b/dict.c index eaf4e4d..34a33ea 100644 --- a/dict.c +++ b/dict.c @@ -73,6 +73,9 @@ static void _dictReset(dict *ht) { /* Create a new hash table */ static dict *dictCreate(dictType *type, void *privDataPtr) { dict *ht = hi_malloc(sizeof(*ht)); + if (ht == NULL) + return NULL; + _dictInit(ht,type,privDataPtr); return ht; } @@ -98,7 +101,9 @@ static int dictExpand(dict *ht, unsigned long size) { _dictInit(&n, ht->type, ht->privdata); n.size = realsize; n.sizemask = realsize-1; - n.table = calloc(realsize,sizeof(dictEntry*)); + n.table = hi_calloc(realsize,sizeof(dictEntry*)); + if (n.table == NULL) + return DICT_ERR; /* Copy all the elements from the old to the new table: * note that if the old hash table is empty ht->size is zero, @@ -125,7 +130,7 @@ static int dictExpand(dict *ht, unsigned long size) { } } assert(ht->used == 0); - free(ht->table); + hi_free(ht->table); /* Remap the new hashtable in the old */ *ht = n; @@ -144,6 +149,9 @@ static int dictAdd(dict *ht, void *key, void *val) { /* Allocates the memory and stores key */ entry = hi_malloc(sizeof(*entry)); + if (entry == NULL) + return DICT_ERR; + entry->next = ht->table[index]; ht->table[index] = entry; @@ -167,6 +175,9 @@ static int dictReplace(dict *ht, void *key, void *val) { return 1; /* It already exists, get the entry */ entry = dictFind(ht, key); + if (entry == NULL) + return 0; + /* Free the old value and set the new one */ /* Set the new value and free the old one. Note that it is important * to do that in this order, as the value may just be exactly the same @@ -200,7 +211,7 @@ static int dictDelete(dict *ht, const void *key) { dictFreeEntryKey(ht,de); dictFreeEntryVal(ht,de); - free(de); + hi_free(de); ht->used--; return DICT_OK; } @@ -223,13 +234,13 @@ static int _dictClear(dict *ht) { nextHe = he->next; dictFreeEntryKey(ht, he); dictFreeEntryVal(ht, he); - free(he); + hi_free(he); ht->used--; he = nextHe; } } /* Free the table and the allocated cache structure */ - free(ht->table); + hi_free(ht->table); /* Re-initialize the table */ _dictReset(ht); return DICT_OK; /* never fails */ @@ -238,7 +249,7 @@ static int _dictClear(dict *ht) { /* Clear & Release the hash table */ static void dictRelease(dict *ht) { _dictClear(ht); - free(ht); + hi_free(ht); } static dictEntry *dictFind(dict *ht, const void *key) { @@ -258,6 +269,8 @@ static dictEntry *dictFind(dict *ht, const void *key) { static dictIterator *dictGetIterator(dict *ht) { dictIterator *iter = hi_malloc(sizeof(*iter)); + if (iter == NULL) + return NULL; iter->ht = ht; iter->index = -1; @@ -287,7 +300,7 @@ static dictEntry *dictNext(dictIterator *iter) { } static void dictReleaseIterator(dictIterator *iter) { - free(iter); + hi_free(iter); } /* ------------------------- private functions ------------------------------ */ diff --git a/hiredis.c b/hiredis.c index 1c9e8ae..3133716 100644 --- a/hiredis.c +++ b/hiredis.c @@ -74,7 +74,7 @@ static redisReplyObjectFunctions defaultFunctions = { /* Create a reply object */ static redisReply *createReplyObject(int type) { - redisReply *r = calloc(1,sizeof(*r)); + redisReply *r = hi_calloc(1,sizeof(*r)); if (r == NULL) return NULL; @@ -101,7 +101,7 @@ void freeReplyObject(void *reply) { if (r->element != NULL) { for (j = 0; j < r->elements; j++) freeReplyObject(r->element[j]); - free(r->element); + hi_free(r->element); } break; case REDIS_REPLY_ERROR: @@ -109,10 +109,10 @@ void freeReplyObject(void *reply) { case REDIS_REPLY_STRING: case REDIS_REPLY_DOUBLE: case REDIS_REPLY_VERB: - free(r->str); + hi_free(r->str); break; } - free(r); + hi_free(r); } static void *createStringObject(const redisReadTask *task, char *str, size_t len) { @@ -130,22 +130,18 @@ static void *createStringObject(const redisReadTask *task, char *str, size_t len /* Copy string value */ if (task->type == REDIS_REPLY_VERB) { - buf = malloc(len-4+1); /* Skip 4 bytes of verbatim type header. */ - if (buf == NULL) { - freeReplyObject(r); - return NULL; - } + buf = hi_malloc(len-4+1); /* Skip 4 bytes of verbatim type header. */ + if (buf == NULL) goto oom; + memcpy(r->vtype,str,3); r->vtype[3] = '\0'; memcpy(buf,str+4,len-4); buf[len-4] = '\0'; r->len = len - 4; } else { - buf = malloc(len+1); - if (buf == NULL) { - freeReplyObject(r); - return NULL; - } + buf = hi_malloc(len+1); + if (buf == NULL) goto oom; + memcpy(buf,str,len); buf[len] = '\0'; r->len = len; @@ -161,6 +157,10 @@ static void *createStringObject(const redisReadTask *task, char *str, size_t len parent->element[task->idx] = r; } return r; + +oom: + freeReplyObject(r); + return NULL; } static void *createArrayObject(const redisReadTask *task, size_t elements) { @@ -171,7 +171,7 @@ static void *createArrayObject(const redisReadTask *task, size_t elements) { return NULL; if (elements > 0) { - r->element = calloc(elements,sizeof(redisReply*)); + r->element = hi_calloc(elements,sizeof(redisReply*)); if (r->element == NULL) { freeReplyObject(r); return NULL; @@ -218,7 +218,7 @@ static void *createDoubleObject(const redisReadTask *task, double value, char *s return NULL; r->dval = value; - r->str = malloc(len+1); + r->str = hi_malloc(len+1); if (r->str == NULL) { freeReplyObject(r); return NULL; @@ -322,7 +322,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) { if (*c != '%' || c[1] == '\0') { if (*c == ' ') { if (touched) { - newargv = realloc(curargv,sizeof(char*)*(argc+1)); + newargv = hi_realloc(curargv,sizeof(char*)*(argc+1)); if (newargv == NULL) goto memory_err; curargv = newargv; curargv[argc++] = curarg; @@ -471,7 +471,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) { /* Add the last argument if needed */ if (touched) { - newargv = realloc(curargv,sizeof(char*)*(argc+1)); + newargv = hi_realloc(curargv,sizeof(char*)*(argc+1)); if (newargv == NULL) goto memory_err; curargv = newargv; curargv[argc++] = curarg; @@ -487,7 +487,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) { totlen += 1+countDigits(argc)+2; /* Build the command at protocol level */ - cmd = malloc(totlen+1); + cmd = hi_malloc(totlen+1); if (cmd == NULL) goto memory_err; pos = sprintf(cmd,"*%d\r\n",argc); @@ -502,7 +502,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) { assert(pos == totlen); cmd[pos] = '\0'; - free(curargv); + hi_free(curargv); *target = cmd; return totlen; @@ -518,11 +518,11 @@ cleanup: if (curargv) { while(argc--) sdsfree(curargv[argc]); - free(curargv); + hi_free(curargv); } sdsfree(curarg); - free(cmd); + hi_free(cmd); return error_type; } @@ -563,7 +563,7 @@ int redisFormatCommand(char **target, const char *format, ...) { int redisFormatSdsCommandArgv(sds *target, int argc, const char **argv, const size_t *argvlen) { - sds cmd; + sds cmd, aux; unsigned long long totlen; int j; size_t len; @@ -585,9 +585,13 @@ int redisFormatSdsCommandArgv(sds *target, int argc, const char **argv, return -1; /* We already know how much storage we need */ - cmd = sdsMakeRoomFor(cmd, totlen); - if (cmd == NULL) + aux = sdsMakeRoomFor(cmd, totlen); + if (aux == NULL) { + sdsfree(cmd); return -1; + } + + cmd = aux; /* Construct command */ cmd = sdscatfmt(cmd, "*%i\r\n", argc); @@ -631,7 +635,7 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz } /* Build the command at protocol level */ - cmd = malloc(totlen+1); + cmd = hi_malloc(totlen+1); if (cmd == NULL) return -1; @@ -652,7 +656,7 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz } void redisFreeCommand(char *cmd) { - free(cmd); + hi_free(cmd); } void __redisSetError(redisContext *c, int type, const char *str) { @@ -678,7 +682,7 @@ redisReader *redisReaderCreate(void) { static redisContext *redisContextInit(const redisOptions *options) { redisContext *c; - c = calloc(1, sizeof(*c)); + c = hi_calloc(1, sizeof(*c)); if (c == NULL) return NULL; @@ -702,16 +706,16 @@ void redisFree(redisContext *c) { sdsfree(c->obuf); redisReaderFree(c->reader); - free(c->tcp.host); - free(c->tcp.source_addr); - free(c->unix_sock.path); - free(c->timeout); - free(c->saddr); + hi_free(c->tcp.host); + hi_free(c->tcp.source_addr); + hi_free(c->unix_sock.path); + hi_free(c->timeout); + hi_free(c->saddr); if (c->funcs->free_privdata) { c->funcs->free_privdata(c->privdata); } memset(c, 0xff, sizeof(*c)); - free(c); + hi_free(c); } redisFD redisFreeKeepFd(redisContext *c) { @@ -738,6 +742,11 @@ int redisReconnect(redisContext *c) { c->obuf = sdsempty(); c->reader = redisReaderCreate(); + if (c->obuf == NULL || c->reader == NULL) { + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; + } + if (c->connection_type == REDIS_CONN_TCP) { return redisContextConnectBindTcp(c, c->tcp.host, c->tcp.port, c->timeout, c->tcp.source_addr); @@ -918,6 +927,8 @@ int redisBufferWrite(redisContext *c, int *done) { if (nwritten == (signed)sdslen(c->obuf)) { sdsfree(c->obuf); c->obuf = sdsempty(); + if (c->obuf == NULL) + goto oom; } else { sdsrange(c->obuf,nwritten,-1); } @@ -925,6 +936,10 @@ int redisBufferWrite(redisContext *c, int *done) { } if (done != NULL) *done = (sdslen(c->obuf) == 0); return REDIS_OK; + +oom: + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } /* Internal helper function to try and get a reply from the reader, @@ -1015,11 +1030,11 @@ int redisvAppendCommand(redisContext *c, const char *format, va_list ap) { } if (__redisAppendCommand(c,cmd,len) != REDIS_OK) { - free(cmd); + hi_free(cmd); return REDIS_ERR; } - free(cmd); + hi_free(cmd); return REDIS_OK; } diff --git a/net.c b/net.c index bdb8943..54d904d 100644 --- a/net.c +++ b/net.c @@ -328,6 +328,22 @@ int redisContextSetTimeout(redisContext *c, const struct timeval tv) { return REDIS_OK; } +static int _redisContextUpdateTimeout(redisContext *c, const struct timeval *timeout) { + /* Same timeval struct, short circuit */ + if (c->timeout == timeout) + return REDIS_OK; + + /* Allocate context timeval if we need to */ + if (c->timeout == NULL) { + c->timeout = hi_malloc(sizeof(*c->timeout)); + if (c->timeout == NULL) + return REDIS_ERR; + } + + memcpy(c->timeout, timeout, sizeof(*c->timeout)); + return REDIS_OK; +} + static int _redisContextConnectTcp(redisContext *c, const char *addr, int port, const struct timeval *timeout, const char *source_addr) { @@ -352,20 +368,18 @@ static int _redisContextConnectTcp(redisContext *c, const char *addr, int port, * This is a bit ugly, but atleast it works and doesn't leak memory. **/ if (c->tcp.host != addr) { - free(c->tcp.host); + hi_free(c->tcp.host); c->tcp.host = hi_strdup(addr); + if (c->tcp.host == NULL) + goto oom; } if (timeout) { - if (c->timeout != timeout) { - if (c->timeout == NULL) - c->timeout = hi_malloc(sizeof(struct timeval)); - - memcpy(c->timeout, timeout, sizeof(struct timeval)); - } + if (_redisContextUpdateTimeout(c, timeout) == REDIS_ERR) + goto oom; } else { - free(c->timeout); + hi_free(c->timeout); c->timeout = NULL; } @@ -375,10 +389,10 @@ static int _redisContextConnectTcp(redisContext *c, const char *addr, int port, } if (source_addr == NULL) { - free(c->tcp.source_addr); + hi_free(c->tcp.source_addr); c->tcp.source_addr = NULL; } else if (c->tcp.source_addr != source_addr) { - free(c->tcp.source_addr); + hi_free(c->tcp.source_addr); c->tcp.source_addr = hi_strdup(source_addr); } @@ -442,8 +456,11 @@ addrretry: } /* For repeat connection */ - free(c->saddr); + hi_free(c->saddr); c->saddr = hi_malloc(p->ai_addrlen); + if (c->saddr == NULL) + goto oom; + memcpy(c->saddr, p->ai_addr, p->ai_addrlen); c->addrlen = p->ai_addrlen; @@ -488,6 +505,8 @@ addrretry: goto error; } +oom: + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); error: rv = REDIS_ERR; end: @@ -521,25 +540,32 @@ int redisContextConnectUnix(redisContext *c, const char *path, const struct time return REDIS_ERR; c->connection_type = REDIS_CONN_UNIX; - if (c->unix_sock.path != path) + if (c->unix_sock.path != path) { + hi_free(c->unix_sock.path); + c->unix_sock.path = hi_strdup(path); + if (c->unix_sock.path == NULL) + goto oom; + } if (timeout) { - if (c->timeout != timeout) { - if (c->timeout == NULL) - c->timeout = hi_malloc(sizeof(struct timeval)); - - memcpy(c->timeout, timeout, sizeof(struct timeval)); - } + if (_redisContextUpdateTimeout(c, timeout) == REDIS_ERR) + goto oom; } else { - free(c->timeout); + hi_free(c->timeout); c->timeout = NULL; } if (redisContextTimeoutMsec(c,&timeout_msec) != REDIS_OK) return REDIS_ERR; + /* Don't leak sockaddr if we're reconnecting */ + if (c->saddr) hi_free(c->saddr); + sa = (struct sockaddr_un*)(c->saddr = hi_malloc(sizeof(struct sockaddr_un))); + if (sa == NULL) + goto oom; + c->addrlen = sizeof(struct sockaddr_un); sa->sun_family = AF_UNIX; strncpy(sa->sun_path, path, sizeof(sa->sun_path) - 1); @@ -564,4 +590,7 @@ int redisContextConnectUnix(redisContext *c, const char *path, const struct time errno = EPROTONOSUPPORT; return REDIS_ERR; #endif /* _WIN32 */ +oom: + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; } diff --git a/read.c b/read.c index 4924014..d10b62a 100644 --- a/read.c +++ b/read.c @@ -42,6 +42,7 @@ #include #include +#include "alloc.h" #include "read.h" #include "sds.h" #include "win32.h" @@ -425,7 +426,7 @@ static int redisReaderGrow(redisReader *r) { /* Grow our stack size */ newlen = r->tasks + REDIS_READER_STACK_SIZE; - aux = realloc(r->task, sizeof(*r->task) * newlen); + aux = hi_realloc(r->task, sizeof(*r->task) * newlen); if (aux == NULL) goto oom; @@ -433,7 +434,7 @@ static int redisReaderGrow(redisReader *r) { /* Allocate new tasks */ for (; r->tasks < newlen; r->tasks++) { - r->task[r->tasks] = calloc(1, sizeof(**r->task)); + r->task[r->tasks] = hi_calloc(1, sizeof(**r->task)); if (r->task[r->tasks] == NULL) goto oom; } @@ -467,7 +468,9 @@ static int processAggregateItem(redisReader *r) { root = (r->ridx == 0); - if (elements < -1 || (LLONG_MAX > SIZE_MAX && elements > SIZE_MAX)) { + if (elements < -1 || (LLONG_MAX > SIZE_MAX && elements > SIZE_MAX) || + (r->maxelements > 0 && elements > r->maxelements)) + { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Multi-bulk length out of range"); return REDIS_ERR; @@ -602,7 +605,7 @@ static int processItem(redisReader *r) { redisReader *redisReaderCreateWithFunctions(redisReplyObjectFunctions *fn) { redisReader *r; - r = calloc(1,sizeof(redisReader)); + r = hi_calloc(1,sizeof(redisReader)); if (r == NULL) return NULL; @@ -610,22 +613,22 @@ redisReader *redisReaderCreateWithFunctions(redisReplyObjectFunctions *fn) { if (r->buf == NULL) goto oom; - r->task = calloc(REDIS_READER_STACK_SIZE, sizeof(*r->task)); + r->task = hi_calloc(REDIS_READER_STACK_SIZE, sizeof(*r->task)); if (r->task == NULL) goto oom; for (; r->tasks < REDIS_READER_STACK_SIZE; r->tasks++) { - r->task[r->tasks] = calloc(1, sizeof(**r->task)); + r->task[r->tasks] = hi_calloc(1, sizeof(**r->task)); if (r->task[r->tasks] == NULL) goto oom; } r->fn = fn; r->maxbuf = REDIS_READER_MAX_BUF; - + r->maxelements = REDIS_READER_MAX_ARRAY_ELEMENTS; r->ridx = -1; - return r; + return r; oom: redisReaderFree(r); return NULL; @@ -640,14 +643,14 @@ void redisReaderFree(redisReader *r) { /* We know r->task[i] is allocatd if i < r->tasks */ for (int i = 0; i < r->tasks; i++) { - free(r->task[i]); + hi_free(r->task[i]); } if (r->task) - free(r->task); + hi_free(r->task); sdsfree(r->buf); - free(r); + hi_free(r); } int redisReaderFeed(redisReader *r, const char *buf, size_t len) { @@ -663,23 +666,22 @@ int redisReaderFeed(redisReader *r, const char *buf, size_t len) { if (r->len == 0 && r->maxbuf != 0 && sdsavail(r->buf) > r->maxbuf) { sdsfree(r->buf); r->buf = sdsempty(); - r->pos = 0; + if (r->buf == 0) goto oom; - /* r->buf should not be NULL since we just free'd a larger one. */ - assert(r->buf != NULL); + r->pos = 0; } newbuf = sdscatlen(r->buf,buf,len); - if (newbuf == NULL) { - __redisReaderSetErrorOOM(r); - return REDIS_ERR; - } + if (newbuf == NULL) goto oom; r->buf = newbuf; r->len = sdslen(r->buf); } return REDIS_OK; +oom: + __redisReaderSetErrorOOM(r); + return REDIS_ERR; } int redisReaderGetReply(redisReader *r, void **reply) { diff --git a/read.h b/read.h index 788c401..2d74d77 100644 --- a/read.h +++ b/read.h @@ -63,7 +63,11 @@ #define REDIS_REPLY_BIGNUM 13 #define REDIS_REPLY_VERB 14 -#define REDIS_READER_MAX_BUF (1024*16) /* Default max unused reader buffer. */ +/* Default max unused reader buffer. */ +#define REDIS_READER_MAX_BUF (1024*16) + +/* Default multi-bulk element limit */ +#define REDIS_READER_MAX_ARRAY_ELEMENTS ((1LL<<32) - 1) #ifdef __cplusplus extern "C" { @@ -71,7 +75,7 @@ extern "C" { typedef struct redisReadTask { int type; - int elements; /* number of elements in multibulk container */ + long long elements; /* number of elements in multibulk container */ int idx; /* index in parent (array) object */ void *obj; /* holds user-generated value for a read task */ struct redisReadTask *parent; /* parent task */ @@ -96,6 +100,7 @@ typedef struct redisReader { size_t pos; /* Buffer cursor */ size_t len; /* Buffer length */ size_t maxbuf; /* Max length of unused buffer */ + long long maxelements; /* Max multi-bulk elements */ redisReadTask **task; int tasks; diff --git a/sds.c b/sds.c index e8b996a..f7811a7 100644 --- a/sds.c +++ b/sds.c @@ -30,6 +30,7 @@ * POSSIBILITY OF SUCH DAMAGE. */ +#include "fmacros.h" #include #include #include @@ -219,10 +220,7 @@ sds sdsMakeRoomFor(sds s, size_t addlen) { hdrlen = sdsHdrSize(type); if (oldtype==type) { newsh = s_realloc(sh, hdrlen+newlen+1); - if (newsh == NULL) { - s_free(sh); - return NULL; - } + if (newsh == NULL) return NULL; s = (char*)newsh+hdrlen; } else { /* Since the header size changes, need to move the string forward, diff --git a/sdsalloc.h b/sdsalloc.h index f43023c..5538dd9 100644 --- a/sdsalloc.h +++ b/sdsalloc.h @@ -37,6 +37,8 @@ * the include of your alternate allocator if needed (not needed in order * to use the default libc allocator). */ -#define s_malloc malloc -#define s_realloc realloc -#define s_free free +#include "alloc.h" + +#define s_malloc hi_malloc +#define s_realloc hi_realloc +#define s_free hi_free diff --git a/ssl.c b/ssl.c index 502814d..8cb133a 100644 --- a/ssl.c +++ b/ssl.c @@ -163,18 +163,22 @@ static void opensslDoLock(int mode, int lkid, const char *f, int line) { (void)line; } -static void initOpensslLocks(void) { +static int initOpensslLocks(void) { unsigned ii, nlocks; if (CRYPTO_get_locking_callback() != NULL) { /* Someone already set the callback before us. Don't destroy it! */ - return; + return REDIS_OK; } nlocks = CRYPTO_num_locks(); ossl_locks = hi_malloc(sizeof(*ossl_locks) * nlocks); + if (ossl_locks == NULL) + return REDIS_ERR; + for (ii = 0; ii < nlocks; ii++) { sslLockInit(ossl_locks + ii); } CRYPTO_set_locking_callback(opensslDoLock); + return REDIS_OK; } #endif /* HIREDIS_USE_CRYPTO_LOCKS */ @@ -183,15 +187,20 @@ static void initOpensslLocks(void) { */ static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) { + redisSSLContext *rssl; + if (c->privdata) { __redisSetError(c, REDIS_ERR_OTHER, "redisContext was already associated"); return REDIS_ERR; } - c->privdata = calloc(1, sizeof(redisSSLContext)); + + rssl = hi_calloc(1, sizeof(redisSSLContext)); + if (rssl == NULL) { + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); + return REDIS_ERR; + } c->funcs = &redisContextSSLFuncs; - redisSSLContext *rssl = c->privdata; - rssl->ssl_ctx = ssl_ctx; rssl->ssl = ssl; @@ -202,12 +211,14 @@ static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) { ERR_clear_error(); int rv = SSL_connect(rssl->ssl); if (rv == 1) { + c->privdata = rssl; return REDIS_OK; } rv = SSL_get_error(rssl->ssl, rv); if (((c->flags & REDIS_BLOCK) == 0) && (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) { + c->privdata = rssl; return REDIS_OK; } @@ -222,6 +233,8 @@ static int redisSSLConnect(redisContext *c, SSL_CTX *ssl_ctx, SSL *ssl) { } __redisSetError(c, REDIS_ERR_IO, err); } + + hi_free(rssl); return REDIS_ERR; } @@ -241,7 +254,10 @@ int redisSecureConnection(redisContext *c, const char *capath, isInit = 1; SSL_library_init(); #ifdef HIREDIS_USE_CRYPTO_LOCKS - initOpensslLocks(); + if (initOpensslLocks() == REDIS_ERR) { + __redisSetError(c, REDIS_ERR_OOM, "Out of memory"); + goto error; + } #endif } @@ -290,7 +306,8 @@ int redisSecureConnection(redisContext *c, const char *capath, } } - return redisSSLConnect(c, ssl_ctx, ssl); + if (redisSSLConnect(c, ssl_ctx, ssl) == REDIS_OK) + return REDIS_OK; error: if (ssl) SSL_free(ssl); @@ -330,7 +347,7 @@ static void redisSSLFreeContext(void *privdata){ SSL_CTX_free(rsc->ssl_ctx); rsc->ssl_ctx = NULL; } - free(rsc); + hi_free(rsc); } static int redisSSLRead(redisContext *c, char *buf, size_t bufcap) { diff --git a/test.c b/test.c index 1d38caa..48d36d0 100644 --- a/test.c +++ b/test.c @@ -180,43 +180,43 @@ static void test_format_commands(void) { len = redisFormatCommand(&cmd,"SET foo bar"); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(3+2)); - free(cmd); + hi_free(cmd); test("Format command with %%s string interpolation: "); len = redisFormatCommand(&cmd,"SET %s %s","foo","bar"); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(3+2)); - free(cmd); + hi_free(cmd); test("Format command with %%s and an empty string: "); len = redisFormatCommand(&cmd,"SET %s %s","foo",""); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$0\r\n\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(0+2)); - free(cmd); + hi_free(cmd); test("Format command with an empty string in between proper interpolations: "); len = redisFormatCommand(&cmd,"SET %s %s","","foo"); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$0\r\n\r\n$3\r\nfoo\r\n",len) == 0 && len == 4+4+(3+2)+4+(0+2)+4+(3+2)); - free(cmd); + hi_free(cmd); test("Format command with %%b string interpolation: "); len = redisFormatCommand(&cmd,"SET %b %b","foo",(size_t)3,"b\0r",(size_t)3); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nb\0r\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(3+2)); - free(cmd); + hi_free(cmd); test("Format command with %%b and an empty string: "); len = redisFormatCommand(&cmd,"SET %b %b","foo",(size_t)3,"",(size_t)0); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$0\r\n\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(0+2)); - free(cmd); + hi_free(cmd); test("Format command with literal %%: "); len = redisFormatCommand(&cmd,"SET %% %%"); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$1\r\n%\r\n$1\r\n%\r\n",len) == 0 && len == 4+4+(3+2)+4+(1+2)+4+(1+2)); - free(cmd); + hi_free(cmd); /* Vararg width depends on the type. These tests make sure that the * width is correctly determined using the format and subsequent varargs @@ -227,7 +227,7 @@ static void test_format_commands(void) { len = redisFormatCommand(&cmd,"key:%08" fmt " str:%s", value, "hello"); \ test_cond(strncmp(cmd,"*2\r\n$12\r\nkey:00000123\r\n$9\r\nstr:hello\r\n",len) == 0 && \ len == 4+5+(12+2)+4+(9+2)); \ - free(cmd); \ + hi_free(cmd); \ } while(0) #define FLOAT_WIDTH_TEST(type) do { \ @@ -236,7 +236,7 @@ static void test_format_commands(void) { len = redisFormatCommand(&cmd,"key:%08.3f str:%s", value, "hello"); \ test_cond(strncmp(cmd,"*2\r\n$12\r\nkey:0123.000\r\n$9\r\nstr:hello\r\n",len) == 0 && \ len == 4+5+(12+2)+4+(9+2)); \ - free(cmd); \ + hi_free(cmd); \ } while(0) INTEGER_WIDTH_TEST("d", int); @@ -267,13 +267,13 @@ static void test_format_commands(void) { len = redisFormatCommandArgv(&cmd,argc,argv,NULL); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",len) == 0 && len == 4+4+(3+2)+4+(3+2)+4+(3+2)); - free(cmd); + hi_free(cmd); test("Format command by passing argc/argv with lengths: "); len = redisFormatCommandArgv(&cmd,argc,argv,lens); test_cond(strncmp(cmd,"*3\r\n$3\r\nSET\r\n$7\r\nfoo\0xxx\r\n$3\r\nbar\r\n",len) == 0 && len == 4+4+(3+2)+4+(7+2)+4+(3+2)); - free(cmd); + hi_free(cmd); sds sds_cmd; @@ -308,7 +308,7 @@ static void test_append_formatted_commands(struct config config) { assert(redisGetReply(c, (void*)&reply) == REDIS_OK); - free(cmd); + hi_free(cmd); freeReplyObject(reply); disconnect(c, 0); @@ -418,6 +418,16 @@ static void test_reply_reader(void) { freeReplyObject(reply); redisReaderFree(reader); + test("Can configure maximum multi-bulk elements: "); + reader = redisReaderCreate(); + reader->maxelements = 1024; + redisReaderFeed(reader, "*1025\r\n", 7); + ret = redisReaderGetReply(reader,&reply); + test_cond(ret == REDIS_ERR && + strcasecmp(reader->errstr, "Multi-bulk length out of range") == 0); + freeReplyObject(reply); + redisReaderFree(reader); + #if LLONG_MAX > SIZE_MAX test("Set error when array > SIZE_MAX: "); reader = redisReaderCreate(); @@ -518,6 +528,46 @@ static void test_free_null(void) { test_cond(reply == NULL); } +static void *hi_malloc_fail(size_t size) { + (void)size; + return NULL; +} + +static void *hi_calloc_fail(size_t nmemb, size_t size) { + (void)nmemb; + (void)size; + return NULL; +} + +static void *hi_realloc_fail(void *ptr, size_t size) { + (void)ptr; + (void)size; + return NULL; +} + +static void test_allocator_injection(void) { + hiredisAllocFuncs ha = { + .malloc = hi_malloc_fail, + .calloc = hi_calloc_fail, + .realloc = hi_realloc_fail, + .free = NULL, + }; + + // Override hiredis allocators + hiredisSetAllocators(&ha); + + test("redisContext uses injected allocators: "); + redisContext *c = redisConnect("localhost", 6379); + test_cond(c == NULL); + + test("redisReader uses injected allocators: "); + redisReader *reader = redisReaderCreate(); + test_cond(reader == NULL); + + // Return allocators to default + hiredisResetAllocators(); +} + #define HIREDIS_BAD_DOMAIN "idontexist-noreally.com" static void test_blocking_connection_errors(void) { redisContext *c; @@ -799,6 +849,18 @@ static void test_invalid_timeout_errors(struct config config) { redisFree(c); } +/* Wrap malloc to abort on failure so OOM checks don't make the test logic + * harder to follow. */ +void *hi_malloc_safe(size_t size) { + void *ptr = hi_malloc(size); + if (ptr == NULL) { + fprintf(stderr, "Error: Out of memory\n"); + exit(-1); + } + + return ptr; +} + static void test_throughput(struct config config) { redisContext *c = do_connect(config); redisReply **replies; @@ -810,7 +872,7 @@ static void test_throughput(struct config config) { freeReplyObject(redisCommand(c,"LPUSH mylist foo")); num = 1000; - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); t1 = usec(); for (i = 0; i < num; i++) { replies[i] = redisCommand(c,"PING"); @@ -818,10 +880,10 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx PING: %.3fs)\n", num, (t2-t1)/1000000.0); - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); t1 = usec(); for (i = 0; i < num; i++) { replies[i] = redisCommand(c,"LRANGE mylist 0 499"); @@ -830,10 +892,10 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx LRANGE with 500 elements: %.3fs)\n", num, (t2-t1)/1000000.0); - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); t1 = usec(); for (i = 0; i < num; i++) { replies[i] = redisCommand(c, "INCRBY incrkey %d", 1000000); @@ -841,11 +903,11 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx INCRBY: %.3fs)\n", num, (t2-t1)/1000000.0); num = 10000; - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); for (i = 0; i < num; i++) redisAppendCommand(c,"PING"); t1 = usec(); @@ -855,10 +917,10 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx PING (pipelined): %.3fs)\n", num, (t2-t1)/1000000.0); - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); for (i = 0; i < num; i++) redisAppendCommand(c,"LRANGE mylist 0 499"); t1 = usec(); @@ -869,10 +931,10 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx LRANGE with 500 elements (pipelined): %.3fs)\n", num, (t2-t1)/1000000.0); - replies = malloc(sizeof(redisReply*)*num); + replies = hi_malloc_safe(sizeof(redisReply*)*num); for (i = 0; i < num; i++) redisAppendCommand(c,"INCRBY incrkey %d", 1000000); t1 = usec(); @@ -882,7 +944,7 @@ static void test_throughput(struct config config) { } t2 = usec(); for (i = 0; i < num; i++) freeReplyObject(replies[i]); - free(replies); + hi_free(replies); printf("\t(%dx INCRBY (pipelined): %.3fs)\n", num, (t2-t1)/1000000.0); disconnect(c, 0); @@ -1049,11 +1111,14 @@ int main(int argc, char **argv) { signal(SIGPIPE, SIG_IGN); test_unix_socket = access(cfg.unix_sock.path, F_OK) == 0; + #else /* Unix sockets don't exist in Windows */ test_unix_socket = 0; #endif + test_allocator_injection(); + test_format_commands(); test_reply_reader(); test_blocking_connection_errors();