diff --git a/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferClient.java b/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferClient.java index 7804bec4ae..a23a1080be 100644 --- a/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferClient.java +++ b/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferClient.java @@ -203,11 +203,15 @@ private IOStreamPair checkTrustAndSend(InetAddress addr, DataEncryptionKeyFactory encryptionKeyFactory, Token accessToken, DatanodeID datanodeId) throws IOException { - if (!trustedChannelResolver.isTrusted() && - !trustedChannelResolver.isTrusted(addr)) { + boolean localTrusted = trustedChannelResolver.isTrusted(); + boolean remoteTrusted = trustedChannelResolver.isTrusted(addr); + LOG.debug("SASL encryption trust check: localHostTrusted = {}, " + + "remoteHostTrusted = {}", localTrusted, remoteTrusted); + + if (!localTrusted || !remoteTrusted) { // The encryption key factory only returns a key if encryption is enabled. - DataEncryptionKey encryptionKey = - encryptionKeyFactory.newDataEncryptionKey(); + DataEncryptionKey encryptionKey = encryptionKeyFactory + .newDataEncryptionKey(); return send(addr, underlyingOut, underlyingIn, encryptionKey, accessToken, datanodeId); } else { diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestSaslDataTransfer.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestSaslDataTransfer.java index 2fe0a1c295..363bb5a76d 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestSaslDataTransfer.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestSaslDataTransfer.java @@ -23,8 +23,12 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import java.io.IOException; +import java.net.InetAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketTimeoutException; @@ -56,6 +60,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.rules.Timeout; +import org.mockito.Mockito; public class TestSaslDataTransfer extends SaslDataTransferTestCase { @@ -74,7 +79,7 @@ public class TestSaslDataTransfer extends SaslDataTransferTestCase { @After public void shutdown() { - IOUtils.cleanup(null, fs); + IOUtils.cleanupWithLogger(null, fs); if (cluster != null) { cluster.shutdown(); cluster = null; @@ -245,4 +250,151 @@ public DataEncryptionKey newDataEncryptionKey() { IOUtils.cleanup(null, socket, serverSocket); } } + + /** + * Verifies that SaslDataTransferClient#checkTrustAndSend should not trust a + * partially trusted channel. + */ + @Test + public void testSaslDataTransferWithTrustedServerUntrustedClient() throws + Exception { + HdfsConfiguration conf = createSecureConfig( + "authentication,integrity,privacy"); + + AtomicBoolean fallbackToSimpleAuth = new AtomicBoolean(false); + TrustedChannelResolver trustedChannelResolver = new + TrustedChannelResolver() { + @Override + public boolean isTrusted() { + return true; + } + + @Override + public boolean isTrusted(InetAddress peerAddress) { + return false; + } + }; + + SaslDataTransferClient saslClient = new SaslDataTransferClient( + conf, DataTransferSaslUtil.getSaslPropertiesResolver(conf), + trustedChannelResolver, fallbackToSimpleAuth); + + ServerSocket serverSocket = null; + Socket socket = null; + DataEncryptionKeyFactory dataEncryptionKeyFactory = null; + try { + serverSocket = new ServerSocket(10002, 10); + socket = new Socket(serverSocket.getInetAddress(), + serverSocket.getLocalPort()); + + dataEncryptionKeyFactory = mock(DataEncryptionKeyFactory.class); + Mockito.when(dataEncryptionKeyFactory.newDataEncryptionKey()) + .thenThrow(new IOException("Encryption enabled")); + + saslClient.socketSend(socket, null, null, dataEncryptionKeyFactory, + null, null); + + Assert.fail("Expected IOException from " + + "SaslDataTransferClient#checkTrustAndSend"); + } catch (IOException e) { + GenericTestUtils.assertExceptionContains("Encryption enabled", e); + verify(dataEncryptionKeyFactory, times(1)).newDataEncryptionKey(); + } finally { + IOUtils.cleanupWithLogger(null, socket, serverSocket); + } + } + + @Test + public void testSaslDataTransferWithUntrustedServerUntrustedClient() throws + Exception { + HdfsConfiguration conf = createSecureConfig( + "authentication,integrity,privacy"); + + AtomicBoolean fallbackToSimpleAuth = new AtomicBoolean(false); + TrustedChannelResolver trustedChannelResolver = new + TrustedChannelResolver() { + @Override + public boolean isTrusted() { + return false; + } + + @Override + public boolean isTrusted(InetAddress peerAddress) { + return false; + } + }; + + SaslDataTransferClient saslClient = new SaslDataTransferClient( + conf, DataTransferSaslUtil.getSaslPropertiesResolver(conf), + trustedChannelResolver, fallbackToSimpleAuth); + + ServerSocket serverSocket = null; + Socket socket = null; + DataEncryptionKeyFactory dataEncryptionKeyFactory = null; + try { + serverSocket = new ServerSocket(10002, 10); + socket = new Socket(serverSocket.getInetAddress(), + serverSocket.getLocalPort()); + + dataEncryptionKeyFactory = mock(DataEncryptionKeyFactory.class); + Mockito.when(dataEncryptionKeyFactory.newDataEncryptionKey()) + .thenThrow(new IOException("Encryption enabled")); + + saslClient.socketSend(socket, null, null, dataEncryptionKeyFactory, + null, null); + + Assert.fail("Expected IOException from " + + "SaslDataTransferClient#checkTrustAndSend"); + } catch (IOException e) { + GenericTestUtils.assertExceptionContains("Encryption enabled", e); + verify(dataEncryptionKeyFactory, times(1)).newDataEncryptionKey(); + } finally { + IOUtils.cleanupWithLogger(null, socket, serverSocket); + } + } + + @Test + public void testSaslDataTransferWithTrustedServerTrustedClient() throws + Exception { + HdfsConfiguration conf = createSecureConfig( + "authentication,integrity,privacy"); + + AtomicBoolean fallbackToSimpleAuth = new AtomicBoolean(false); + TrustedChannelResolver trustedChannelResolver = new + TrustedChannelResolver() { + @Override + public boolean isTrusted() { + return true; + } + + @Override + public boolean isTrusted(InetAddress peerAddress) { + return true; + } + }; + + SaslDataTransferClient saslClient = new SaslDataTransferClient( + conf, DataTransferSaslUtil.getSaslPropertiesResolver(conf), + trustedChannelResolver, fallbackToSimpleAuth); + + ServerSocket serverSocket = null; + Socket socket = null; + DataEncryptionKeyFactory dataEncryptionKeyFactory = null; + try { + serverSocket = new ServerSocket(10002, 10); + socket = new Socket(serverSocket.getInetAddress(), + serverSocket.getLocalPort()); + + dataEncryptionKeyFactory = mock(DataEncryptionKeyFactory.class); + Mockito.when(dataEncryptionKeyFactory.newDataEncryptionKey()) + .thenThrow(new IOException("Encryption enabled")); + + saslClient.socketSend(socket, null, null, dataEncryptionKeyFactory, + null, null); + verify(dataEncryptionKeyFactory, times(0)).newDataEncryptionKey(); + } finally { + IOUtils.cleanupWithLogger(null, socket, serverSocket); + } + } + }