From c93a9128ff14605fe9c08c0f5bb3fa374d852eaf Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Sat, 27 Feb 2010 06:17:00 +0000 Subject: [PATCH] HADOOP-6589. Provide better error messages when RPC authentication fails. (Kan Zhang via omalley) git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@916915 13f79535-47bb-0310-9956-ffa450edef68 --- CHANGES.txt | 4 + src/java/org/apache/hadoop/ipc/Server.java | 162 +++++++++++------- .../apache/hadoop/security/SaslRpcClient.java | 13 ++ .../apache/hadoop/security/SaslRpcServer.java | 28 ++- .../core/org/apache/hadoop/ipc/TestRPC.java | 25 +++ .../org/apache/hadoop/ipc/TestSaslRPC.java | 29 ++++ 6 files changed, 189 insertions(+), 72 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 2927bf47a4..dd066ba8ff 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -67,6 +67,7 @@ Trunk (unreleased changes) HADOOP-6586. Log authentication and authorization failures and successes for RPC (boryas) + IMPROVEMENTS HADOOP-6283. Improve the exception messages thrown by @@ -177,6 +178,9 @@ Trunk (unreleased changes) HADOOP-6594. Provide a fetchdt tool via bin/hdfs. (jhoman via acmurthy) + HADOOP-6589. Provide better error messages when RPC authentication fails. + (Kan Zhang via omalley) + OPTIMIZATIONS HADOOP-6467. Improve the performance on HarFileSystem.listStatus(..). diff --git a/src/java/org/apache/hadoop/ipc/Server.java b/src/java/org/apache/hadoop/ipc/Server.java index 6d7e154922..1b0e53b444 100644 --- a/src/java/org/apache/hadoop/ipc/Server.java +++ b/src/java/org/apache/hadoop/ipc/Server.java @@ -60,12 +60,15 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.ipc.metrics.RpcMetrics; import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.SaslRpcServer; import org.apache.hadoop.security.SaslRpcServer.AuthMethod; +import org.apache.hadoop.security.SaslRpcServer.SaslStatus; import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler; import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler; import org.apache.hadoop.security.UserGroupInformation; @@ -74,6 +77,7 @@ import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.SecretManager; +import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.StringUtils; @@ -757,11 +761,11 @@ public class Connection { // Fake 'call' for failed authorization response private static final int AUTHROIZATION_FAILED_CALLID = -1; private final Call authFailedCall = - new Call(AUTHROIZATION_FAILED_CALLID, null, null); + new Call(AUTHROIZATION_FAILED_CALLID, null, this); private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); // Fake 'call' for SASL context setup private static final int SASL_CALLID = -33; - private final Call saslCall = new Call(SASL_CALLID, null, null); + private final Call saslCall = new Call(SASL_CALLID, null, this); private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream(); public Connection(SelectionKey key, SocketChannel channel, @@ -843,68 +847,78 @@ private UserGroupInformation getAuthorizedUgi(String authorizedId) private void saslReadAndProcess(byte[] saslToken) throws IOException, InterruptedException { if (!saslContextEstablished) { - if (saslServer == null) { - switch (authMethod) { - case DIGEST: - saslServer = Sasl.createSaslServer(AuthMethod.DIGEST - .getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM, - SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler( - secretManager, this)); - break; - default: - UserGroupInformation current = UserGroupInformation - .getCurrentUser(); - String fullName = current.getUserName(); - if (LOG.isDebugEnabled()) - LOG.debug("Kerberos principal name is " + fullName); - final String names[] = SaslRpcServer.splitKerberosName(fullName); - if (names.length != 3) { - throw new IOException( - "Kerberos principal name does NOT have the expected " - + "hostname part: " + fullName); - } - current.doAs(new PrivilegedExceptionAction() { - @Override - public Object run() throws IOException { - saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS - .getMechanismName(), names[0], names[1], - SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler()); - return null; - } - }); - } - if (saslServer == null) - throw new IOException( - "Unable to find SASL server implementation for " - + authMethod.getMechanismName()); - if (LOG.isDebugEnabled()) - LOG.debug("Created SASL server with mechanism = " - + authMethod.getMechanismName()); - } - if (LOG.isDebugEnabled()) - LOG.debug("Have read input token of size " + saslToken.length - + " for processing by saslServer.evaluateResponse()"); - byte[] replyToken; + byte[] replyToken = null; try { + if (saslServer == null) { + switch (authMethod) { + case DIGEST: + if (secretManager == null) { + throw new AccessControlException( + "Server is not configured to do DIGEST authentication."); + } + saslServer = Sasl.createSaslServer(AuthMethod.DIGEST + .getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM, + SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler( + secretManager, this)); + break; + default: + UserGroupInformation current = UserGroupInformation + .getCurrentUser(); + String fullName = current.getUserName(); + if (LOG.isDebugEnabled()) + LOG.debug("Kerberos principal name is " + fullName); + final String names[] = SaslRpcServer.splitKerberosName(fullName); + if (names.length != 3) { + throw new AccessControlException( + "Kerberos principal name does NOT have the expected " + + "hostname part: " + fullName); + } + current.doAs(new PrivilegedExceptionAction() { + @Override + public Object run() throws SaslException { + saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS + .getMechanismName(), names[0], names[1], + SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler()); + return null; + } + }); + } + if (saslServer == null) + throw new AccessControlException( + "Unable to find SASL server implementation for " + + authMethod.getMechanismName()); + if (LOG.isDebugEnabled()) + LOG.debug("Created SASL server with mechanism = " + + authMethod.getMechanismName()); + } + if (LOG.isDebugEnabled()) + LOG.debug("Have read input token of size " + saslToken.length + + " for processing by saslServer.evaluateResponse()"); replyToken = saslServer.evaluateResponse(saslToken); - } catch (SaslException se) { + } catch (IOException e) { + IOException sendToClient = e; + Throwable cause = e; + while (cause != null) { + if (cause instanceof InvalidToken) { + sendToClient = (InvalidToken) cause; + break; + } + cause = cause.getCause(); + } + doSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), + sendToClient.getLocalizedMessage()); rpcMetrics.authenticationFailures.inc(); String clientIP = this.toString(); // attempting user could be null - auditLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser, se); - throw se; + auditLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser, e); + throw e; } if (replyToken != null) { if (LOG.isDebugEnabled()) LOG.debug("Will send token of size " + replyToken.length + " from saslServer."); - saslCall.connection = this; - saslResponse.reset(); - DataOutputStream out = new DataOutputStream(saslResponse); - out.writeInt(replyToken.length); - out.write(replyToken, 0, replyToken.length); - saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray())); - responder.doRespond(saslCall); + doSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, + null); } if (saslServer.isComplete()) { if (LOG.isDebugEnabled()) { @@ -927,6 +941,21 @@ public Object run() throws IOException { } } + private void doSaslReply(SaslStatus status, Writable rv, + String errorClass, String error) throws IOException { + saslResponse.reset(); + DataOutputStream out = new DataOutputStream(saslResponse); + out.writeInt(status.state); // write status + if (status == SaslStatus.SUCCESS) { + rv.write(out); + } else { + WritableUtils.writeString(out, errorClass); + WritableUtils.writeString(out, error); + } + saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray())); + responder.doRespond(saslCall); + } + private void disposeSasl() { if (saslServer != null) { try { @@ -936,15 +965,6 @@ private void disposeSasl() { } } - private void askClientToUseSimpleAuth() throws IOException { - saslCall.connection = this; - saslResponse.reset(); - DataOutputStream out = new DataOutputStream(saslResponse); - out.writeInt(SaslRpcServer.SWITCH_TO_SIMPLE_AUTH); - saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray())); - responder.doRespond(saslCall); - } - public int readAndProcess() throws IOException, InterruptedException { while (true) { /* Read at most one RPC. If the header is not read completely yet @@ -974,10 +994,16 @@ public int readAndProcess() throws IOException, InterruptedException { throw new IOException("Unable to read authentication method"); } if (isSecurityEnabled && authMethod == AuthMethod.SIMPLE) { - throw new IOException("Authentication is required"); + AccessControlException ae = new AccessControlException( + "Authentication is required"); + setupResponse(authFailedResponse, authFailedCall, Status.FATAL, + null, ae.getClass().getName(), ae.getMessage()); + responder.doRespond(authFailedCall); + throw ae; } if (!isSecurityEnabled && authMethod != AuthMethod.SIMPLE) { - askClientToUseSimpleAuth(); + doSaslReply(SaslStatus.SUCCESS, new IntWritable( + SaslRpcServer.SWITCH_TO_SIMPLE_AUTH), null, null); authMethod = AuthMethod.SIMPLE; // client has already sent the initial Sasl message and we // should ignore it. Both client and server should fall back @@ -1159,7 +1185,6 @@ private boolean authorizeConnection() throws IOException { rpcMetrics.authorizationSuccesses.inc(); } catch (AuthorizationException ae) { rpcMetrics.authorizationFailures.inc(); - authFailedCall.connection = this; setupResponse(authFailedResponse, authFailedCall, Status.FATAL, null, ae.getClass().getName(), ae.getMessage()); responder.doRespond(authFailedCall); @@ -1387,6 +1412,11 @@ void disableSecurity() { this.isSecurityEnabled = false; } + /** for unit testing only, should be called before server is started */ + void enableSecurity() { + this.isSecurityEnabled = true; + } + /** Sets the socket buffer size used for responding to RPCs */ public void setSocketSendBufSize(int size) { this.socketSendBufferSize = size; } diff --git a/src/java/org/apache/hadoop/security/SaslRpcClient.java b/src/java/org/apache/hadoop/security/SaslRpcClient.java index 7b171c5cc0..6e2820308d 100644 --- a/src/java/org/apache/hadoop/security/SaslRpcClient.java +++ b/src/java/org/apache/hadoop/security/SaslRpcClient.java @@ -39,7 +39,10 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.security.SaslRpcServer.AuthMethod; +import org.apache.hadoop.security.SaslRpcServer.SaslStatus; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; @@ -99,6 +102,14 @@ public SaslRpcClient(AuthMethod method, throw new IOException("Unable to find SASL client implementation"); } + private static void readStatus(DataInputStream inStream) throws IOException { + int status = inStream.readInt(); // read status + if (status != SaslStatus.SUCCESS.state) { + throw new RemoteException(WritableUtils.readString(inStream), + WritableUtils.readString(inStream)); + } + } + /** * Do client side SASL authentication with server via the given InputStream * and OutputStream @@ -130,6 +141,7 @@ public boolean saslConnect(InputStream inS, OutputStream outS) + " from initSASLContext."); } if (!saslClient.isComplete()) { + readStatus(inStream); int len = inStream.readInt(); if (len == SaslRpcServer.SWITCH_TO_SIMPLE_AUTH) { if (LOG.isDebugEnabled()) @@ -155,6 +167,7 @@ public boolean saslConnect(InputStream inS, OutputStream outS) outStream.flush(); } if (!saslClient.isComplete()) { + readStatus(inStream); saslToken = new byte[inStream.readInt()]; if (LOG.isDebugEnabled()) LOG.debug("Will read input token of size " + saslToken.length diff --git a/src/java/org/apache/hadoop/security/SaslRpcServer.java b/src/java/org/apache/hadoop/security/SaslRpcServer.java index 7b69476a6a..16a7edbc3d 100644 --- a/src/java/org/apache/hadoop/security/SaslRpcServer.java +++ b/src/java/org/apache/hadoop/security/SaslRpcServer.java @@ -41,6 +41,7 @@ import org.apache.hadoop.ipc.Server; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.security.token.SecretManager.InvalidToken; /** * A utility class for dealing with SASL on RPC server @@ -67,11 +68,16 @@ static byte[] decodeIdentifier(String identifier) { } public static TokenIdentifier getIdentifier(String id, - SecretManager secretManager) throws IOException { + SecretManager secretManager) throws InvalidToken { byte[] tokenId = decodeIdentifier(id); TokenIdentifier tokenIdentifier = secretManager.createIdentifier(); - tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream( - tokenId))); + try { + tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream( + tokenId))); + } catch (IOException e) { + throw (InvalidToken) new InvalidToken( + "Can't de-serialize tokenIdentifier").initCause(e); + } return tokenIdentifier; } @@ -84,6 +90,16 @@ public static String[] splitKerberosName(String fullName) { return fullName.split("[/@]"); } + public enum SaslStatus { + SUCCESS (0), + ERROR (1); + + public final int state; + private SaslStatus(int state) { + this.state = state; + } + } + /** Authentication method */ public static enum AuthMethod { SIMPLE((byte) 80, ""), // no authentication @@ -135,13 +151,13 @@ public SaslDigestCallbackHandler( this.connection = connection; } - private char[] getPassword(TokenIdentifier tokenid) throws IOException { + private char[] getPassword(TokenIdentifier tokenid) throws InvalidToken { return encodePassword(secretManager.retrievePassword(tokenid)); } /** {@inheritDoc} */ @Override - public void handle(Callback[] callbacks) throws IOException, + public void handle(Callback[] callbacks) throws InvalidToken, UnsupportedCallbackException { NameCallback nc = null; PasswordCallback pc = null; @@ -198,7 +214,7 @@ public static class SaslGssCallbackHandler implements CallbackHandler { /** {@inheritDoc} */ @Override - public void handle(Callback[] callbacks) throws IOException, + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { AuthorizeCallback ac = null; for (Callback callback : callbacks) { diff --git a/src/test/core/org/apache/hadoop/ipc/TestRPC.java b/src/test/core/org/apache/hadoop/ipc/TestRPC.java index c7b41a4ed4..7f2e170893 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestRPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestRPC.java @@ -39,6 +39,7 @@ import org.apache.hadoop.security.authorize.PolicyProvider; import org.apache.hadoop.security.authorize.Service; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; +import org.apache.hadoop.security.AccessControlException; import static org.mockito.Mockito.*; @@ -421,6 +422,30 @@ public void testStopNonRegisteredProxy() throws Exception { RPC.stopProxy(mock(TestProtocol.class)); } + public void testErrorMsgForInsecureClient() throws Exception { + final Server server = RPC.getServer(TestProtocol.class, + new TestImpl(), ADDRESS, 0, 5, true, conf, null); + server.enableSecurity(); + server.start(); + boolean succeeded = false; + final InetSocketAddress addr = NetUtils.getConnectAddress(server); + TestProtocol proxy = null; + try { + proxy = (TestProtocol) RPC.getProxy(TestProtocol.class, + TestProtocol.versionID, addr, conf); + } catch (RemoteException e) { + LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); + assertTrue(e.unwrapRemoteException() instanceof AccessControlException); + succeeded = true; + } finally { + server.stop(); + if (proxy != null) { + RPC.stopProxy(proxy); + } + } + assertTrue(succeeded); + } + public static void main(String[] args) throws Exception { new TestRPC("test").testCalls(conf); diff --git a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java index 6b2bf790d8..4ea5415cba 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java @@ -19,6 +19,7 @@ package org.apache.hadoop.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION; +import static org.junit.Assert.*; import java.io.DataInput; import java.io.DataOutput; @@ -38,6 +39,7 @@ import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenInfo; import org.apache.hadoop.security.token.TokenSelector; +import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.security.SaslInputStream; import org.apache.hadoop.security.SaslRpcClient; import org.apache.hadoop.security.SaslRpcServer; @@ -53,6 +55,7 @@ public class TestSaslRPC { public static final Log LOG = LogFactory.getLog(TestSaslRPC.class); + static final String ERROR_MESSAGE = "Token is invalid"; static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal"; private static Configuration conf; static { @@ -127,6 +130,14 @@ public TestTokenIdentifier createIdentifier() { return new TestTokenIdentifier(); } } + + public static class BadTokenSecretManager extends TestTokenSecretManager { + + public byte[] retrievePassword(TestTokenIdentifier id) + throws InvalidToken { + throw new InvalidToken(ERROR_MESSAGE); + } + } public static class TestTokenSelector implements TokenSelector { @@ -174,6 +185,24 @@ public void testSecureToInsecureRpc() throws Exception { doDigestRpc(server, sm); } + @Test + public void testErrorMessage() throws Exception { + BadTokenSecretManager sm = new BadTokenSecretManager(); + final Server server = RPC.getServer(TestSaslProtocol.class, + new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); + + boolean succeeded = false; + try { + doDigestRpc(server, sm); + } catch (RemoteException e) { + LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); + assertTrue(ERROR_MESSAGE.equals(e.getLocalizedMessage())); + assertTrue(e.unwrapRemoteException() instanceof InvalidToken); + succeeded = true; + } + assertTrue(succeeded); + } + private void doDigestRpc(Server server, TestTokenSecretManager sm) throws Exception { server.start();