From 65be21267587f04a2c33af65b951211cc9085b15 Mon Sep 17 00:00:00 2001 From: Daryn Sharp Date: Mon, 29 Jul 2013 14:44:21 +0000 Subject: [PATCH] HADOOP-9698. [RPC v9] Client must honor server's SASL negotiate response (daryn) git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1508086 13f79535-47bb-0310-9956-ffa450edef68 --- .../hadoop-common/CHANGES.txt | 2 + .../java/org/apache/hadoop/ipc/Client.java | 131 ++----- .../java/org/apache/hadoop/ipc/Server.java | 113 +++--- .../apache/hadoop/security/SaslRpcClient.java | 341 +++++++++++++----- .../org/apache/hadoop/ipc/TestSaslRPC.java | 275 ++++++++------ 5 files changed, 503 insertions(+), 359 deletions(-) diff --git a/hadoop-common-project/hadoop-common/CHANGES.txt b/hadoop-common-project/hadoop-common/CHANGES.txt index 610a15800f..10b517587f 100644 --- a/hadoop-common-project/hadoop-common/CHANGES.txt +++ b/hadoop-common-project/hadoop-common/CHANGES.txt @@ -352,6 +352,8 @@ Release 2.1.0-beta - 2013-07-02 HADOOP-9683. [RPC v9] Wrap IpcConnectionContext in RPC headers (daryn) + HADOOP-9698. [RPC v9] Client must honor server's SASL negotiate response (daryn) + NEW FEATURES HADOOP-9283. Add support for running the Hadoop client on AIX. (atm) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index d57876b22f..0c6b765e39 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -82,11 +82,6 @@ import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; -import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; -import org.apache.hadoop.security.token.Token; -import org.apache.hadoop.security.token.TokenIdentifier; -import org.apache.hadoop.security.token.TokenInfo; -import org.apache.hadoop.security.token.TokenSelector; import org.apache.hadoop.util.ProtoUtil; import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.StringUtils; @@ -368,10 +363,9 @@ public synchronized Writable getRpcResponse() { * socket: responses may be delivered out of order. */ private class Connection extends Thread { private InetSocketAddress server; // server ip:port - private String serverPrincipal; // server's krb5 principal name private final ConnectionId remoteId; // connection id private AuthMethod authMethod; // authentication method - private Token token; + private AuthProtocol authProtocol; private int serviceClass; private SaslRpcClient saslRpcClient; @@ -418,45 +412,11 @@ public Connection(ConnectionId remoteId, int serviceClass) throws IOException { } UserGroupInformation ticket = remoteId.getTicket(); - Class protocol = remoteId.getProtocol(); - if (protocol != null) { - TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf); - if (tokenInfo != null) { - TokenSelector tokenSelector = null; - try { - tokenSelector = tokenInfo.value().newInstance(); - } catch (InstantiationException e) { - throw new IOException(e.toString()); - } catch (IllegalAccessException e) { - throw new IOException(e.toString()); - } - token = tokenSelector.selectToken( - SecurityUtil.buildTokenService(server), - ticket.getTokens()); - } - KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); - if (krbInfo != null) { - serverPrincipal = remoteId.getServerPrincipal(); - if (LOG.isDebugEnabled()) { - LOG.debug("RPC Server's Kerberos principal name for protocol=" - + protocol.getCanonicalName() + " is " + serverPrincipal); - } - } - } - - AuthenticationMethod authentication; - if (token != null) { - authentication = AuthenticationMethod.TOKEN; - } else if (ticket != null) { - authentication = ticket.getRealAuthenticationMethod(); - } else { // this only happens in lazy tests - authentication = AuthenticationMethod.SIMPLE; - } - authMethod = authentication.getAuthMethod(); - - if (LOG.isDebugEnabled()) - LOG.debug("Use " + authMethod + " authentication for protocol " - + (protocol == null? null: protocol.getSimpleName())); + // try SASL if security is enabled or if the ugi contains tokens. + // this causes a SIMPLE client with tokens to attempt SASL + boolean trySasl = UserGroupInformation.isSecurityEnabled() || + (ticket != null && !ticket.getTokens().isEmpty()); + this.authProtocol = trySasl ? AuthProtocol.SASL : AuthProtocol.NONE; this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " + server.toString() + @@ -567,11 +527,10 @@ private synchronized boolean shouldAuthenticateOverKrb() throws IOException { return false; } - private synchronized boolean setupSaslConnection(final InputStream in2, - final OutputStream out2) - throws IOException { - saslRpcClient = new SaslRpcClient(authMethod, token, serverPrincipal, - fallbackAllowed); + private synchronized AuthMethod setupSaslConnection(final InputStream in2, + final OutputStream out2) throws IOException, InterruptedException { + saslRpcClient = new SaslRpcClient(remoteId.getTicket(), + remoteId.getProtocol(), remoteId.getAddress(), conf); return saslRpcClient.saslConnect(in2, out2); } @@ -609,7 +568,8 @@ private synchronized void setupConnection() throws IOException { * client, to ensure Server matching address of the client connection * to host name in principal passed. */ - if (UserGroupInformation.isSecurityEnabled()) { + UserGroupInformation ticket = remoteId.getTicket(); + if (ticket != null && ticket.hasKerberosCredentials()) { KerberosInfo krbInfo = remoteId.getProtocol().getAnnotation(KerberosInfo.class); if (krbInfo != null && krbInfo.clientPrincipal() != null) { @@ -687,7 +647,7 @@ public Object run() throws IOException, InterruptedException { } else { String msg = "Couldn't setup connection for " + UserGroupInformation.getLoginUser().getUserName() + " to " - + serverPrincipal; + + remoteId; LOG.warn(msg); throw (IOException) new IOException(msg).initCause(ex); } @@ -723,19 +683,19 @@ private synchronized void setupIOstreams() { InputStream inStream = NetUtils.getInputStream(socket); OutputStream outStream = NetUtils.getOutputStream(socket); writeConnectionHeader(outStream); - if (authMethod != AuthMethod.SIMPLE) { + if (authProtocol == AuthProtocol.SASL) { final InputStream in2 = inStream; final OutputStream out2 = outStream; UserGroupInformation ticket = remoteId.getTicket(); if (ticket.getRealUser() != null) { ticket = ticket.getRealUser(); } - boolean continueSasl = false; try { - continueSasl = ticket - .doAs(new PrivilegedExceptionAction() { + authMethod = ticket + .doAs(new PrivilegedExceptionAction() { @Override - public Boolean run() throws IOException { + public AuthMethod run() + throws IOException, InterruptedException { return setupSaslConnection(in2, out2); } }); @@ -747,13 +707,15 @@ public Boolean run() throws IOException { ticket); continue; } - if (continueSasl) { + if (authMethod != AuthMethod.SIMPLE) { // Sasl connect is successful. Let's set up Sasl i/o streams. inStream = saslRpcClient.getInputStream(inStream); outStream = saslRpcClient.getOutputStream(outStream); - } else { - // fall back to simple auth because server told us so. - authMethod = AuthMethod.SIMPLE; + } else if (UserGroupInformation.isSecurityEnabled() && + !fallbackAllowed) { + throw new IOException("Server asks us to fall back to SIMPLE " + + "auth, but this client is configured to only allow secure " + + "connections."); } } @@ -873,14 +835,6 @@ private void writeConnectionHeader(OutputStream outStream) out.write(RpcConstants.HEADER.array()); out.write(RpcConstants.CURRENT_VERSION); out.write(serviceClass); - final AuthProtocol authProtocol; - switch (authMethod) { - case SIMPLE: - authProtocol = AuthProtocol.NONE; - break; - default: - authProtocol = AuthProtocol.SASL; - } out.write(authProtocol.callId); out.flush(); } @@ -1493,7 +1447,6 @@ public static class ConnectionId { final Class protocol; private static final int PRIME = 16777619; private final int rpcTimeout; - private final String serverPrincipal; private final int maxIdleTime; //connections will be culled if it was idle for //maxIdleTime msecs private final RetryPolicy connectionRetryPolicy; @@ -1504,15 +1457,13 @@ public static class ConnectionId { private final int pingInterval; // how often sends ping to the server in msecs ConnectionId(InetSocketAddress address, Class protocol, - UserGroupInformation ticket, int rpcTimeout, - String serverPrincipal, int maxIdleTime, + UserGroupInformation ticket, int rpcTimeout, int maxIdleTime, RetryPolicy connectionRetryPolicy, int maxRetriesOnSocketTimeouts, boolean tcpNoDelay, boolean doPing, int pingInterval) { this.protocol = protocol; this.address = address; this.ticket = ticket; this.rpcTimeout = rpcTimeout; - this.serverPrincipal = serverPrincipal; this.maxIdleTime = maxIdleTime; this.connectionRetryPolicy = connectionRetryPolicy; this.maxRetriesOnSocketTimeouts = maxRetriesOnSocketTimeouts; @@ -1537,10 +1488,6 @@ private int getRpcTimeout() { return rpcTimeout; } - String getServerPrincipal() { - return serverPrincipal; - } - int getMaxIdleTime() { return maxIdleTime; } @@ -1590,11 +1537,9 @@ static ConnectionId getConnectionId(InetSocketAddress addr, max, 1, TimeUnit.SECONDS); } - String remotePrincipal = getRemotePrincipal(conf, addr, protocol); boolean doPing = conf.getBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true); - return new ConnectionId(addr, protocol, ticket, - rpcTimeout, remotePrincipal, + return new ConnectionId(addr, protocol, ticket, rpcTimeout, conf.getInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_DEFAULT), connectionRetryPolicy, @@ -1607,25 +1552,6 @@ static ConnectionId getConnectionId(InetSocketAddress addr, (doPing ? Client.getPingInterval(conf) : 0)); } - private static String getRemotePrincipal(Configuration conf, - InetSocketAddress address, Class protocol) throws IOException { - if (!UserGroupInformation.isSecurityEnabled() || protocol == null) { - return null; - } - KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); - if (krbInfo != null) { - String serverKey = krbInfo.serverPrincipal(); - if (serverKey == null) { - throw new IOException( - "Can't obtain server Kerberos config key from protocol=" - + protocol.getCanonicalName()); - } - return SecurityUtil.getServerPrincipal(conf.get(serverKey), address - .getAddress()); - } - return null; - } - static boolean isEqual(Object a, Object b) { return a == null ? b == null : a.equals(b); } @@ -1644,7 +1570,6 @@ && isEqual(this.connectionRetryPolicy, that.connectionRetryPolicy) && this.pingInterval == that.pingInterval && isEqual(this.protocol, that.protocol) && this.rpcTimeout == that.rpcTimeout - && isEqual(this.serverPrincipal, that.serverPrincipal) && this.tcpNoDelay == that.tcpNoDelay && isEqual(this.ticket, that.ticket); } @@ -1660,8 +1585,6 @@ public int hashCode() { result = PRIME * result + pingInterval; result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode()); result = PRIME * result + rpcTimeout; - result = PRIME * result - + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode()); result = PRIME * result + (tcpNoDelay ? 1231 : 1237); result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode()); return result; @@ -1669,7 +1592,7 @@ public int hashCode() { @Override public String toString() { - return serverPrincipal + "@" + address; + return address.toString(); } } diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index a31f50c30e..60fecddb68 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -86,6 +86,7 @@ import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto; +import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslAuth; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslState; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.AccessControlException; @@ -795,7 +796,10 @@ void doRead(SelectionKey key) throws InterruptedException { LOG.info(getName() + ": readAndProcess caught InterruptedException", ieo); throw ieo; } catch (Exception e) { - // log stack trace for "interesting" exceptions not sent to client + // a WrappedRpcServerException is an exception that has been sent + // to the client, so the stacktrace is unnecessary; any other + // exceptions are unexpected internal server errors and thus the + // stacktrace should be logged LOG.info(getName() + ": readAndProcess from client " + c.getHostAddress() + " threw exception [" + e + "]", (e instanceof WrappedRpcServerException) ? null : e); @@ -1164,7 +1168,6 @@ public class Connection { private AuthMethod authMethod; private AuthProtocol authProtocol; private boolean saslContextEstablished; - private boolean skipInitialSaslHandshake; private ByteBuffer connectionHeaderBuf = null; private ByteBuffer unwrappedData; private ByteBuffer unwrappedDataLengthBuffer; @@ -1339,23 +1342,39 @@ private RpcSaslProto processSaslMessage(DataInputStream dis) "Client already attempted negotiation"); } saslResponse = buildSaslNegotiateResponse(); + // simple-only server negotiate response is success which client + // interprets as switch to simple + if (saslResponse.getState() == SaslState.SUCCESS) { + switchToSimple(); + } break; } case INITIATE: { if (saslMessage.getAuthsCount() != 1) { throw new SaslException("Client mechanism is malformed"); } - String authMethodName = saslMessage.getAuths(0).getMethod(); - authMethod = createSaslServer(authMethodName); - if (authMethod == null) { // the auth method is not supported + // verify the client requested an advertised authType + SaslAuth clientSaslAuth = saslMessage.getAuths(0); + if (!negotiateResponse.getAuthsList().contains(clientSaslAuth)) { if (sentNegotiate) { throw new AccessControlException( - authMethodName + " authentication is not enabled." + clientSaslAuth.getMethod() + " authentication is not enabled." + " Available:" + enabledAuthMethods); } saslResponse = buildSaslNegotiateResponse(); break; } + authMethod = AuthMethod.valueOf(clientSaslAuth.getMethod()); + // abort SASL for SIMPLE auth, server has already ensured that + // SIMPLE is a legit option above. we will send no response + if (authMethod == AuthMethod.SIMPLE) { + switchToSimple(); + break; + } + // sasl server for tokens may already be instantiated + if (saslServer == null || authMethod != AuthMethod.TOKEN) { + saslServer = createSaslServer(authMethod); + } // fallthru to process sasl token } case RESPONSE: { @@ -1378,6 +1397,12 @@ private RpcSaslProto processSaslMessage(DataInputStream dis) } return saslResponse; } + + private void switchToSimple() { + // disable SASL and blank out any SASL server + authProtocol = AuthProtocol.NONE; + saslServer = null; + } private RpcSaslProto buildSaslResponse(SaslState state, byte[] replyToken) { if (LOG.isDebugEnabled()) { @@ -1434,7 +1459,8 @@ private void checkDataLength(int dataLength) throws IOException { } } - public int readAndProcess() throws IOException, InterruptedException { + public int readAndProcess() + throws WrappedRpcServerException, IOException, InterruptedException { while (true) { /* Read at most one RPC. If the header is not read completely yet * then iterate until we read first RPC or until there is no data left. @@ -1537,15 +1563,7 @@ private AuthProtocol initializeAuthContext(int authType) } break; } - case SASL: { - // switch to simple hack, but don't switch if other auths are - // supported, ex. tokens - if (isSimpleEnabled && enabledAuthMethods.size() == 1) { - authProtocol = AuthProtocol.NONE; - skipInitialSaslHandshake = true; - doSaslReply(buildSaslResponse(SaslState.SUCCESS, null)); - } - // else wait for a negotiate or initiate + default: { break; } } @@ -1570,25 +1588,6 @@ private RpcSaslProto buildSaslNegotiateResponse() return negotiateMessage; } - private AuthMethod createSaslServer(String authMethodName) - throws IOException, InterruptedException { - AuthMethod authMethod; - try { - authMethod = AuthMethod.valueOf(authMethodName); - if (!enabledAuthMethods.contains(authMethod)) { - authMethod = null; - } - } catch (IllegalArgumentException iae) { - authMethod = null; - } - if (authMethod != null && - // sasl server for tokens may already be instantiated - (saslServer == null || authMethod != AuthMethod.TOKEN)) { - saslServer = createSaslServer(authMethod); - } - return authMethod; - } - private SaslServer createSaslServer(AuthMethod authMethod) throws IOException, InterruptedException { return new SaslRpcServer(authMethod).create(this, secretManager); @@ -1703,8 +1702,8 @@ private void processConnectionContext(DataInputStream dis) * or the request could not be decoded into a Call * @throws InterruptedException */ - private void processRpcRequestPacket(byte[] buf) throws IOException, - InterruptedException { + private void processRpcRequestPacket(byte[] buf) + throws WrappedRpcServerException, IOException, InterruptedException { if (saslContextEstablished && useWrap) { if (LOG.isDebugEnabled()) LOG.debug("Have read input token of size " + buf.length @@ -1717,8 +1716,8 @@ private void processRpcRequestPacket(byte[] buf) throws IOException, } } - private void unwrapPacketAndProcessRpcs(byte[] inBuf) throws IOException, - InterruptedException { + private void unwrapPacketAndProcessRpcs(byte[] inBuf) + throws WrappedRpcServerException, IOException, InterruptedException { ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream( inBuf)); // Read all RPCs contained in the inBuf, even partial ones @@ -1903,13 +1902,9 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, } else if (callId == AuthProtocol.SASL.callId) { // if client was switched to simple, ignore first SASL message if (authProtocol != AuthProtocol.SASL) { - if (!skipInitialSaslHandshake) { - throw new WrappedRpcServerException( - RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, - "SASL protocol not requested by client"); - } - skipInitialSaslHandshake = false; - return; + throw new WrappedRpcServerException( + RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, + "SASL protocol not requested by client"); } RpcSaslProto response = saslReadAndProcess(dis); // send back response if any, may throw IOException @@ -2220,17 +2215,23 @@ protected Server(String bindAddress, int port, private RpcSaslProto buildNegotiateResponse(List authMethods) throws IOException { RpcSaslProto.Builder negotiateBuilder = RpcSaslProto.newBuilder(); - negotiateBuilder.setState(SaslState.NEGOTIATE); - for (AuthMethod authMethod : authMethods) { - if (authMethod == AuthMethod.SIMPLE) { // not a SASL method - continue; + if (authMethods.contains(AuthMethod.SIMPLE) && authMethods.size() == 1) { + // SIMPLE-only servers return success in response to negotiate + negotiateBuilder.setState(SaslState.SUCCESS); + } else { + negotiateBuilder.setState(SaslState.NEGOTIATE); + for (AuthMethod authMethod : authMethods) { + SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); + SaslAuth.Builder builder = negotiateBuilder.addAuthsBuilder() + .setMethod(authMethod.toString()) + .setMechanism(saslRpcServer.mechanism); + if (saslRpcServer.protocol != null) { + builder.setProtocol(saslRpcServer.protocol); + } + if (saslRpcServer.serverId != null) { + builder.setServerId(saslRpcServer.serverId); + } } - SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); - negotiateBuilder.addAuthsBuilder() - .setMethod(authMethod.toString()) - .setMechanism(saslRpcServer.mechanism) - .setProtocol(saslRpcServer.protocol) - .setServerId(saslRpcServer.serverId); } return negotiateBuilder.build(); } diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java index aacd792794..fe6afd2390 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java @@ -25,6 +25,9 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import javax.security.auth.callback.Callback; @@ -32,6 +35,7 @@ import javax.security.auth.callback.NameCallback; import javax.security.auth.callback.PasswordCallback; import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.kerberos.KerberosPrincipal; import javax.security.sasl.RealmCallback; import javax.security.sasl.RealmChoiceCallback; import javax.security.sasl.Sasl; @@ -42,6 +46,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper; import org.apache.hadoop.ipc.RPC.RpcKind; @@ -58,6 +63,8 @@ import org.apache.hadoop.security.authentication.util.KerberosName; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.security.token.TokenInfo; +import org.apache.hadoop.security.token.TokenSelector; import org.apache.hadoop.util.ProtoUtil; import com.google.protobuf.ByteString; @@ -69,9 +76,13 @@ public class SaslRpcClient { public static final Log LOG = LogFactory.getLog(SaslRpcClient.class); - private final AuthMethod authMethod; - private final SaslClient saslClient; - private final boolean fallbackAllowed; + private final UserGroupInformation ugi; + private final Class protocol; + private final InetSocketAddress serverAddr; + private final Configuration conf; + + private SaslClient saslClient; + private static final RpcRequestHeaderProto saslHeader = ProtoUtil .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, OperationProto.RPC_FINAL_PACKET, AuthProtocol.SASL.callId, @@ -80,44 +91,121 @@ public class SaslRpcClient { RpcSaslProto.newBuilder().setState(SaslState.NEGOTIATE).build(); /** - * Create a SaslRpcClient for an authentication method - * - * @param method - * the requested authentication method - * @param token - * token to use if needed by the authentication method + * Create a SaslRpcClient that can be used by a RPC client to negotiate + * SASL authentication with a RPC server + * @param ugi - connecting user + * @param protocol - RPC protocol + * @param serverAddr - InetSocketAddress of remote server + * @param conf - Configuration */ - public SaslRpcClient(AuthMethod method, - Token token, String serverPrincipal, - boolean fallbackAllowed) - throws IOException { - this.authMethod = method; - this.fallbackAllowed = fallbackAllowed; + public SaslRpcClient(UserGroupInformation ugi, Class protocol, + InetSocketAddress serverAddr, Configuration conf) { + this.ugi = ugi; + this.protocol = protocol; + this.serverAddr = serverAddr; + this.conf = conf; + } + + /** + * Instantiate a sasl client for the first supported auth type in the + * given list. The auth type must be defined, enabled, and the user + * must possess the required credentials, else the next auth is tried. + * + * @param authTypes to attempt in the given order + * @return SaslAuth of instantiated client + * @throws AccessControlException - client doesn't support any of the auths + * @throws IOException - misc errors + */ + private SaslAuth selectSaslClient(List authTypes) + throws SaslException, AccessControlException, IOException { + SaslAuth selectedAuthType = null; + boolean switchToSimple = false; + for (SaslAuth authType : authTypes) { + if (!isValidAuthType(authType)) { + continue; // don't know what it is, try next + } + AuthMethod authMethod = AuthMethod.valueOf(authType.getMethod()); + if (authMethod == AuthMethod.SIMPLE) { + switchToSimple = true; + } else { + saslClient = createSaslClient(authType); + if (saslClient == null) { // client lacks credentials, try next + continue; + } + } + selectedAuthType = authType; + break; + } + if (saslClient == null && !switchToSimple) { + List serverAuthMethods = new ArrayList(); + for (SaslAuth authType : authTypes) { + serverAuthMethods.add(authType.getMethod()); + } + throw new AccessControlException( + "Client cannot authenticate via:" + serverAuthMethods); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Use " + selectedAuthType.getMethod() + + " authentication for protocol " + protocol.getSimpleName()); + } + return selectedAuthType; + } + + + private boolean isValidAuthType(SaslAuth authType) { + AuthMethod authMethod; + try { + authMethod = AuthMethod.valueOf(authType.getMethod()); + } catch (IllegalArgumentException iae) { // unknown auth + authMethod = null; + } + // do we know what it is? is it using our mechanism? + return authMethod != null && + authMethod.getMechanismName().equals(authType.getMechanism()); + } + + /** + * Try to create a SaslClient for an authentication type. May return + * null if the type isn't supported or the client lacks the required + * credentials. + * + * @param authType - the requested authentication method + * @return SaslClient for the authType or null + * @throws SaslException - error instantiating client + * @throws IOException - misc errors + */ + private SaslClient createSaslClient(SaslAuth authType) + throws SaslException, IOException { String saslUser = null; - String saslProtocol = null; - String saslServerName = null; + // SASL requires the client and server to use the same proto and serverId + // if necessary, auth types below will verify they are valid + final String saslProtocol = authType.getProtocol(); + final String saslServerName = authType.getServerId(); Map saslProperties = SaslRpcServer.SASL_PROPS; CallbackHandler saslCallback = null; + final AuthMethod method = AuthMethod.valueOf(authType.getMethod()); switch (method) { case TOKEN: { - saslProtocol = ""; - saslServerName = SaslRpcServer.SASL_DEFAULT_REALM; + Token token = getServerToken(authType); + if (token == null) { + return null; // tokens aren't supported or user doesn't have one + } saslCallback = new SaslClientCallbackHandler(token); break; } case KERBEROS: { - if (serverPrincipal == null || serverPrincipal.isEmpty()) { - throw new IOException( - "Failed to specify server's Kerberos principal name"); + if (ugi.getRealAuthenticationMethod().getAuthMethod() != + AuthMethod.KERBEROS) { + return null; // client isn't using kerberos } - KerberosName name = new KerberosName(serverPrincipal); - saslProtocol = name.getServiceName(); - saslServerName = name.getHostName(); - if (saslServerName == null) { - throw new IOException( - "Kerberos principal name does NOT have the expected hostname part: " - + serverPrincipal); + String serverPrincipal = getServerPrincipal(authType); + if (serverPrincipal == null) { + return null; // protocol doesn't use kerberos + } + if (LOG.isDebugEnabled()) { + LOG.debug("RPC Server's Kerberos principal name for protocol=" + + protocol.getCanonicalName() + " is " + serverPrincipal); } break; } @@ -127,16 +215,85 @@ public SaslRpcClient(AuthMethod method, String mechanism = method.getMechanismName(); if (LOG.isDebugEnabled()) { - LOG.debug("Creating SASL " + mechanism + "(" + authMethod + ") " + LOG.debug("Creating SASL " + mechanism + "(" + method + ") " + " client to authenticate to service at " + saslServerName); } - saslClient = Sasl.createSaslClient( + return Sasl.createSaslClient( new String[] { mechanism }, saslUser, saslProtocol, saslServerName, saslProperties, saslCallback); - if (saslClient == null) { - throw new IOException("Unable to find SASL client implementation"); - } } + + /** + * Try to locate the required token for the server. + * + * @param authType of the SASL client + * @return Token for server, or null if no token available + * @throws IOException - token selector cannot be instantiated + */ + private Token getServerToken(SaslAuth authType) throws IOException { + TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf); + LOG.debug("Get token info proto:"+protocol+" info:"+tokenInfo); + if (tokenInfo == null) { // protocol has no support for tokens + return null; + } + TokenSelector tokenSelector = null; + try { + tokenSelector = tokenInfo.value().newInstance(); + } catch (InstantiationException e) { + throw new IOException(e.toString()); + } catch (IllegalAccessException e) { + throw new IOException(e.toString()); + } + return tokenSelector.selectToken( + SecurityUtil.buildTokenService(serverAddr), ugi.getTokens()); + } + + /** + * Get the remote server's principal. The value will be obtained from + * the config and cross-checked against the server's advertised principal. + * + * @param authType of the SASL client + * @return String of the server's principal + * @throws IOException - error determining configured principal + */ + + // try to get the configured principal for the remote server + private String getServerPrincipal(SaslAuth authType) throws IOException { + KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); + LOG.debug("Get kerberos info proto:"+protocol+" info:"+krbInfo); + if (krbInfo == null) { // protocol has no support for kerberos + return null; + } + String serverKey = krbInfo.serverPrincipal(); + if (serverKey == null) { + throw new IllegalArgumentException( + "Can't obtain server Kerberos config key from protocol=" + + protocol.getCanonicalName()); + } + // construct the expected principal from the config + String confPrincipal = SecurityUtil.getServerPrincipal( + conf.get(serverKey), serverAddr.getAddress()); + if (confPrincipal == null || confPrincipal.isEmpty()) { + throw new IllegalArgumentException( + "Failed to specify server's Kerberos principal name"); + } + // ensure it looks like a host-based service principal + KerberosName name = new KerberosName(confPrincipal); + if (name.getHostName() == null) { + throw new IllegalArgumentException( + "Kerberos principal name does NOT have the expected hostname part: " + + confPrincipal); + } + // check that the server advertised principal matches our conf + KerberosPrincipal serverPrincipal = new KerberosPrincipal( + authType.getProtocol() + "/" + authType.getServerId()); + if (!serverPrincipal.getName().equals(confPrincipal)) { + throw new IllegalArgumentException( + "Server has invalid Kerberos principal: " + serverPrincipal); + } + return confPrincipal; + } + /** * Do client side SASL authentication with server via the given InputStream @@ -146,18 +303,18 @@ public SaslRpcClient(AuthMethod method, * InputStream to use * @param outS * OutputStream to use - * @return true if connection is set up, or false if needs to switch - * to simple Auth. + * @return AuthMethod used to negotiate the connection * @throws IOException */ - public boolean saslConnect(InputStream inS, OutputStream outS) + public AuthMethod saslConnect(InputStream inS, OutputStream outS) throws IOException { DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS)); DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream( outS)); - // track if SASL ever started, or server switched us to simple - boolean inSasl = false; + // redefined if/when a SASL negotiation completes + AuthMethod authMethod = AuthMethod.SIMPLE; + sendSaslMessage(outStream, negotiateRequest); // loop until sasl is complete or a rpc error occurs @@ -191,50 +348,48 @@ public boolean saslConnect(InputStream inS, OutputStream outS) RpcSaslProto.Builder response = null; switch (saslMessage.getState()) { case NEGOTIATE: { - inSasl = true; - // TODO: should instantiate sasl client based on advertisement - // but just blindly use the pre-instantiated sasl client for now - String clientAuthMethod = authMethod.toString(); - SaslAuth saslAuthType = null; - for (SaslAuth authType : saslMessage.getAuthsList()) { - if (clientAuthMethod.equals(authType.getMethod())) { - saslAuthType = authType; - break; + // create a compatible SASL client, throws if no supported auths + SaslAuth saslAuthType = selectSaslClient(saslMessage.getAuthsList()); + authMethod = AuthMethod.valueOf(saslAuthType.getMethod()); + + byte[] responseToken = null; + if (authMethod == AuthMethod.SIMPLE) { // switching to SIMPLE + done = true; // not going to wait for success ack + } else { + byte[] challengeToken = null; + if (saslAuthType.hasChallenge()) { + // server provided the first challenge + challengeToken = saslAuthType.getChallenge().toByteArray(); + saslAuthType = + SaslAuth.newBuilder(saslAuthType).clearChallenge().build(); + } else if (saslClient.hasInitialResponse()) { + challengeToken = new byte[0]; } + responseToken = (challengeToken != null) + ? saslClient.evaluateChallenge(challengeToken) + : new byte[0]; } - if (saslAuthType == null) { - saslAuthType = SaslAuth.newBuilder() - .setMethod(clientAuthMethod) - .setMechanism(saslClient.getMechanismName()) - .build(); - } - - byte[] challengeToken = null; - if (saslAuthType != null && saslAuthType.hasChallenge()) { - // server provided the first challenge - challengeToken = saslAuthType.getChallenge().toByteArray(); - saslAuthType = - SaslAuth.newBuilder(saslAuthType).clearChallenge().build(); - } else if (saslClient.hasInitialResponse()) { - challengeToken = new byte[0]; - } - byte[] responseToken = (challengeToken != null) - ? saslClient.evaluateChallenge(challengeToken) - : new byte[0]; - response = createSaslReply(SaslState.INITIATE, responseToken); response.addAuths(saslAuthType); break; } case CHALLENGE: { - inSasl = true; + if (saslClient == null) { + // should probably instantiate a client to allow a server to + // demand a specific negotiation + throw new SaslException("Server sent unsolicited challenge"); + } byte[] responseToken = saslEvaluateToken(saslMessage, false); response = createSaslReply(SaslState.RESPONSE, responseToken); break; } case SUCCESS: { - if (inSasl && saslEvaluateToken(saslMessage, true) != null) { - throw new SaslException("SASL client generated spurious token"); + // simple server sends immediate success to a SASL client for + // switch to simple + if (saslClient == null) { + authMethod = AuthMethod.SIMPLE; + } else { + saslEvaluateToken(saslMessage, true); } done = true; break; @@ -248,12 +403,7 @@ public boolean saslConnect(InputStream inS, OutputStream outS) sendSaslMessage(outStream, response.build()); } } while (!done); - if (!inSasl && !fallbackAllowed) { - throw new IOException("Server asks us to fall back to SIMPLE " + - "auth, but this client is configured to only allow secure " + - "connections."); - } - return inSasl; + return authMethod; } private void sendSaslMessage(DataOutputStream out, RpcSaslProto message) @@ -268,17 +418,37 @@ private void sendSaslMessage(DataOutputStream out, RpcSaslProto message) out.flush(); } + /** + * Evaluate the server provided challenge. The server must send a token + * if it's not done. If the server is done, the challenge token is + * optional because not all mechanisms send a final token for the client to + * update its internal state. The client must also be done after + * evaluating the optional token to ensure a malicious server doesn't + * prematurely end the negotiation with a phony success. + * + * @param saslResponse - client response to challenge + * @param serverIsDone - server negotiation state + * @throws SaslException - any problems with negotiation + */ private byte[] saslEvaluateToken(RpcSaslProto saslResponse, - boolean done) throws SaslException { + boolean serverIsDone) throws SaslException { byte[] saslToken = null; if (saslResponse.hasToken()) { saslToken = saslResponse.getToken().toByteArray(); saslToken = saslClient.evaluateChallenge(saslToken); - } else if (!done) { - throw new SaslException("Challenge contains no token"); + } else if (!serverIsDone) { + // the server may only omit a token when it's done + throw new SaslException("Server challenge contains no token"); } - if (done && !saslClient.isComplete()) { - throw new SaslException("Client is out of sync with server"); + if (serverIsDone) { + // server tried to report success before our client completed + if (!saslClient.isComplete()) { + throw new SaslException("Client is out of sync with server"); + } + // a client cannot generate a response to a success message + if (saslToken != null) { + throw new SaslException("Client generated spurious response"); + } } return saslToken; } @@ -327,7 +497,10 @@ public OutputStream getOutputStream(OutputStream out) throws IOException { /** Release resources used by wrapped saslClient */ public void dispose() throws SaslException { - saslClient.dispose(); + if (saslClient != null) { + saslClient.dispose(); + saslClient = null; + } } private static class SaslClientCallbackHandler implements CallbackHandler { @@ -377,4 +550,4 @@ public void handle(Callback[] callbacks) } } } -} +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java index 7fdfa98871..138e12f851 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java @@ -18,14 +18,9 @@ package org.apache.hadoop.ipc; -import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS; -import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.SIMPLE; -import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.TOKEN; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION; +import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.*; +import static org.junit.Assert.*; import java.io.DataInput; import java.io.DataOutput; @@ -51,6 +46,7 @@ import junit.framework.Assert; +import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.impl.Log4JLogger; @@ -103,6 +99,13 @@ public class TestSaslRPC { static Boolean forceSecretManager = null; static Boolean clientFallBackToSimpleAllowed = true; + static enum UseToken { + NONE(), + VALID(), + INVALID(), + OTHER(); + } + @BeforeClass public static void setupKerb() { System.setProperty("java.security.krb5.kdc", ""); @@ -113,9 +116,11 @@ public static void setupKerb() { @Before public void setup() { conf = new Configuration(); - SecurityUtil.setAuthenticationMethod(KERBEROS, conf); + conf.set(HADOOP_SECURITY_AUTHENTICATION, KERBEROS.toString()); UserGroupInformation.setConfiguration(conf); enableSecretManager = null; + forceSecretManager = null; + clientFallBackToSimpleAllowed = true; } static { @@ -367,28 +372,6 @@ public void testPingInterval() throws Exception { assertEquals(0, remoteId.getPingInterval()); } - @Test - public void testGetRemotePrincipal() throws Exception { - try { - Configuration newConf = new Configuration(conf); - newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); - ConnectionId remoteId = ConnectionId.getConnectionId( - new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf); - assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal()); - // this following test needs security to be off - SecurityUtil.setAuthenticationMethod(SIMPLE, newConf); - UserGroupInformation.setConfiguration(newConf); - remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), - TestSaslProtocol.class, null, 0, newConf); - assertEquals( - "serverPrincipal should be null when security is turned off", null, - remoteId.getServerPrincipal()); - } finally { - // revert back to security is on - UserGroupInformation.setConfiguration(conf); - } - } - @Test public void testPerConnectionConf() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); @@ -409,12 +392,13 @@ public void testPerConnectionConf() throws Exception { Configuration newConf = new Configuration(conf); newConf.set(CommonConfigurationKeysPublic. HADOOP_RPC_SOCKET_FACTORY_CLASS_DEFAULT_KEY, ""); - newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); TestSaslProtocol proxy1 = null; TestSaslProtocol proxy2 = null; TestSaslProtocol proxy3 = null; + int timeouts[] = {111222, 3333333}; try { + newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[0]); proxy1 = RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy1.getAuthMethod(); @@ -427,20 +411,21 @@ public void testPerConnectionConf() throws Exception { proxy2.getAuthMethod(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // different conf, new connection should be set up - newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); + newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[1]); proxy3 = RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy3.getAuthMethod(); - ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); - assertEquals("number of connections in cache is wrong", 2, - connsArray.length); - String p1 = connsArray[0].getServerPrincipal(); - String p2 = connsArray[1].getServerPrincipal(); - assertFalse("should have different principals", p1.equals(p2)); - assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1) - || p1.equals(SERVER_PRINCIPAL_2)); - assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1) - || p2.equals(SERVER_PRINCIPAL_2)); + assertEquals("number of connections in cache is wrong", 2, conns.size()); + // now verify the proxies have the correct connection ids and timeouts + ConnectionId[] connsArray = { + RPC.getConnectionIdForProxy(proxy1), + RPC.getConnectionIdForProxy(proxy2), + RPC.getConnectionIdForProxy(proxy3) + }; + assertEquals(connsArray[0], connsArray[1]); + assertEquals(connsArray[0].getMaxIdleTime(), timeouts[0]); + assertFalse(connsArray[0].equals(connsArray[2])); + assertNotSame(connsArray[2].getMaxIdleTime(), timeouts[1]); } finally { server.stop(); RPC.stopProxy(proxy1); @@ -599,75 +584,118 @@ public void handle(Callback[] callbacks) private static Pattern KrbFailed = Pattern.compile(".*Failed on local exception:.* " + "Failed to specify server's Kerberos principal name.*"); - private static Pattern Denied(AuthenticationMethod method) { + private static Pattern Denied(AuthMethod method) { return Pattern.compile(".*RemoteException.*AccessControlException.*: " - +method.getAuthMethod() + " authentication is not enabled.*"); + + method + " authentication is not enabled.*"); + } + private static Pattern No(AuthMethod ... method) { + String methods = StringUtils.join(method, ",\\s*"); + return Pattern.compile(".*Failed on local exception:.* " + + "Client cannot authenticate via:\\[" + methods + "\\].*"); } private static Pattern NoTokenAuth = Pattern.compile(".*IllegalArgumentException: " + "TOKEN authentication requires a secret manager"); - + private static Pattern NoFallback = + Pattern.compile(".*Failed on local exception:.* " + + "Server asks us to fall back to SIMPLE auth, " + + "but this client is configured to only allow secure connections.*"); + /* * simple server */ @Test public void testSimpleServer() throws Exception { assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE)); - // SASL methods are reverted to SIMPLE, but test setup fails - assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, SIMPLE)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER)); + // SASL methods are normally reverted to SIMPLE + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); } @Test - public void testSimpleServerWithTokensWithNoClientFallbackToSimple() + public void testNoClientFallbackToSimple() throws Exception { - clientFallBackToSimpleAllowed = false; + // tokens are irrelevant w/o secret manager enabled + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID)); - try{ - // Client has a token even though its configs says simple auth. Server - // is configured for simple auth, but as client sends the token, and - // server asks to switch to simple, this should fail. - getAuthMethod(SIMPLE, SIMPLE, true); - } catch (IOException ioe) { - Assert - .assertTrue(ioe.getMessage().contains("Failed on local exception: " + - "java.io.IOException: java.io.IOException: " + - "Server asks us to fall back to SIMPLE auth, " + - "but this client is configured to only allow secure connections" - )); - } + // A secure client must not fallback + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE)); + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID)); + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID)); // Now set server to simple and also force the secret-manager. Now server // should have both simple and token enabled. forceSecretManager = true; - assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, SIMPLE, true)); - forceSecretManager = false; - clientFallBackToSimpleAllowed = true; + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER)); + assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID)); + assertAuthEquals(BadToken, getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID)); + + // A secure client must not fallback + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE)); + assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); + assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID)); + assertAuthEquals(BadToken, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID)); + + // doesn't try SASL + assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, TOKEN)); + // does try SASL + assertAuthEquals(No(TOKEN), getAuthMethod(SIMPLE, TOKEN, UseToken.OTHER)); + assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, TOKEN, UseToken.VALID)); + assertAuthEquals(BadToken, getAuthMethod(SIMPLE, TOKEN, UseToken.INVALID)); + + assertAuthEquals(No(TOKEN), getAuthMethod(KERBEROS, TOKEN)); + assertAuthEquals(No(TOKEN), getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER)); + assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, TOKEN, UseToken.VALID)); + assertAuthEquals(BadToken, getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID)); } @Test public void testSimpleServerWithTokens() throws Exception { // Client not using tokens assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE)); - // SASL methods are reverted to SIMPLE, but test setup fails - assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, SIMPLE)); + // SASL methods are reverted to SIMPLE + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE)); // Use tokens. But tokens are ignored because client is reverted to simple - assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, true)); + // due to server not using tokens + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); + // server isn't really advertising tokens enableSecretManager = true; - assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, true)); - assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, true)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER)); + + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); + + // now the simple server takes tokens + forceSecretManager = true; + assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER)); + + assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER)); } @Test public void testSimpleServerWithInvalidTokens() throws Exception { // Tokens are ignored because client is reverted to simple - assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, false)); - assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, false)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID)); enableSecretManager = true; - assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, false)); - assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, false)); + assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID)); + assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID)); + forceSecretManager = true; + assertAuthEquals(BadToken, getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID)); + assertAuthEquals(BadToken, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID)); } /* @@ -675,26 +703,29 @@ public void testSimpleServerWithInvalidTokens() throws Exception { */ @Test public void testTokenOnlyServer() throws Exception { + // simple client w/o tokens won't try SASL, so server denies assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, TOKEN)); - assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, TOKEN)); + assertAuthEquals(No(TOKEN), getAuthMethod(SIMPLE, TOKEN, UseToken.OTHER)); + assertAuthEquals(No(TOKEN), getAuthMethod(KERBEROS, TOKEN)); + assertAuthEquals(No(TOKEN), getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER)); } @Test public void testTokenOnlyServerWithTokens() throws Exception { - assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, TOKEN, true)); - assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, TOKEN, true)); + assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, TOKEN, UseToken.VALID)); + assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, TOKEN, UseToken.VALID)); enableSecretManager = false; - assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE, TOKEN, true)); - assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, true)); + assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE, TOKEN, UseToken.VALID)); + assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.VALID)); } @Test public void testTokenOnlyServerWithInvalidTokens() throws Exception { - assertAuthEquals(BadToken, getAuthMethod(SIMPLE, TOKEN, false)); - assertAuthEquals(BadToken, getAuthMethod(KERBEROS, TOKEN, false)); + assertAuthEquals(BadToken, getAuthMethod(SIMPLE, TOKEN, UseToken.INVALID)); + assertAuthEquals(BadToken, getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID)); enableSecretManager = false; - assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE, TOKEN, false)); - assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, false)); + assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE, TOKEN, UseToken.INVALID)); + assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID)); } /* @@ -702,38 +733,43 @@ public void testTokenOnlyServerWithInvalidTokens() throws Exception { */ @Test public void testKerberosServer() throws Exception { - assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, KERBEROS)); - assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, KERBEROS)); + // doesn't try SASL + assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, KERBEROS)); + // does try SASL + assertAuthEquals(No(TOKEN,KERBEROS), getAuthMethod(SIMPLE, KERBEROS, UseToken.OTHER)); + // no tgt + assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, KERBEROS)); + assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, KERBEROS, UseToken.OTHER)); } @Test public void testKerberosServerWithTokens() throws Exception { // can use tokens regardless of auth - assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, KERBEROS, true)); - assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, KERBEROS, true)); - // can't fallback to simple when using kerberos w/o tokens + assertAuthEquals(TOKEN, getAuthMethod(SIMPLE, KERBEROS, UseToken.VALID)); + assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID)); enableSecretManager = false; - assertAuthEquals(Denied(TOKEN), getAuthMethod(SIMPLE, KERBEROS, true)); - assertAuthEquals(Denied(TOKEN), getAuthMethod(KERBEROS, KERBEROS, true)); + // shouldn't even try token because server didn't tell us to + assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE, KERBEROS, UseToken.VALID)); + assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID)); } @Test public void testKerberosServerWithInvalidTokens() throws Exception { - assertAuthEquals(BadToken, getAuthMethod(SIMPLE, KERBEROS, false)); - assertAuthEquals(BadToken, getAuthMethod(KERBEROS, KERBEROS, false)); + assertAuthEquals(BadToken, getAuthMethod(SIMPLE, KERBEROS, UseToken.INVALID)); + assertAuthEquals(BadToken, getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID)); enableSecretManager = false; - assertAuthEquals(Denied(TOKEN), getAuthMethod(SIMPLE, KERBEROS, false)); - assertAuthEquals(Denied(TOKEN), getAuthMethod(KERBEROS, KERBEROS, false)); + assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE, KERBEROS, UseToken.INVALID)); + assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID)); } // test helpers private String getAuthMethod( - final AuthenticationMethod clientAuth, - final AuthenticationMethod serverAuth) throws Exception { + final AuthMethod clientAuth, + final AuthMethod serverAuth) throws Exception { try { - return internalGetAuthMethod(clientAuth, serverAuth, false, false); + return internalGetAuthMethod(clientAuth, serverAuth, UseToken.NONE); } catch (Exception e) { LOG.warn("Auth method failure", e); return e.toString(); @@ -741,11 +777,11 @@ private String getAuthMethod( } private String getAuthMethod( - final AuthenticationMethod clientAuth, - final AuthenticationMethod serverAuth, - final boolean useValidToken) throws Exception { + final AuthMethod clientAuth, + final AuthMethod serverAuth, + final UseToken tokenType) throws Exception { try { - return internalGetAuthMethod(clientAuth, serverAuth, true, useValidToken); + return internalGetAuthMethod(clientAuth, serverAuth, tokenType); } catch (Exception e) { LOG.warn("Auth method failure", e); return e.toString(); @@ -753,15 +789,14 @@ private String getAuthMethod( } private String internalGetAuthMethod( - final AuthenticationMethod clientAuth, - final AuthenticationMethod serverAuth, - final boolean useToken, - final boolean useValidToken) throws Exception { + final AuthMethod clientAuth, + final AuthMethod serverAuth, + final UseToken tokenType) throws Exception { String currentUser = UserGroupInformation.getCurrentUser().getUserName(); final Configuration serverConf = new Configuration(conf); - SecurityUtil.setAuthenticationMethod(serverAuth, serverConf); + serverConf.set(HADOOP_SECURITY_AUTHENTICATION, serverAuth.toString()); UserGroupInformation.setConfiguration(serverConf); final UserGroupInformation serverUgi = @@ -793,7 +828,7 @@ public Server run() throws IOException { }); final Configuration clientConf = new Configuration(conf); - SecurityUtil.setAuthenticationMethod(clientAuth, clientConf); + clientConf.set(HADOOP_SECURITY_AUTHENTICATION, clientAuth.toString()); clientConf.setBoolean( CommonConfigurationKeys.IPC_CLIENT_FALLBACK_TO_SIMPLE_AUTH_ALLOWED_KEY, clientFallBackToSimpleAllowed); @@ -804,16 +839,26 @@ public Server run() throws IOException { clientUgi.setAuthenticationMethod(clientAuth); final InetSocketAddress addr = NetUtils.getConnectAddress(server); - if (useToken) { + if (tokenType != UseToken.NONE) { TestTokenIdentifier tokenId = new TestTokenIdentifier( new Text(clientUgi.getUserName())); - Token token = useValidToken - ? new Token(tokenId, sm) - : new Token( + Token token = null; + switch (tokenType) { + case VALID: + token = new Token(tokenId, sm); + SecurityUtil.setTokenService(token, addr); + break; + case INVALID: + token = new Token( tokenId.getBytes(), "bad-password!".getBytes(), tokenId.getKind(), null); - - SecurityUtil.setTokenService(token, addr); + SecurityUtil.setTokenService(token, addr); + break; + case OTHER: + token = new Token(); + break; + case NONE: // won't get here + } clientUgi.addToken(token); } @@ -848,7 +893,7 @@ public String run() throws IOException { } } - private static void assertAuthEquals(AuthenticationMethod expect, + private static void assertAuthEquals(AuthMethod expect, String actual) { assertEquals(expect.toString(), actual); }