HDFS-11210. Enhance key rolling to guarantee new KeyVersion is returned from generateEncryptedKeys after a key is rolled.

This commit is contained in:
Xiao Chen 2017-02-07 20:35:05 -08:00
parent a87e3850b9
commit 2007e0cf2a
16 changed files with 375 additions and 66 deletions

View File

@ -141,8 +141,7 @@ public void deleteKey(String name) throws IOException {
public KeyVersion rollNewVersion(String name, byte[] material) public KeyVersion rollNewVersion(String name, byte[] material)
throws IOException { throws IOException {
KeyVersion key = getKeyProvider().rollNewVersion(name, material); KeyVersion key = getKeyProvider().rollNewVersion(name, material);
getExtension().currentKeyCache.invalidate(name); invalidateCache(name);
getExtension().keyMetadataCache.invalidate(name);
return key; return key;
} }
@ -150,9 +149,18 @@ public KeyVersion rollNewVersion(String name, byte[] material)
public KeyVersion rollNewVersion(String name) public KeyVersion rollNewVersion(String name)
throws NoSuchAlgorithmException, IOException { throws NoSuchAlgorithmException, IOException {
KeyVersion key = getKeyProvider().rollNewVersion(name); KeyVersion key = getKeyProvider().rollNewVersion(name);
invalidateCache(name);
return key;
}
@Override
public void invalidateCache(String name) throws IOException {
getKeyProvider().invalidateCache(name);
getExtension().currentKeyCache.invalidate(name); getExtension().currentKeyCache.invalidate(name);
getExtension().keyMetadataCache.invalidate(name); getExtension().keyMetadataCache.invalidate(name);
return key; // invalidating all key versions as we don't know
// which ones belonged to the deleted key
getExtension().keyVersionCache.invalidateAll();
} }
@Override @Override

View File

@ -593,6 +593,18 @@ public KeyVersion rollNewVersion(String name) throws NoSuchAlgorithmException,
return rollNewVersion(name, material); return rollNewVersion(name, material);
} }
/**
* Can be used by implementing classes to invalidate the caches. This could be
* used after rollNewVersion to provide a strong guarantee to return the new
* version of the given key.
*
* @param name the basename of the key
* @throws IOException
*/
public void invalidateCache(String name) throws IOException {
// NOP
}
/** /**
* Ensures that any changes to the keys are written to persistent store. * Ensures that any changes to the keys are written to persistent store.
* @throws IOException * @throws IOException

View File

@ -117,6 +117,11 @@ public KeyVersion rollNewVersion(String name, byte[] material)
return keyProvider.rollNewVersion(name, material); return keyProvider.rollNewVersion(name, material);
} }
@Override
public void invalidateCache(String name) throws IOException {
keyProvider.invalidateCache(name);
}
@Override @Override
public void flush() throws IOException { public void flush() throws IOException {
keyProvider.flush(); keyProvider.flush();

View File

@ -46,7 +46,8 @@ public class KeyShell extends CommandShell {
" [" + CreateCommand.USAGE + "]\n" + " [" + CreateCommand.USAGE + "]\n" +
" [" + RollCommand.USAGE + "]\n" + " [" + RollCommand.USAGE + "]\n" +
" [" + DeleteCommand.USAGE + "]\n" + " [" + DeleteCommand.USAGE + "]\n" +
" [" + ListCommand.USAGE + "]\n"; " [" + ListCommand.USAGE + "]\n" +
" [" + InvalidateCacheCommand.USAGE + "]\n";
private static final String LIST_METADATA = "keyShell.list.metadata"; private static final String LIST_METADATA = "keyShell.list.metadata";
@VisibleForTesting @VisibleForTesting
public static final String NO_VALID_PROVIDERS = public static final String NO_VALID_PROVIDERS =
@ -70,6 +71,7 @@ public class KeyShell extends CommandShell {
* % hadoop key roll keyName [-provider providerPath] * % hadoop key roll keyName [-provider providerPath]
* % hadoop key list [-provider providerPath] * % hadoop key list [-provider providerPath]
* % hadoop key delete keyName [-provider providerPath] [-i] * % hadoop key delete keyName [-provider providerPath] [-i]
* % hadoop key invalidateCache keyName [-provider providerPath]
* </pre> * </pre>
* @param args Command line arguments. * @param args Command line arguments.
* @return 0 on success, 1 on failure. * @return 0 on success, 1 on failure.
@ -111,6 +113,15 @@ protected int init(String[] args) throws IOException {
} }
} else if ("list".equals(args[i])) { } else if ("list".equals(args[i])) {
setSubCommand(new ListCommand()); setSubCommand(new ListCommand());
} else if ("invalidateCache".equals(args[i])) {
String keyName = "-help";
if (moreTokens) {
keyName = args[++i];
}
setSubCommand(new InvalidateCacheCommand(keyName));
if ("-help".equals(keyName)) {
return 1;
}
} else if ("-size".equals(args[i]) && moreTokens) { } else if ("-size".equals(args[i]) && moreTokens) {
options.setBitLength(Integer.parseInt(args[++i])); options.setBitLength(Integer.parseInt(args[++i]));
} else if ("-cipher".equals(args[i]) && moreTokens) { } else if ("-cipher".equals(args[i]) && moreTokens) {
@ -168,6 +179,9 @@ public String getCommandUsage() {
sbuf.append(DeleteCommand.USAGE + ":\n\n" + DeleteCommand.DESC + "\n"); sbuf.append(DeleteCommand.USAGE + ":\n\n" + DeleteCommand.DESC + "\n");
sbuf.append(banner + "\n"); sbuf.append(banner + "\n");
sbuf.append(ListCommand.USAGE + ":\n\n" + ListCommand.DESC + "\n"); sbuf.append(ListCommand.USAGE + ":\n\n" + ListCommand.DESC + "\n");
sbuf.append(banner + "\n");
sbuf.append(InvalidateCacheCommand.USAGE + ":\n\n"
+ InvalidateCacheCommand.DESC + "\n");
return sbuf.toString(); return sbuf.toString();
} }
@ -466,6 +480,57 @@ public String getUsage() {
} }
} }
private class InvalidateCacheCommand extends Command {
public static final String USAGE =
"invalidateCache <keyname> [-provider <provider>] [-help]";
public static final String DESC =
"The invalidateCache subcommand invalidates the cached key versions\n"
+ "of the specified key, on the provider indicated using the"
+ " -provider argument.\n";
private String keyName = null;
InvalidateCacheCommand(String keyName) {
this.keyName = keyName;
}
public boolean validate() {
boolean rc = true;
provider = getKeyProvider();
if (provider == null) {
getOut().println("Invalid provider.");
rc = false;
}
if (keyName == null) {
getOut().println("Please provide a <keyname>.\n" +
"See the usage description by using -help.");
rc = false;
}
return rc;
}
public void execute() throws NoSuchAlgorithmException, IOException {
try {
warnIfTransientProvider();
getOut().println("Invalidating cache on KeyProvider: "
+ provider + "\n for key name: " + keyName);
provider.invalidateCache(keyName);
getOut().println("Cached keyversions of " + keyName
+ " has been successfully invalidated.");
printProviderWritten();
} catch (IOException e) {
getOut().println("Cannot invalidate cache for key: " + keyName +
" within KeyProvider: " + provider + ". " + e.toString());
throw e;
}
}
@Override
public String getUsage() {
return USAGE + ":\n\n" + DESC;
}
}
/** /**
* main() entry point for the KeyShell. While strictly speaking the * main() entry point for the KeyShell. While strictly speaking the
* return is void, it will System.exit() with a return code: 0 is for * return is void, it will System.exit() with a return code: 0 is for

View File

@ -757,6 +757,17 @@ public KeyVersion createKey(String name, byte[] material, Options options)
} }
} }
@Override
public void invalidateCache(String name) throws IOException {
checkNotEmpty(name, "name");
final URL url = createURL(KMSRESTConstants.KEY_RESOURCE, name,
KMSRESTConstants.INVALIDATECACHE_RESOURCE, null);
final HttpURLConnection conn = createConnection(url, HTTP_POST);
// invalidate the server cache first, then drain local cache.
call(conn, null, HttpURLConnection.HTTP_OK, null);
drain(name);
}
private KeyVersion rollNewVersionInternal(String name, byte[] material) private KeyVersion rollNewVersionInternal(String name, byte[] material)
throws NoSuchAlgorithmException, IOException { throws NoSuchAlgorithmException, IOException {
checkNotEmpty(name, "name"); checkNotEmpty(name, "name");
@ -771,7 +782,7 @@ private KeyVersion rollNewVersionInternal(String name, byte[] material)
Map response = call(conn, jsonMaterial, Map response = call(conn, jsonMaterial,
HttpURLConnection.HTTP_OK, Map.class); HttpURLConnection.HTTP_OK, Map.class);
KeyVersion keyVersion = parseJSONKeyVersion(response); KeyVersion keyVersion = parseJSONKeyVersion(response);
encKeyVersionQueue.drain(name); invalidateCache(name);
return keyVersion; return keyVersion;
} }

View File

@ -36,6 +36,7 @@ public class KMSRESTConstants {
public static final String VERSIONS_SUB_RESOURCE = "_versions"; public static final String VERSIONS_SUB_RESOURCE = "_versions";
public static final String EEK_SUB_RESOURCE = "_eek"; public static final String EEK_SUB_RESOURCE = "_eek";
public static final String CURRENT_VERSION_SUB_RESOURCE = "_currentversion"; public static final String CURRENT_VERSION_SUB_RESOURCE = "_currentversion";
public static final String INVALIDATECACHE_RESOURCE = "_invalidatecache";
public static final String KEY = "key"; public static final String KEY = "key";
public static final String EEK_OP = "eek_op"; public static final String EEK_OP = "eek_op";

View File

@ -178,6 +178,14 @@ public void drain(String keyName) {
} }
} }
// This request is sent to all providers in the load-balancing group
@Override
public void invalidateCache(String keyName) throws IOException {
for (KMSClientProvider provider : providers) {
provider.invalidateCache(keyName);
}
}
@Override @Override
public EncryptedKeyVersion public EncryptedKeyVersion
generateEncryptedKey(final String encryptionKeyName) generateEncryptedKey(final String encryptionKeyName)
@ -218,14 +226,14 @@ public KeyVersion call(KMSClientProvider provider)
} }
} }
public EncryptedKeyVersion reencryptEncryptedKey(EncryptedKeyVersion edek) public EncryptedKeyVersion reencryptEncryptedKey(EncryptedKeyVersion ekv)
throws IOException, GeneralSecurityException { throws IOException, GeneralSecurityException {
try { try {
return doOp(new ProviderCallable<EncryptedKeyVersion>() { return doOp(new ProviderCallable<EncryptedKeyVersion>() {
@Override @Override
public EncryptedKeyVersion call(KMSClientProvider provider) public EncryptedKeyVersion call(KMSClientProvider provider)
throws IOException, GeneralSecurityException { throws IOException, GeneralSecurityException {
return provider.reencryptEncryptedKey(edek); return provider.reencryptEncryptedKey(ekv);
} }
}, nextIdx()); }, nextIdx());
} catch (WrapperException we) { } catch (WrapperException we) {
@ -325,6 +333,7 @@ public KeyVersion call(KMSClientProvider provider) throws IOException,
throw new IOException(e.getCause()); throw new IOException(e.getCause());
} }
} }
@Override @Override
public void deleteKey(final String name) throws IOException { public void deleteKey(final String name) throws IOException {
doOp(new ProviderCallable<Void>() { doOp(new ProviderCallable<Void>() {
@ -335,28 +344,33 @@ public Void call(KMSClientProvider provider) throws IOException {
} }
}, nextIdx()); }, nextIdx());
} }
@Override @Override
public KeyVersion rollNewVersion(final String name, final byte[] material) public KeyVersion rollNewVersion(final String name, final byte[] material)
throws IOException { throws IOException {
return doOp(new ProviderCallable<KeyVersion>() { final KeyVersion newVersion = doOp(new ProviderCallable<KeyVersion>() {
@Override @Override
public KeyVersion call(KMSClientProvider provider) throws IOException { public KeyVersion call(KMSClientProvider provider) throws IOException {
return provider.rollNewVersion(name, material); return provider.rollNewVersion(name, material);
} }
}, nextIdx()); }, nextIdx());
invalidateCache(name);
return newVersion;
} }
@Override @Override
public KeyVersion rollNewVersion(final String name) public KeyVersion rollNewVersion(final String name)
throws NoSuchAlgorithmException, IOException { throws NoSuchAlgorithmException, IOException {
try { try {
return doOp(new ProviderCallable<KeyVersion>() { final KeyVersion newVersion = doOp(new ProviderCallable<KeyVersion>() {
@Override @Override
public KeyVersion call(KMSClientProvider provider) throws IOException, public KeyVersion call(KMSClientProvider provider) throws IOException,
NoSuchAlgorithmException { NoSuchAlgorithmException {
return provider.rollNewVersion(name); return provider.rollNewVersion(name);
} }
}, nextIdx()); }, nextIdx());
invalidateCache(name);
return newVersion;
} catch (WrapperException e) { } catch (WrapperException e) {
if (e.getCause() instanceof GeneralSecurityException) { if (e.getCause() instanceof GeneralSecurityException) {
throw (NoSuchAlgorithmException) e.getCause(); throw (NoSuchAlgorithmException) e.getCause();

View File

@ -18,8 +18,9 @@
package org.apache.hadoop.crypto.key.kms; package org.apache.hadoop.crypto.key.kms;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashMap;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -28,6 +29,9 @@
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheBuilder;
@ -67,8 +71,17 @@ public void fillQueueForKey(String keyName,
private static final String REFILL_THREAD = private static final String REFILL_THREAD =
ValueQueue.class.getName() + "_thread"; ValueQueue.class.getName() + "_thread";
private static final int LOCK_ARRAY_SIZE = 16;
// Using a mask assuming array size is the power of 2, of MAX_VALUE.
private static final int MASK = LOCK_ARRAY_SIZE == Integer.MAX_VALUE ?
LOCK_ARRAY_SIZE :
LOCK_ARRAY_SIZE - 1;
private final LoadingCache<String, LinkedBlockingQueue<E>> keyQueues; private final LoadingCache<String, LinkedBlockingQueue<E>> keyQueues;
// Stripped rwlocks based on key name to synchronize the queue from
// the sync'ed rw-thread and the background async refill thread.
private final List<ReadWriteLock> lockArray =
new ArrayList<>(LOCK_ARRAY_SIZE);
private final ThreadPoolExecutor executor; private final ThreadPoolExecutor executor;
private final UniqueKeyBlockingQueue queue = new UniqueKeyBlockingQueue(); private final UniqueKeyBlockingQueue queue = new UniqueKeyBlockingQueue();
private final QueueRefiller<E> refiller; private final QueueRefiller<E> refiller;
@ -84,9 +97,47 @@ public void fillQueueForKey(String keyName,
*/ */
private abstract static class NamedRunnable implements Runnable { private abstract static class NamedRunnable implements Runnable {
final String name; final String name;
private AtomicBoolean canceled = new AtomicBoolean(false);
private NamedRunnable(String keyName) { private NamedRunnable(String keyName) {
this.name = keyName; this.name = keyName;
} }
public void cancel() {
canceled.set(true);
}
public boolean isCanceled() {
return canceled.get();
}
}
private void readLock(String keyName) {
getLock(keyName).readLock().lock();
}
private void readUnlock(String keyName) {
getLock(keyName).readLock().unlock();
}
private void writeUnlock(String keyName) {
getLock(keyName).writeLock().unlock();
}
private void writeLock(String keyName) {
getLock(keyName).writeLock().lock();
}
/**
* Get the stripped lock given a key name.
*
* @param keyName The key name.
*/
private ReadWriteLock getLock(String keyName) {
return lockArray.get(indexFor(keyName));
}
private static int indexFor(String keyName) {
return keyName.hashCode() & MASK;
} }
/** /**
@ -103,11 +154,12 @@ private static class UniqueKeyBlockingQueue extends
LinkedBlockingQueue<Runnable> { LinkedBlockingQueue<Runnable> {
private static final long serialVersionUID = -2152747693695890371L; private static final long serialVersionUID = -2152747693695890371L;
private HashSet<String> keysInProgress = new HashSet<String>(); private HashMap<String, Runnable> keysInProgress = new HashMap<>();
@Override @Override
public synchronized void put(Runnable e) throws InterruptedException { public synchronized void put(Runnable e) throws InterruptedException {
if (keysInProgress.add(((NamedRunnable)e).name)) { if (!keysInProgress.containsKey(((NamedRunnable)e).name)) {
keysInProgress.put(((NamedRunnable)e).name, e);
super.put(e); super.put(e);
} }
} }
@ -131,6 +183,14 @@ public Runnable poll(long timeout, TimeUnit unit)
return k; return k;
} }
public Runnable deleteByName(String name) {
NamedRunnable e = (NamedRunnable) keysInProgress.remove(name);
if (e != null) {
e.cancel();
super.remove(e);
}
return e;
}
} }
/** /**
@ -172,6 +232,9 @@ public ValueQueue(final int numValues, final float lowWatermark,
this.policy = policy; this.policy = policy;
this.numValues = numValues; this.numValues = numValues;
this.lowWatermark = lowWatermark; this.lowWatermark = lowWatermark;
for (int i = 0; i < LOCK_ARRAY_SIZE; ++i) {
lockArray.add(i, new ReentrantReadWriteLock());
}
keyQueues = CacheBuilder.newBuilder() keyQueues = CacheBuilder.newBuilder()
.expireAfterAccess(expiry, TimeUnit.MILLISECONDS) .expireAfterAccess(expiry, TimeUnit.MILLISECONDS)
.build(new CacheLoader<String, LinkedBlockingQueue<E>>() { .build(new CacheLoader<String, LinkedBlockingQueue<E>>() {
@ -233,9 +296,18 @@ public E getNext(String keyName)
* *
* @param keyName the key to drain the Queue for * @param keyName the key to drain the Queue for
*/ */
public void drain(String keyName ) { public void drain(String keyName) {
try {
Runnable e;
while ((e = queue.deleteByName(keyName)) != null) {
executor.remove(e);
}
writeLock(keyName);
try { try {
keyQueues.get(keyName).clear(); keyQueues.get(keyName).clear();
} finally {
writeUnlock(keyName);
}
} catch (ExecutionException ex) { } catch (ExecutionException ex) {
//NOP //NOP
} }
@ -247,6 +319,8 @@ public void drain(String keyName ) {
* @return int queue size * @return int queue size
*/ */
public int getSize(String keyName) { public int getSize(String keyName) {
readLock(keyName);
try {
// We can't do keyQueues.get(keyName).size() here, // We can't do keyQueues.get(keyName).size() here,
// since that will have the side effect of populating the cache. // since that will have the side effect of populating the cache.
Map<String, LinkedBlockingQueue<E>> map = Map<String, LinkedBlockingQueue<E>> map =
@ -255,6 +329,9 @@ public int getSize(String keyName) {
return 0; return 0;
} }
return map.get(keyName).size(); return map.get(keyName).size();
} finally {
readUnlock(keyName);
}
} }
/** /**
@ -276,7 +353,9 @@ public List<E> getAtMost(String keyName, int num) throws IOException,
LinkedList<E> ekvs = new LinkedList<E>(); LinkedList<E> ekvs = new LinkedList<E>();
try { try {
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
readLock(keyName);
E val = keyQueue.poll(); E val = keyQueue.poll();
readUnlock(keyName);
// If queue is empty now, Based on the provided SyncGenerationPolicy, // If queue is empty now, Based on the provided SyncGenerationPolicy,
// figure out how many new values need to be generated synchronously // figure out how many new values need to be generated synchronously
if (val == null) { if (val == null) {
@ -336,10 +415,18 @@ public void run() {
int threshold = (int) (lowWatermark * (float) cacheSize); int threshold = (int) (lowWatermark * (float) cacheSize);
// Need to ensure that only one refill task per key is executed // Need to ensure that only one refill task per key is executed
try { try {
if (keyQueue.size() < threshold) { writeLock(keyName);
try {
if (keyQueue.size() < threshold && !isCanceled()) {
refiller.fillQueueForKey(name, keyQueue, refiller.fillQueueForKey(name, keyQueue,
cacheSize - keyQueue.size()); cacheSize - keyQueue.size());
} }
if (isCanceled()) {
keyQueue.clear();
}
} finally {
writeUnlock(keyName);
}
} catch (final Exception e) { } catch (final Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -138,6 +138,15 @@ public void testKeySuccessfulKeyLifecycle() throws Exception {
assertTrue(outContent.toString().contains("key1 has been successfully " + assertTrue(outContent.toString().contains("key1 has been successfully " +
"rolled.")); "rolled."));
// jceks provider's invalidate is a no-op.
outContent.reset();
final String[] args3 =
{"invalidateCache", keyName, "-provider", jceksProvider};
rc = ks.run(args3);
assertEquals(0, rc);
assertTrue(outContent.toString()
.contains("key1 has been successfully " + "invalidated."));
deleteKey(ks, keyName); deleteKey(ks, keyName);
listOut = listKeys(ks, false); listOut = listKeys(ks, false);

View File

@ -183,4 +183,10 @@ public KeyVersion rollNewVersion(String name, byte[] material)
getExtension().drain(name); getExtension().drain(name);
return keyVersion; return keyVersion;
} }
@Override
public void invalidateCache(String name) throws IOException {
super.invalidateCache(name);
getExtension().drain(name);
}
} }

View File

@ -61,7 +61,7 @@
public class KMS { public class KMS {
public static enum KMSOp { public static enum KMSOp {
CREATE_KEY, DELETE_KEY, ROLL_NEW_VERSION, CREATE_KEY, DELETE_KEY, ROLL_NEW_VERSION, INVALIDATE_CACHE,
GET_KEYS, GET_KEYS_METADATA, GET_KEYS, GET_KEYS_METADATA,
GET_KEY_VERSIONS, GET_METADATA, GET_KEY_VERSION, GET_CURRENT_KEY, GET_KEY_VERSIONS, GET_METADATA, GET_KEY_VERSION, GET_CURRENT_KEY,
GENERATE_EEK, DECRYPT_EEK, REENCRYPT_EEK GENERATE_EEK, DECRYPT_EEK, REENCRYPT_EEK
@ -252,6 +252,37 @@ public KeyVersion run() throws Exception {
} }
} }
@POST
@Path(KMSRESTConstants.KEY_RESOURCE + "/{name:.*}/"
+ KMSRESTConstants.INVALIDATECACHE_RESOURCE)
public Response invalidateCache(@PathParam("name") final String name)
throws Exception {
try {
LOG.trace("Entering invalidateCache Method.");
KMSWebApp.getAdminCallsMeter().mark();
KMSClientProvider.checkNotEmpty(name, "name");
UserGroupInformation user = HttpUserGroupInformation.get();
assertAccess(KMSACLs.Type.ROLLOVER, user, KMSOp.INVALIDATE_CACHE, name);
LOG.debug("Invalidating cache with key name {}.", name);
user.doAs(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws Exception {
provider.invalidateCache(name);
provider.flush();
return null;
}
});
kmsAudit.ok(user, KMSOp.INVALIDATE_CACHE, name, "");
LOG.trace("Exiting invalidateCache for key name {}.", name);
return Response.ok().build();
} catch (Exception e) {
LOG.debug("Exception in invalidateCache for key name {}.", name, e);
throw e;
}
}
@GET @GET
@Path(KMSRESTConstants.KEYS_METADATA_RESOURCE) @Path(KMSRESTConstants.KEYS_METADATA_RESOURCE)
@Produces(MediaType.APPLICATION_JSON + "; " + JettyUtils.UTF_8) @Produces(MediaType.APPLICATION_JSON + "; " + JettyUtils.UTF_8)

View File

@ -210,6 +210,17 @@ public KeyVersion rollNewVersion(String name, byte[] material)
} }
} }
@Override
public void invalidateCache(String name) throws IOException {
writeLock.lock();
try {
doAccessCheck(name, KeyOpType.MANAGEMENT);
provider.invalidateCache(name);
} finally {
writeLock.unlock();
}
}
@Override @Override
public void warmUpEncryptedKeys(String... names) throws IOException { public void warmUpEncryptedKeys(String... names) throws IOException {
readLock.lock(); readLock.lock();

View File

@ -103,7 +103,9 @@ This cache is used with the following 3 methods only, `getCurrentKey()` and `get
For the `getCurrentKey()` method, cached entries are kept for a maximum of 30000 milliseconds regardless the number of times the key is being accessed (to avoid stale keys to be considered current). For the `getCurrentKey()` method, cached entries are kept for a maximum of 30000 milliseconds regardless the number of times the key is being accessed (to avoid stale keys to be considered current).
For the `getKeyVersion()` method, cached entries are kept with a default inactivity timeout of 600000 milliseconds (10 mins). For the `getKeyVersion()` and `getMetadata()` methods, cached entries are kept with a default inactivity timeout of 600000 milliseconds (10 mins).
The cache is invalidated when the key is deleted by `deleteKey()`, or when `invalidateCache()` is called.
These configurations can be changed via the following properties in the `etc/hadoop/kms-site.xml` configuration file: These configurations can be changed via the following properties in the `etc/hadoop/kms-site.xml` configuration file:
@ -841,6 +843,16 @@ $H4 Rollover Key
"material" : "<material>", //base64, not present without GET ACL "material" : "<material>", //base64, not present without GET ACL
} }
$H4 Invalidate Cache of a Key
*REQUEST:*
POST http://HOST:PORT/kms/v1/key/<key-name>/_invalidatecache
*RESPONSE:*
200 OK
$H4 Delete Key $H4 Delete Key
*REQUEST:* *REQUEST:*

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.crypto.key.kms.server; package org.apache.hadoop.crypto.key.kms.server;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.cache.LoadingCache;
import org.apache.curator.test.TestingServer; import org.apache.curator.test.TestingServer;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProviderFactory; import org.apache.hadoop.crypto.key.KeyProviderFactory;
@ -31,7 +32,7 @@
import org.apache.hadoop.crypto.key.kms.KMSClientProvider; import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
import org.apache.hadoop.crypto.key.kms.KMSDelegationToken; import org.apache.hadoop.crypto.key.kms.KMSDelegationToken;
import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider; import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.crypto.key.kms.ValueQueue;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.minikdc.MiniKdc; import org.apache.hadoop.minikdc.MiniKdc;
import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.Credentials;
@ -49,6 +50,8 @@
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.mockito.Mockito;
import org.mockito.internal.util.reflection.Whitebox;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -79,11 +82,14 @@
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.LinkedBlockingQueue;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.when;
public class TestKMS { public class TestKMS {
private static final Logger LOG = LoggerFactory.getLogger(TestKMS.class); private static final Logger LOG = LoggerFactory.getLogger(TestKMS.class);
@ -128,6 +134,11 @@ protected KeyProvider createProvider(URI uri, Configuration conf)
new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf); new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf);
} }
private KMSClientProvider createKMSClientProvider(URI uri, Configuration conf)
throws IOException {
return new KMSClientProvider(uri, conf);
}
protected <T> T runServer(String keystore, String password, File confDir, protected <T> T runServer(String keystore, String password, File confDir,
KMSCallable<T> callable) throws Exception { KMSCallable<T> callable) throws Exception {
return runServer(-1, keystore, password, confDir, callable); return runServer(-1, keystore, password, confDir, callable);
@ -723,24 +734,68 @@ public Void call() throws Exception {
EncryptedKeyVersion ekv1 = kpce.generateEncryptedKey("k6"); EncryptedKeyVersion ekv1 = kpce.generateEncryptedKey("k6");
kpce.rollNewVersion("k6"); kpce.rollNewVersion("k6");
kpce.invalidateCache("k6");
/**
* due to the cache on the server side, client may get old keys.
* @see EagerKeyGeneratorKeyProviderCryptoExtension#rollNewVersion(String)
*/
boolean rollSucceeded = false;
for (int i = 0; i <= EagerKeyGeneratorKeyProviderCryptoExtension
.KMS_KEY_CACHE_SIZE_DEFAULT + CommonConfigurationKeysPublic.
KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT; ++i) {
EncryptedKeyVersion ekv2 = kpce.generateEncryptedKey("k6"); EncryptedKeyVersion ekv2 = kpce.generateEncryptedKey("k6");
if (!(ekv1.getEncryptionKeyVersionName() assertNotEquals("rollover did not generate a new key even after"
.equals(ekv2.getEncryptionKeyVersionName()))) { + " queue is drained", ekv1.getEncryptionKeyVersionName(),
rollSucceeded = true; ekv2.getEncryptionKeyVersionName());
break; return null;
} }
});
}
@Test
public void testKMSProviderCaching() throws Exception {
Configuration conf = new Configuration();
File confDir = getTestDir();
conf = createBaseKMSConf(confDir, conf);
conf.set(KeyAuthorizationKeyProvider.KEY_ACL + "k1.ALL", "*");
writeConf(confDir, conf);
runServer(null, null, confDir, new KMSCallable<Void>() {
@Override
public Void call() throws Exception {
final String keyName = "k1";
final String mockVersionName = "mock";
final Configuration conf = new Configuration();
final URI uri = createKMSUri(getKMSUrl());
KMSClientProvider kmscp = createKMSClientProvider(uri, conf);
// get the reference to the internal cache, to test invalidation.
ValueQueue vq =
(ValueQueue) Whitebox.getInternalState(kmscp, "encKeyVersionQueue");
LoadingCache<String, LinkedBlockingQueue<EncryptedKeyVersion>> kq =
((LoadingCache<String, LinkedBlockingQueue<EncryptedKeyVersion>>)
Whitebox.getInternalState(vq, "keyQueues"));
EncryptedKeyVersion mockEKV = Mockito.mock(EncryptedKeyVersion.class);
when(mockEKV.getEncryptionKeyName()).thenReturn(keyName);
when(mockEKV.getEncryptionKeyVersionName()).thenReturn(mockVersionName);
// createKey()
KeyProvider.Options options = new KeyProvider.Options(conf);
options.setCipher("AES/CTR/NoPadding");
options.setBitLength(128);
options.setDescription("l1");
KeyProvider.KeyVersion kv0 = kmscp.createKey(keyName, options);
assertNotNull(kv0.getVersionName());
assertEquals("Default key version name is incorrect.", "k1@0",
kmscp.generateEncryptedKey(keyName).getEncryptionKeyVersionName());
kmscp.invalidateCache(keyName);
kq.get(keyName).put(mockEKV);
assertEquals("Key version incorrect after invalidating cache + putting"
+ " mock key.", mockVersionName,
kmscp.generateEncryptedKey(keyName).getEncryptionKeyVersionName());
// test new version is returned after invalidation.
for (int i = 0; i < 100; ++i) {
kq.get(keyName).put(mockEKV);
kmscp.invalidateCache(keyName);
assertEquals("Cache invalidation guarantee failed.", "k1@0",
kmscp.generateEncryptedKey(keyName)
.getEncryptionKeyVersionName());
} }
Assert.assertTrue("rollover did not generate a new key even after"
+ " queue is drained", rollSucceeded);
return null; return null;
} }
}); });

