Make command formatters gracefully abort when out of memory

This commit is contained in:
Pieter Noordhuis 2011-04-21 20:59:41 +02:00
parent d4ebb60d65
commit dd5fc26457
1 changed files with 77 additions and 31 deletions

108
hiredis.c
View File

@ -652,59 +652,74 @@ static int intlen(int i) {
return len;
}
/* Helper function for redisvFormatCommand(). */
static void addArgument(sds a, char ***argv, int *argc, int *totlen) {
(*argc)++;
if ((*argv = realloc(*argv, sizeof(char*)*(*argc))) == NULL) redisOOM();
if (totlen) *totlen = *totlen+1+intlen(sdslen(a))+2+sdslen(a)+2;
(*argv)[(*argc)-1] = a;
/* Helper that calculates the bulk length given a certain string length. */
static size_t bulklen(size_t len) {
return 1+intlen(len)+2+len+2;
}
int redisvFormatCommand(char **target, const char *format, va_list ap) {
size_t size;
const char *arg, *c = format;
const char *c = format;
char *cmd = NULL; /* final command */
int pos; /* position in final command */
sds current; /* current argument */
sds curarg, newarg; /* current argument */
int touched = 0; /* was the current argument touched? */
char **argv = NULL;
int argc = 0, j;
char **curargv = NULL, **newargv = NULL;
int argc = 0;
int totlen = 0;
int j;
/* Abort if there is not target to set */
if (target == NULL)
return -1;
/* Build the command string accordingly to protocol */
current = sdsempty();
curarg = sdsempty();
if (curarg == NULL)
return -1;
while(*c != '\0') {
if (*c != '%' || c[1] == '\0') {
if (*c == ' ') {
if (touched) {
addArgument(current, &argv, &argc, &totlen);
current = sdsempty();
newargv = realloc(curargv,sizeof(char*)*(argc+1));
if (newargv == NULL) goto err;
curargv = newargv;
curargv[argc++] = curarg;
totlen += bulklen(sdslen(curarg));
/* curarg is put in argv so it can be overwritten. */
curarg = sdsempty();
if (curarg == NULL) goto err;
touched = 0;
}
} else {
current = sdscatlen(current,c,1);
newarg = sdscatlen(curarg,c,1);
if (newarg == NULL) goto err;
curarg = newarg;
touched = 1;
}
} else {
char *arg;
size_t size;
/* Set newarg so it can be checked even if it is not touched. */
newarg = curarg;
switch(c[1]) {
case 's':
arg = va_arg(ap,char*);
size = strlen(arg);
if (size > 0)
current = sdscatlen(current,arg,size);
newarg = sdscatlen(curarg,arg,size);
break;
case 'b':
arg = va_arg(ap,char*);
size = va_arg(ap,size_t);
if (size > 0)
current = sdscatlen(current,arg,size);
newarg = sdscatlen(curarg,arg,size);
break;
case '%':
current = sdscat(current,"%");
newarg = sdscat(curarg,"%");
break;
default:
/* Try to detect printf format */
@ -746,7 +761,7 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {
memcpy(_format,c,_l);
_format[_l] = '\0';
va_copy(_cpy,ap);
current = sdscatvprintf(current,_format,_cpy);
newarg = sdscatvprintf(curarg,_format,_cpy);
va_end(_cpy);
/* Update current position (note: outer blocks
@ -759,6 +774,10 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {
va_arg(ap,void);
}
}
if (newarg == NULL) goto err;
curarg = newarg;
touched = 1;
c++;
}
@ -767,31 +786,55 @@ int redisvFormatCommand(char **target, const char *format, va_list ap) {
/* Add the last argument if needed */
if (touched) {
addArgument(current, &argv, &argc, &totlen);
newargv = realloc(curargv,sizeof(char*)*(argc+1));
if (newargv == NULL) goto err;
curargv = newargv;
curargv[argc++] = curarg;
totlen += bulklen(sdslen(curarg));
} else {
sdsfree(current);
sdsfree(curarg);
}
/* Clear curarg because it was put in curargv or was free'd. */
curarg = NULL;
/* Add bytes needed to hold multi bulk count */
totlen += 1+intlen(argc)+2;
/* Build the command at protocol level */
cmd = malloc(totlen+1);
if (!cmd) redisOOM();
if (cmd == NULL) goto err;
pos = sprintf(cmd,"*%d\r\n",argc);
for (j = 0; j < argc; j++) {
pos += sprintf(cmd+pos,"$%zu\r\n",sdslen(argv[j]));
memcpy(cmd+pos,argv[j],sdslen(argv[j]));
pos += sdslen(argv[j]);
sdsfree(argv[j]);
pos += sprintf(cmd+pos,"$%zu\r\n",sdslen(curargv[j]));
memcpy(cmd+pos,curargv[j],sdslen(curargv[j]));
pos += sdslen(curargv[j]);
sdsfree(curargv[j]);
cmd[pos++] = '\r';
cmd[pos++] = '\n';
}
assert(pos == totlen);
free(argv);
cmd[totlen] = '\0';
cmd[pos] = '\0';
free(curargv);
*target = cmd;
return totlen;
err:
while(argc--)
sdsfree(curargv[argc]);
free(curargv);
if (curarg != NULL)
sdsfree(curarg);
/* No need to check cmd since it is the last statement that can fail,
* but do it anyway to be as defensive as possible. */
if (cmd != NULL)
free(cmd);
return -1;
}
/* Format a command according to the Redis protocol. This function
@ -830,12 +873,14 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz
totlen = 1+intlen(argc)+2;
for (j = 0; j < argc; j++) {
len = argvlen ? argvlen[j] : strlen(argv[j]);
totlen += 1+intlen(len)+2+len+2;
totlen += bulklen(len);
}
/* Build the command at protocol level */
cmd = malloc(totlen+1);
if (!cmd) redisOOM();
if (cmd == NULL)
return -1;
pos = sprintf(cmd,"*%d\r\n",argc);
for (j = 0; j < argc; j++) {
len = argvlen ? argvlen[j] : strlen(argv[j]);
@ -846,7 +891,8 @@ int redisFormatCommandArgv(char **target, int argc, const char **argv, const siz
cmd[pos++] = '\n';
}
assert(pos == totlen);
cmd[totlen] = '\0';
cmd[pos] = '\0';
*target = cmd;
return totlen;
}