diff --git a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/KMSBenchmark.java b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/KMSBenchmark.java
new file mode 100644
index 0000000000..49518cb457
--- /dev/null
+++ b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/KMSBenchmark.java
@@ -0,0 +1,627 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.crypto.key.kms.server;
+
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.util.ExitUtil;
+import org.apache.hadoop.util.GenericOptionsParser;
+import org.apache.hadoop.util.KMSUtil;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.util.Time;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.log4j.Level;
+
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Main class for a series of KMS benchmarks.
+ *
+ * Each benchmark measures throughput and average execution time
+ * of a specific kms operation, e.g. encrypt or decrypt of
+ * Data Encryption Keys.
+ *
+ * The benchmark does not involve any other hadoop components
+ * except for kms operations. Each operation is executed
+ * by calling directly the respective kms operation.
+ *
+ * For usage, please see
+ * the documentation.
+ * Meanwhile, if you change the usage of this program, please also update the
+ * documentation accordingly.
+ */
+public class KMSBenchmark implements Tool {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(KMSBenchmark.class);
+
+ private static final String GENERAL_OPTIONS_USAGE = "[-logLevel L] |";
+
+ private static Configuration config;
+
+ private KeyProviderCryptoExtension kp;
+ private KeyProviderCryptoExtension.EncryptedKeyVersion eek = null;
+ private String encryptionKeyName = "systest";
+ private boolean createEncryptionKey = false;
+ private boolean warmupKey = false;
+ private List keys = new ArrayList();
+
+ KMSBenchmark(Configuration conf, String[] args)
+ throws IOException {
+ config = conf;
+ kp = createKeyProviderCryptoExtension(config);
+ // create key and/or warm up
+ for (int i = 2; i < args.length; i++) {
+ if (args[i].equals("-warmup")) {
+ warmupKey = Boolean.parseBoolean(args[++i]);
+ } else if (args[i].equals("-createkey")) {
+ encryptionKeyName = args[++i];
+ }
+ }
+ try {
+ if (createEncryptionKey) {
+ keys = kp.getKeys();
+ if (!keys.contains(encryptionKeyName)) {
+ kp.createKey(encryptionKeyName, KeyProvider.options(conf));
+ } else {
+ LOG.warn("encryption key already exists: {}",
+ encryptionKeyName);
+ }
+ }
+ if (warmupKey) {
+ kp.warmUpEncryptedKeys(encryptionKeyName);
+ }
+ } catch (GeneralSecurityException e) {
+ LOG.warn(" failed to create or warmup encryption key", e);
+ }
+ }
+
+ /**
+ * Base class for collecting operation statistics.
+ *
+ * Overload this class in order to run statistics for a
+ * specific kms operation.
+ */
+ abstract class OperationStatsBase {
+ protected static final String OP_ALL_NAME = "all";
+ protected static final String OP_ALL_USAGE =
+ "-op all ";
+
+ // number of threads
+ private int numThreads = 0;
+
+ // number of operations requested
+ private int numOpsRequired = 0;
+
+ // number of operations executed
+ private int numOpsExecuted = 0;
+
+ // sum of times for each op
+ private long cumulativeTime = 0;
+
+ // time from start to finish
+ private long elapsedTime = 0;
+
+ // logging level, ERROR by default
+ private Level logLevel;
+
+ private List daemons;
+
+ /**
+ * Operation name.
+ */
+ abstract String getOpName();
+
+ /**
+ * Parse command line arguments.
+ *
+ * @param args arguments
+ * @throws IOException
+ */
+ abstract void parseArguments(List args) throws IOException;
+
+ /**
+ * This corresponds to the arg1 argument of
+ * {@link #executeOp(int, int, String)}, which can have
+ * different meanings depending on the operation performed.
+ *
+ * @param daemonId id of the daemon calling this method
+ * @return the argument
+ */
+ abstract String getExecutionArgument(int daemonId);
+
+ /**
+ * Execute kms operation.
+ *
+ * @param daemonId id of the daemon calling this method.
+ * @param inputIdx serial index of the operation called by the deamon.
+ * @param arg1 operation specific argument.
+ * @return time of the individual kms call.
+ * @throws IOException
+ */
+ abstract long executeOp(int daemonId, int inputIdx, String arg1)
+ throws IOException;
+
+ /**
+ * Print the results of the benchmarking.
+ */
+ abstract void printResults();
+
+ OperationStatsBase() {
+ numOpsRequired = 10000;
+ numThreads = 3;
+ logLevel = Level.ERROR;
+ }
+
+ void benchmark() throws IOException {
+ daemons = new ArrayList();
+ long start = 0;
+ try {
+ numOpsExecuted = 0;
+ cumulativeTime = 0;
+ if (numThreads < 1) {
+ return;
+ }
+ // thread index < nrThreads
+ int tIdx = 0;
+ int[] opsPerThread = new int[numThreads];
+ for (int opsScheduled = 0; opsScheduled < numOpsRequired;
+ opsScheduled += opsPerThread[tIdx++]) {
+ // execute in a separate thread
+ opsPerThread[tIdx] =
+ (numOpsRequired-opsScheduled)/(numThreads-tIdx);
+ if (opsPerThread[tIdx] == 0) {
+ opsPerThread[tIdx] = 1;
+ }
+ }
+ // if numThreads > numOpsRequired then the remaining threads
+ // will do nothing
+ for (; tIdx < numThreads; tIdx++) {
+ opsPerThread[tIdx] = 0;
+ }
+ for (tIdx=0; tIdx < numThreads; tIdx++) {
+ daemons.add(new StatsDaemon(tIdx, opsPerThread[tIdx], this));
+ }
+ start = Time.now();
+ LOG.info("Starting "+numOpsRequired+" "+getOpName()+"(s).");
+ for (StatsDaemon d : daemons) {
+ d.start();
+ }
+ } finally {
+ while(isInProgress()) {
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {}
+ }
+ elapsedTime = Time.now() - start;
+ for (StatsDaemon d : daemons) {
+ incrementStats(d.localNumOpsExecuted, d.localCumulativeTime);
+ System.out.println(d.toString() + ": ops Exec = " +
+ d.localNumOpsExecuted);
+ }
+ }
+ }
+
+ private boolean isInProgress() {
+ for (StatsDaemon d : daemons) {
+ if (d.isInProgress()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void cleanUp() throws IOException {
+ }
+
+ int getNumOpsExecuted() {
+ return numOpsExecuted;
+ }
+
+ long getCumulativeTime() {
+ return cumulativeTime;
+ }
+
+ long getElapsedTime() {
+ return elapsedTime;
+ }
+
+ long getAverageTime() {
+ LOG.info("getAverageTime, cumulativeTime = " + cumulativeTime);
+ LOG.info("getAverageTime, numOpsExecuted = " + numOpsExecuted);
+ return numOpsExecuted == 0? 0 : cumulativeTime/numOpsExecuted;
+ }
+
+ double getOpsPerSecond() {
+ return elapsedTime == 0?
+ 0 : 1000*(double)numOpsExecuted / elapsedTime;
+ }
+
+ String getClientName(int idx) {
+ return getOpName() + "-client-" + idx;
+ }
+
+ void incrementStats(int ops, long time) {
+ numOpsExecuted += ops;
+ cumulativeTime += time;
+ }
+
+ int getNumThreads() {
+ return numThreads;
+ }
+
+ void setNumThreads(int num) {
+ numThreads = num;
+ }
+
+ int getNumOpsRequired() {
+ return numOpsRequired;
+ }
+
+ void setNumOpsRequired(int num) {
+ numOpsRequired = num;
+ }
+
+ /**
+ * Parse first 2 arguments, corresponding to the "-op" option.
+ *
+ * @param args argument list
+ * @return true if operation is all, which means that options not
+ * related to this operation should be ignored, or false
+ * otherwise, meaning that usage should be printed when an
+ * unrelated option is encountered.
+ */
+ protected boolean verifyOpArgument(List args) {
+ if (args.size() < 2 || !args.get(0).startsWith("-op")) {
+ printUsage();
+ }
+
+ // process common options
+ int llIndex = args.indexOf("-logLevel");
+ if (llIndex >= 0) {
+ if (args.size() <= llIndex + 1) {
+ printUsage();
+ }
+ logLevel = Level.toLevel(args.get(llIndex+1), Level.ERROR);
+ args.remove(llIndex+1);
+ args.remove(llIndex);
+ }
+
+ String type = args.get(1);
+ if (OP_ALL_NAME.equals(type)) {
+ type = getOpName();
+ return true;
+ }
+ if (!getOpName().equals(type)) {
+ printUsage();
+ }
+ return false;
+ }
+
+ void printStats() {
+ LOG.info("--- " + getOpName() + " stats ---");
+ LOG.info("# operations: " + getNumOpsExecuted());
+ LOG.info("Elapsed Time: " + getElapsedTime());
+ LOG.info(" Ops per sec: " + getOpsPerSecond());
+ LOG.info("Average Time: " + getAverageTime());
+ }
+ }
+
+ /**
+ * One of the threads that perform stats operations.
+ */
+ private class StatsDaemon extends Thread {
+ private final int daemonId;
+ private int opsPerThread;
+ private String arg1; // argument passed to executeOp()
+ private volatile int localNumOpsExecuted = 0;
+ private volatile long localCumulativeTime = 0;
+ private final OperationStatsBase statsOp;
+
+ StatsDaemon(int daemonId, int nOps, OperationStatsBase op) {
+ this.daemonId = daemonId;
+ this.opsPerThread = nOps;
+ this.statsOp = op;
+ setName(toString());
+ }
+
+ @Override
+ public void run() {
+ localNumOpsExecuted = 0;
+ localCumulativeTime = 0;
+ arg1 = statsOp.getExecutionArgument(daemonId);
+ try {
+ benchmarkOne();
+ } catch(IOException ex) {
+ LOG.error("StatsDaemon " + daemonId + " failed: \n"
+ + StringUtils.stringifyException(ex));
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "StatsDaemon-" + daemonId;
+ }
+
+ void benchmarkOne() throws IOException {
+ for (int idx = 0; idx < opsPerThread; idx++) {
+ long stat = statsOp.executeOp(daemonId, idx, arg1);
+ localNumOpsExecuted++;
+ localCumulativeTime += stat;
+ }
+ }
+
+ boolean isInProgress() {
+ return localNumOpsExecuted < opsPerThread;
+ }
+
+ /**
+ * Schedule to stop this daemon.
+ */
+ void terminate() {
+ opsPerThread = localNumOpsExecuted;
+ }
+ }
+
+ /**
+ * Encrypt key statistics.
+ *
+ * Each thread encrypts the key.
+ */
+ class EncryptKeyStats extends OperationStatsBase {
+ // Operation types
+ static final String OP_ENCRYPT_KEY = "encrypt";
+ static final String OP_ENCRYPT_USAGE =
+ "-op encrypt [-threads T -numops N -warmup F]";
+
+ EncryptKeyStats(List args) {
+ super();
+ parseArguments(args);
+ }
+
+ @Override
+ String getOpName() {
+ return OP_ENCRYPT_KEY;
+ }
+
+ @Override
+ void parseArguments(List args) {
+ verifyOpArgument(args);
+ // parse command line
+ for (int i = 2; i < args.size(); i++) {
+ if (args.get(i).equals("-threads")) {
+ if (i+1 == args.size()) {
+ printUsage();
+ }
+ setNumThreads(Integer.parseInt(args.get(++i)));
+ } else if (args.get(i).equals("-numops")) {
+ setNumOpsRequired(Integer.parseInt(args.get(++i)));
+ }
+ }
+ }
+
+ /**
+ * Returns client name.
+ */
+ @Override
+ String getExecutionArgument(int daemonId) {
+ return getClientName(daemonId);
+ }
+
+ /**
+ * Execute key encryption.
+ */
+ @Override
+ long executeOp(int daemonId, int inputIdx, String clientName)
+ throws IOException {
+ long start = Time.now();
+ try {
+ eek = kp.generateEncryptedKey(encryptionKeyName);
+ } catch (GeneralSecurityException e) {
+ LOG.warn("failed to generate encrypted key", e);
+ }
+
+ long end = Time.now();
+ return end-start;
+ }
+
+ @Override
+ void printResults() {
+ LOG.info("--- " + getOpName() + " inputs ---");
+ LOG.info("nOps = " + getNumOpsRequired());
+ LOG.info("nThreads = " + getNumThreads());
+ printStats();
+ }
+ }
+
+ /**
+ * Decrypt key statistics.
+ *
+ * Each thread decrypts the key.
+ */
+ class DecryptKeyStats extends OperationStatsBase {
+ // Operation types
+ static final String OP_DECRYPT_KEY = "decrypt";
+ static final String OP_DECRYPT_USAGE =
+ "-op decrypt [-threads T -numops N -warmup F]";
+
+ DecryptKeyStats(List args) {
+ super();
+ parseArguments(args);
+ }
+
+ @Override
+ String getOpName() {
+ return OP_DECRYPT_KEY;
+ }
+
+ @Override
+ void parseArguments(List args) {
+ verifyOpArgument(args);
+ // parse command line
+ for (int i = 2; i < args.size(); i++) {
+ if (args.get(i).equals("-threads")) {
+ if (i+1 == args.size()) {
+ printUsage();
+ }
+ setNumThreads(Integer.parseInt(args.get(++i)));
+ } else if (args.get(i).equals("-numops")) {
+ setNumOpsRequired(Integer.parseInt(args.get(++i)));
+ }
+ }
+ }
+
+ /**
+ * returns client name.
+ */
+ @Override
+ String getExecutionArgument(int daemonId) {
+ return getClientName(daemonId);
+ }
+
+ /**
+ * Execute key decryption.
+ */
+ @Override
+ long executeOp(int daemonId, int inputIdx, String clientName)
+ throws IOException {
+ long start = Time.now();
+ try {
+ eek = kp.generateEncryptedKey(encryptionKeyName);
+ kp.decryptEncryptedKey(eek);
+ } catch (GeneralSecurityException e) {
+ LOG.warn("failed to generate and/or decrypt key", e);
+ }
+ long end = Time.now();
+ return end - start;
+ }
+
+ @Override
+ void printResults() {
+ LOG.info("--- " + getOpName() + " inputs ---");
+ LOG.info("nrOps = " + getNumOpsRequired());
+ LOG.info("nrThreads = " + getNumThreads());
+ printStats();
+ }
+ }
+
+ static void printUsage() {
+ System.err.println("Usage: KMSBenchmark"
+ + "\n\t" + OperationStatsBase.OP_ALL_USAGE
+ + " | \n\t" + EncryptKeyStats.OP_ENCRYPT_USAGE
+ + " | \n\t" + DecryptKeyStats.OP_DECRYPT_USAGE
+ + " | \n\t" + GENERAL_OPTIONS_USAGE
+ );
+ System.err.println();
+ GenericOptionsParser.printGenericCommandUsage(System.err);
+ ExitUtil.terminate(-1);
+ }
+
+ public static KeyProviderCryptoExtension createKeyProviderCryptoExtension(
+ final Configuration conf) throws IOException {
+
+ KeyProvider keyProvider = KMSUtil.createKeyProvider(conf,
+ CommonConfigurationKeysPublic.HADOOP_SECURITY_KEY_PROVIDER_PATH);
+ if (keyProvider == null) {
+ throw new IOException("Key provider was not configured.");
+ }
+ return KeyProviderCryptoExtension.
+ createKeyProviderCryptoExtension(keyProvider);
+ }
+
+ public static void runBenchmark(Configuration conf, String[] args)
+ throws Exception {
+ KMSBenchmark bench = null;
+ try {
+ bench = new KMSBenchmark(conf, args);
+ ToolRunner.run(bench, args);
+ } finally {
+ LOG.info("runBenchmark finished.");
+ }
+ }
+
+ /**
+ * Main method of the benchmark.
+ * @param aArgs command line parameters
+ */
+ @Override // Tool
+ public int run(String[] aArgs) throws Exception {
+ List args = new ArrayList(Arrays.asList(aArgs));
+ if (args.size() < 2 || !args.get(0).startsWith("-op")) {
+ printUsage();
+ }
+
+ String type = args.get(1);
+ boolean runAll = OperationStatsBase.OP_ALL_NAME.equals(type);
+
+ List ops = new ArrayList();
+ OperationStatsBase opStat = null;
+ try {
+ if (runAll || EncryptKeyStats.OP_ENCRYPT_KEY.equals(type)) {
+ opStat = new EncryptKeyStats(args);
+ ops.add(opStat);
+ }
+ if (runAll || DecryptKeyStats.OP_DECRYPT_KEY.equals(type)) {
+ opStat = new DecryptKeyStats(args);
+ ops.add(opStat);
+ }
+ if (ops.isEmpty()) {
+ printUsage();
+ }
+
+ // run each benchmark
+ for (OperationStatsBase op : ops) {
+ LOG.info("Starting benchmark: " + op.getOpName());
+ op.benchmark();
+ op.cleanUp();
+ }
+ // print statistics
+ for (OperationStatsBase op : ops) {
+ LOG.info("");
+ op.printResults();
+ }
+ } catch(Exception e) {
+ LOG.error("failed to run benchmarks", e);
+ throw e;
+ }
+ return 0;
+ }
+
+ public static void main(String[] args) throws Exception {
+ runBenchmark(new Configuration(), args);
+ }
+
+ @Override // Configurable
+ public void setConf(Configuration conf) {
+ config = conf;
+ }
+
+ @Override // Configurable
+ public Configuration getConf() {
+ return config;
+ }
+}
\ No newline at end of file