View File

@ -104,6 +104,7 @@ public void testAggregation() throws Exception {
kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.DELETE_KEY, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.DELETE_KEY, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.ROLL_NEW_VERSION, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.ROLL_NEW_VERSION, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.INVALIDATE_CACHE, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg"); kmsAudit.ok(luser, KMSOp.DECRYPT_EEK, "k1", "testmsg");
@ -122,6 +123,7 @@ public void testAggregation() throws Exception {
// Not aggregated !! // Not aggregated !!
+ "OK\\[op=DELETE_KEY, key=k1, user=luser\\] testmsg" + "OK\\[op=DELETE_KEY, key=k1, user=luser\\] testmsg"
+ "OK\\[op=ROLL_NEW_VERSION, key=k1, user=luser\\] testmsg" + "OK\\[op=ROLL_NEW_VERSION, key=k1, user=luser\\] testmsg"
+ "OK\\[op=INVALIDATE_CACHE, key=k1, user=luser\\] testmsg"
// Aggregated // Aggregated
+ "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=6, interval=[^m]{1,4}ms\\] testmsg" + "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=6, interval=[^m]{1,4}ms\\] testmsg"
+ "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=1, interval=[^m]{1,4}ms\\] testmsg" + "OK\\[op=DECRYPT_EEK, key=k1, user=luser, accessCount=1, interval=[^m]{1,4}ms\\] testmsg"

View File

@ -44,9 +44,7 @@
import org.apache.hadoop.crypto.CryptoProtocolVersion; import org.apache.hadoop.crypto.CryptoProtocolVersion;
import org.apache.hadoop.crypto.key.JavaKeyStoreProvider; import org.apache.hadoop.crypto.key.JavaKeyStoreProvider;
import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
import org.apache.hadoop.crypto.key.KeyProviderFactory; import org.apache.hadoop.crypto.key.KeyProviderFactory;
import org.apache.hadoop.crypto.key.kms.server.EagerKeyGeneratorKeyProviderCryptoExtension;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.fs.CreateFlag; import org.apache.hadoop.fs.CreateFlag;
import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FSDataOutputStream;
@ -730,33 +728,15 @@ public void testReadWrite() throws Exception {
// Roll the key of the encryption zone // Roll the key of the encryption zone
assertNumZones(1); assertNumZones(1);
String keyName = dfsAdmin.listEncryptionZones().next().getKeyName(); String keyName = dfsAdmin.listEncryptionZones().next().getKeyName();
FileEncryptionInfo feInfo1 = getFileEncryptionInfo(encFile1);
cluster.getNamesystem().getProvider().rollNewVersion(keyName); cluster.getNamesystem().getProvider().rollNewVersion(keyName);
/** cluster.getNamesystem().getProvider().invalidateCache(keyName);
* due to the cache on the server side, client may get old keys.
* @see EagerKeyGeneratorKeyProviderCryptoExtension#rollNewVersion(String)
*/
boolean rollSucceeded = false;
for (int i = 0; i <= EagerKeyGeneratorKeyProviderCryptoExtension
.KMS_KEY_CACHE_SIZE_DEFAULT + CommonConfigurationKeysPublic.
KMS_CLIENT_ENC_KEY_CACHE_SIZE_DEFAULT; ++i) {
KeyProviderCryptoExtension.EncryptedKeyVersion ekv2 =
cluster.getNamesystem().getProvider().generateEncryptedKey(TEST_KEY);
if (!(feInfo1.getEzKeyVersionName()
.equals(ekv2.getEncryptionKeyVersionName()))) {
rollSucceeded = true;
break;
}
}
Assert.assertTrue("rollover did not generate a new key even after"
+ " queue is drained", rollSucceeded);
// Read them back in and compare byte-by-byte // Read them back in and compare byte-by-byte
verifyFilesEqual(fs, baseFile, encFile1, len); verifyFilesEqual(fs, baseFile, encFile1, len);
// Write a new enc file and validate // Write a new enc file and validate
final Path encFile2 = new Path(zone, "myfile2"); final Path encFile2 = new Path(zone, "myfile2");
DFSTestUtil.createFile(fs, encFile2, len, (short) 1, 0xFEED); DFSTestUtil.createFile(fs, encFile2, len, (short) 1, 0xFEED);
// FEInfos should be different // FEInfos should be different
FileEncryptionInfo feInfo1 = getFileEncryptionInfo(encFile1);
FileEncryptionInfo feInfo2 = getFileEncryptionInfo(encFile2); FileEncryptionInfo feInfo2 = getFileEncryptionInfo(encFile2);
assertFalse("EDEKs should be different", Arrays assertFalse("EDEKs should be different", Arrays
.equals(feInfo1.getEncryptedDataEncryptionKey(), .equals(feInfo1.getEncryptedDataEncryptionKey(),