HADOOP-14521. KMS client needs retry logic. Contributed by Rushabh S Shah.

This commit is contained in:
Xiao Chen 2017-10-05 19:38:00 -07:00
parent 644c2f6924
commit 25f31d9fc4
6 changed files with 466 additions and 43 deletions

View File

@ -250,9 +250,8 @@ public static class Factory extends KeyProviderFactory {
* - HOSTNAME = string * - HOSTNAME = string
* - PORT = integer * - PORT = integer
* *
* If multiple hosts are provider, the Factory will create a * This will always create a {@link LoadBalancingKMSClientProvider}
* {@link LoadBalancingKMSClientProvider} that round-robins requests * if the uri is correct.
* across the provided list of hosts.
*/ */
@Override @Override
public KeyProvider createProvider(URI providerUri, Configuration conf) public KeyProvider createProvider(URI providerUri, Configuration conf)
@ -279,30 +278,26 @@ public KeyProvider createProvider(URI providerUri, Configuration conf)
} }
hostsPart = t[0]; hostsPart = t[0];
} }
return createProvider(providerUri, conf, origUrl, port, hostsPart); return createProvider(conf, origUrl, port, hostsPart);
} }
return null; return null;
} }
private KeyProvider createProvider(URI providerUri, Configuration conf, private KeyProvider createProvider(Configuration conf,
URL origUrl, int port, String hostsPart) throws IOException { URL origUrl, int port, String hostsPart) throws IOException {
String[] hosts = hostsPart.split(";"); String[] hosts = hostsPart.split(";");
if (hosts.length == 1) { KMSClientProvider[] providers = new KMSClientProvider[hosts.length];
return new KMSClientProvider(providerUri, conf); for (int i = 0; i < hosts.length; i++) {
} else { try {
KMSClientProvider[] providers = new KMSClientProvider[hosts.length]; providers[i] =
for (int i = 0; i < hosts.length; i++) { new KMSClientProvider(
try { new URI("kms", origUrl.getProtocol(), hosts[i], port,
providers[i] = origUrl.getPath(), null, null), conf);
new KMSClientProvider( } catch (URISyntaxException e) {
new URI("kms", origUrl.getProtocol(), hosts[i], port, throw new IOException("Could not instantiate KMSProvider.", e);
origUrl.getPath(), null, null), conf);
} catch (URISyntaxException e) {
throw new IOException("Could not instantiate KMSProvider..", e);
}
} }
return new LoadBalancingKMSClientProvider(providers, conf);
} }
return new LoadBalancingKMSClientProvider(providers, conf);
} }
} }
@ -1031,7 +1026,11 @@ public Token<?> run() throws Exception {
} catch (InterruptedException e) { } catch (InterruptedException e) {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
} catch (Exception e) { } catch (Exception e) {
throw new IOException(e); if (e instanceof IOException) {
throw (IOException) e;
} else {
throw new IOException(e);
}
} }
} }
return tokens; return tokens;

View File

