diff --git a/hadoop-common-project/hadoop-common/CHANGES.txt b/hadoop-common-project/hadoop-common/CHANGES.txt
index 0d452f7c87..39062a81f0 100644
--- a/hadoop-common-project/hadoop-common/CHANGES.txt
+++ b/hadoop-common-project/hadoop-common/CHANGES.txt
@@ -648,6 +648,9 @@ Release 2.7.0 - UNRELEASED
HADOOP-11506. Configuration variable expansion regex expensive for long
values. (Gera Shegalov via gera)
+ HADOOP-11620. Add support for load balancing across a group of KMS for HA.
+ (Arun Suresh via wang)
+
BUG FIXES
HADOOP-11512. Use getTrimmedStrings when reading serialization keys
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
index 97ab2535d9..223e69a1a8 100644
--- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
@@ -52,6 +52,7 @@
import java.lang.reflect.UndeclaredThrowableException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
+import java.net.MalformedURLException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URISyntaxException;
@@ -74,6 +75,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.base.Strings;
/**
* KMS client KeyProvider
implementation.
@@ -221,14 +223,71 @@ private static void writeJson(Map map, OutputStream os) throws IOException {
*/
public static class Factory extends KeyProviderFactory {
+ /**
+ * This provider expects URIs in the following form :
+ * kms://@/
+ *
+ * where :
+ * - PROTO = http or https
+ * - AUTHORITY = [:]
+ * - HOSTS = [;]
+ * - HOSTNAME = string
+ * - PORT = integer
+ *
+ * If multiple hosts are provider, the Factory will create a
+ * {@link LoadBalancingKMSClientProvider} that round-robins requests
+ * across the provided list of hosts.
+ */
@Override
- public KeyProvider createProvider(URI providerName, Configuration conf)
+ public KeyProvider createProvider(URI providerUri, Configuration conf)
throws IOException {
- if (SCHEME_NAME.equals(providerName.getScheme())) {
- return new KMSClientProvider(providerName, conf);
+ if (SCHEME_NAME.equals(providerUri.getScheme())) {
+ URL origUrl = new URL(extractKMSPath(providerUri).toString());
+ String authority = origUrl.getAuthority();
+ // check for ';' which delimits the backup hosts
+ if (Strings.isNullOrEmpty(authority)) {
+ throw new IOException(
+ "No valid authority in kms uri [" + origUrl + "]");
+ }
+ // Check if port is present in authority
+ // In the current scheme, all hosts have to run on the same port
+ int port = -1;
+ String hostsPart = authority;
+ if (authority.contains(":")) {
+ String[] t = authority.split(":");
+ try {
+ port = Integer.parseInt(t[1]);
+ } catch (Exception e) {
+ throw new IOException(
+ "Could not parse port in kms uri [" + origUrl + "]");
+ }
+ hostsPart = t[0];
+ }
+ return createProvider(providerUri, conf, origUrl, port, hostsPart);
}
return null;
}
+
+ private KeyProvider createProvider(URI providerUri, Configuration conf,
+ URL origUrl, int port, String hostsPart) throws IOException {
+ String[] hosts = hostsPart.split(";");
+ if (hosts.length == 1) {
+ return new KMSClientProvider(providerUri, conf);
+ } else {
+ KMSClientProvider[] providers = new KMSClientProvider[hosts.length];
+ for (int i = 0; i < hosts.length; i++) {
+ try {
+ providers[i] =
+ new KMSClientProvider(
+ new URI("kms", origUrl.getProtocol(), hosts[i], port,
+ origUrl.getPath(), null, null), conf);
+ } catch (URISyntaxException e) {
+ throw new IOException("Could not instantiate KMSProvider..", e);
+ }
+ }
+ return new LoadBalancingKMSClientProvider(providers, conf);
+ }
+ }
}
public static T checkNotNull(T o, String name)
@@ -302,10 +361,8 @@ public HttpURLConnection configure(HttpURLConnection conn)
public KMSClientProvider(URI uri, Configuration conf) throws IOException {
super(conf);
- Path path = ProviderUtils.unnestUri(uri);
- URL url = path.toUri().toURL();
- kmsUrl = createServiceURL(url);
- if ("https".equalsIgnoreCase(url.getProtocol())) {
+ kmsUrl = createServiceURL(extractKMSPath(uri));
+ if ("https".equalsIgnoreCase(new URL(kmsUrl).getProtocol())) {
sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf);
try {
sslFactory.init();
@@ -346,8 +403,12 @@ public KMSClientProvider(URI uri, Configuration conf) throws IOException {
.getCurrentUser();
}
- private String createServiceURL(URL url) throws IOException {
- String str = url.toExternalForm();
+ private static Path extractKMSPath(URI uri) throws MalformedURLException, IOException {
+ return ProviderUtils.unnestUri(uri);
+ }
+
+ private static String createServiceURL(Path path) throws IOException {
+ String str = new URL(path.toString()).toExternalForm();
if (str.endsWith("/")) {
str = str.substring(0, str.length() - 1);
}
@@ -853,4 +914,9 @@ public void close() throws IOException {
}
}
}
+
+ @VisibleForTesting
+ String getKMSUrl() {
+ return kmsUrl;
+ }
}
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
new file mode 100644
index 0000000000..c1579e7132
--- /dev/null
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
@@ -0,0 +1,347 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.crypto.key.kms;
+
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.security.NoSuchAlgorithmException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
+import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.util.Time;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.annotations.VisibleForTesting;
+
+/**
+ * A simple LoadBalancing KMSClientProvider that round-robins requests
+ * across a provided array of KMSClientProviders. It also retries failed
+ * requests on the next available provider in the load balancer group. It
+ * only retries failed requests that result in an IOException, sending back
+ * all other Exceptions to the caller without retry.
+ */
+public class LoadBalancingKMSClientProvider extends KeyProvider implements
+ CryptoExtension,
+ KeyProviderDelegationTokenExtension.DelegationTokenExtension {
+
+ public static Logger LOG =
+ LoggerFactory.getLogger(LoadBalancingKMSClientProvider.class);
+
+ static interface ProviderCallable {
+ public T call(KMSClientProvider provider) throws IOException, Exception;
+ }
+
+ @SuppressWarnings("serial")
+ static class WrapperException extends RuntimeException {
+ public WrapperException(Throwable cause) {
+ super(cause);
+ }
+ }
+
+ private final KMSClientProvider[] providers;
+ private final AtomicInteger currentIdx;
+
+ public LoadBalancingKMSClientProvider(KMSClientProvider[] providers,
+ Configuration conf) {
+ this(shuffle(providers), Time.monotonicNow(), conf);
+ }
+
+ @VisibleForTesting
+ LoadBalancingKMSClientProvider(KMSClientProvider[] providers, long seed,
+ Configuration conf) {
+ super(conf);
+ this.providers = providers;
+ this.currentIdx = new AtomicInteger((int)(seed % providers.length));
+ }
+
+ @VisibleForTesting
+ KMSClientProvider[] getProviders() {
+ return providers;
+ }
+
+ private T doOp(ProviderCallable op, int currPos)
+ throws IOException {
+ IOException ex = null;
+ for (int i = 0; i < providers.length; i++) {
+ KMSClientProvider provider = providers[(currPos + i) % providers.length];
+ try {
+ return op.call(provider);
+ } catch (IOException ioe) {
+ LOG.warn("KMS provider at [{}] threw an IOException [{}]!!",
+ provider.getKMSUrl(), ioe.getMessage());
+ ex = ioe;
+ } catch (Exception e) {
+ if (e instanceof RuntimeException) {
+ throw (RuntimeException)e;
+ } else {
+ throw new WrapperException(e);
+ }
+ }
+ }
+ if (ex != null) {
+ LOG.warn("Aborting since the Request has failed with all KMS"
+ + " providers in the group. !!");
+ throw ex;
+ }
+ throw new IOException("No providers configured !!");
+ }
+
+ private int nextIdx() {
+ while (true) {
+ int current = currentIdx.get();
+ int next = (current + 1) % providers.length;
+ if (currentIdx.compareAndSet(current, next)) {
+ return current;
+ }
+ }
+ }
+
+ @Override
+ public Token>[]
+ addDelegationTokens(final String renewer, final Credentials credentials)
+ throws IOException {
+ return doOp(new ProviderCallable[]>() {
+ @Override
+ public Token>[] call(KMSClientProvider provider) throws IOException {
+ return provider.addDelegationTokens(renewer, credentials);
+ }
+ }, nextIdx());
+ }
+
+ // This request is sent to all providers in the load-balancing group
+ @Override
+ public void warmUpEncryptedKeys(String... keyNames) throws IOException {
+ for (KMSClientProvider provider : providers) {
+ try {
+ provider.warmUpEncryptedKeys(keyNames);
+ } catch (IOException ioe) {
+ LOG.error(
+ "Error warming up keys for provider with url"
+ + "[" + provider.getKMSUrl() + "]");
+ }
+ }
+ }
+
+ // This request is sent to all providers in the load-balancing group
+ @Override
+ public void drain(String keyName) {
+ for (KMSClientProvider provider : providers) {
+ provider.drain(keyName);
+ }
+ }
+
+ @Override
+ public EncryptedKeyVersion
+ generateEncryptedKey(final String encryptionKeyName)
+ throws IOException, GeneralSecurityException {
+ try {
+ return doOp(new ProviderCallable() {
+ @Override
+ public EncryptedKeyVersion call(KMSClientProvider provider)
+ throws IOException, GeneralSecurityException {
+ return provider.generateEncryptedKey(encryptionKeyName);
+ }
+ }, nextIdx());
+ } catch (WrapperException we) {
+ throw (GeneralSecurityException) we.getCause();
+ }
+ }
+
+ @Override
+ public KeyVersion
+ decryptEncryptedKey(final EncryptedKeyVersion encryptedKeyVersion)
+ throws IOException, GeneralSecurityException {
+ try {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider)
+ throws IOException, GeneralSecurityException {
+ return provider.decryptEncryptedKey(encryptedKeyVersion);
+ }
+ }, nextIdx());
+ } catch (WrapperException we) {
+ throw (GeneralSecurityException)we.getCause();
+ }
+ }
+
+ @Override
+ public KeyVersion getKeyVersion(final String versionName) throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException {
+ return provider.getKeyVersion(versionName);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public List getKeys() throws IOException {
+ return doOp(new ProviderCallable>() {
+ @Override
+ public List call(KMSClientProvider provider) throws IOException {
+ return provider.getKeys();
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public Metadata[] getKeysMetadata(final String... names) throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public Metadata[] call(KMSClientProvider provider) throws IOException {
+ return provider.getKeysMetadata(names);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public List getKeyVersions(final String name) throws IOException {
+ return doOp(new ProviderCallable>() {
+ @Override
+ public List call(KMSClientProvider provider)
+ throws IOException {
+ return provider.getKeyVersions(name);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public KeyVersion getCurrentKey(final String name) throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException {
+ return provider.getCurrentKey(name);
+ }
+ }, nextIdx());
+ }
+ @Override
+ public Metadata getMetadata(final String name) throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public Metadata call(KMSClientProvider provider) throws IOException {
+ return provider.getMetadata(name);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public KeyVersion createKey(final String name, final byte[] material,
+ final Options options) throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException {
+ return provider.createKey(name, material, options);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public KeyVersion createKey(final String name, final Options options)
+ throws NoSuchAlgorithmException, IOException {
+ try {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException,
+ NoSuchAlgorithmException {
+ return provider.createKey(name, options);
+ }
+ }, nextIdx());
+ } catch (WrapperException e) {
+ throw (NoSuchAlgorithmException)e.getCause();
+ }
+ }
+ @Override
+ public void deleteKey(final String name) throws IOException {
+ doOp(new ProviderCallable() {
+ @Override
+ public Void call(KMSClientProvider provider) throws IOException {
+ provider.deleteKey(name);
+ return null;
+ }
+ }, nextIdx());
+ }
+ @Override
+ public KeyVersion rollNewVersion(final String name, final byte[] material)
+ throws IOException {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException {
+ return provider.rollNewVersion(name, material);
+ }
+ }, nextIdx());
+ }
+
+ @Override
+ public KeyVersion rollNewVersion(final String name)
+ throws NoSuchAlgorithmException, IOException {
+ try {
+ return doOp(new ProviderCallable() {
+ @Override
+ public KeyVersion call(KMSClientProvider provider) throws IOException,
+ NoSuchAlgorithmException {
+ return provider.rollNewVersion(name);
+ }
+ }, nextIdx());
+ } catch (WrapperException e) {
+ throw (NoSuchAlgorithmException)e.getCause();
+ }
+ }
+
+ // Close all providers in the LB group
+ @Override
+ public void close() throws IOException {
+ for (KMSClientProvider provider : providers) {
+ try {
+ provider.close();
+ } catch (IOException ioe) {
+ LOG.error("Error closing provider with url"
+ + "[" + provider.getKMSUrl() + "]");
+ }
+ }
+ }
+
+
+ @Override
+ public void flush() throws IOException {
+ for (KMSClientProvider provider : providers) {
+ try {
+ provider.flush();
+ } catch (IOException ioe) {
+ LOG.error("Error flushing provider with url"
+ + "[" + provider.getKMSUrl() + "]");
+ }
+ }
+ }
+
+ private static KMSClientProvider[] shuffle(KMSClientProvider[] providers) {
+ List list = Arrays.asList(providers);
+ Collections.shuffle(list);
+ return list.toArray(providers);
+ }
+}
diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java
new file mode 100644
index 0000000000..08a3d93d2f
--- /dev/null
+++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java
@@ -0,0 +1,166 @@
+/** when(p1.getKMSUrl()).thenReturn("p1");
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.crypto.key.kms;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.net.URI;
+import java.security.NoSuchAlgorithmException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProvider.Options;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import com.google.common.collect.Sets;
+
+public class TestLoadBalancingKMSClientProvider {
+
+ @Test
+ public void testCreation() throws Exception {
+ Configuration conf = new Configuration();
+ KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI(
+ "kms://http@host1/kms/foo"), conf);
+ assertTrue(kp instanceof KMSClientProvider);
+ assertEquals("http://host1/kms/foo/v1/",
+ ((KMSClientProvider) kp).getKMSUrl());
+
+ kp = new KMSClientProvider.Factory().createProvider(new URI(
+ "kms://http@host1;host2;host3/kms/foo"), conf);
+ assertTrue(kp instanceof LoadBalancingKMSClientProvider);
+ KMSClientProvider[] providers =
+ ((LoadBalancingKMSClientProvider) kp).getProviders();
+ assertEquals(3, providers.length);
+ assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/",
+ "http://host2/kms/foo/v1/",
+ "http://host3/kms/foo/v1/"),
+ Sets.newHashSet(providers[0].getKMSUrl(),
+ providers[1].getKMSUrl(),
+ providers[2].getKMSUrl()));
+
+ kp = new KMSClientProvider.Factory().createProvider(new URI(
+ "kms://http@host1;host2;host3:16000/kms/foo"), conf);
+ assertTrue(kp instanceof LoadBalancingKMSClientProvider);
+ providers =
+ ((LoadBalancingKMSClientProvider) kp).getProviders();
+ assertEquals(3, providers.length);
+ assertEquals(Sets.newHashSet("http://host1:16000/kms/foo/v1/",
+ "http://host2:16000/kms/foo/v1/",
+ "http://host3:16000/kms/foo/v1/"),
+ Sets.newHashSet(providers[0].getKMSUrl(),
+ providers[1].getKMSUrl(),
+ providers[2].getKMSUrl()));
+ }
+
+ @Test
+ public void testLoadBalancing() throws Exception {
+ Configuration conf = new Configuration();
+ KMSClientProvider p1 = mock(KMSClientProvider.class);
+ when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenReturn(
+ new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
+ KMSClientProvider p2 = mock(KMSClientProvider.class);
+ when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenReturn(
+ new KMSClientProvider.KMSKeyVersion("p2", "v2", new byte[0]));
+ KMSClientProvider p3 = mock(KMSClientProvider.class);
+ when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenReturn(
+ new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
+ KeyProvider kp = new LoadBalancingKMSClientProvider(
+ new KMSClientProvider[] { p1, p2, p3 }, 0, conf);
+ assertEquals("p1", kp.createKey("test1", new Options(conf)).getName());
+ assertEquals("p2", kp.createKey("test2", new Options(conf)).getName());
+ assertEquals("p3", kp.createKey("test3", new Options(conf)).getName());
+ assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
+ }
+
+ @Test
+ public void testLoadBalancingWithFailure() throws Exception {
+ Configuration conf = new Configuration();
+ KMSClientProvider p1 = mock(KMSClientProvider.class);
+ when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenReturn(
+ new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
+ when(p1.getKMSUrl()).thenReturn("p1");
+ // This should not be retried
+ KMSClientProvider p2 = mock(KMSClientProvider.class);
+ when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new NoSuchAlgorithmException("p2"));
+ when(p2.getKMSUrl()).thenReturn("p2");
+ KMSClientProvider p3 = mock(KMSClientProvider.class);
+ when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenReturn(
+ new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
+ when(p3.getKMSUrl()).thenReturn("p3");
+ // This should be retried
+ KMSClientProvider p4 = mock(KMSClientProvider.class);
+ when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new IOException("p4"));
+ when(p4.getKMSUrl()).thenReturn("p4");
+ KeyProvider kp = new LoadBalancingKMSClientProvider(
+ new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);
+
+ assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
+ // Exceptions other than IOExceptions will not be retried
+ try {
+ kp.createKey("test1", new Options(conf)).getName();
+ fail("Should fail since its not an IOException");
+ } catch (Exception e) {
+ assertTrue(e instanceof NoSuchAlgorithmException);
+ }
+ assertEquals("p3", kp.createKey("test2", new Options(conf)).getName());
+ // IOException will trigger retry in next provider
+ assertEquals("p1", kp.createKey("test3", new Options(conf)).getName());
+ }
+
+ @Test
+ public void testLoadBalancingWithAllBadNodes() throws Exception {
+ Configuration conf = new Configuration();
+ KMSClientProvider p1 = mock(KMSClientProvider.class);
+ when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new IOException("p1"));
+ KMSClientProvider p2 = mock(KMSClientProvider.class);
+ when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new IOException("p2"));
+ KMSClientProvider p3 = mock(KMSClientProvider.class);
+ when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new IOException("p3"));
+ KMSClientProvider p4 = mock(KMSClientProvider.class);
+ when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+ .thenThrow(new IOException("p4"));
+ when(p1.getKMSUrl()).thenReturn("p1");
+ when(p2.getKMSUrl()).thenReturn("p2");
+ when(p3.getKMSUrl()).thenReturn("p3");
+ when(p4.getKMSUrl()).thenReturn("p4");
+ KeyProvider kp = new LoadBalancingKMSClientProvider(
+ new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);
+ try {
+ kp.createKey("test3", new Options(conf)).getName();
+ fail("Should fail since all providers threw an IOException");
+ } catch (Exception e) {
+ assertTrue(e instanceof IOException);
+ }
+ }
+}
diff --git a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
index 70ba95f28d..c5a990b58b 100644
--- a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
+++ b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
@@ -24,9 +24,11 @@
import org.apache.hadoop.crypto.key.KeyProvider.KeyVersion;
import org.apache.hadoop.crypto.key.KeyProvider.Options;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
+import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.minikdc.MiniKdc;
@@ -99,6 +101,12 @@ protected URL getKMSUrl() {
}
}
+ protected KeyProvider createProvider(URI uri, Configuration conf)
+ throws IOException {
+ return new LoadBalancingKMSClientProvider(
+ new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf);
+ }
+
protected T runServer(String keystore, String password, File confDir,
KMSCallable callable) throws Exception {
return runServer(-1, keystore, password, confDir, callable);
@@ -305,7 +313,7 @@ public Void call() throws Exception {
final URI uri = createKMSUri(getKMSUrl());
if (ssl) {
- KeyProvider testKp = new KMSClientProvider(uri, conf);
+ KeyProvider testKp = createProvider(uri, conf);
ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
while (threadGroup.getParent() != null) {
threadGroup = threadGroup.getParent();
@@ -335,12 +343,14 @@ public Void call() throws Exception {
doAs(user, new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- final KeyProvider kp = new KMSClientProvider(uri, conf);
+ final KeyProvider kp = createProvider(uri, conf);
// getKeys() empty
Assert.assertTrue(kp.getKeys().isEmpty());
Thread.sleep(4000);
- Token>[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials());
+ Token>[] tokens =
+ ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp)
+ .addDelegationTokens("myuser", new Credentials());
Assert.assertEquals(1, tokens.length);
Assert.assertEquals("kms-dt", tokens[0].getKind().toString());
return null;
@@ -348,12 +358,14 @@ public Void run() throws Exception {
});
}
} else {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
// getKeys() empty
Assert.assertTrue(kp.getKeys().isEmpty());
Thread.sleep(4000);
- Token>[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials());
+ Token>[] tokens =
+ ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp)
+ .addDelegationTokens("myuser", new Credentials());
Assert.assertEquals(1, tokens.length);
Assert.assertEquals("kms-dt", tokens[0].getKind().toString());
}
@@ -404,7 +416,7 @@ public Void call() throws Exception {
Date started = new Date();
Configuration conf = new Configuration();
URI uri = createKMSUri(getKMSUrl());
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
// getKeys() empty
Assert.assertTrue(kp.getKeys().isEmpty());
@@ -687,7 +699,7 @@ public Void call() throws Exception {
doAs("CREATE", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
Options options = new KeyProvider.Options(conf);
Map attributes = options.getAttributes();
@@ -727,7 +739,7 @@ public Void run() throws Exception {
doAs("DECRYPT_EEK", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
Options options = new KeyProvider.Options(conf);
Map attributes = options.getAttributes();
@@ -760,7 +772,7 @@ public Void run() throws Exception {
doAs("ROLLOVER", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
Options options = new KeyProvider.Options(conf);
Map attributes = options.getAttributes();
@@ -804,7 +816,7 @@ public Void run() throws Exception {
doAs("GET", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
Options options = new KeyProvider.Options(conf);
Map attributes = options.getAttributes();
@@ -836,7 +848,7 @@ public Void run() throws Exception {
final EncryptedKeyVersion ekv = doAs("GENERATE_EEK", new PrivilegedExceptionAction() {
@Override
public EncryptedKeyVersion run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
Options options = new KeyProvider.Options(conf);
Map attributes = options.getAttributes();
@@ -861,7 +873,7 @@ public EncryptedKeyVersion run() throws Exception {
doAs("ROLLOVER", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProviderCryptoExtension kpce =
KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp);
@@ -891,7 +903,7 @@ public Void call() throws Exception {
doAs("GENERATE_EEK", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProviderCryptoExtension kpce =
KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp);
@@ -964,7 +976,7 @@ public KeyProvider call() throws Exception {
new PrivilegedExceptionAction() {
@Override
public KeyProvider run() throws Exception {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey("k1", new byte[16],
new KeyProvider.Options(conf));
return kp;
@@ -1041,7 +1053,7 @@ public Void call() throws Exception {
new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey("k0", new byte[16],
new KeyProvider.Options(conf));
@@ -1072,7 +1084,7 @@ public Void call() throws Exception {
new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey("k3", new byte[16],
new KeyProvider.Options(conf));
// Atleast 2 rollovers.. so should induce signer Exception
@@ -1132,7 +1144,7 @@ public Void call() throws Exception {
doAs("client", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
kp.createKey("k", new KeyProvider.Options(conf));
Assert.fail();
@@ -1223,7 +1235,7 @@ public Void run() throws Exception {
doAs("CREATE", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProvider.KeyVersion kv = kp.createKey("k0",
new KeyProvider.Options(conf));
@@ -1238,7 +1250,7 @@ public Void run() throws Exception {
doAs("DELETE", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
kp.deleteKey("k0");
} catch (Exception ex) {
@@ -1251,7 +1263,7 @@ public Void run() throws Exception {
doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProvider.KeyVersion kv = kp.createKey("k1", new byte[16],
new KeyProvider.Options(conf));
@@ -1266,7 +1278,7 @@ public Void run() throws Exception {
doAs("ROLLOVER", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProvider.KeyVersion kv = kp.rollNewVersion("k1");
Assert.assertNull(kv.getMaterial());
@@ -1280,7 +1292,7 @@ public Void run() throws Exception {
doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProvider.KeyVersion kv =
kp.rollNewVersion("k1", new byte[16]);
@@ -1296,7 +1308,7 @@ public Void run() throws Exception {
doAs("GET", new PrivilegedExceptionAction() {
@Override
public KeyVersion run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
kp.getKeyVersion("k1@0");
KeyVersion kv = kp.getCurrentKey("k1");
@@ -1313,7 +1325,7 @@ public KeyVersion run() throws Exception {
new PrivilegedExceptionAction() {
@Override
public EncryptedKeyVersion run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
createKeyProviderCryptoExtension(kp);
@@ -1330,7 +1342,7 @@ public EncryptedKeyVersion run() throws Exception {
doAs("DECRYPT_EEK", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
createKeyProviderCryptoExtension(kp);
@@ -1345,7 +1357,7 @@ public Void run() throws Exception {
doAs("GET_KEYS", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
kp.getKeys();
} catch (Exception ex) {
@@ -1358,7 +1370,7 @@ public Void run() throws Exception {
doAs("GET_METADATA", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
try {
kp.getMetadata("k1");
kp.getKeysMetadata("k1");
@@ -1385,7 +1397,7 @@ public Void run() throws Exception {
@Override
public Void run() throws Exception {
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("k2",
new KeyProvider.Options(conf));
Assert.fail();
@@ -1440,12 +1452,12 @@ public Void call() throws Exception {
@Override
public Void run() throws Exception {
try {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("ck0",
new KeyProvider.Options(conf));
EncryptedKeyVersion eek =
- kp.generateEncryptedKey("ck0");
- kp.decryptEncryptedKey(eek);
+ ((CryptoExtension)kp).generateEncryptedKey("ck0");
+ ((CryptoExtension)kp).decryptEncryptedKey(eek);
Assert.assertNull(kv.getMaterial());
} catch (Exception ex) {
Assert.fail(ex.getMessage());
@@ -1458,12 +1470,12 @@ public Void run() throws Exception {
@Override
public Void run() throws Exception {
try {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("ck1",
new KeyProvider.Options(conf));
EncryptedKeyVersion eek =
- kp.generateEncryptedKey("ck1");
- kp.decryptEncryptedKey(eek);
+ ((CryptoExtension)kp).generateEncryptedKey("ck1");
+ ((CryptoExtension)kp).decryptEncryptedKey(eek);
Assert.fail("admin user must not be allowed to decrypt !!");
} catch (Exception ex) {
}
@@ -1475,12 +1487,12 @@ public Void run() throws Exception {
@Override
public Void run() throws Exception {
try {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("ck2",
new KeyProvider.Options(conf));
EncryptedKeyVersion eek =
- kp.generateEncryptedKey("ck2");
- kp.decryptEncryptedKey(eek);
+ ((CryptoExtension)kp).generateEncryptedKey("ck2");
+ ((CryptoExtension)kp).decryptEncryptedKey(eek);
Assert.fail("admin user must not be allowed to decrypt !!");
} catch (Exception ex) {
}
@@ -1525,7 +1537,7 @@ public Void call() throws Exception {
@Override
public Void run() throws Exception {
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("ck0",
new KeyProvider.Options(conf));
Assert.assertNull(kv.getMaterial());
@@ -1540,7 +1552,7 @@ public Void run() throws Exception {
@Override
public Void run() throws Exception {
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProvider.KeyVersion kv = kp.createKey("ck1",
new KeyProvider.Options(conf));
Assert.assertNull(kv.getMaterial());
@@ -1583,7 +1595,7 @@ public void testKMSTimeout() throws Exception {
boolean caughtTimeout = false;
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.getKeys();
} catch (SocketTimeoutException e) {
caughtTimeout = true;
@@ -1593,7 +1605,7 @@ public void testKMSTimeout() throws Exception {
caughtTimeout = false;
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
.generateEncryptedKey("a");
} catch (SocketTimeoutException e) {
@@ -1604,7 +1616,7 @@ public void testKMSTimeout() throws Exception {
caughtTimeout = false;
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
.decryptEncryptedKey(
new KMSClientProvider.KMSEncryptedKeyVersion("a",
@@ -1651,7 +1663,7 @@ public Void call() throws Exception {
UserGroupInformation.getCurrentUser();
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey(keyA, new KeyProvider.Options(conf));
} catch (IOException ex) {
System.out.println(ex.getMessage());
@@ -1660,7 +1672,7 @@ public Void call() throws Exception {
doAs("client", new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
KeyProviderDelegationTokenExtension kpdte =
KeyProviderDelegationTokenExtension.
createKeyProviderDelegationTokenExtension(kp);
@@ -1672,7 +1684,7 @@ public Void run() throws Exception {
nonKerberosUgi.addCredentials(credentials);
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey(keyA, new KeyProvider.Options(conf));
} catch (IOException ex) {
System.out.println(ex.getMessage());
@@ -1681,7 +1693,7 @@ public Void run() throws Exception {
nonKerberosUgi.doAs(new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey(keyD, new KeyProvider.Options(conf));
return null;
}
@@ -1767,7 +1779,7 @@ public KeyProvider call() throws Exception {
new PrivilegedExceptionAction() {
@Override
public KeyProvider run() throws Exception {
- KMSClientProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey("k1", new byte[16],
new KeyProvider.Options(conf));
kp.createKey("k2", new byte[16],
@@ -1844,7 +1856,7 @@ public Void call() throws Exception {
clientUgi.doAs(new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- final KeyProvider kp = new KMSClientProvider(uri, conf);
+ final KeyProvider kp = createProvider(uri, conf);
kp.createKey("kaa", new KeyProvider.Options(conf));
// authorized proxyuser
@@ -1956,7 +1968,7 @@ public Void run() throws Exception {
fooUgi.doAs(new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
Assert.assertNotNull(kp.createKey("kaa",
new KeyProvider.Options(conf)));
return null;
@@ -1970,7 +1982,7 @@ public Void run() throws Exception {
@Override
public Void run() throws Exception {
try {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
kp.createKey("kbb", new KeyProvider.Options(conf));
Assert.fail();
} catch (Exception ex) {
@@ -1986,7 +1998,7 @@ public Void run() throws Exception {
barUgi.doAs(new PrivilegedExceptionAction() {
@Override
public Void run() throws Exception {
- KeyProvider kp = new KMSClientProvider(uri, conf);
+ KeyProvider kp = createProvider(uri, conf);
Assert.assertNotNull(kp.createKey("kcc",
new KeyProvider.Options(conf)));
return null;