diff --git a/example.c b/example.c index c3f1473..96283c6 100644 --- a/example.c +++ b/example.c @@ -5,13 +5,13 @@ #include "hiredis.h" int main(void) { - int fd; + redisContext *fd; unsigned int j; redisReply *reply; - reply = redisConnect(&fd, "127.0.0.1", 6379); - if (reply != NULL) { - printf("Connection error: %s", reply->reply); + fd = redisConnect((char*)"127.0.0.1", 6379, NULL); + if (fd->error != NULL) { + printf("Connection error: %s", ((redisReply*)fd->error)->reply); exit(1); } diff --git a/hiredis.c b/hiredis.c index 6ab8063..6dc3e1d 100644 --- a/hiredis.c +++ b/hiredis.c @@ -33,6 +33,7 @@ #include #include #include +#include #include "hiredis.h" #include "anet.h" @@ -50,9 +51,8 @@ typedef struct redisReader { unsigned int rpos; /* list cursor */ } redisReader; -static redisReply *redisReadReply(int fd); static redisReply *createReplyObject(int type, sds reply); -static void *createErrorObject(redisReader *context, const char *fmt, ...); +static void *createErrorObject(const char *str, size_t len); static void *createStringObject(redisReadTask *task, char *str, size_t len); static void *createArrayObject(redisReadTask *task, int elements); static void *createIntegerObject(redisReadTask *task, long long value); @@ -61,6 +61,7 @@ static void redisSetReplyReaderError(redisReader *r, void *obj); /* Default set of functions to build the reply. */ static redisReplyFunctions defaultFunctions = { + createErrorObject, createStringObject, createArrayObject, createIntegerObject, @@ -74,20 +75,6 @@ static void redisOOM(void) { exit(1); } -/* Connect to a Redis instance. On success NULL is returned and *fd is set - * to the socket file descriptor. On error a redisReply object is returned - * with reply->type set to REDIS_REPLY_ERROR and reply->string containing - * the error message. This replyObject must be freed with redisFreeReply(). */ -redisReply *redisConnect(int *fd, const char *ip, int port) { - char err[ANET_ERR_LEN]; - - *fd = anetTcpConnect(err,ip,port); - if (*fd == ANET_ERR) - return (redisReply*)createErrorObject(NULL,err); - anetTcpNoDelay(NULL,*fd); - return NULL; -} - /* Create a reply object */ static redisReply *createReplyObject(int type, sds reply) { redisReply *r = calloc(sizeof(*r),1); @@ -119,28 +106,27 @@ void freeReplyObject(void *reply) { free(r); } -static void *createErrorObject(redisReader *context, const char *fmt, ...) { +/* Helper function that allows printf-like creation of error objects. */ +static void *formatError(redisReplyFunctions *fn, const char *fmt, ...) { va_list ap; sds err; void *obj; - redisReadTask t = { REDIS_PROTOCOL_ERROR, NULL, -1 }; va_start(ap,fmt); err = sdscatvprintf(sdsempty(),fmt,ap); va_end(ap); - - /* Use the context of the reader if it is provided. */ - if (context) - obj = context->fn->createString(&t,err,sdslen(err)); - else - obj = createStringObject(&t,err,sdslen(err)); + obj = fn->createError(err,sdslen(err)); sdsfree(err); return obj; } +static void *createErrorObject(const char *str, size_t len) { + redisReply *r = createReplyObject(REDIS_ERROR,sdsnewlen(str,len)); + return r; +} + static void *createStringObject(redisReadTask *task, char *str, size_t len) { redisReply *r = createReplyObject(task->type,sdsnewlen(str,len)); - assert(task->type == REDIS_PROTOCOL_ERROR || - task->type == REDIS_REPLY_ERROR || + assert(task->type == REDIS_REPLY_ERROR || task->type == REDIS_REPLY_STATUS || task->type == REDIS_REPLY_STRING); @@ -322,6 +308,7 @@ static int processItem(redisReader *r) { redisReadTask *cur = &(r->rlist[r->rpos]); char *p; sds byte; + void *err; /* check if we need to read type */ if (cur->type < 0) { @@ -344,8 +331,9 @@ static int processItem(redisReader *r) { break; default: byte = sdscatrepr(sdsempty(),p,1); - redisSetReplyReaderError(r,createErrorObject(r, - "protocol error, got %s as reply type byte", byte)); + err = formatError(r->fn, + "protocol error, got %s as reply type byte", byte); + redisSetReplyReaderError(r,err); sdsfree(byte); return -1; } @@ -366,33 +354,12 @@ static int processItem(redisReader *r) { case REDIS_REPLY_ARRAY: return processMultiBulkItem(r); default: - redisSetReplyReaderError(r,createErrorObject(r, - "unknown item type '%d'", cur->type)); + err = formatError(r->fn,"unknown item type '%d'", cur->type); + redisSetReplyReaderError(r,err); return -1; } } -#define READ_BUFFER_SIZE 2048 -static redisReply *redisReadReply(int fd) { - void *reader = redisReplyReaderCreate(&defaultFunctions); - redisReply *reply; - char buf[1024]; - int nread; - - do { - if ((nread = read(fd,buf,sizeof(buf))) <= 0) { - reply = createErrorObject(reader,"I/O error"); - break; - } else { - redisReplyReaderFeed(reader,buf,nread); - reply = redisReplyReaderGetReply(reader); - } - } while (reply == NULL); - - redisReplyReaderFree(reader); - return reply; -} - void *redisReplyReaderCreate(redisReplyFunctions *fn) { redisReader *r = calloc(sizeof(redisReader),1); r->fn = fn == NULL ? &defaultFunctions : fn; @@ -592,17 +559,143 @@ static sds redisFormatCommand(const char *format, va_list ap) { sdsfree(argv[j]); } free(argv); + return cmd; } -redisReply *redisCommand(int fd, const char *format, ...) { +static int redisContextConnect(redisContext *c, const char *ip, int port) { + char err[ANET_ERR_LEN]; + if (c->flags & HIREDIS_BLOCK) { + c->fd = anetTcpConnect(err,(char*)ip,port); + } else { + c->fd = anetTcpNonBlockConnect(err,(char*)ip,port); + } + + if (c->fd == ANET_ERR) { + c->error = c->fn->createError(err,strlen(err)); + return HIREDIS_ERR; + } + if (anetTcpNoDelay(err,c->fd) == ANET_ERR) { + c->error = c->fn->createError(err,strlen(err)); + return HIREDIS_ERR; + } + return HIREDIS_OK; +} + +static redisContext *redisContextInit(redisReplyFunctions *fn) { + redisContext *c = malloc(sizeof(*c)); + c->fn = fn == NULL ? &defaultFunctions : fn; + c->obuf = sdsempty(); + c->error = NULL; + c->reader = redisReplyReaderCreate(fn); + return c; +} + +/* Connect to a Redis instance. On error the field error in the returned + * context will be set to the return value of the error function. + * When no set of reply functions is given, the default set will be used. */ +redisContext *redisConnect(const char *ip, int port, redisReplyFunctions *fn) { + redisContext *c = redisContextInit(fn); + c->flags |= HIREDIS_BLOCK; + redisContextConnect(c,ip,port); + return c; +} + +redisContext *redisConnectNonBlock(const char *ip, int port, redisReplyFunctions *fn) { + redisContext *c = redisContextInit(fn); + c->flags &= ~HIREDIS_BLOCK; + redisContextConnect(c,ip,port); + return c; +} + +/* Use this function to handle a read event on the descriptor. It will try + * and read some bytes from the socket and feed them to the reply parser. + * + * After this function is called, you may use redisContextReadReply to + * see if there is a reply available. */ +int redisBufferRead(redisContext *c) { + char buf[2048]; + int nread = read(c->fd,buf,sizeof(buf)); + if (nread == -1) { + if (errno == EAGAIN) { + /* Try again later */ + } else { + /* Set error in context */ + c->error = formatError(c->fn, + "Error reading from socket: %s", strerror(errno)); + return HIREDIS_ERR; + } + } else if (nread == 0) { + c->error = formatError(c->fn, + "Server closed the connection"); + return HIREDIS_ERR; + } else { + redisReplyReaderFeed(c->reader,buf,nread); + } + return HIREDIS_OK; +} + +void *redisGetReply(redisContext *c) { + return redisReplyReaderGetReply(c->reader); +} + +/* Use this function to try and write the entire output buffer to the + * descriptor. Returns 1 when the entire buffer was written, 0 otherwise. */ +int redisBufferWrite(redisContext *c, int *done) { + int nwritten = write(c->fd,c->obuf,sdslen(c->obuf)); + if (nwritten == -1) { + if (errno == EAGAIN) { + /* Try again later */ + } else { + /* Set error in context */ + c->error = formatError(c->fn, + "Error writing to socket: %s", strerror(errno)); + return HIREDIS_ERR; + } + } else if (nwritten > 0) { + if (nwritten == (signed)sdslen(c->obuf)) { + sdsfree(c->obuf); + c->obuf = sdsempty(); + } else { + c->obuf = sdsrange(c->obuf,nwritten,-1); + } + } + if (done != NULL) *done = (sdslen(c->obuf) == 0); + return HIREDIS_OK; +} + +static void* redisCommandWrite(redisContext *c, char *str, size_t len) { + void *reply = NULL; + int wdone = 0; + c->obuf = sdscatlen(c->obuf,str,len); + + /* Only take action when this is a blocking context. */ + if (c->flags & HIREDIS_BLOCK) { + do { /* Write until done. */ + if (redisBufferWrite(c,&wdone) == HIREDIS_ERR) + return c->error; + } while (!wdone); + + do { /* Read until there is a reply. */ + if (redisBufferRead(c) == HIREDIS_ERR) + return c->error; + reply = redisGetReply(c); + } while (reply == NULL); + } + return reply; +} + +/* Write a formatted command to the output buffer, and, if the context is a + * non-blocking connection, read the reply and return it. When this function + * is called from a blocking context, it will always return NULL. */ +void *redisCommand(redisContext *c, const char *format, ...) { va_list ap; sds cmd; + void *reply; va_start(ap,format); cmd = redisFormatCommand(format,ap); va_end(ap); - /* Send the command via socket */ - anetWrite(fd,cmd,sdslen(cmd)); + reply = redisCommandWrite(c,cmd,sdslen(cmd)); sdsfree(cmd); - return redisReadReply(fd); + return reply; } diff --git a/hiredis.h b/hiredis.h index ea01dee..60615b1 100644 --- a/hiredis.h +++ b/hiredis.h @@ -30,13 +30,20 @@ #ifndef __HIREDIS_H #define __HIREDIS_H +#define HIREDIS_ERR -1 +#define HIREDIS_OK 0 + +/* Connection type can be blocking or non-blocking and is set in the + * least significant bit of the flags field in redisContext. */ +#define HIREDIS_BLOCK 0x1 + +#define REDIS_ERROR -1 #define REDIS_REPLY_ERROR 0 #define REDIS_REPLY_STRING 1 #define REDIS_REPLY_ARRAY 2 #define REDIS_REPLY_INTEGER 3 #define REDIS_REPLY_NIL 4 #define REDIS_REPLY_STATUS 5 -#define REDIS_PROTOCOL_ERROR 6 #include "sds.h" @@ -56,6 +63,7 @@ typedef struct redisReadTask { } redisReadTask; typedef struct redisReplyObjectFunctions { + void *(*createError)(const char*, size_t); void *(*createString)(redisReadTask*, char*, size_t); void *(*createArray)(redisReadTask*, int); void *(*createInteger)(redisReadTask*, long long); @@ -63,13 +71,29 @@ typedef struct redisReplyObjectFunctions { void (*freeObject)(void*); } redisReplyFunctions; -redisReply *redisConnect(int *fd, const char *ip, int port); +/* Context for a connection to Redis */ +typedef struct redisContext { + int fd; + int flags; + char *error; /* error object is set when in erronous state */ + void *reader; /* reply reader */ + sds obuf; /* output buffer */ + redisReplyFunctions *fn; /* functions for reply buildup */ +} redisContext; + void freeReplyObject(void *reply); -redisReply *redisCommand(int fd, const char *format, ...); void *redisReplyReaderCreate(redisReplyFunctions *fn); void *redisReplyReaderGetObject(void *reader); void redisReplyReaderFree(void *ptr); void redisReplyReaderFeed(void *reader, char *buf, int len); void *redisReplyReaderGetReply(void *reader); +redisContext *redisConnect(const char *ip, int port, redisReplyFunctions *fn); +redisContext *redisConnectNonBlock(const char *ip, int port, redisReplyFunctions *fn); +int redisBufferRead(redisContext *c); +int redisBufferWrite(redisContext *c, int *done); +void *redisGetReply(redisContext *c); + +void *redisCommand(redisContext *c, const char *format, ...); + #endif diff --git a/test.c b/test.c index 640cabd..fd63726 100644 --- a/test.c +++ b/test.c @@ -16,16 +16,16 @@ static long long usec(void) { return (((long long)tv.tv_sec)*1000000)+tv.tv_usec; } -static void __connect(int *fd) { - redisReply *reply = redisConnect(fd, "127.0.0.1", 6379); - if (reply != NULL) { - printf("Connection error: %s", reply->reply); +static void __connect(redisContext **c) { + *c = redisConnect((char*)"127.0.0.1", 6379, NULL); + if ((*c)->error != NULL) { + printf("Connection error: %s", ((redisReply*)(*c)->error)->reply); exit(1); } } int main(void) { - int fd; + redisContext *fd; int i, tests = 0, fails = 0; long long t1, t2; redisReply *reply; @@ -34,8 +34,8 @@ int main(void) { test("Returns I/O error when the connection is lost: "); reply = redisCommand(fd,"QUIT"); - test_cond(reply->type == REDIS_PROTOCOL_ERROR && - strcasecmp(reply->reply,"i/o error") == 0); + test_cond(reply->type == REDIS_ERROR && + strcmp(reply->reply,"Server closed the connection") == 0); freeReplyObject(reply); __connect(&fd); /* reconnect */ @@ -127,7 +127,7 @@ int main(void) { reader = redisReplyReaderCreate(NULL); redisReplyReaderFeed(reader,(char*)"@foo\r\n",6); reply = redisReplyReaderGetReply(reader); - test_cond(reply->type == REDIS_PROTOCOL_ERROR && + test_cond(reply->type == REDIS_ERROR && strcasecmp(reply->reply,"protocol error, got \"@\" as reply type byte") == 0); freeReplyObject(reply); redisReplyReaderFree(reader); @@ -140,7 +140,7 @@ int main(void) { redisReplyReaderFeed(reader,(char*)"$5\r\nhello\r\n",11); redisReplyReaderFeed(reader,(char*)"@foo\r\n",6); reply = redisReplyReaderGetReply(reader); - test_cond(reply->type == REDIS_PROTOCOL_ERROR && + test_cond(reply->type == REDIS_ERROR && strcasecmp(reply->reply,"protocol error, got \"@\" as reply type byte") == 0); freeReplyObject(reply); redisReplyReaderFree(reader);