@ -19,6 +19,7 @@
package org.apache.hadoop.crypto.key.kms; package org.apache.hadoop.crypto.key.kms;
import java.io.IOException; import java.io.IOException;
import java.io.InterruptedIOException;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.Arrays; import java.util.Arrays;
@ -31,9 +32,13 @@
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension; import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.retry.RetryPolicies;
import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.io.retry.RetryPolicy.RetryAction;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -69,6 +74,8 @@ public WrapperException(Throwable cause) {
private final KMSClientProvider[] providers; private final KMSClientProvider[] providers;
private final AtomicInteger currentIdx; private final AtomicInteger currentIdx;
private RetryPolicy retryPolicy = null;
public LoadBalancingKMSClientProvider(KMSClientProvider[] providers, public LoadBalancingKMSClientProvider(KMSClientProvider[] providers,
Configuration conf) { Configuration conf) {
this(shuffle(providers), Time.monotonicNow(), conf); this(shuffle(providers), Time.monotonicNow(), conf);
@ -80,24 +87,82 @@ public LoadBalancingKMSClientProvider(KMSClientProvider[] providers,
super(conf); super(conf);
this.providers = providers; this.providers = providers;
this.currentIdx = new AtomicInteger((int)(seed % providers.length)); this.currentIdx = new AtomicInteger((int)(seed % providers.length));
int maxNumRetries = conf.getInt(CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, providers.length);
int sleepBaseMillis = conf.getInt(CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_KEY,
CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_DEFAULT);
int sleepMaxMillis = conf.getInt(CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_KEY,
CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_DEFAULT);
Preconditions.checkState(maxNumRetries >= 0);
Preconditions.checkState(sleepBaseMillis >= 0);
Preconditions.checkState(sleepMaxMillis >= 0);
this.retryPolicy = RetryPolicies.failoverOnNetworkException(
RetryPolicies.TRY_ONCE_THEN_FAIL, maxNumRetries, 0, sleepBaseMillis,
sleepMaxMillis);
} }
@VisibleForTesting @VisibleForTesting
KMSClientProvider[] getProviders() { public KMSClientProvider[] getProviders() {
return providers; return providers;
} }
private <T> T doOp(ProviderCallable<T> op, int currPos) private <T> T doOp(ProviderCallable<T> op, int currPos)
throws IOException { throws IOException {
if (providers.length == 0) {
throw new IOException("No providers configured !");
}
IOException ex = null; IOException ex = null;
for (int i = 0; i < providers.length; i++) { int numFailovers = 0;
for (int i = 0;; i++, numFailovers++) {
KMSClientProvider provider = providers[(currPos + i) % providers.length]; KMSClientProvider provider = providers[(currPos + i) % providers.length];
try { try {
return op.call(provider); return op.call(provider);
} catch (AccessControlException ace) {
// No need to retry on AccessControlException
// and AuthorizationException.
// This assumes all the servers are configured with identical
// permissions and identical key acls.
throw ace;
} catch (IOException ioe) { } catch (IOException ioe) {
LOG.warn("KMS provider at [{}] threw an IOException!! {}", LOG.warn("KMS provider at [{}] threw an IOException: ",
provider.getKMSUrl(), StringUtils.stringifyException(ioe)); provider.getKMSUrl(), ioe);
ex = ioe; ex = ioe;
RetryAction action = null;
try {
action = retryPolicy.shouldRetry(ioe, 0, numFailovers, false);
} catch (Exception e) {
if (e instanceof IOException) {
throw (IOException)e;
}
throw new IOException(e);
}
// make sure each provider is tried at least once, to keep behavior
// compatible with earlier versions of LBKMSCP
if (action.action == RetryAction.RetryDecision.FAIL
&& numFailovers >= providers.length - 1) {
LOG.warn("Aborting since the Request has failed with all KMS"
+ " providers(depending on {}={} setting and numProviders={})"
+ " in the group OR the exception is not recoverable",
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY,
getConf().getInt(
CommonConfigurationKeysPublic.
KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, providers.length),
providers.length);
throw ex;
}
if (((numFailovers + 1) % providers.length) == 0) {
// Sleep only after we try all the providers for every cycle.
try {
Thread.sleep(action.delayMillis);
} catch (InterruptedException e) {
throw new InterruptedIOException("Thread Interrupted");
}
}
} catch (Exception e) { } catch (Exception e) {
if (e instanceof RuntimeException) { if (e instanceof RuntimeException) {
throw (RuntimeException)e; throw (RuntimeException)e;
@ -106,12 +171,6 @@ private <T> T doOp(ProviderCallable<T> op, int currPos)
} }
} }
} }
if (ex != null) {
LOG.warn("Aborting since the Request has failed with all KMS"
+ " providers in the group. !!");
throw ex;
}
throw new IOException("No providers configured !!");
} }
private int nextIdx() { private int nextIdx() {

View File

@ -721,6 +721,35 @@ public class CommonConfigurationKeysPublic {
/** Default value for KMS_CLIENT_ENC_KEY_CACHE_EXPIRY (12 hrs)*/ /** Default value for KMS_CLIENT_ENC_KEY_CACHE_EXPIRY (12 hrs)*/
public static final int KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_DEFAULT = 43200000; public static final int KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_DEFAULT = 43200000;
/**
* @see
* <a href="{@docRoot}/../hadoop-project-dist/hadoop-common/core-default.xml">
* core-default.xml</a>
*/
/** Default value is the number of providers specified. */
public static final String KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY =
"hadoop.security.kms.client.failover.max.retries";
/**
* @see
* <a href="{@docRoot}/../hadoop-project-dist/hadoop-common/core-default.xml">
* core-default.xml</a>
*/
public static final String KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_KEY =
"hadoop.security.kms.client.failover.sleep.base.millis";
/** Default value is 100 ms. */
public static final int KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_DEFAULT = 100;
/**
* @see
* <a href="{@docRoot}/../hadoop-project-dist/hadoop-common/core-default.xml">
* core-default.xml</a>
*/
public static final String KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_KEY =
"hadoop.security.kms.client.failover.sleep.max.millis";
/** Default value is 2 secs. */
public static final int KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_DEFAULT = 2000;
/** /**
* @see * @see
* <a href="{@docRoot}/../hadoop-project-dist/hadoop-common/core-default.xml"> * <a href="{@docRoot}/../hadoop-project-dist/hadoop-common/core-default.xml">

View File

@ -2335,6 +2335,34 @@
</description> </description>
</property> </property>
<property>
<name>hadoop.security.kms.client.failover.sleep.base.millis</name>
<value>100</value>
<description>
Expert only. The time to wait, in milliseconds, between failover
attempts increases exponentially as a function of the number of
attempts made so far, with a random factor of +/- 50%. This option
specifies the base value used in the failover calculation. The
first failover will retry immediately. The 2nd failover attempt
will delay at least hadoop.security.client.failover.sleep.base.millis
milliseconds. And so on.
</description>
</property>
<property>
<name>hadoop.security.kms.client.failover.sleep.max.millis</name>
<value>2000</value>
<description>
Expert only. The time to wait, in milliseconds, between failover
attempts increases exponentially as a function of the number of
attempts made so far, with a random factor of +/- 50%. This option
specifies the maximum value to wait between failovers.
Specifically, the time between two failover attempts will not
exceed +/- 50% of hadoop.security.client.failover.sleep.max.millis
milliseconds.
</description>
</property>
<property> <property>
<name>ipc.server.max.connections</name> <name>ipc.server.max.connections</name>
<value>0</value> <value>0</value>

View File

@ -23,9 +23,12 @@
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import java.io.IOException; import java.io.IOException;
import java.net.NoRouteToHostException;
import java.net.URI; import java.net.URI;
import java.net.UnknownHostException;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
@ -33,6 +36,9 @@
import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProvider.Options; import org.apache.hadoop.crypto.key.KeyProvider.Options;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.authentication.client.AuthenticationException; import org.apache.hadoop.security.authentication.client.AuthenticationException;
import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.AuthorizationException;
import org.junit.Test; import org.junit.Test;
@ -47,14 +53,17 @@ public void testCreation() throws Exception {
Configuration conf = new Configuration(); Configuration conf = new Configuration();
KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI( KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI(
"kms://http@host1/kms/foo"), conf); "kms://http@host1/kms/foo"), conf);
assertTrue(kp instanceof KMSClientProvider); assertTrue(kp instanceof LoadBalancingKMSClientProvider);
assertEquals("http://host1/kms/foo/v1/", KMSClientProvider[] providers =
((KMSClientProvider) kp).getKMSUrl()); ((LoadBalancingKMSClientProvider) kp).getProviders();
assertEquals(1, providers.length);
assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/"),
Sets.newHashSet(providers[0].getKMSUrl()));
kp = new KMSClientProvider.Factory().createProvider(new URI( kp = new KMSClientProvider.Factory().createProvider(new URI(
"kms://http@host1;host2;host3/kms/foo"), conf); "kms://http@host1;host2;host3/kms/foo"), conf);
assertTrue(kp instanceof LoadBalancingKMSClientProvider); assertTrue(kp instanceof LoadBalancingKMSClientProvider);
KMSClientProvider[] providers = providers =
((LoadBalancingKMSClientProvider) kp).getProviders(); ((LoadBalancingKMSClientProvider) kp).getProviders();
assertEquals(3, providers.length); assertEquals(3, providers.length);
assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/", assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/",
@ -320,4 +329,298 @@ public void testWarmUpEncryptedKeysWhenOneProviderSucceeds()
Mockito.verify(p1, Mockito.times(1)).warmUpEncryptedKeys(keyName); Mockito.verify(p1, Mockito.times(1)).warmUpEncryptedKeys(keyName);
Mockito.verify(p2, Mockito.times(1)).warmUpEncryptedKeys(keyName); Mockito.verify(p2, Mockito.times(1)).warmUpEncryptedKeys(keyName);
} }
/**
* Tests whether retryPolicy fails immediately, after trying each provider
* once, on encountering IOException which is not SocketException.
* @throws Exception
*/
@Test
public void testClientRetriesWithIOException() throws Exception {
Configuration conf = new Configuration();
// Setting total failover attempts to .
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 10);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.getMetadata(Mockito.anyString()))
.thenThrow(new IOException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.getMetadata(Mockito.anyString()))
.thenThrow(new IOException("p2"));
KMSClientProvider p3 = mock(KMSClientProvider.class);
when(p3.getMetadata(Mockito.anyString()))
.thenThrow(new IOException("p3"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
when(p3.getKMSUrl()).thenReturn("p3");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2, p3}, 0, conf);
try {
kp.getMetadata("test3");
fail("Should fail since all providers threw an IOException");
} catch (Exception e) {
assertTrue(e instanceof IOException);
}
verify(kp.getProviders()[0], Mockito.times(1))
.getMetadata(Mockito.eq("test3"));
verify(kp.getProviders()[1], Mockito.times(1))
.getMetadata(Mockito.eq("test3"));
verify(kp.getProviders()[2], Mockito.times(1))
.getMetadata(Mockito.eq("test3"));
}
/**
* Tests that client doesn't retry once it encounters AccessControlException
* from first provider.
* This assumes all the kms servers are configured with identical access to
* keys.
* @throws Exception
*/
@Test
public void testClientRetriesWithAccessControlException() throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new AccessControlException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new IOException("p2"));
KMSClientProvider p3 = mock(KMSClientProvider.class);
when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new IOException("p3"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
when(p3.getKMSUrl()).thenReturn("p3");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2, p3}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
fail("Should fail because provider p1 threw an AccessControlException");
} catch (Exception e) {
assertTrue(e instanceof AccessControlException);
}
verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.never()).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p3, Mockito.never()).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests that client doesn't retry once it encounters RunTimeException
* from first provider.
* This assumes all the kms servers are configured with identical access to
* keys.
* @throws Exception
*/
@Test
public void testClientRetriesWithRuntimeException() throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new RuntimeException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new IOException("p2"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
fail("Should fail since provider p1 threw RuntimeException");
} catch (Exception e) {
assertTrue(e instanceof RuntimeException);
}
verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.never()).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests the client retries until it finds a good provider.
* @throws Exception
*/
@Test
public void testClientRetriesWithTimeoutsException() throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 4);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new UnknownHostException("p2"));
KMSClientProvider p3 = mock(KMSClientProvider.class);
when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new NoRouteToHostException("p3"));
KMSClientProvider p4 = mock(KMSClientProvider.class);
when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenReturn(
new KMSClientProvider.KMSKeyVersion("test3", "v1", new byte[0]));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
when(p3.getKMSUrl()).thenReturn("p3");
when(p4.getKMSUrl()).thenReturn("p4");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2, p3, p4}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
} catch (Exception e) {
fail("Provider p4 should have answered the request.");
}
verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p3, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p4, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests the operation succeeds second time after ConnectTimeoutException.
* @throws Exception
*/
@Test
public void testClientRetriesSucceedsSecondTime() throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p1"))
.thenReturn(new KMSClientProvider.KMSKeyVersion("test3", "v1",
new byte[0]));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p2"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
} catch (Exception e) {
fail("Provider p1 should have answered the request second time.");
}
verify(p1, Mockito.times(2)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests whether retryPolicy retries specified number of times.
* @throws Exception
*/
@Test
public void testClientRetriesSpecifiedNumberOfTimes() throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 10);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p2"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
fail("Should fail");
} catch (Exception e) {
assert (e instanceof ConnectTimeoutException);
}
verify(p1, Mockito.times(6)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.times(5)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests whether retryPolicy retries number of times equals to number of
* providers if conf kms.client.failover.max.attempts is not set.
* @throws Exception
*/
@Test
public void testClientRetriesIfMaxAttemptsNotSet() throws Exception {
Configuration conf = new Configuration();
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p1"));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new ConnectTimeoutException("p2"));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
fail("Should fail");
} catch (Exception e) {
assert (e instanceof ConnectTimeoutException);
}
verify(p1, Mockito.times(2)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
/**
* Tests that client reties each provider once, when it encounters
* AuthenticationException wrapped in an IOException from first provider.
* @throws Exception
*/
@Test
public void testClientRetriesWithAuthenticationExceptionWrappedinIOException()
throws Exception {
Configuration conf = new Configuration();
conf.setInt(
CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3);
KMSClientProvider p1 = mock(KMSClientProvider.class);
when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new IOException(new AuthenticationException("p1")));
KMSClientProvider p2 = mock(KMSClientProvider.class);
when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
.thenThrow(new IOException(new AuthenticationException("p1")));
when(p1.getKMSUrl()).thenReturn("p1");
when(p2.getKMSUrl()).thenReturn("p2");
LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider(
new KMSClientProvider[] {p1, p2}, 0, conf);
try {
kp.createKey("test3", new Options(conf));
fail("Should fail since provider p1 threw AuthenticationException");
} catch (Exception e) {
assertTrue(e.getCause() instanceof AuthenticationException);
}
verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"),
Mockito.any(Options.class));
}
} }

