Move SSL management to a distinct private pointer. (#855)

We need to allow our users to use redisContext->privdata as context
for any RESP3 PUSH messages, which means we can't use it for managing
SSL connections.

Bulletpoints:

* Create a secondary redisContext member for internal use only called
  privctx and rename the redisContextFuncs->free_privdata accordingly.

* Adds a `free_privdata` function pointer so the user can tie allocated
  memory to the lifetime of a redisContext (like they can already do
  with redisAsyncContext)

* Enables SSL tests in .travis.yml
This commit is contained in:
Michael Grunder 2020-07-29 11:53:03 -07:00 committed by GitHub
parent be32bcdc8e
commit d8ff72387d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 36 deletions

View File

@ -7,6 +7,8 @@ os:
- linux - linux
- osx - osx
dist: bionic
branches: branches:
only: only:
- staging - staging
@ -14,6 +16,13 @@ branches:
- master - master
- /^release\/.*$/ - /^release\/.*$/
install:
- if [ "$BITS" == "64" ]; then
wget https://github.com/redis/redis/archive/6.0.6.tar.gz;
tar -xzvf 6.0.6.tar.gz;
pushd redis-6.0.6 && BUILD_TLS=yes make && export PATH=$PWD/src:$PATH && popd;
fi
before_script: before_script:
- if [ "$TRAVIS_OS_NAME" == "osx" ]; then - if [ "$TRAVIS_OS_NAME" == "osx" ]; then
curl -O https://distfiles.macports.org/MacPorts/MacPorts-2.6.2-10.13-HighSierra.pkg; curl -O https://distfiles.macports.org/MacPorts/MacPorts-2.6.2-10.13-HighSierra.pkg;
@ -45,6 +54,9 @@ env:
script: script:
- EXTRA_CMAKE_OPTS="-DENABLE_EXAMPLES:BOOL=ON -DENABLE_SSL:BOOL=ON"; - EXTRA_CMAKE_OPTS="-DENABLE_EXAMPLES:BOOL=ON -DENABLE_SSL:BOOL=ON";
if [ "$BITS" == "64" ]; then
EXTRA_CMAKE_OPTS="$EXTRA_CMAKE_OPTS -DENABLE_SSL_TESTS:BOOL=ON";
fi;
if [ "$TRAVIS_OS_NAME" == "osx" ]; then if [ "$TRAVIS_OS_NAME" == "osx" ]; then
if [ "$BITS" == "32" ]; then if [ "$BITS" == "32" ]; then
CFLAGS="-m32 -Werror"; CFLAGS="-m32 -Werror";
@ -79,7 +91,11 @@ script:
- mkdir build/ && cd build/ - mkdir build/ && cd build/
- cmake .. ${EXTRA_CMAKE_OPTS} - cmake .. ${EXTRA_CMAKE_OPTS}
- make VERBOSE=1 - make VERBOSE=1
- SKIPS_AS_FAILS=1 ctest -V - if [ "$BITS" == "64" ]; then
TEST_SSL=1 SKIPS_AS_FAILS=1 ctest -V;
else
SKIPS_AS_FAILS=1 ctest -V;
fi;
jobs: jobs:
include: include:

View File

@ -4,6 +4,7 @@ PROJECT(hiredis)
OPTION(ENABLE_SSL "Build hiredis_ssl for SSL support" OFF) OPTION(ENABLE_SSL "Build hiredis_ssl for SSL support" OFF)
OPTION(DISABLE_TESTS "If tests should be compiled or not" OFF) OPTION(DISABLE_TESTS "If tests should be compiled or not" OFF)
OPTION(ENABLE_SSL_TESTS, "Should we test SSL connections" OFF)
MACRO(getVersionBit name) MACRO(getVersionBit name)
SET(VERSION_REGEX "^#define ${name} (.+)$") SET(VERSION_REGEX "^#define ${name} (.+)$")
@ -148,7 +149,12 @@ ENDIF()
IF(NOT DISABLE_TESTS) IF(NOT DISABLE_TESTS)
ENABLE_TESTING() ENABLE_TESTING()
ADD_EXECUTABLE(hiredis-test test.c) ADD_EXECUTABLE(hiredis-test test.c)
IF(ENABLE_SSL_TESTS)
ADD_DEFINITIONS(-DHIREDIS_TEST_SSL=1)
TARGET_LINK_LIBRARIES(hiredis-test hiredis hiredis_ssl)
ELSE()
TARGET_LINK_LIBRARIES(hiredis-test hiredis) TARGET_LINK_LIBRARIES(hiredis-test hiredis)
ENDIF()
ADD_TEST(NAME hiredis-test ADD_TEST(NAME hiredis-test
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test.sh) COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test.sh)
ENDIF() ENDIF()

View File

@ -97,6 +97,13 @@ void pushReplyHandler(void *privdata, void *r) {
freeReplyObject(reply); freeReplyObject(reply);
} }
/* We aren't actually freeing anything here, but it is included to show that we can
* have hiredis call our data destructor when freeing the context */
void privdata_dtor(void *privdata) {
unsigned int *icount = privdata;
printf("privdata_dtor(): In context privdata dtor (invalidations: %u)\n", *icount);
}
int main(int argc, char **argv) { int main(int argc, char **argv) {
unsigned int j, invalidations = 0; unsigned int j, invalidations = 0;
redisContext *c; redisContext *c;
@ -108,6 +115,16 @@ int main(int argc, char **argv) {
redisOptions o = {0}; redisOptions o = {0};
REDIS_OPTIONS_SET_TCP(&o, hostname, port); REDIS_OPTIONS_SET_TCP(&o, hostname, port);
/* Set our context privdata to the address of our invalidation counter. Each
* time our PUSH handler is called, hiredis will pass the privdata for context.
*
* This could also be done after we create the context like so:
*
* c->privdata = &invalidations;
* c->free_privdata = privdata_dtor;
*/
REDIS_OPTIONS_SET_PRIVDATA(&o, &invalidations, privdata_dtor);
/* Set our custom PUSH message handler */ /* Set our custom PUSH message handler */
o.push_cb = pushReplyHandler; o.push_cb = pushReplyHandler;
@ -118,10 +135,6 @@ int main(int argc, char **argv) {
/* Enable RESP3 and turn on client tracking */ /* Enable RESP3 and turn on client tracking */
enableClientTracking(c); enableClientTracking(c);
/* Set our context privdata to the address of our invalidation counter. Each
* time our PUSH handler is called, hiredis will pass the privdata for context */
c->privdata = &invalidations;
/* Set some keys and then read them back. Once we do that, Redis will deliver /* Set some keys and then read them back. Once we do that, Redis will deliver
* invalidation push messages whenever the key is modified */ * invalidation push messages whenever the key is modified */
for (j = 0; j < KEY_COUNT; j++) { for (j = 0; j < KEY_COUNT; j++) {

View File

@ -48,7 +48,7 @@ extern int redisContextUpdateConnectTimeout(redisContext *c, const struct timeva
extern int redisContextUpdateCommandTimeout(redisContext *c, const struct timeval *timeout); extern int redisContextUpdateCommandTimeout(redisContext *c, const struct timeval *timeout);
static redisContextFuncs redisContextDefaultFuncs = { static redisContextFuncs redisContextDefaultFuncs = {
.free_privdata = NULL, .free_privctx = NULL,
.async_read = redisAsyncRead, .async_read = redisAsyncRead,
.async_write = redisAsyncWrite, .async_write = redisAsyncWrite,
.read = redisNetRead, .read = redisNetRead,
@ -688,7 +688,7 @@ static void redisPushAutoFree(void *privdata, void *reply) {
freeReplyObject(reply); freeReplyObject(reply);
} }
static redisContext *redisContextInit(const redisOptions *options) { static redisContext *redisContextInit(void) {
redisContext *c; redisContext *c;
c = hi_calloc(1, sizeof(*c)); c = hi_calloc(1, sizeof(*c));
@ -697,13 +697,6 @@ static redisContext *redisContextInit(const redisOptions *options) {
c->funcs = &redisContextDefaultFuncs; c->funcs = &redisContextDefaultFuncs;
/* Set any user supplied RESP3 PUSH handler or use freeReplyObject
* as a default unless specifically flagged that we don't want one. */
if (options->push_cb != NULL)
redisSetPushCallback(c, options->push_cb);
else if (!(options->options & REDIS_OPT_NO_PUSH_AUTOFREE))
redisSetPushCallback(c, redisPushAutoFree);
c->obuf = sdsempty(); c->obuf = sdsempty();
c->reader = redisReaderCreate(); c->reader = redisReaderCreate();
c->fd = REDIS_INVALID_FD; c->fd = REDIS_INVALID_FD;
@ -712,7 +705,7 @@ static redisContext *redisContextInit(const redisOptions *options) {
redisFree(c); redisFree(c);
return NULL; return NULL;
} }
(void)options; /* options are used in other functions */
return c; return c;
} }
@ -729,9 +722,13 @@ void redisFree(redisContext *c) {
hi_free(c->connect_timeout); hi_free(c->connect_timeout);
hi_free(c->command_timeout); hi_free(c->command_timeout);
hi_free(c->saddr); hi_free(c->saddr);
if (c->funcs->free_privdata) {
c->funcs->free_privdata(c->privdata); if (c->privdata && c->free_privdata)
} c->free_privdata(c->privdata);
if (c->funcs->free_privctx)
c->funcs->free_privctx(c->privctx);
memset(c, 0xff, sizeof(*c)); memset(c, 0xff, sizeof(*c));
hi_free(c); hi_free(c);
} }
@ -747,9 +744,9 @@ int redisReconnect(redisContext *c) {
c->err = 0; c->err = 0;
memset(c->errstr, '\0', strlen(c->errstr)); memset(c->errstr, '\0', strlen(c->errstr));
if (c->privdata && c->funcs->free_privdata) { if (c->privctx && c->funcs->free_privctx) {
c->funcs->free_privdata(c->privdata); c->funcs->free_privctx(c->privctx);
c->privdata = NULL; c->privctx = NULL;
} }
redisNetClose(c); redisNetClose(c);
@ -786,7 +783,7 @@ int redisReconnect(redisContext *c) {
} }
redisContext *redisConnectWithOptions(const redisOptions *options) { redisContext *redisConnectWithOptions(const redisOptions *options) {
redisContext *c = redisContextInit(options); redisContext *c = redisContextInit();
if (c == NULL) { if (c == NULL) {
return NULL; return NULL;
} }
@ -800,6 +797,16 @@ redisContext *redisConnectWithOptions(const redisOptions *options) {
c->flags |= REDIS_NO_AUTO_FREE; c->flags |= REDIS_NO_AUTO_FREE;
} }
/* Set any user supplied RESP3 PUSH handler or use freeReplyObject
* as a default unless specifically flagged that we don't want one. */
if (options->push_cb != NULL)
redisSetPushCallback(c, options->push_cb);
else if (!(options->options & REDIS_OPT_NO_PUSH_AUTOFREE))
redisSetPushCallback(c, redisPushAutoFree);
c->privdata = options->privdata;
c->free_privdata = options->free_privdata;
if (redisContextUpdateConnectTimeout(c, options->connect_timeout) != REDIS_OK || if (redisContextUpdateConnectTimeout(c, options->connect_timeout) != REDIS_OK ||
redisContextUpdateCommandTimeout(c, options->command_timeout) != REDIS_OK) { redisContextUpdateCommandTimeout(c, options->command_timeout) != REDIS_OK) {
__redisSetError(c, REDIS_ERR_OOM, "Out of memory"); __redisSetError(c, REDIS_ERR_OOM, "Out of memory");

View File

@ -196,6 +196,10 @@ typedef struct {
redisFD fd; redisFD fd;
} endpoint; } endpoint;
/* Optional user defined data/destructor */
void *privdata;
void (*free_privdata)(void *);
/* A user defined PUSH message callback */ /* A user defined PUSH message callback */
redisPushFn *push_cb; redisPushFn *push_cb;
redisAsyncPushFn *async_push_cb; redisAsyncPushFn *async_push_cb;
@ -213,8 +217,12 @@ typedef struct {
(opts)->type = REDIS_CONN_UNIX; \ (opts)->type = REDIS_CONN_UNIX; \
(opts)->endpoint.unix_socket = path; (opts)->endpoint.unix_socket = path;
#define REDIS_OPTIONS_SET_PRIVDATA(opts, data, dtor) \
(opts)->privdata = data; \
(opts)->free_privdata = dtor; \
typedef struct redisContextFuncs { typedef struct redisContextFuncs {
void (*free_privdata)(void *); void (*free_privctx)(void *);
void (*async_read)(struct redisAsyncContext *); void (*async_read)(struct redisAsyncContext *);
void (*async_write)(struct redisAsyncContext *); void (*async_write)(struct redisAsyncContext *);
ssize_t (*read)(struct redisContext *, char *, size_t); ssize_t (*read)(struct redisContext *, char *, size_t);
@ -250,8 +258,14 @@ typedef struct redisContext {
struct sockadr *saddr; struct sockadr *saddr;
size_t addrlen; size_t addrlen;
/* Additional private data for hiredis addons such as SSL */ /* Optional data and corresponding destructor users can use to provide
* context to a given redisContext. Not used by hiredis. */
void *privdata; void *privdata;
void (*free_privdata)(void *);
/* Internal context pointer presently used by hiredis to manage
* SSL connections. */
void *privctx;
/* An optional RESP3 PUSH handler */ /* An optional RESP3 PUSH handler */
redisPushFn *push_cb; redisPushFn *push_cb;

22
ssl.c
View File

@ -267,7 +267,7 @@ error:
static int redisSSLConnect(redisContext *c, SSL *ssl) { static int redisSSLConnect(redisContext *c, SSL *ssl) {
if (c->privdata) { if (c->privctx) {
__redisSetError(c, REDIS_ERR_OTHER, "redisContext was already associated"); __redisSetError(c, REDIS_ERR_OTHER, "redisContext was already associated");
return REDIS_ERR; return REDIS_ERR;
} }
@ -288,14 +288,14 @@ static int redisSSLConnect(redisContext *c, SSL *ssl) {
ERR_clear_error(); ERR_clear_error();
int rv = SSL_connect(rssl->ssl); int rv = SSL_connect(rssl->ssl);
if (rv == 1) { if (rv == 1) {
c->privdata = rssl; c->privctx = rssl;
return REDIS_OK; return REDIS_OK;
} }
rv = SSL_get_error(rssl->ssl, rv); rv = SSL_get_error(rssl->ssl, rv);
if (((c->flags & REDIS_BLOCK) == 0) && if (((c->flags & REDIS_BLOCK) == 0) &&
(rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) { (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) {
c->privdata = rssl; c->privctx = rssl;
return REDIS_OK; return REDIS_OK;
} }
@ -337,7 +337,7 @@ int redisInitiateSSLWithContext(redisContext *c, redisSSLContext *redis_ssl_ctx)
/* We want to verify that redisSSLConnect() won't fail on this, as it will /* We want to verify that redisSSLConnect() won't fail on this, as it will
* not own the SSL object in that case and we'll end up leaking. * not own the SSL object in that case and we'll end up leaking.
*/ */
if (c->privdata) if (c->privctx)
return REDIS_ERR; return REDIS_ERR;
SSL *ssl = SSL_new(redis_ssl_ctx->ssl_ctx); SSL *ssl = SSL_new(redis_ssl_ctx->ssl_ctx);
@ -381,8 +381,8 @@ static int maybeCheckWant(redisSSL *rssl, int rv) {
* Implementation of redisContextFuncs for SSL connections. * Implementation of redisContextFuncs for SSL connections.
*/ */
static void redisSSLFree(void *privdata){ static void redisSSLFree(void *privctx){
redisSSL *rsc = privdata; redisSSL *rsc = privctx;
if (!rsc) return; if (!rsc) return;
if (rsc->ssl) { if (rsc->ssl) {
@ -393,7 +393,7 @@ static void redisSSLFree(void *privdata){
} }
static ssize_t redisSSLRead(redisContext *c, char *buf, size_t bufcap) { static ssize_t redisSSLRead(redisContext *c, char *buf, size_t bufcap) {
redisSSL *rssl = c->privdata; redisSSL *rssl = c->privctx;
int nread = SSL_read(rssl->ssl, buf, bufcap); int nread = SSL_read(rssl->ssl, buf, bufcap);
if (nread > 0) { if (nread > 0) {
@ -435,7 +435,7 @@ static ssize_t redisSSLRead(redisContext *c, char *buf, size_t bufcap) {
} }
static ssize_t redisSSLWrite(redisContext *c) { static ssize_t redisSSLWrite(redisContext *c) {
redisSSL *rssl = c->privdata; redisSSL *rssl = c->privctx;
size_t len = rssl->lastLen ? rssl->lastLen : sdslen(c->obuf); size_t len = rssl->lastLen ? rssl->lastLen : sdslen(c->obuf);
int rv = SSL_write(rssl->ssl, c->obuf, len); int rv = SSL_write(rssl->ssl, c->obuf, len);
@ -458,7 +458,7 @@ static ssize_t redisSSLWrite(redisContext *c) {
static void redisSSLAsyncRead(redisAsyncContext *ac) { static void redisSSLAsyncRead(redisAsyncContext *ac) {
int rv; int rv;
redisSSL *rssl = ac->c.privdata; redisSSL *rssl = ac->c.privctx;
redisContext *c = &ac->c; redisContext *c = &ac->c;
rssl->wantRead = 0; rssl->wantRead = 0;
@ -488,7 +488,7 @@ static void redisSSLAsyncRead(redisAsyncContext *ac) {
static void redisSSLAsyncWrite(redisAsyncContext *ac) { static void redisSSLAsyncWrite(redisAsyncContext *ac) {
int rv, done = 0; int rv, done = 0;
redisSSL *rssl = ac->c.privdata; redisSSL *rssl = ac->c.privctx;
redisContext *c = &ac->c; redisContext *c = &ac->c;
rssl->pendingWrite = 0; rssl->pendingWrite = 0;
@ -517,7 +517,7 @@ static void redisSSLAsyncWrite(redisAsyncContext *ac) {
} }
redisContextFuncs redisContextSSLFuncs = { redisContextFuncs redisContextSSLFuncs = {
.free_privdata = redisSSLFree, .free_privctx = redisSSLFree,
.async_read = redisSSLAsyncRead, .async_read = redisSSLAsyncRead,
.async_write = redisSSLAsyncWrite, .async_write = redisSSLAsyncWrite,
.read = redisSSLRead, .read = redisSSLRead,

27
test.c
View File

@ -49,6 +49,10 @@ struct config {
} ssl; } ssl;
}; };
struct privdata {
int dtor_counter;
};
#ifdef HIREDIS_TEST_SSL #ifdef HIREDIS_TEST_SSL
redisSSLContext *_ssl_ctx = NULL; redisSSLContext *_ssl_ctx = NULL;
#endif #endif
@ -786,6 +790,27 @@ static void test_resp3_push_options(struct config config) {
redisAsyncFree(ac); redisAsyncFree(ac);
} }
void free_privdata(void *privdata) {
struct privdata *data = privdata;
data->dtor_counter++;
}
static void test_privdata_hooks(struct config config) {
struct privdata data = {0};
redisOptions options;
redisContext *c;
test("We can use redisOptions to set privdata: ");
options = get_redis_tcp_options(config);
REDIS_OPTIONS_SET_PRIVDATA(&options, &data, free_privdata);
assert((c = redisConnectWithOptions(&options)) != NULL);
test_cond(c->privdata == &data);
test("Our privdata destructor fires when we free the context: ");
redisFree(c);
test_cond(data.dtor_counter == 1);
}
static void test_blocking_connection(struct config config) { static void test_blocking_connection(struct config config) {
redisContext *c; redisContext *c;
redisReply *reply; redisReply *reply;
@ -871,6 +896,8 @@ static void test_blocking_connection(struct config config) {
if (major >= 6) test_resp3_push_handler(c); if (major >= 6) test_resp3_push_handler(c);
test_resp3_push_options(config); test_resp3_push_options(config);
test_privdata_hooks(config);
disconnect(c, 0); disconnect(c, 0);
} }