diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java index 1ac3fd3b52..23cdc50d66 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InterruptedIOException; +import java.net.ConnectException; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; @@ -27,6 +28,8 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import javax.net.ssl.SSLHandshakeException; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; @@ -115,7 +118,6 @@ private T doOp(ProviderCallable op, int currPos) if (providers.length == 0) { throw new IOException("No providers configured !"); } - IOException ex = null; int numFailovers = 0; for (int i = 0;; i++, numFailovers++) { KMSClientProvider provider = providers[(currPos + i) % providers.length]; @@ -130,8 +132,15 @@ private T doOp(ProviderCallable op, int currPos) } catch (IOException ioe) { LOG.warn("KMS provider at [{}] threw an IOException: ", provider.getKMSUrl(), ioe); - ex = ioe; - + // SSLHandshakeException can occur here because of lost connection + // with the KMS server, creating a ConnectException from it, + // so that the FailoverOnNetworkExceptionRetry policy will retry + if (ioe instanceof SSLHandshakeException) { + Exception cause = ioe; + ioe = new ConnectException("SSLHandshakeException: " + + cause.getMessage()); + ioe.initCause(cause); + } RetryAction action = null; try { action = retryPolicy.shouldRetry(ioe, 0, numFailovers, false); @@ -153,7 +162,7 @@ private T doOp(ProviderCallable op, int currPos) CommonConfigurationKeysPublic. KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, providers.length), providers.length); - throw ex; + throw ioe; } if (((numFailovers + 1) % providers.length) == 0) { // Sleep only after we try all the providers for every cycle. diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java index bd68dca22c..4e7aed9cac 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java @@ -18,6 +18,7 @@ package org.apache.hadoop.crypto.key.kms; import static org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; +import static org.apache.hadoop.test.LambdaTestUtils.intercept; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -26,12 +27,15 @@ import static org.mockito.Mockito.verify; import java.io.IOException; +import java.net.ConnectException; import java.net.NoRouteToHostException; import java.net.URI; import java.net.UnknownHostException; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; +import javax.net.ssl.SSLHandshakeException; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider.Options; @@ -44,13 +48,18 @@ import org.apache.hadoop.security.authorize.AuthorizationException; import org.junit.After; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.mockito.Mockito; import com.google.common.collect.Sets; public class TestLoadBalancingKMSClientProvider { + @Rule + public Timeout testTimeout = new Timeout(30 * 1000); + @BeforeClass public static void setup() throws IOException { SecurityUtil.setTokenServiceUseIp(false); @@ -638,4 +647,74 @@ public void testClientRetriesWithAuthenticationExceptionWrappedinIOException() verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"), Mockito.any(Options.class)); } + + /** + * Tests the operation succeeds second time after SSLHandshakeException. + * @throws Exception + */ + @Test + public void testClientRetriesWithSSLHandshakeExceptionSucceedsSecondTime() + throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3); + final String keyName = "test"; + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new SSLHandshakeException("p1")) + .thenReturn(new KMSClientProvider.KMSKeyVersion(keyName, "v1", + new byte[0])); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectException("p2")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + + kp.createKey(keyName, new Options(conf)); + verify(p1, Mockito.times(2)).createKey(Mockito.eq(keyName), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq(keyName), + Mockito.any(Options.class)); + } + + /** + * Tests the operation fails at every attempt after SSLHandshakeException. + * @throws Exception + */ + @Test + public void testClientRetriesWithSSLHandshakeExceptionFailsAtEveryAttempt() + throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 2); + final String keyName = "test"; + final String exceptionMessage = "p1 exception message"; + KMSClientProvider p1 = mock(KMSClientProvider.class); + Exception originalSslEx = new SSLHandshakeException(exceptionMessage); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(originalSslEx); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectException("p2 exception message")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + + Exception interceptedEx = intercept(ConnectException.class, + "SSLHandshakeException: " + exceptionMessage, + ()-> kp.createKey(keyName, new Options(conf))); + assertEquals(originalSslEx, interceptedEx.getCause()); + + verify(p1, Mockito.times(2)).createKey(Mockito.eq(keyName), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq(keyName), + Mockito.any(Options.class)); + } } \ No newline at end of file