HDFS-10366: libhdfs++: Add SASL authentication. Contributed by Bob Hansen
This commit is contained in:
parent
93382381f6
commit
f1f0b8f0f8
@ -0,0 +1,49 @@
|
||||
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# - Find Cyrus SASL (sasl.h, libsasl2.so)
|
||||
#
|
||||
# This module defines
|
||||
# CYRUS_SASL_INCLUDE_DIR, directory containing headers
|
||||
# CYRUS_SASL_SHARED_LIB, path to Cyrus SASL's shared library
|
||||
# CYRUS_SASL_FOUND, whether Cyrus SASL and its plugins have been found
|
||||
#
|
||||
# N.B: we do _not_ include sasl in thirdparty, for a fairly subtle reason. The
|
||||
# TLDR version is that newer versions of cyrus-sasl (>=2.1.26) have a bug fix
|
||||
# for https://bugzilla.cyrusimap.org/show_bug.cgi?id=3590, but that bug fix
|
||||
# relied on a change both on the plugin side and on the library side. If you
|
||||
# then try to run the new version of sasl (e.g from our thirdparty tree) with
|
||||
# an older version of a plugin (eg from RHEL6 install), you'll get a SASL_NOMECH
|
||||
# error due to this bug.
|
||||
#
|
||||
# In practice, Cyrus-SASL is so commonly used and generally non-ABI-breaking that
|
||||
# we should be OK to depend on the host installation.
|
||||
|
||||
# Note that this is modified from the version that was copied from our
|
||||
# friends at the Kudu project. The original version implicitly required
|
||||
# the Cyrus SASL. This version will only complain if REQUIRED is added.
|
||||
|
||||
|
||||
find_path(CYRUS_SASL_INCLUDE_DIR sasl/sasl.h)
|
||||
find_library(CYRUS_SASL_SHARED_LIB sasl2)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(CYRUS_SASL DEFAULT_MSG
|
||||
CYRUS_SASL_SHARED_LIB CYRUS_SASL_INCLUDE_DIR)
|
||||
|
||||
MARK_AS_ADVANCED(CYRUS_SASL_INCLUDE_DIR CYRUS_SASL_SHARED_LIB)
|
@ -0,0 +1,44 @@
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# - Try to find the GNU sasl library (gsasl)
|
||||
#
|
||||
# Once done this will define
|
||||
#
|
||||
# GSASL_FOUND - System has gnutls
|
||||
# GSASL_INCLUDE_DIR - The gnutls include directory
|
||||
# GSASL_LIBRARIES - The libraries needed to use gnutls
|
||||
# GSASL_DEFINITIONS - Compiler switches required for using gnutls
|
||||
|
||||
|
||||
IF (GSASL_INCLUDE_DIR AND GSASL_LIBRARIES)
|
||||
# in cache already
|
||||
SET(GSasl_FIND_QUIETLY TRUE)
|
||||
ENDIF (GSASL_INCLUDE_DIR AND GSASL_LIBRARIES)
|
||||
|
||||
FIND_PATH(GSASL_INCLUDE_DIR gsasl.h)
|
||||
|
||||
FIND_LIBRARY(GSASL_LIBRARIES gsasl)
|
||||
|
||||
INCLUDE(FindPackageHandleStandardArgs)
|
||||
|
||||
# handle the QUIETLY and REQUIRED arguments and set GSASL_FOUND to TRUE if
|
||||
# all listed variables are TRUE
|
||||
FIND_PACKAGE_HANDLE_STANDARD_ARGS(GSASL DEFAULT_MSG GSASL_LIBRARIES GSASL_INCLUDE_DIR)
|
||||
|
||||
MARK_AS_ADVANCED(GSASL_INCLUDE_DIR GSASL_LIBRARIES)
|
@ -22,10 +22,14 @@ cmake_minimum_required(VERSION 2.8)
|
||||
|
||||
enable_testing()
|
||||
include (CTest)
|
||||
SET(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/CMake" ${CMAKE_MODULE_PATH})
|
||||
SET(CMAKE_DISABLE_FIND_PACKAGE_CyrusSASL TRUE) # Until development is done.
|
||||
|
||||
find_package(Doxygen)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
find_package(Protobuf REQUIRED)
|
||||
find_package(GSasl)
|
||||
find_package(CyrusSASL)
|
||||
find_package(Threads)
|
||||
|
||||
find_program(MEMORYCHECK_COMMAND valgrind HINTS ${VALGRIND_DIR} )
|
||||
@ -37,6 +41,36 @@ if (REQUIRE_VALGRIND AND MEMORYCHECK_COMMAND MATCHES "MEMORYCHECK_COMMAND-NOTFOU
|
||||
"The path can be included via a -DVALGRIND_DIR=... flag passed to CMake.")
|
||||
endif (REQUIRE_VALGRIND AND MEMORYCHECK_COMMAND MATCHES "MEMORYCHECK_COMMAND-NOTFOUND" )
|
||||
|
||||
# Prefer Cyrus SASL, but use GSASL if it is found
|
||||
# Note that the packages can be disabled by setting CMAKE_DISABLE_FIND_PACKAGE_GSasl or
|
||||
# CMAKE_DISABLE_FIND_PACKAGE_CyrusSASL, respectively (case sensitive)
|
||||
set (SASL_LIBRARIES)
|
||||
set (SASL_INCLUDE_DIR)
|
||||
if (CYRUS_SASL_FOUND)
|
||||
message(STATUS "Using Cyrus SASL; link with ${CYRUS_SASL_LIBRARIES}")
|
||||
set (SASL_INCLUDE_DIR ${CYRUS_SASL_INCLUDE_DIR})
|
||||
set (SASL_LIBRARIES ${CYRUS_SASL_SHARED_LIB})
|
||||
add_definitions(-DUSE_SASL -DUSE_CYRUS_SASL)
|
||||
else (CYRUS_SASL_FOUND)
|
||||
if (REQUIRE_CYRUS_SASL)
|
||||
message(FATAL_ERROR "Cyrus SASL was required but not found. "
|
||||
"The path can be included via a -DCYRUS_SASL_DIR=... flag passed to CMake.")
|
||||
endif (REQUIRE_CYRUS_SASL)
|
||||
|
||||
# If we didn't pick Cyrus, use GSASL instead
|
||||
if (GSASL_FOUND)
|
||||
message(STATUS "Using GSASL; link with ${GSASL_LIBRARIES}")
|
||||
set (SASL_INCLUDE_DIR ${GSASL_INCLUDE_DIR})
|
||||
set (SASL_LIBRARIES ${GSASL_LIBRARIES})
|
||||
add_definitions(-DUSE_SASL -DUSE_GSASL)
|
||||
else (GSASL_FOUND)
|
||||
if (REQUIRE_GSASL)
|
||||
message(FATAL_ERROR "GSASL was required but not found. "
|
||||
"The path can be included via a -DGSASL_DIR=... flag passed to CMake.")
|
||||
endif (REQUIRE_GSASL)
|
||||
message(STATUS "Not using SASL")
|
||||
endif (GSASL_FOUND)
|
||||
endif (CYRUS_SASL_FOUND)
|
||||
|
||||
add_definitions(-DASIO_STANDALONE -DASIO_CPP11_DATE_TIME)
|
||||
|
||||
@ -120,6 +154,7 @@ include_directories( SYSTEM
|
||||
third_party/protobuf
|
||||
third_party/uriparser2
|
||||
${OPENSSL_INCLUDE_DIR}
|
||||
${SASL_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
|
||||
@ -146,6 +181,7 @@ if (HADOOP_BUILD)
|
||||
${LIB_DL}
|
||||
${PROTOBUF_LIBRARY}
|
||||
${OPENSSL_LIBRARIES}
|
||||
${SASL_LIBRARIES}
|
||||
)
|
||||
else (HADOOP_BUILD)
|
||||
add_library(hdfspp_static STATIC ${EMPTY_FILE_CC} ${LIBHDFSPP_ALL_OBJECTS})
|
||||
@ -153,12 +189,14 @@ else (HADOOP_BUILD)
|
||||
${LIB_DL}
|
||||
${PROTOBUF_LIBRARY}
|
||||
${OPENSSL_LIBRARIES}
|
||||
${SASL_LIBRARIES}
|
||||
)
|
||||
add_library(hdfspp SHARED ${EMPTY_FILE_CC} ${LIBHDFSPP_ALL_OBJECTS})
|
||||
target_link_libraries(hdfspp_static
|
||||
${LIB_DL}
|
||||
${PROTOBUF_LIBRARY}
|
||||
${OPENSSL_LIBRARIES}
|
||||
${SASL_LIBRARIES}
|
||||
)
|
||||
endif (HADOOP_BUILD)
|
||||
set(LIBHDFSPP_VERSION "0.1.0")
|
||||
|
@ -66,6 +66,7 @@ void parse_uri(const char * uri_string, struct Uri * uri) {
|
||||
|
||||
char * host_port_separator = strstr(authority, ":");
|
||||
if (host_port_separator != NULL) {
|
||||
errno = 0;
|
||||
uri->port = strtol(host_port_separator + 1, NULL, 10);
|
||||
if (errno != 0)
|
||||
return;
|
||||
|
@ -58,6 +58,17 @@ struct Options {
|
||||
*/
|
||||
URI defaultFS;
|
||||
|
||||
/**
|
||||
* Which form of authentication to use with the server
|
||||
* Default: simple
|
||||
*/
|
||||
enum Authentication {
|
||||
kSimple,
|
||||
kKerberos
|
||||
};
|
||||
Authentication authentication;
|
||||
static const Authentication kDefaultAuthentication = kSimple;
|
||||
|
||||
Options();
|
||||
};
|
||||
}
|
||||
|
@ -37,7 +37,8 @@ class Status {
|
||||
static Status Unimplemented();
|
||||
static Status Exception(const char *expception_class_name, const char *error_message);
|
||||
static Status Error(const char *error_message);
|
||||
static Status Canceled();
|
||||
static Status AuthenticationFailed();
|
||||
static Status Canceled();
|
||||
|
||||
// success
|
||||
bool ok() const { return code_ == 0; }
|
||||
@ -55,7 +56,8 @@ class Status {
|
||||
kUnimplemented = static_cast<unsigned>(std::errc::function_not_supported),
|
||||
kOperationCanceled = static_cast<unsigned>(std::errc::operation_canceled),
|
||||
kPermissionDenied = static_cast<unsigned>(std::errc::permission_denied),
|
||||
kException = 255,
|
||||
kException = 256,
|
||||
kAuthenticationFailed = 257,
|
||||
};
|
||||
|
||||
private:
|
||||
|
@ -19,6 +19,6 @@ if(NEED_LINK_DL)
|
||||
set(LIB_DL dl)
|
||||
endif()
|
||||
|
||||
add_library(common_obj OBJECT base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc configuration_loader.cc hdfs_configuration.cc uri.cc util.cc retry_policy.cc cancel_tracker.cc logging.cc libhdfs_events_impl.cc)
|
||||
add_library(common_obj OBJECT base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc configuration_loader.cc hdfs_configuration.cc uri.cc util.cc retry_policy.cc cancel_tracker.cc logging.cc libhdfs_events_impl.cc auth_info.cc)
|
||||
add_library(common $<TARGET_OBJECTS:common_obj> $<TARGET_OBJECTS:uriparser2_obj>)
|
||||
target_link_libraries(common ${LIB_DL})
|
||||
|
@ -0,0 +1,18 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "auth_info.h"
|
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LIB_FS_AUTHINFO_H
|
||||
#define LIB_FS_AUTHINFO_H
|
||||
|
||||
#include <optional.hpp>
|
||||
|
||||
namespace hdfs {
|
||||
|
||||
class Token {
|
||||
public:
|
||||
std::string identifier;
|
||||
std::string password;
|
||||
};
|
||||
|
||||
class AuthInfo {
|
||||
public:
|
||||
enum AuthMethod {
|
||||
kSimple,
|
||||
kKerberos,
|
||||
kToken,
|
||||
kUnknownAuth,
|
||||
kAuthFailed
|
||||
};
|
||||
|
||||
AuthInfo() :
|
||||
method(kSimple) {
|
||||
}
|
||||
|
||||
explicit AuthInfo(AuthMethod mech) :
|
||||
method(mech) {
|
||||
}
|
||||
|
||||
bool useSASL() {
|
||||
return method != kSimple;
|
||||
}
|
||||
|
||||
const std::string & getUser() const {
|
||||
return user;
|
||||
}
|
||||
|
||||
void setUser(const std::string & user) {
|
||||
this->user = user;
|
||||
}
|
||||
|
||||
AuthMethod getMethod() const {
|
||||
return method;
|
||||
}
|
||||
|
||||
void setMethod(AuthMethod method) {
|
||||
this->method = method;
|
||||
}
|
||||
|
||||
const std::experimental::optional<Token> & getToken() const {
|
||||
return token;
|
||||
}
|
||||
|
||||
void setToken(const Token & token) {
|
||||
this->token = token;
|
||||
}
|
||||
|
||||
void clearToken() {
|
||||
this->token = std::experimental::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
AuthMethod method;
|
||||
std::string user;
|
||||
std::experimental::optional<Token> token;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* RPCAUTHINFO_H */
|
@ -49,6 +49,15 @@ Options HdfsConfiguration::GetOptions() {
|
||||
OptionalSet(result.rpc_retry_delay_ms, GetInt(kIpcClientConnectRetryIntervalKey));
|
||||
OptionalSet(result.defaultFS, GetUri(kFsDefaultFsKey));
|
||||
|
||||
optional<std::string> authentication_value = Get(kHadoopSecurityAuthentication);
|
||||
if (authentication_value ) {
|
||||
std::string fixed_case_value = fixCase(authentication_value.value());
|
||||
if (fixed_case_value == fixCase(kHadoopSecurityAuthentication_kerberos))
|
||||
result.authentication = Options::kKerberos;
|
||||
else
|
||||
result.authentication = Options::kSimple;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,10 @@ class HdfsConfiguration : public Configuration {
|
||||
static constexpr const char * kDfsClientSocketTimeoutKey = "dfs.client.socket-timeout";
|
||||
static constexpr const char * kIpcClientConnectMaxRetriesKey = "ipc.client.connect.max.retries";
|
||||
static constexpr const char * kIpcClientConnectRetryIntervalKey = "ipc.client.connect.retry.interval";
|
||||
static constexpr const char * kHadoopSecurityAuthentication = "hadoop.security.authentication";
|
||||
static constexpr const char * kHadoopSecurityAuthentication_simple = "simple";
|
||||
static constexpr const char * kHadoopSecurityAuthentication_kerberos = "kerberos";
|
||||
|
||||
private:
|
||||
friend class ConfigurationLoader;
|
||||
|
||||
|
@ -30,6 +30,7 @@ const unsigned int Options::kDefaultHostExclusionDuration;
|
||||
Options::Options() : rpc_timeout(kDefaultRpcTimeout), max_rpc_retries(kDefaultMaxRpcRetries),
|
||||
rpc_retry_delay_ms(kDefaultRpcRetryDelayMs),
|
||||
host_exclusion_duration(kDefaultHostExclusionDuration),
|
||||
defaultFS()
|
||||
defaultFS(),
|
||||
authentication(kDefaultAuthentication)
|
||||
{}
|
||||
}
|
||||
|
@ -25,6 +25,7 @@
|
||||
namespace hdfs {
|
||||
|
||||
const char * kStatusAccessControlException = "org.apache.hadoop.security.AccessControlException";
|
||||
const char * kStatusSaslException = "javax.security.sasl.SaslException";
|
||||
|
||||
Status::Status(int code, const char *msg1) : code_(code) {
|
||||
if(msg1) {
|
||||
@ -63,12 +64,18 @@ Status Status::Unimplemented() {
|
||||
Status Status::Exception(const char *exception_class_name, const char *error_message) {
|
||||
if (exception_class_name && (strcmp(exception_class_name, kStatusAccessControlException) == 0) )
|
||||
return Status(kPermissionDenied, error_message);
|
||||
else if (exception_class_name && (strcmp(exception_class_name, kStatusSaslException) == 0))
|
||||
return AuthenticationFailed();
|
||||
else
|
||||
return Status(kException, exception_class_name, error_message);
|
||||
}
|
||||
|
||||
Status Status::Error(const char *error_message) {
|
||||
return Exception("Exception", error_message);
|
||||
return Status(kAuthenticationFailed, error_message);
|
||||
}
|
||||
|
||||
Status Status::AuthenticationFailed() {
|
||||
return Status(kAuthenticationFailed, "Authentication failed");
|
||||
}
|
||||
|
||||
Status Status::Canceled() {
|
||||
|
@ -17,6 +17,6 @@
|
||||
#
|
||||
|
||||
include_directories(${OPENSSL_INCLUDE_DIRS})
|
||||
add_library(rpc_obj OBJECT rpc_connection.cc rpc_engine.cc)
|
||||
add_library(rpc_obj OBJECT rpc_connection.cc rpc_engine.cc sasl_protocol.cc sasl_engine.cc)
|
||||
add_dependencies(rpc_obj proto)
|
||||
add_library(rpc $<TARGET_OBJECTS:rpc_obj>)
|
||||
|
@ -16,6 +16,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "rpc_engine.h"
|
||||
#include "sasl_protocol.h"
|
||||
|
||||
#include "RpcHeader.pb.h"
|
||||
#include "ProtobufRpcEngine.pb.h"
|
||||
@ -110,22 +111,22 @@ static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
|
||||
|
||||
RpcConnection::~RpcConnection() {}
|
||||
|
||||
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
||||
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
||||
const std::string &request, Handler &&handler)
|
||||
: engine_(engine),
|
||||
method_name_(method_name),
|
||||
call_id_(engine->NextCallId()),
|
||||
call_id_(call_id),
|
||||
timer_(engine->io_service()),
|
||||
handler_(std::move(handler)),
|
||||
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
|
||||
ConstructPayload(&payload_, &request);
|
||||
}
|
||||
|
||||
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
||||
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
||||
const pb::MessageLite *request, Handler &&handler)
|
||||
: engine_(engine),
|
||||
method_name_(method_name),
|
||||
call_id_(engine->NextCallId()),
|
||||
call_id_(call_id),
|
||||
timer_(engine->io_service()),
|
||||
handler_(std::move(handler)),
|
||||
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
|
||||
@ -148,7 +149,12 @@ void Request::GetPacket(std::string *res) const {
|
||||
RequestHeaderProto req_header;
|
||||
SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
|
||||
&req_header);
|
||||
AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
|
||||
|
||||
// SASL messages don't have a request header
|
||||
if (method_name_ != SASL_METHOD_NAME)
|
||||
AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
|
||||
else
|
||||
AddHeadersToPacket(res, {&rpc_header}, &payload_);
|
||||
}
|
||||
|
||||
void Request::OnResponseArrived(pbio::CodedInputStream *is,
|
||||
@ -171,12 +177,80 @@ void RpcConnection::StartReading() {
|
||||
});
|
||||
}
|
||||
|
||||
void RpcConnection::HandshakeComplete(const Status &s) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::HandshakeComplete called");
|
||||
|
||||
if (s.ok()) {
|
||||
if (connected_ == kConnecting) {
|
||||
auto shared_this = shared_from_this();
|
||||
|
||||
connected_ = kAuthenticating;
|
||||
if (auth_info_.useSASL()) {
|
||||
#ifdef USE_SASL
|
||||
sasl_protocol_ = std::make_shared<SaslProtocol>(cluster_name_, auth_info_, shared_from_this());
|
||||
sasl_protocol_->SetEventHandlers(event_handlers_);
|
||||
sasl_protocol_->authenticate([shared_this, this](
|
||||
const Status & status, const AuthInfo & new_auth_info) {
|
||||
AuthComplete(status, new_auth_info); } );
|
||||
#else
|
||||
AuthComplete_locked(Status::Error("SASL is required, but no SASL library was found"), auth_info_);
|
||||
#endif
|
||||
} else {
|
||||
AuthComplete_locked(Status::OK(), auth_info_);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CommsError(s);
|
||||
};
|
||||
}
|
||||
|
||||
void RpcConnection::AuthComplete(const Status &s, const AuthInfo & new_auth_info) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
AuthComplete_locked(s, new_auth_info);
|
||||
}
|
||||
|
||||
void RpcConnection::AuthComplete_locked(const Status &s, const AuthInfo & new_auth_info) {
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::AuthComplete called");
|
||||
|
||||
// Free the sasl_protocol object
|
||||
sasl_protocol_.reset();
|
||||
|
||||
if (s.ok()) {
|
||||
auth_info_ = new_auth_info;
|
||||
|
||||
auto shared_this = shared_from_this();
|
||||
SendContext([shared_this, this](const Status & s) {
|
||||
ContextComplete(s);
|
||||
});
|
||||
} else {
|
||||
CommsError(s);
|
||||
};
|
||||
}
|
||||
|
||||
void RpcConnection::ContextComplete(const Status &s) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::ContextComplete called");
|
||||
|
||||
if (s.ok()) {
|
||||
if (connected_ == kAuthenticating) {
|
||||
connected_ = kConnected;
|
||||
}
|
||||
FlushPendingRequests();
|
||||
} else {
|
||||
CommsError(s);
|
||||
};
|
||||
}
|
||||
|
||||
void RpcConnection::AsyncFlushPendingRequests() {
|
||||
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
|
||||
io_service().post([shared_this, this]() {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnection::AsyncRpc called (connected=" << ToString(connected_) << ")");
|
||||
LOG_TRACE(kRPC, << "RpcConnection::AsyncFlushPendingRequests called (connected=" << ToString(connected_) << ")");
|
||||
|
||||
if (!request_over_the_wire_) {
|
||||
FlushPendingRequests();
|
||||
@ -246,10 +320,22 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
|
||||
*
|
||||
* AuthProtocol: 0->none, -33->SASL
|
||||
*/
|
||||
static const char kHandshakeHeader[] = {'h', 'r', 'p', 'c',
|
||||
RpcEngine::kRpcVersion, 0, 0};
|
||||
|
||||
char auth_protocol = auth_info_.useSASL() ? -33 : 0;
|
||||
const char handshake_header[] = {'h', 'r', 'p', 'c',
|
||||
RpcEngine::kRpcVersion, 0, auth_protocol};
|
||||
auto res =
|
||||
std::make_shared<std::string>(kHandshakeHeader, sizeof(kHandshakeHeader));
|
||||
std::make_shared<std::string>(handshake_header, sizeof(handshake_header));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::string> RpcConnection::PrepareContextPacket() {
|
||||
// This needs to be send after the SASL handshake, and
|
||||
// after the SASL handshake (if any)
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
auto res = std::make_shared<std::string>();
|
||||
|
||||
RpcRequestHeaderProto h;
|
||||
h.set_rpckind(RPC_PROTOCOL_BUFFER);
|
||||
@ -259,11 +345,12 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
|
||||
|
||||
IpcConnectionContextProto handshake;
|
||||
handshake.set_protocol(engine_->protocol_name());
|
||||
const std::string & user_name = engine()->user_name();
|
||||
const std::string & user_name = auth_info_.getUser();
|
||||
if (!user_name.empty()) {
|
||||
*handshake.mutable_userinfo()->mutable_effectiveuser() = user_name;
|
||||
}
|
||||
AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -272,6 +359,14 @@ void RpcConnection::AsyncRpc(
|
||||
std::shared_ptr<::google::protobuf::MessageLite> resp,
|
||||
const RpcCallback &handler) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
AsyncRpc_locked(method_name, req, resp, handler);
|
||||
}
|
||||
|
||||
void RpcConnection::AsyncRpc_locked(
|
||||
const std::string &method_name, const ::google::protobuf::MessageLite *req,
|
||||
std::shared_ptr<::google::protobuf::MessageLite> resp,
|
||||
const RpcCallback &handler) {
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
auto wrapped_handler =
|
||||
[resp, handler](pbio::CodedInputStream *is, const Status &status) {
|
||||
@ -283,29 +378,21 @@ void RpcConnection::AsyncRpc(
|
||||
handler(status);
|
||||
};
|
||||
|
||||
auto r = std::make_shared<Request>(engine_, method_name, req,
|
||||
int call_id = (method_name != SASL_METHOD_NAME ? engine_->NextCallId() : RpcEngine::kCallIdSasl);
|
||||
auto r = std::make_shared<Request>(engine_, method_name, call_id, req,
|
||||
std::move(wrapped_handler));
|
||||
|
||||
if (connected_ == kDisconnected) {
|
||||
// Oops. The connection failed _just_ before the engine got a chance
|
||||
// to send it. Register it as a failure
|
||||
Status status = Status::ResourceUnavailable("RpcConnection closed before send.");
|
||||
auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
|
||||
assert(r_vector[0].get() != nullptr);
|
||||
|
||||
engine_->AsyncRpcCommsError(status, shared_from_this(), r_vector);
|
||||
} else {
|
||||
pending_requests_.push_back(r);
|
||||
|
||||
if (connected_ == kConnected) { // Dont flush if we're waiting or handshaking
|
||||
FlushPendingRequests();
|
||||
}
|
||||
}
|
||||
auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
|
||||
SendRpcRequests(r_vector);
|
||||
}
|
||||
|
||||
void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requests) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
LOG_TRACE(kRPC, << "RpcConnection::AsyncRpc[] called; connected=" << ToString(connected_));
|
||||
SendRpcRequests(requests);
|
||||
}
|
||||
|
||||
void RpcConnection::SendRpcRequests(const std::vector<std::shared_ptr<Request> > & requests) {
|
||||
LOG_TRACE(kRPC, << "RpcConnection::SendRpcRequests[] called; connected=" << ToString(connected_));
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
if (connected_ == kDisconnected) {
|
||||
// Oops. The connection failed _just_ before the engine got a chance
|
||||
@ -315,9 +402,12 @@ void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requ
|
||||
} else {
|
||||
pending_requests_.reserve(pending_requests_.size() + requests.size());
|
||||
for (auto r: requests) {
|
||||
pending_requests_.push_back(r);
|
||||
if (r->method_name() != SASL_METHOD_NAME)
|
||||
pending_requests_.push_back(r);
|
||||
else
|
||||
auth_requests_.push_back(r);
|
||||
}
|
||||
if (connected_ == kConnected) { // Dont flush if we're waiting or handshaking
|
||||
if (connected_ == kConnected || connected_ == kAuthenticating) { // Dont flush if we're waiting or handshaking
|
||||
FlushPendingRequests();
|
||||
}
|
||||
}
|
||||
@ -341,6 +431,9 @@ void RpcConnection::PreEnqueueRequests(
|
||||
void RpcConnection::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
event_handlers_ = event_handlers;
|
||||
if (sasl_protocol_) {
|
||||
sasl_protocol_->SetEventHandlers(event_handlers);
|
||||
}
|
||||
}
|
||||
|
||||
void RpcConnection::SetClusterName(std::string cluster_name) {
|
||||
@ -401,6 +494,7 @@ std::string RpcConnection::ToString(ConnectedState connected) {
|
||||
switch(connected) {
|
||||
case kNotYetConnected: return "NotYetConnected";
|
||||
case kConnecting: return "Connecting";
|
||||
case kAuthenticating: return "Authenticating";
|
||||
case kConnected: return "Connected";
|
||||
case kDisconnected: return "Disconnected";
|
||||
default: return "Invalid ConnectedState";
|
||||
|
@ -20,9 +20,11 @@
|
||||
|
||||
#include "rpc_engine.h"
|
||||
|
||||
#include "common/auth_info.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/util.h"
|
||||
#include "common/libhdfs_events_impl.h"
|
||||
#include "sasl_protocol.h"
|
||||
|
||||
#include <asio/connect.hpp>
|
||||
#include <asio/read.hpp>
|
||||
@ -37,10 +39,12 @@ public:
|
||||
virtual ~RpcConnectionImpl() override;
|
||||
|
||||
virtual void Connect(const std::vector<::asio::ip::tcp::endpoint> &server,
|
||||
const AuthInfo & auth_info,
|
||||
RpcCallback &handler);
|
||||
virtual void ConnectAndFlush(
|
||||
const std::vector<::asio::ip::tcp::endpoint> &server) override;
|
||||
virtual void Handshake(RpcCallback &handler) override;
|
||||
virtual void SendHandshake(RpcCallback &handler) override;
|
||||
virtual void SendContext(RpcCallback &handler) override;
|
||||
virtual void Disconnect() override;
|
||||
virtual void OnSendCompleted(const ::asio::error_code &ec,
|
||||
size_t transferred) override;
|
||||
@ -59,7 +63,6 @@ public:
|
||||
NextLayer next_layer_;
|
||||
|
||||
void ConnectComplete(const ::asio::error_code &ec);
|
||||
void HandshakeComplete(const Status &s);
|
||||
};
|
||||
|
||||
template <class NextLayer>
|
||||
@ -84,8 +87,13 @@ RpcConnectionImpl<NextLayer>::~RpcConnectionImpl() {
|
||||
|
||||
template <class NextLayer>
|
||||
void RpcConnectionImpl<NextLayer>::Connect(
|
||||
const std::vector<::asio::ip::tcp::endpoint> &server, RpcCallback &handler) {
|
||||
const std::vector<::asio::ip::tcp::endpoint> &server,
|
||||
const AuthInfo & auth_info,
|
||||
RpcCallback &handler) {
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::Connect called");
|
||||
|
||||
this->auth_info_ = auth_info;
|
||||
|
||||
auto connectionSuccessfulReq = std::make_shared<Request>(
|
||||
engine_, [handler](::google::protobuf::io::CodedInputStream *is,
|
||||
const Status &status) {
|
||||
@ -147,7 +155,7 @@ void RpcConnectionImpl<NextLayer>::ConnectComplete(const ::asio::error_code &ec)
|
||||
|
||||
if (status.ok()) {
|
||||
StartReading();
|
||||
Handshake([shared_this, this](const Status & s) {
|
||||
SendHandshake([shared_this, this](const Status & s) {
|
||||
HandshakeComplete(s);
|
||||
});
|
||||
} else {
|
||||
@ -172,24 +180,10 @@ void RpcConnectionImpl<NextLayer>::ConnectComplete(const ::asio::error_code &ec)
|
||||
}
|
||||
|
||||
template <class NextLayer>
|
||||
void RpcConnectionImpl<NextLayer>::HandshakeComplete(const Status &s) {
|
||||
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::HandshakeComplete called");
|
||||
|
||||
if (s.ok()) {
|
||||
FlushPendingRequests();
|
||||
} else {
|
||||
CommsError(s);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
template <class NextLayer>
|
||||
void RpcConnectionImpl<NextLayer>::Handshake(RpcCallback &handler) {
|
||||
void RpcConnectionImpl<NextLayer>::SendHandshake(RpcCallback &handler) {
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::Handshake called");
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::SendHandshake called");
|
||||
|
||||
auto shared_this = shared_from_this();
|
||||
auto handshake_packet = PrepareHandshakePacket();
|
||||
@ -197,9 +191,22 @@ void RpcConnectionImpl<NextLayer>::Handshake(RpcCallback &handler) {
|
||||
[handshake_packet, handler, shared_this, this](
|
||||
const ::asio::error_code &ec, size_t) {
|
||||
Status status = ToStatus(ec);
|
||||
if (status.ok() && connected_ == kConnecting) {
|
||||
connected_ = kConnected;
|
||||
}
|
||||
handler(status);
|
||||
});
|
||||
}
|
||||
|
||||
template <class NextLayer>
|
||||
void RpcConnectionImpl<NextLayer>::SendContext(RpcCallback &handler) {
|
||||
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::SendContext called");
|
||||
|
||||
auto shared_this = shared_from_this();
|
||||
auto context_packet = PrepareContextPacket();
|
||||
::asio::async_write(next_layer_, asio::buffer(*context_packet),
|
||||
[context_packet, handler, shared_this, this](
|
||||
const ::asio::error_code &ec, size_t) {
|
||||
Status status = ToStatus(ec);
|
||||
handler(status);
|
||||
});
|
||||
}
|
||||
@ -232,29 +239,43 @@ void RpcConnectionImpl<NextLayer>::FlushPendingRequests() {
|
||||
|
||||
LOG_TRACE(kRPC, << "RpcConnectionImpl::FlushPendingRequests called");
|
||||
|
||||
if (pending_requests_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (connected_ == kDisconnected) {
|
||||
LOG_WARN(kRPC, << "RpcConnectionImpl::FlushPendingRequests attempted to flush a disconnected connection");
|
||||
return;
|
||||
}
|
||||
if (connected_ != kConnected) {
|
||||
LOG_DEBUG(kRPC, << "RpcConnectionImpl::FlushPendingRequests attempted to flush a " << ToString(connected_) << " connection");
|
||||
}
|
||||
|
||||
// Don't send if we don't need to
|
||||
if (request_over_the_wire_) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<Request> req = pending_requests_.front();
|
||||
auto weak_req = std::weak_ptr<Request>(req);
|
||||
pending_requests_.erase(pending_requests_.begin());
|
||||
std::shared_ptr<Request> req;
|
||||
switch (connected_) {
|
||||
case kNotYetConnected:
|
||||
return;
|
||||
case kConnecting:
|
||||
return;
|
||||
case kAuthenticating:
|
||||
if (auth_requests_.empty()) {
|
||||
return;
|
||||
}
|
||||
req = auth_requests_.front();
|
||||
auth_requests_.erase(auth_requests_.begin());
|
||||
break;
|
||||
case kConnected:
|
||||
if (pending_requests_.empty()) {
|
||||
return;
|
||||
}
|
||||
req = pending_requests_.front();
|
||||
pending_requests_.erase(pending_requests_.begin());
|
||||
break;
|
||||
case kDisconnected:
|
||||
LOG_DEBUG(kRPC, << "RpcConnectionImpl::FlushPendingRequests attempted to flush a " << ToString(connected_) << " connection");
|
||||
return;
|
||||
default:
|
||||
LOG_DEBUG(kRPC, << "RpcConnectionImpl::FlushPendingRequests invalid state: " << ToString(connected_));
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
|
||||
auto weak_this = std::weak_ptr<RpcConnection>(shared_this);
|
||||
auto weak_req = std::weak_ptr<Request>(req);
|
||||
|
||||
std::shared_ptr<std::string> payload = std::make_shared<std::string>();
|
||||
req->GetPacket(payload.get());
|
||||
if (!payload->empty()) {
|
||||
@ -322,31 +343,31 @@ void RpcConnectionImpl<NextLayer>::OnRecvCompleted(const ::asio::error_code &asi
|
||||
return;
|
||||
}
|
||||
|
||||
if (!response_) { /* start a new one */
|
||||
response_ = std::make_shared<Response>();
|
||||
if (!current_response_state_) { /* start a new one */
|
||||
current_response_state_ = std::make_shared<Response>();
|
||||
}
|
||||
|
||||
if (response_->state_ == Response::kReadLength) {
|
||||
response_->state_ = Response::kReadContent;
|
||||
auto buf = ::asio::buffer(reinterpret_cast<char *>(&response_->length_),
|
||||
sizeof(response_->length_));
|
||||
if (current_response_state_->state_ == Response::kReadLength) {
|
||||
current_response_state_->state_ = Response::kReadContent;
|
||||
auto buf = ::asio::buffer(reinterpret_cast<char *>(¤t_response_state_->length_),
|
||||
sizeof(current_response_state_->length_));
|
||||
asio::async_read(
|
||||
next_layer_, buf,
|
||||
[shared_this, this](const ::asio::error_code &ec, size_t size) {
|
||||
OnRecvCompleted(ec, size);
|
||||
});
|
||||
} else if (response_->state_ == Response::kReadContent) {
|
||||
response_->state_ = Response::kParseResponse;
|
||||
response_->length_ = ntohl(response_->length_);
|
||||
response_->data_.resize(response_->length_);
|
||||
} else if (current_response_state_->state_ == Response::kReadContent) {
|
||||
current_response_state_->state_ = Response::kParseResponse;
|
||||
current_response_state_->length_ = ntohl(current_response_state_->length_);
|
||||
current_response_state_->data_.resize(current_response_state_->length_);
|
||||
asio::async_read(
|
||||
next_layer_, ::asio::buffer(response_->data_),
|
||||
next_layer_, ::asio::buffer(current_response_state_->data_),
|
||||
[shared_this, this](const ::asio::error_code &ec, size_t size) {
|
||||
OnRecvCompleted(ec, size);
|
||||
});
|
||||
} else if (response_->state_ == Response::kParseResponse) {
|
||||
HandleRpcResponse(response_);
|
||||
response_ = nullptr;
|
||||
} else if (current_response_state_->state_ == Response::kParseResponse) {
|
||||
HandleRpcResponse(current_response_state_);
|
||||
current_response_state_ = nullptr;
|
||||
StartReading();
|
||||
}
|
||||
}
|
||||
@ -358,7 +379,7 @@ void RpcConnectionImpl<NextLayer>::Disconnect() {
|
||||
LOG_INFO(kRPC, << "RpcConnectionImpl::Disconnect called");
|
||||
|
||||
request_over_the_wire_.reset();
|
||||
if (connected_ == kConnecting || connected_ == kConnected) {
|
||||
if (connected_ == kConnecting || connected_ == kAuthenticating || connected_ == kConnected) {
|
||||
// Don't print out errors, we were expecting a disconnect here
|
||||
SafeDisconnect(get_asio_socket_ptr(&next_layer_));
|
||||
}
|
||||
|
@ -34,13 +34,18 @@ RpcEngine::RpcEngine(::asio::io_service *io_service, const Options &options,
|
||||
: io_service_(io_service),
|
||||
options_(options),
|
||||
client_name_(client_name),
|
||||
user_name_(user_name),
|
||||
protocol_name_(protocol_name),
|
||||
protocol_version_(protocol_version),
|
||||
retry_policy_(std::move(MakeRetryPolicy(options))),
|
||||
call_id_(0),
|
||||
retry_timer(*io_service),
|
||||
event_handlers_(std::make_shared<LibhdfsEvents>()) {
|
||||
|
||||
auth_info_.setUser(user_name);
|
||||
if (options.authentication == Options::kKerberos) {
|
||||
auth_info_.setMethod(AuthInfo::kKerberos);
|
||||
}
|
||||
|
||||
LOG_DEBUG(kRPC, << "RpcEngine::RpcEngine called");
|
||||
}
|
||||
|
||||
@ -54,7 +59,7 @@ void RpcEngine::Connect(const std::string &cluster_name,
|
||||
cluster_name_ = cluster_name;
|
||||
|
||||
conn_ = InitializeConnection();
|
||||
conn_->Connect(last_endpoints_, handler);
|
||||
conn_->Connect(last_endpoints_, auth_info_, handler);
|
||||
}
|
||||
|
||||
void RpcEngine::Shutdown() {
|
||||
@ -121,6 +126,7 @@ std::shared_ptr<RpcConnection> RpcEngine::InitializeConnection()
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
void RpcEngine::AsyncRpcCommsError(
|
||||
const Status &status,
|
||||
std::shared_ptr<RpcConnection> failedConnection,
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "hdfspp/options.h"
|
||||
#include "hdfspp/status.h"
|
||||
|
||||
#include "common/auth_info.h"
|
||||
#include "common/retry_policy.h"
|
||||
#include "common/libhdfs_events_impl.h"
|
||||
#include "common/new_delete.h"
|
||||
@ -56,6 +57,7 @@ typedef const std::function<void(const Status &)> RpcCallback;
|
||||
|
||||
class LockFreeRpcEngine;
|
||||
class RpcConnection;
|
||||
class SaslProtocol;
|
||||
|
||||
/*
|
||||
* Internal bookkeeping for an outstanding request from the consumer.
|
||||
@ -69,9 +71,9 @@ class Request {
|
||||
typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
|
||||
const Status &status)> Handler;
|
||||
|
||||
Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
||||
Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
||||
const std::string &request, Handler &&callback);
|
||||
Request(LockFreeRpcEngine *engine, const std::string &method_name,
|
||||
Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
|
||||
const ::google::protobuf::MessageLite *request, Handler &&callback);
|
||||
|
||||
// Null request (with no actual message) used to track the state of an
|
||||
@ -79,6 +81,7 @@ class Request {
|
||||
Request(LockFreeRpcEngine *engine, Handler &&handler);
|
||||
|
||||
int call_id() const { return call_id_; }
|
||||
std::string method_name() const { return method_name_; }
|
||||
::asio::deadline_timer &timer() { return timer_; }
|
||||
int IncrementRetryCount() { return retry_count_++; }
|
||||
void GetPacket(std::string *res) const;
|
||||
@ -117,9 +120,9 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
|
||||
// Note that a single server can have multiple endpoints - especially both
|
||||
// an ipv4 and ipv6 endpoint
|
||||
virtual void Connect(const std::vector<::asio::ip::tcp::endpoint> &server,
|
||||
const AuthInfo & auth_info,
|
||||
RpcCallback &handler) = 0;
|
||||
virtual void ConnectAndFlush(const std::vector<::asio::ip::tcp::endpoint> &server) = 0;
|
||||
virtual void Handshake(RpcCallback &handler) = 0;
|
||||
virtual void Disconnect() = 0;
|
||||
|
||||
void StartReading();
|
||||
@ -157,18 +160,33 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
|
||||
};
|
||||
|
||||
|
||||
LockFreeRpcEngine *const engine_;
|
||||
// Initial handshaking protocol: connect->handshake-->(auth)?-->context->connected
|
||||
virtual void SendHandshake(RpcCallback &handler) = 0;
|
||||
void HandshakeComplete(const Status &s);
|
||||
void AuthComplete(const Status &s, const AuthInfo & new_auth_info);
|
||||
void AuthComplete_locked(const Status &s, const AuthInfo & new_auth_info);
|
||||
virtual void SendContext(RpcCallback &handler) = 0;
|
||||
void ContextComplete(const Status &s);
|
||||
|
||||
|
||||
virtual void OnSendCompleted(const ::asio::error_code &ec,
|
||||
size_t transferred) = 0;
|
||||
virtual void OnRecvCompleted(const ::asio::error_code &ec,
|
||||
size_t transferred) = 0;
|
||||
virtual void FlushPendingRequests()=0; // Synchronously write the next request
|
||||
|
||||
void AsyncRpc_locked(
|
||||
const std::string &method_name,
|
||||
const ::google::protobuf::MessageLite *req,
|
||||
std::shared_ptr<::google::protobuf::MessageLite> resp,
|
||||
const RpcCallback &handler);
|
||||
void SendRpcRequests(const std::vector<std::shared_ptr<Request> > & requests);
|
||||
void AsyncFlushPendingRequests(); // Queue requests to be flushed at a later time
|
||||
|
||||
|
||||
|
||||
std::shared_ptr<std::string> PrepareHandshakePacket();
|
||||
std::shared_ptr<std::string> PrepareContextPacket();
|
||||
static std::string SerializeRpcRequest(
|
||||
const std::string &method_name,
|
||||
const ::google::protobuf::MessageLite *req);
|
||||
@ -180,23 +198,31 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
|
||||
void ClearAndDisconnect(const ::asio::error_code &ec);
|
||||
std::shared_ptr<Request> RemoveFromRunningQueue(int call_id);
|
||||
|
||||
std::shared_ptr<Response> response_;
|
||||
LockFreeRpcEngine *const engine_;
|
||||
std::shared_ptr<Response> current_response_state_;
|
||||
AuthInfo auth_info_;
|
||||
|
||||
// Connection can have deferred connection, especially when we're pausing
|
||||
// during retry
|
||||
enum ConnectedState {
|
||||
kNotYetConnected,
|
||||
kConnecting,
|
||||
kAuthenticating,
|
||||
kConnected,
|
||||
kDisconnected
|
||||
};
|
||||
static std::string ToString(ConnectedState connected);
|
||||
|
||||
ConnectedState connected_;
|
||||
|
||||
// State machine for performing a SASL handshake
|
||||
std::shared_ptr<SaslProtocol> sasl_protocol_;
|
||||
// The request being sent over the wire; will also be in requests_on_fly_
|
||||
std::shared_ptr<Request> request_over_the_wire_;
|
||||
// Requests to be sent over the wire
|
||||
std::vector<std::shared_ptr<Request>> pending_requests_;
|
||||
// Requests to be sent over the wire during authentication; not retried if
|
||||
// there is a connection error
|
||||
std::vector<std::shared_ptr<Request>> auth_requests_;
|
||||
// Requests that are waiting for responses
|
||||
typedef std::unordered_map<int, std::shared_ptr<Request>> RequestOnFlyMap;
|
||||
RequestOnFlyMap requests_on_fly_;
|
||||
@ -206,6 +232,8 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
|
||||
|
||||
// Lock for mutable parts of this class that need to be thread safe
|
||||
std::mutex connection_state_lock_;
|
||||
|
||||
friend class SaslProtocol;
|
||||
};
|
||||
|
||||
|
||||
@ -248,7 +276,8 @@ class RpcEngine : public LockFreeRpcEngine {
|
||||
kCallIdAuthorizationFailed = -1,
|
||||
kCallIdInvalid = -2,
|
||||
kCallIdConnectionContext = -3,
|
||||
kCallIdPing = -4
|
||||
kCallIdPing = -4,
|
||||
kCallIdSasl = -33
|
||||
};
|
||||
|
||||
RpcEngine(::asio::io_service *io_service, const Options &options,
|
||||
@ -286,7 +315,7 @@ class RpcEngine : public LockFreeRpcEngine {
|
||||
void TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn);
|
||||
|
||||
const std::string &client_name() const override { return client_name_; }
|
||||
const std::string &user_name() const override { return user_name_; }
|
||||
const std::string &user_name() const override { return auth_info_.getUser(); }
|
||||
const std::string &protocol_name() const override { return protocol_name_; }
|
||||
int protocol_version() const override { return protocol_version_; }
|
||||
::asio::io_service &io_service() override { return *io_service_; }
|
||||
@ -307,10 +336,10 @@ private:
|
||||
::asio::io_service * const io_service_;
|
||||
const Options options_;
|
||||
const std::string client_name_;
|
||||
const std::string user_name_;
|
||||
const std::string protocol_name_;
|
||||
const int protocol_version_;
|
||||
const std::unique_ptr<const RetryPolicy> retry_policy_; //null --> no retry
|
||||
AuthInfo auth_info_;
|
||||
std::string cluster_name_;
|
||||
std::atomic_int call_id_;
|
||||
::asio::deadline_timer retry_timer;
|
||||
|
@ -0,0 +1,185 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sasl_engine.h"
|
||||
|
||||
#include "common/logging.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace hdfs {
|
||||
|
||||
/*****************************************************************************
|
||||
* BASE CLASS
|
||||
*/
|
||||
|
||||
SaslEngine::State SaslEngine::getState()
|
||||
{
|
||||
return state_;
|
||||
}
|
||||
|
||||
SaslEngine::~SaslEngine() {
|
||||
}
|
||||
|
||||
Status SaslEngine::setKerberosInfo(const std::string &principal)
|
||||
{
|
||||
principal_ = principal;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SaslEngine::setPasswordInfo(const std::string &id,
|
||||
const std::string &password)
|
||||
{
|
||||
id_ = id;
|
||||
password_ = password;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
#ifdef USE_GSASL
|
||||
/*****************************************************************************
|
||||
* GSASL
|
||||
*/
|
||||
|
||||
#include <gsasl.h>
|
||||
|
||||
|
||||
/*****************************************************************************
|
||||
* UTILITY FUNCTIONS
|
||||
*/
|
||||
|
||||
Status gsasl_rc_to_status(int rc)
|
||||
{
|
||||
if (rc == GSASL_OK) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
ss << "Cannot initialize client (" << rc << "): " << gsasl_strerror(rc);
|
||||
return Status::Error(ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
GSaslEngine::~GSaslEngine()
|
||||
{
|
||||
if (session_ != nullptr) {
|
||||
gsasl_finish(session_);
|
||||
}
|
||||
|
||||
if (ctx_ != nullptr) {
|
||||
gsasl_done(ctx_);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Status, SaslMethod> GSaslEngine::start(const std::vector<SaslMethod> &protocols)
|
||||
{
|
||||
int rc = gsasl_init(&ctx_);
|
||||
if (rc != GSASL_OK) {
|
||||
state_ = kError;
|
||||
return std::make_pair(gsasl_rc_to_status(rc), SaslMethod());
|
||||
}
|
||||
|
||||
// Hack to only do GSSAPI at the moment
|
||||
for (auto protocol: protocols) {
|
||||
if (protocol.mechanism == "GSSAPI") {
|
||||
Status init = init_kerberos(protocol);
|
||||
if (init.ok()) {
|
||||
state_ = kWaitingForData;
|
||||
return std::make_pair(init, protocol);
|
||||
} else {
|
||||
state_ = kError;
|
||||
return std::make_pair(init, SaslMethod());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state_ = kError;
|
||||
return std::make_pair(Status::Error("No good protocol"), SaslMethod());
|
||||
}
|
||||
|
||||
Status GSaslEngine::init_kerberos(const SaslMethod & mechanism) {
|
||||
/* Create new authentication session. */
|
||||
int rc = gsasl_client_start(ctx_, mechanism.mechanism.c_str(), &session_);
|
||||
if (rc != GSASL_OK) {
|
||||
return gsasl_rc_to_status(rc);
|
||||
}
|
||||
|
||||
if (!principal_) {
|
||||
return Status::Error("Attempted kerberos authentication with no principal");
|
||||
}
|
||||
|
||||
gsasl_property_set(session_, GSASL_SERVICE, mechanism.protocol.c_str());
|
||||
gsasl_property_set(session_, GSASL_AUTHID, principal_.value().c_str());
|
||||
gsasl_property_set(session_, GSASL_HOSTNAME, mechanism.serverid.c_str());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::pair<Status, std::string> GSaslEngine::step(const std::string data)
|
||||
{
|
||||
if (state_ != kWaitingForData)
|
||||
LOG_WARN(kRPC, << "GSaslEngine::step when state is " << state_);
|
||||
|
||||
char * output = NULL;
|
||||
size_t outputSize;
|
||||
int rc = gsasl_step(session_, data.c_str(), data.size(), &output,
|
||||
&outputSize);
|
||||
|
||||
if (rc == GSASL_NEEDS_MORE || rc == GSASL_OK) {
|
||||
std::string retval(output, output ? outputSize : 0);
|
||||
if (output) {
|
||||
free(output);
|
||||
}
|
||||
|
||||
if (rc == GSASL_OK) {
|
||||
state_ = kSuccess;
|
||||
}
|
||||
|
||||
return std::make_pair(Status::OK(), retval);
|
||||
}
|
||||
else {
|
||||
if (output) {
|
||||
free(output);
|
||||
}
|
||||
state_ = kFailure;
|
||||
return std::make_pair(gsasl_rc_to_status(rc), "");
|
||||
}
|
||||
}
|
||||
|
||||
Status GSaslEngine::finish()
|
||||
{
|
||||
if (state_ != kSuccess && state_ != kFailure && state_ != kError )
|
||||
LOG_WARN(kRPC, << "GSaslEngine::finish when state is " << state_);
|
||||
|
||||
if (session_ != nullptr) {
|
||||
gsasl_finish(session_);
|
||||
session_ = NULL;
|
||||
}
|
||||
|
||||
if (ctx_ != nullptr) {
|
||||
gsasl_done(ctx_);
|
||||
ctx_ = nullptr;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // USE_GSASL
|
||||
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LIB_RPC_SASLENGINE_H
|
||||
#define LIB_RPC_SASLENGINE_H
|
||||
|
||||
#include "hdfspp/status.h"
|
||||
#include "optional.hpp"
|
||||
|
||||
#ifdef USE_GSASL
|
||||
#include "gsasl.h"
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace hdfs {
|
||||
|
||||
template <class T>
|
||||
using optional = std::experimental::optional<T>;
|
||||
|
||||
class SaslMethod {
|
||||
public:
|
||||
std::string protocol;
|
||||
std::string mechanism;
|
||||
std::string serverid;
|
||||
void * data;
|
||||
};
|
||||
|
||||
class SaslEngine {
|
||||
public:
|
||||
enum State {
|
||||
kUnstarted,
|
||||
kWaitingForData,
|
||||
kSuccess,
|
||||
kFailure,
|
||||
kError,
|
||||
};
|
||||
|
||||
// State transitions:
|
||||
// \--------------------------/
|
||||
// kUnstarted --start--> kWaitingForData --step-+--> kSuccess --finish--v
|
||||
// \-> kFailure -/
|
||||
|
||||
SaslEngine() : state_ (kUnstarted) {}
|
||||
virtual ~SaslEngine();
|
||||
|
||||
// Must be called when state is kUnstarted
|
||||
Status setKerberosInfo(const std::string &principal);
|
||||
// Must be called when state is kUnstarted
|
||||
Status setPasswordInfo(const std::string &id,
|
||||
const std::string &password);
|
||||
|
||||
// Returns the current state
|
||||
State getState();
|
||||
|
||||
// Must be called when state is kUnstarted
|
||||
virtual std::pair<Status,SaslMethod> start(
|
||||
const std::vector<SaslMethod> &protocols) = 0;
|
||||
|
||||
// Must be called when state is kWaitingForData
|
||||
// Returns kOK and any data that should be sent to the server
|
||||
virtual std::pair<Status,std::string> step(const std::string data) = 0;
|
||||
|
||||
// Must only be called when state is kSuccess, kFailure, or kError
|
||||
virtual Status finish() = 0;
|
||||
protected:
|
||||
State state_;
|
||||
|
||||
optional<std::string> principal_;
|
||||
optional<std::string> id_;
|
||||
optional<std::string> password_;
|
||||
|
||||
};
|
||||
|
||||
#ifdef USE_GSASL
|
||||
class GSaslEngine : public SaslEngine
|
||||
{
|
||||
public:
|
||||
GSaslEngine() : SaslEngine(), ctx_(nullptr), session_(nullptr) {}
|
||||
virtual ~GSaslEngine();
|
||||
|
||||
virtual std::pair<Status,SaslMethod> start(
|
||||
const std::vector<SaslMethod> &protocols);
|
||||
virtual std::pair<Status,std::string> step(const std::string data);
|
||||
virtual Status finish();
|
||||
private:
|
||||
Gsasl * ctx_;
|
||||
Gsasl_session * session_;
|
||||
|
||||
Status init_kerberos(const SaslMethod & mechanism);
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef USE_CYRUS_SASL
|
||||
class CyrusSaslEngine : public SaslEngine
|
||||
{
|
||||
public:
|
||||
GSaslEngine() : SaslEngine(), ctx_(nullptr), session_(nullptr) {}
|
||||
virtual ~GSaslEngine();
|
||||
|
||||
virtual std::pair<Status,SaslMethod> start(
|
||||
const std::vector<SaslMethod> &protocols);
|
||||
virtual std::pair<Status,std::string> step(const std::string data);
|
||||
virtual Status finish();
|
||||
private:
|
||||
};
|
||||
#endif
|
||||
|
||||
}
|
||||
#endif /* LIB_RPC_SASLENGINE_H */
|
@ -0,0 +1,320 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sasl_protocol.h"
|
||||
#include "sasl_engine.h"
|
||||
#include "rpc_engine.h"
|
||||
#include "common/logging.h"
|
||||
|
||||
#include <optional.hpp>
|
||||
|
||||
namespace hdfs {
|
||||
|
||||
using namespace hadoop::common;
|
||||
using namespace google::protobuf;
|
||||
template <class T>
|
||||
using optional = std::experimental::optional<T>;
|
||||
|
||||
/*****
|
||||
* Threading model: all entry points need to acquire the sasl_lock before accessing
|
||||
* members of the class
|
||||
*
|
||||
* Lifecycle model: asio may have outstanding callbacks into this class for arbitrary
|
||||
* amounts of time, so any references to the class must be shared_ptr's. The
|
||||
* SASLProtocol keeps a weak_ptr to the owning RpcConnection, which might go away,
|
||||
* so the weak_ptr should be locked only long enough to make callbacks into the
|
||||
* RpcConnection.
|
||||
*/
|
||||
|
||||
SaslProtocol::SaslProtocol(const std::string & cluster_name,
|
||||
const AuthInfo & auth_info,
|
||||
std::shared_ptr<RpcConnection> connection) :
|
||||
state_(kUnstarted),
|
||||
cluster_name_(cluster_name),
|
||||
auth_info_(auth_info),
|
||||
connection_(connection)
|
||||
{
|
||||
}
|
||||
|
||||
SaslProtocol::~SaslProtocol()
|
||||
{
|
||||
std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
|
||||
event_handlers_->call("SASL End", cluster_name_.c_str(), 0);
|
||||
}
|
||||
|
||||
void SaslProtocol::SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers) {
|
||||
std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
|
||||
event_handlers_ = event_handlers;
|
||||
}
|
||||
|
||||
void SaslProtocol::authenticate(std::function<void(const Status & status, const AuthInfo new_auth_info)> callback)
|
||||
{
|
||||
std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
|
||||
|
||||
LOG_TRACE(kRPC, << "Authenticating as " << auth_info_.getUser());
|
||||
|
||||
assert(state_ == kUnstarted);
|
||||
event_handlers_->call("SASL Start", cluster_name_.c_str(), 0);
|
||||
|
||||
callback_ = callback;
|
||||
state_ = kNegotiate;
|
||||
|
||||
std::shared_ptr<RpcSaslProto> req_msg = std::make_shared<RpcSaslProto>();
|
||||
req_msg->set_state(RpcSaslProto_SaslState_NEGOTIATE);
|
||||
|
||||
// We cheat here since this is always called while holding the RpcConnection's lock
|
||||
std::shared_ptr<RpcConnection> connection = connection_.lock();
|
||||
if (!connection) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<RpcSaslProto> resp_msg = std::make_shared<RpcSaslProto>();
|
||||
auto self(shared_from_this());
|
||||
connection->AsyncRpc_locked(SASL_METHOD_NAME, req_msg.get(), resp_msg,
|
||||
[self, req_msg, resp_msg] (const Status & status) { self->OnServerResponse(status, resp_msg.get()); } );
|
||||
}
|
||||
|
||||
AuthInfo::AuthMethod ParseMethod(const std::string & method)
|
||||
{
|
||||
if (0 == strcasecmp(method.c_str(), "SIMPLE")) {
|
||||
return AuthInfo::kSimple;
|
||||
}
|
||||
else if (0 == strcasecmp(method.c_str(), "KERBEROS")) {
|
||||
return AuthInfo::kKerberos;
|
||||
}
|
||||
else if (0 == strcasecmp(method.c_str(), "TOKEN")) {
|
||||
return AuthInfo::kToken;
|
||||
}
|
||||
else {
|
||||
return AuthInfo::kUnknownAuth;
|
||||
}
|
||||
}
|
||||
|
||||
void SaslProtocol::Negotiate(const hadoop::common::RpcSaslProto * response)
|
||||
{
|
||||
std::vector<SaslMethod> protocols;
|
||||
|
||||
bool simple_available = false;
|
||||
|
||||
#if defined USE_SASL
|
||||
#if defined USE_CYRUS_SASL
|
||||
sasl_engine_.reset(new CyrusSaslEngine());
|
||||
#elif defined USE_GSASL
|
||||
sasl_engine_.reset(new GSaslEngine());
|
||||
#else
|
||||
#error USE_SASL defined but no engine (USE_GSASL) defined
|
||||
#endif
|
||||
#endif
|
||||
if (auth_info_.getToken()) {
|
||||
sasl_engine_->setPasswordInfo(auth_info_.getToken().value().identifier,
|
||||
auth_info_.getToken().value().password);
|
||||
}
|
||||
sasl_engine_->setKerberosInfo(auth_info_.getUser()); // HDFS-10451 will look up principal by username
|
||||
|
||||
|
||||
auto auths = response->auths();
|
||||
for (int i = 0; i < auths.size(); ++i) {
|
||||
auto auth = auths.Get(i);
|
||||
AuthInfo::AuthMethod method = ParseMethod(auth.method());
|
||||
|
||||
switch(method) {
|
||||
case AuthInfo::kToken:
|
||||
case AuthInfo::kKerberos: {
|
||||
SaslMethod new_method;
|
||||
new_method.mechanism = auth.mechanism();
|
||||
new_method.protocol = auth.protocol();
|
||||
new_method.serverid = auth.serverid();
|
||||
new_method.data = const_cast<RpcSaslProto_SaslAuth *>(&response->auths().Get(i));
|
||||
protocols.push_back(new_method);
|
||||
}
|
||||
break;
|
||||
case AuthInfo::kSimple:
|
||||
simple_available = true;
|
||||
break;
|
||||
case AuthInfo::kUnknownAuth:
|
||||
LOG_WARN(kRPC, << "Unknown auth method " << auth.method() << "; ignoring");
|
||||
break;
|
||||
default:
|
||||
LOG_WARN(kRPC, << "Invalid auth type: " << method << "; ignoring");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!protocols.empty()) {
|
||||
auto init = sasl_engine_->start(protocols);
|
||||
if (init.first.ok()) {
|
||||
auto chosen_auth = reinterpret_cast<RpcSaslProto_SaslAuth *>(init.second.data);
|
||||
|
||||
// Prepare initiate message
|
||||
RpcSaslProto initiate;
|
||||
initiate.set_state(RpcSaslProto_SaslState_INITIATE);
|
||||
RpcSaslProto_SaslAuth * respAuth = initiate.add_auths();
|
||||
respAuth->CopyFrom(*chosen_auth);
|
||||
|
||||
LOG_TRACE(kRPC, << "Using auth: " << chosen_auth->protocol() << "/" <<
|
||||
chosen_auth->mechanism() << "/" << chosen_auth->serverid());
|
||||
|
||||
std::string challenge = chosen_auth->has_challenge() ? chosen_auth->challenge() : "";
|
||||
auto sasl_challenge = sasl_engine_->step(challenge);
|
||||
|
||||
if (sasl_challenge.first.ok()) {
|
||||
if (!sasl_challenge.second.empty()) {
|
||||
initiate.set_token(sasl_challenge.second);
|
||||
}
|
||||
|
||||
std::shared_ptr<RpcSaslProto> return_msg = std::make_shared<RpcSaslProto>();
|
||||
SendSaslMessage(initiate);
|
||||
return;
|
||||
} else {
|
||||
AuthComplete(sasl_challenge.first, auth_info_);
|
||||
return;
|
||||
}
|
||||
} else if (!simple_available) {
|
||||
// If simple IS available, fall through to below
|
||||
AuthComplete(init.first, auth_info_);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// There were no protocols, or the SaslEngine couldn't make one work
|
||||
if (simple_available) {
|
||||
// Simple was the only one we could use. That's OK.
|
||||
AuthComplete(Status::OK(), auth_info_);
|
||||
return;
|
||||
} else {
|
||||
// We didn't understand any of the protocols; give back some information
|
||||
std::stringstream ss;
|
||||
ss << "Client cannot authenticate via: ";
|
||||
|
||||
for (int i = 0; i < auths.size(); ++i) {
|
||||
auto auth = auths.Get(i);
|
||||
ss << auth.mechanism() << ", ";
|
||||
}
|
||||
|
||||
AuthComplete(Status::Error(ss.str().c_str()), auth_info_);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void SaslProtocol::Challenge(const hadoop::common::RpcSaslProto * challenge)
|
||||
{
|
||||
if (!sasl_engine_) {
|
||||
AuthComplete(Status::Error("Received challenge before negotiate"), auth_info_);
|
||||
return;
|
||||
}
|
||||
|
||||
RpcSaslProto response;
|
||||
response.CopyFrom(*challenge);
|
||||
response.set_state(RpcSaslProto_SaslState_RESPONSE);
|
||||
|
||||
std::string challenge_token = challenge->has_token() ? challenge->token() : "";
|
||||
auto sasl_response = sasl_engine_->step(challenge_token);
|
||||
|
||||
if (sasl_response.first.ok()) {
|
||||
response.set_token(sasl_response.second);
|
||||
|
||||
std::shared_ptr<RpcSaslProto> return_msg = std::make_shared<RpcSaslProto>();
|
||||
SendSaslMessage(response);
|
||||
} else {
|
||||
AuthComplete(sasl_response.first, auth_info_);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool SaslProtocol::SendSaslMessage(RpcSaslProto & message)
|
||||
{
|
||||
assert(lock_held(sasl_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
// RpcConnection might have been freed when we weren't looking. Lock it
|
||||
// to make sure it's there long enough for us
|
||||
std::shared_ptr<RpcConnection> connection = connection_.lock();
|
||||
if (!connection) {
|
||||
LOG_DEBUG(kRPC, << "Tried sending a SASL Message but the RPC connection was gone");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<RpcSaslProto> resp_msg = std::make_shared<RpcSaslProto>();
|
||||
auto self(shared_from_this());
|
||||
connection->AsyncRpc(SASL_METHOD_NAME, &message, resp_msg,
|
||||
[self, resp_msg] (const Status & status) {
|
||||
self->OnServerResponse(status, resp_msg.get());
|
||||
} );
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SaslProtocol::AuthComplete(const Status & status, const AuthInfo & auth_info)
|
||||
{
|
||||
assert(lock_held(sasl_state_lock_)); // Must be holding lock before calling
|
||||
|
||||
// RpcConnection might have been freed when we weren't looking. Lock it
|
||||
// to make sure it's there long enough for us
|
||||
std::shared_ptr<RpcConnection> connection = connection_.lock();
|
||||
if (!connection) {
|
||||
LOG_DEBUG(kRPC, << "Tried sending an AuthComplete but the RPC connection was gone: " << status.ToString());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
auth_info_.setMethod(AuthInfo::kAuthFailed);
|
||||
}
|
||||
|
||||
LOG_TRACE(kRPC, << "AuthComplete: " << status.ToString());
|
||||
connection->AuthComplete(status, auth_info);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SaslProtocol::OnServerResponse(const Status & status, const hadoop::common::RpcSaslProto * response)
|
||||
{
|
||||
std::lock_guard<std::mutex> state_lock(sasl_state_lock_);
|
||||
LOG_TRACE(kRPC, << "Received SASL response: " << status.ToString());
|
||||
|
||||
if (status.ok()) {
|
||||
switch(response->state()) {
|
||||
case RpcSaslProto_SaslState_NEGOTIATE:
|
||||
Negotiate(response);
|
||||
break;
|
||||
case RpcSaslProto_SaslState_CHALLENGE:
|
||||
Challenge(response);
|
||||
break;
|
||||
case RpcSaslProto_SaslState_SUCCESS:
|
||||
if (sasl_engine_) {
|
||||
sasl_engine_->finish();
|
||||
}
|
||||
AuthComplete(Status::OK(), auth_info_);
|
||||
break;
|
||||
|
||||
case RpcSaslProto_SaslState_INITIATE: // Server side only
|
||||
case RpcSaslProto_SaslState_RESPONSE: // Server side only
|
||||
case RpcSaslProto_SaslState_WRAP:
|
||||
LOG_ERROR(kRPC, << "Invalid client-side SASL state: " << response->state());
|
||||
AuthComplete(Status::Error("Invalid client-side state"), auth_info_);
|
||||
break;
|
||||
default:
|
||||
LOG_ERROR(kRPC, << "Unknown client-side SASL state: " << response->state());
|
||||
AuthComplete(Status::Error("Unknown client-side state"), auth_info_);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
AuthComplete(status, auth_info_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LIB_RPC_SASLPROTOCOL_H
|
||||
#define LIB_RPC_SASLPROTOCOL_H
|
||||
|
||||
#include "hdfspp/status.h"
|
||||
#include "common/auth_info.h"
|
||||
#include "common/libhdfs_events_impl.h"
|
||||
|
||||
#include <RpcHeader.pb.h>
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <functional>
|
||||
|
||||
namespace hdfs {
|
||||
|
||||
static constexpr const char * SASL_METHOD_NAME = "sasl message";
|
||||
|
||||
class RpcConnection;
|
||||
class SaslEngine;
|
||||
|
||||
class SaslProtocol : public std::enable_shared_from_this<SaslProtocol>
|
||||
{
|
||||
public:
|
||||
SaslProtocol(const std::string &cluster_name,
|
||||
const AuthInfo & auth_info,
|
||||
std::shared_ptr<RpcConnection> connection);
|
||||
virtual ~SaslProtocol();
|
||||
|
||||
void SetEventHandlers(std::shared_ptr<LibhdfsEvents> event_handlers);
|
||||
|
||||
// Start the async authentication process. Must be called while holding the
|
||||
// connection lock, but all callbacks will occur outside of the connection lock
|
||||
void authenticate(std::function<void(const Status & status, const AuthInfo new_auth_info)> callback);
|
||||
void OnServerResponse(const Status & status, const hadoop::common::RpcSaslProto * response);
|
||||
private:
|
||||
enum State {
|
||||
kUnstarted,
|
||||
kNegotiate,
|
||||
kAuthenticate,
|
||||
kComplete
|
||||
};
|
||||
|
||||
// Lock for access to members of the class
|
||||
std::mutex sasl_state_lock_;
|
||||
|
||||
State state_;
|
||||
const std::string cluster_name_;
|
||||
AuthInfo auth_info_;
|
||||
std::weak_ptr<RpcConnection> connection_;
|
||||
std::function<void(const Status & status, const AuthInfo new_auth_info)> callback_;
|
||||
std::unique_ptr<SaslEngine> sasl_engine_;
|
||||
std::shared_ptr<LibhdfsEvents> event_handlers_;
|
||||
|
||||
bool SendSaslMessage(hadoop::common::RpcSaslProto & message);
|
||||
bool AuthComplete(const Status & status, const AuthInfo & auth_info);
|
||||
|
||||
void Negotiate(const hadoop::common::RpcSaslProto * response);
|
||||
void Challenge(const hadoop::common::RpcSaslProto * response);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* LIB_RPC_SASLPROTOCOL_H */
|
@ -73,15 +73,15 @@ add_memcheck_test(retry_policy retry_policy_test)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_executable(rpc_engine_test rpc_engine_test.cc ${PROTO_TEST_SRCS} ${PROTO_TEST_HDRS})
|
||||
target_link_libraries(rpc_engine_test test_common rpc proto common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(rpc_engine_test test_common rpc proto common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(rpc_engine rpc_engine_test)
|
||||
|
||||
add_executable(bad_datanode_test bad_datanode_test.cc)
|
||||
target_link_libraries(bad_datanode_test rpc reader proto fs bindings_c rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(bad_datanode_test rpc reader proto fs bindings_c rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(bad_datanode bad_datanode_test)
|
||||
|
||||
add_executable(node_exclusion_test node_exclusion_test.cc)
|
||||
target_link_libraries(node_exclusion_test fs gmock_main common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(node_exclusion_test fs gmock_main common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(node_exclusion node_exclusion_test)
|
||||
|
||||
add_executable(configuration_test configuration_test.cc)
|
||||
@ -93,7 +93,7 @@ target_link_libraries(hdfs_configuration_test common gmock_main ${CMAKE_THREAD_L
|
||||
add_memcheck_test(hdfs_configuration hdfs_configuration_test)
|
||||
|
||||
add_executable(hdfspp_errors_test hdfspp_errors.cc)
|
||||
target_link_libraries(hdfspp_errors_test common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(hdfspp_errors_test common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(hdfspp_errors hdfspp_errors_test)
|
||||
|
||||
#This test requires a great deal of Hadoop Java infrastructure to run.
|
||||
@ -101,15 +101,15 @@ if(HADOOP_BUILD)
|
||||
add_library(hdfspp_test_shim_static STATIC hdfs_shim.c libhdfs_wrapper.c libhdfspp_wrapper.cc ${LIBHDFSPP_BINDING_C}/hdfs.cc)
|
||||
|
||||
build_libhdfs_test(libhdfs_threaded hdfspp_test_shim_static expect.c test_libhdfs_threaded.c ${OS_DIR}/thread.c)
|
||||
link_libhdfs_test(libhdfs_threaded hdfspp_test_shim_static fs reader rpc proto common connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} native_mini_dfs ${JAVA_JVM_LIBRARY})
|
||||
link_libhdfs_test(libhdfs_threaded hdfspp_test_shim_static fs reader rpc proto common connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} native_mini_dfs ${JAVA_JVM_LIBRARY})
|
||||
add_libhdfs_test(libhdfs_threaded hdfspp_test_shim_static)
|
||||
|
||||
endif(HADOOP_BUILD)
|
||||
|
||||
add_executable(hdfs_builder_test hdfs_builder_test.cc)
|
||||
target_link_libraries(hdfs_builder_test test_common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(hdfs_builder_test test_common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(hdfs_builder_test hdfs_builder_test)
|
||||
|
||||
add_executable(logging_test logging_test.cc)
|
||||
target_link_libraries(logging_test common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(logging_test common gmock_main bindings_c fs rpc proto common reader connection ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} ${SASL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
|
||||
add_memcheck_test(logging_test logging_test)
|
||||
|
Loading…
Reference in New Issue
Block a user