diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java index c1579e7132..83bc7b5316 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java @@ -169,7 +169,10 @@ public EncryptedKeyVersion call(KMSClientProvider provider) } }, nextIdx()); } catch (WrapperException we) { - throw (GeneralSecurityException) we.getCause(); + if (we.getCause() instanceof GeneralSecurityException) { + throw (GeneralSecurityException) we.getCause(); + } + throw new IOException(we.getCause()); } } @@ -186,7 +189,10 @@ public KeyVersion call(KMSClientProvider provider) } }, nextIdx()); } catch (WrapperException we) { - throw (GeneralSecurityException)we.getCause(); + if (we.getCause() instanceof GeneralSecurityException) { + throw (GeneralSecurityException) we.getCause(); + } + throw new IOException(we.getCause()); } } @@ -273,7 +279,10 @@ public KeyVersion call(KMSClientProvider provider) throws IOException, } }, nextIdx()); } catch (WrapperException e) { - throw (NoSuchAlgorithmException)e.getCause(); + if (e.getCause() instanceof GeneralSecurityException) { + throw (NoSuchAlgorithmException) e.getCause(); + } + throw new IOException(e.getCause()); } } @Override @@ -309,7 +318,10 @@ public KeyVersion call(KMSClientProvider provider) throws IOException, } }, nextIdx()); } catch (WrapperException e) { - throw (NoSuchAlgorithmException)e.getCause(); + if (e.getCause() instanceof GeneralSecurityException) { + throw (NoSuchAlgorithmException) e.getCause(); + } + throw new IOException(e.getCause()); } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java index 4e421da221..0499691820 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java @@ -17,6 +17,7 @@ */ package org.apache.hadoop.crypto.key.kms; +import static org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -25,11 +26,14 @@ import java.io.IOException; import java.net.URI; +import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider.Options; +import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension; +import org.apache.hadoop.security.authentication.client.AuthenticationException; import org.junit.Test; import org.mockito.Mockito; @@ -163,4 +167,94 @@ public void testLoadBalancingWithAllBadNodes() throws Exception { assertTrue(e instanceof IOException); } } + + // copied from HttpExceptionUtils: + + // trick, riding on generics to throw an undeclared exception + + private static void throwEx(Throwable ex) { + TestLoadBalancingKMSClientProvider.throwException(ex); + } + + @SuppressWarnings("unchecked") + private static void throwException(Throwable ex) + throws E { + throw (E) ex; + } + + private class MyKMSClientProvider extends KMSClientProvider { + public MyKMSClientProvider(URI uri, Configuration conf) throws IOException { + super(uri, conf); + } + + @Override + public EncryptedKeyVersion generateEncryptedKey( + final String encryptionKeyName) + throws IOException, GeneralSecurityException { + throwEx(new AuthenticationException("bar")); + return null; + } + + @Override + public KeyVersion decryptEncryptedKey( + final EncryptedKeyVersion encryptedKeyVersion) throws IOException, + GeneralSecurityException { + throwEx(new AuthenticationException("bar")); + return null; + } + + @Override + public KeyVersion createKey(final String name, final Options options) + throws NoSuchAlgorithmException, IOException { + throwEx(new AuthenticationException("bar")); + return null; + } + + @Override + public KeyVersion rollNewVersion(final String name) + throws NoSuchAlgorithmException, IOException { + throwEx(new AuthenticationException("bar")); + return null; + } + } + + @Test + public void testClassCastException() throws Exception { + Configuration conf = new Configuration(); + KMSClientProvider p1 = new MyKMSClientProvider( + new URI("kms://http@host1/kms/foo"), conf); + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1}, 0, conf); + try { + kp.generateEncryptedKey("foo"); + } catch (IOException ioe) { + assertTrue(ioe.getCause().getClass().getName().contains( + "AuthenticationException")); + } + + try { + final KeyProviderCryptoExtension.EncryptedKeyVersion + encryptedKeyVersion = + mock(KeyProviderCryptoExtension.EncryptedKeyVersion.class); + kp.decryptEncryptedKey(encryptedKeyVersion); + } catch (IOException ioe) { + assertTrue(ioe.getCause().getClass().getName().contains( + "AuthenticationException")); + } + + try { + final KeyProvider.Options options = KeyProvider.options(conf); + kp.createKey("foo", options); + } catch (IOException ioe) { + assertTrue(ioe.getCause().getClass().getName().contains( + "AuthenticationException")); + } + + try { + kp.rollNewVersion("foo"); + } catch (IOException ioe) { + assertTrue(ioe.getCause().getClass().getName().contains( + "AuthenticationException")); + } + } }