View File

@ -21,6 +21,7 @@
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import org.apache.hadoop.crypto.key.kms.KMSClientProvider; import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider;
import org.apache.hadoop.crypto.key.kms.server.MiniKMS; import org.apache.hadoop.crypto.key.kms.server.MiniKMS;
import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
@ -69,14 +70,21 @@ public void teardown() {
protected void setProvider() { protected void setProvider() {
} }
private KMSClientProvider getKMSClientProvider() {
LoadBalancingKMSClientProvider lbkmscp =
(LoadBalancingKMSClientProvider) Whitebox
.getInternalState(cluster.getNamesystem().getProvider(), "extension");
assert lbkmscp.getProviders().length == 1;
return lbkmscp.getProviders()[0];
}
@Test(timeout = 120000) @Test(timeout = 120000)
public void testCreateEZPopulatesEDEKCache() throws Exception { public void testCreateEZPopulatesEDEKCache() throws Exception {
final Path zonePath = new Path("/TestEncryptionZone"); final Path zonePath = new Path("/TestEncryptionZone");
fsWrapper.mkdir(zonePath, FsPermission.getDirDefault(), false); fsWrapper.mkdir(zonePath, FsPermission.getDirDefault(), false);
dfsAdmin.createEncryptionZone(zonePath, TEST_KEY, NO_TRASH); dfsAdmin.createEncryptionZone(zonePath, TEST_KEY, NO_TRASH);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
KMSClientProvider kcp = (KMSClientProvider) Whitebox KMSClientProvider kcp = getKMSClientProvider();
.getInternalState(cluster.getNamesystem().getProvider(), "extension");
assertTrue(kcp.getEncKeyQueueSize(TEST_KEY) > 0); assertTrue(kcp.getEncKeyQueueSize(TEST_KEY) > 0);
} }
@ -110,8 +118,7 @@ public void testWarmupEDEKCacheOnStartup() throws Exception {
dfsAdmin.createEncryptionZone(zonePath, anotherKey, NO_TRASH); dfsAdmin.createEncryptionZone(zonePath, anotherKey, NO_TRASH);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
KMSClientProvider spy = (KMSClientProvider) Whitebox KMSClientProvider spy = getKMSClientProvider();
.getInternalState(cluster.getNamesystem().getProvider(), "extension");
assertTrue("key queue is empty after creating encryption zone", assertTrue("key queue is empty after creating encryption zone",
spy.getEncKeyQueueSize(TEST_KEY) > 0); spy.getEncKeyQueueSize(TEST_KEY) > 0);
@ -122,9 +129,7 @@ public void testWarmupEDEKCacheOnStartup() throws Exception {
GenericTestUtils.waitFor(new Supplier<Boolean>() { GenericTestUtils.waitFor(new Supplier<Boolean>() {
@Override @Override
public Boolean get() { public Boolean get() {
final KMSClientProvider kspy = (KMSClientProvider) Whitebox final KMSClientProvider kspy = getKMSClientProvider();
.getInternalState(cluster.getNamesystem().getProvider(),
"extension");
return kspy.getEncKeyQueueSize(TEST_KEY) > 0; return kspy.getEncKeyQueueSize(TEST_KEY) > 0;
} }
}, 1000, 60000); }, 1000, 60000);