diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslClientFactory.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslClientFactory.java new file mode 100644 index 0000000000..d5259d338f --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslClientFactory.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.security; + +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; +import javax.security.sasl.SaslException; + +import org.apache.hadoop.classification.InterfaceAudience; + +/** + * Class for dealing with caching SASL client factories. + */ +@InterfaceAudience.LimitedPrivate({ "HDFS", "MapReduce" }) +public class FastSaslClientFactory implements SaslClientFactory { + private final Map> factoryCache = + new HashMap>(); + + public FastSaslClientFactory(Map props) { + final Enumeration factories = + Sasl.getSaslClientFactories(); + while (factories.hasMoreElements()) { + SaslClientFactory factory = factories.nextElement(); + for (String mech : factory.getMechanismNames(props)) { + if (!factoryCache.containsKey(mech)) { + factoryCache.put(mech, new ArrayList()); + } + factoryCache.get(mech).add(factory); + } + } + } + + @Override + public String[] getMechanismNames(Map props) { + return factoryCache.keySet().toArray(new String[0]); + } + + @Override + public SaslClient createSaslClient(String[] mechanisms, + String authorizationId, String protocol, String serverName, + Map props, CallbackHandler cbh) throws SaslException { + for (String mechanism : mechanisms) { + List factories = factoryCache.get(mechanism); + if (factories != null) { + for (SaslClientFactory factory : factories) { + SaslClient saslClient = + factory.createSaslClient(new String[] {mechanism}, + authorizationId, protocol, serverName, props, cbh); + if (saslClient != null) { + return saslClient; + } + } + } + } + return null; + } +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslServerFactory.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslServerFactory.java new file mode 100644 index 0000000000..79519d408f --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/FastSaslServerFactory.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.security; + +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import javax.security.sasl.SaslServerFactory; + +import org.apache.hadoop.classification.InterfaceAudience; + +/** + * Class for dealing with caching SASL server factories. + */ +@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) +public class FastSaslServerFactory implements SaslServerFactory { + private final Map> factoryCache = + new HashMap>(); + + public FastSaslServerFactory(Map props) { + final Enumeration factories = + Sasl.getSaslServerFactories(); + while (factories.hasMoreElements()) { + SaslServerFactory factory = factories.nextElement(); + for (String mech : factory.getMechanismNames(props)) { + if (!factoryCache.containsKey(mech)) { + factoryCache.put(mech, new ArrayList()); + } + factoryCache.get(mech).add(factory); + } + } + } + + @Override + public SaslServer createSaslServer(String mechanism, String protocol, + String serverName, Map props, CallbackHandler cbh) + throws SaslException { + SaslServer saslServer = null; + List factories = factoryCache.get(mechanism); + if (factories != null) { + for (SaslServerFactory factory : factories) { + saslServer = factory.createSaslServer( + mechanism, protocol, serverName, props, cbh); + if (saslServer != null) { + break; + } + } + } + return saslServer; + } + + @Override + public String[] getMechanismNames(Map props) { + return factoryCache.keySet().toArray(new String[0]); + } +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java index d236ab0c0e..a63ad4fdbb 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java @@ -44,6 +44,7 @@ import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; @@ -93,6 +94,7 @@ public class SaslRpcClient { private SaslClient saslClient; private SaslPropertiesResolver saslPropsResolver; private AuthMethod authMethod; + private static SaslClientFactory saslFactory; private static final RpcRequestHeaderProto saslHeader = ProtoUtil .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, @@ -101,6 +103,10 @@ public class SaslRpcClient { private static final RpcSaslProto negotiateRequest = RpcSaslProto.newBuilder().setState(SaslState.NEGOTIATE).build(); + static { + saslFactory = new FastSaslClientFactory(null); + } + /** * Create a SaslRpcClient that can be used by a RPC client to negotiate * SASL authentication with a RPC server @@ -251,8 +257,8 @@ private SaslClient createSaslClient(SaslAuth authType) LOG.debug("Creating SASL " + mechanism + "(" + method + ") " + " client to authenticate to service at " + saslServerName); } - return Sasl.createSaslClient( - new String[] { mechanism }, saslUser, saslProtocol, saslServerName, + return saslFactory.createSaslClient( + new String[] {mechanism}, saslUser, saslProtocol, saslServerName, saslProperties, saslCallback); } @@ -687,4 +693,4 @@ public void handle(Callback[] callbacks) } } } -} +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcServer.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcServer.java index 643af79e4b..7c3f14da21 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcServer.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcServer.java @@ -26,10 +26,6 @@ import java.nio.charset.StandardCharsets; import java.security.PrivilegedExceptionAction; import java.security.Security; -import java.util.ArrayList; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.List; import java.util.Map; import javax.security.auth.callback.Callback; @@ -39,7 +35,6 @@ import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.RealmCallback; -import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import javax.security.sasl.SaslServerFactory; @@ -178,10 +173,12 @@ public SaslServer run() throws SaslException { } public static void init(Configuration conf) { - Security.addProvider(new SaslPlainServer.SecurityProvider()); - // passing null so factory is populated with all possibilities. the - // properties passed when instantiating a server are what really matter - saslFactory = new FastSaslServerFactory(null); + if (saslFactory == null) { + Security.addProvider(new SaslPlainServer.SecurityProvider()); + // passing null so factory is populated with all possibilities. the + // properties passed when instantiating a server are what really matter + saslFactory = new FastSaslServerFactory(null); + } } static String encodeIdentifier(byte[] identifier) { @@ -367,47 +364,4 @@ public void handle(Callback[] callbacks) throws } } } - - // Sasl.createSaslServer is 100-200X slower than caching the factories! - private static class FastSaslServerFactory implements SaslServerFactory { - private final Map> factoryCache = - new HashMap>(); - - FastSaslServerFactory(Map props) { - final Enumeration factories = - Sasl.getSaslServerFactories(); - while (factories.hasMoreElements()) { - SaslServerFactory factory = factories.nextElement(); - for (String mech : factory.getMechanismNames(props)) { - if (!factoryCache.containsKey(mech)) { - factoryCache.put(mech, new ArrayList()); - } - factoryCache.get(mech).add(factory); - } - } - } - - @Override - public SaslServer createSaslServer(String mechanism, String protocol, - String serverName, Map props, CallbackHandler cbh) - throws SaslException { - SaslServer saslServer = null; - List factories = factoryCache.get(mechanism); - if (factories != null) { - for (SaslServerFactory factory : factories) { - saslServer = factory.createSaslServer( - mechanism, protocol, serverName, props, cbh); - if (saslServer != null) { - break; - } - } - } - return saslServer; - } - - @Override - public String[] getMechanismNames(Map props) { - return factoryCache.keySet().toArray(new String[0]); - } - } } diff --git a/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslParticipant.java b/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslParticipant.java index 1db9f50a8c..f51f458fb2 100644 --- a/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslParticipant.java +++ b/hadoop-hdfs-project/hadoop-hdfs-client/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslParticipant.java @@ -23,11 +23,15 @@ import javax.security.auth.callback.CallbackHandler; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; +import javax.security.sasl.SaslServerFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.hdfs.protocol.datatransfer.IOStreamPair; +import org.apache.hadoop.security.FastSaslClientFactory; +import org.apache.hadoop.security.FastSaslServerFactory; import org.apache.hadoop.security.SaslInputStream; import org.apache.hadoop.security.SaslOutputStream; @@ -51,7 +55,20 @@ class SaslParticipant { // One of these will always be null. private final SaslServer saslServer; private final SaslClient saslClient; + private static SaslServerFactory saslServerFactory; + private static SaslClientFactory saslClientFactory; + private static void initializeSaslServerFactory() { + if (saslServerFactory == null) { + saslServerFactory = new FastSaslServerFactory(null); + } + } + + private static void initializeSaslClientFactory() { + if (saslClientFactory == null) { + saslClientFactory = new FastSaslClientFactory(null); + } + } /** * Creates a SaslParticipant wrapping a SaslServer. * @@ -63,7 +80,8 @@ class SaslParticipant { public static SaslParticipant createServerSaslParticipant( Map saslProps, CallbackHandler callbackHandler) throws SaslException { - return new SaslParticipant(Sasl.createSaslServer(MECHANISM, + initializeSaslServerFactory(); + return new SaslParticipant(saslServerFactory.createSaslServer(MECHANISM, PROTOCOL, SERVER_NAME, saslProps, callbackHandler)); } @@ -79,8 +97,10 @@ public static SaslParticipant createServerSaslParticipant( public static SaslParticipant createClientSaslParticipant(String userName, Map saslProps, CallbackHandler callbackHandler) throws SaslException { - return new SaslParticipant(Sasl.createSaslClient(new String[] { MECHANISM }, - userName, PROTOCOL, SERVER_NAME, saslProps, callbackHandler)); + initializeSaslClientFactory(); + return new SaslParticipant( + saslClientFactory.createSaslClient(new String[] {MECHANISM}, userName, + PROTOCOL, SERVER_NAME, saslProps, callbackHandler)); } /**