HADOOP-9421. [RPC v9] Convert SASL to use ProtoBuf and provide negotiation capabilities (daryn)

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1495577 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Daryn Sharp 2013-06-21 20:09:31 +00:00
parent 6cb5ad16d0
commit 5f9b4c14a1
11 changed files with 628 additions and 253 deletions

View File

@ -319,6 +319,9 @@ Release 2.1.0-beta - UNRELEASED
HADOOP-9630. [RPC v9] Remove IpcSerializationType. (Junping Du via llu) HADOOP-9630. [RPC v9] Remove IpcSerializationType. (Junping Du via llu)
HADOOP-9421. [RPC v9] Convert SASL to use ProtoBuf and provide
negotiation capabilities (daryn)
NEW FEATURES NEW FEATURES
HADOOP-9283. Add support for running the Hadoop client on AIX. (atm) HADOOP-9283. Add support for running the Hadoop client on AIX. (atm)

View File

@ -320,6 +320,15 @@
<Field name="in" /> <Field name="in" />
<Bug pattern="IS2_INCONSISTENT_SYNC" /> <Bug pattern="IS2_INCONSISTENT_SYNC" />
</Match> </Match>
<!--
The switch condition for INITIATE is expected to fallthru to RESPONSE
to process initial sasl response token included in the INITIATE
-->
<Match>
<Class name="org.apache.hadoop.ipc.Server$Connection" />
<Method name="processSaslMessage" />
<Bug pattern="SF_SWITCH_FALLTHROUGH" />
</Match>
<!-- Synchronization performed on util.concurrent instance. --> <!-- Synchronization performed on util.concurrent instance. -->
<Match> <Match>

View File

@ -62,6 +62,7 @@
import org.apache.hadoop.io.retry.RetryPolicies; import org.apache.hadoop.io.retry.RetryPolicies;
import org.apache.hadoop.io.retry.RetryPolicy; import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.io.retry.RetryPolicy.RetryAction; import org.apache.hadoop.io.retry.RetryPolicy.RetryAction;
import org.apache.hadoop.ipc.Server.AuthProtocol;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
@ -751,7 +752,7 @@ private void handleConnectionFailure(int curRetries, IOException ioe
* +----------------------------------+ * +----------------------------------+
* | Service Class (1 byte) | * | Service Class (1 byte) |
* +----------------------------------+ * +----------------------------------+
* | Authmethod (1 byte) | * | AuthProtocol (1 byte) |
* +----------------------------------+ * +----------------------------------+
*/ */
private void writeConnectionHeader(OutputStream outStream) private void writeConnectionHeader(OutputStream outStream)
@ -761,7 +762,15 @@ private void writeConnectionHeader(OutputStream outStream)
out.write(Server.HEADER.array()); out.write(Server.HEADER.array());
out.write(Server.CURRENT_VERSION); out.write(Server.CURRENT_VERSION);
out.write(serviceClass); out.write(serviceClass);
authMethod.write(out); final AuthProtocol authProtocol;
switch (authMethod) {
case SIMPLE:
authProtocol = AuthProtocol.NONE;
break;
default:
authProtocol = AuthProtocol.SASL;
}
out.write(authProtocol.callId);
out.flush(); out.flush();
} }

View File

@ -41,6 +41,8 @@
import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto; import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
@ -48,10 +50,10 @@
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.AbstractMessageLite;
import com.google.protobuf.BlockingService; import com.google.protobuf.BlockingService;
import com.google.protobuf.CodedOutputStream; import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.Message; import com.google.protobuf.Message;
import com.google.protobuf.ServiceException; import com.google.protobuf.ServiceException;
import com.google.protobuf.TextFormat; import com.google.protobuf.TextFormat;
@ -279,16 +281,16 @@ interface RpcWrapper extends Writable {
* Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC}
* use type Writable as a wrapper to work across multiple RpcEngine kinds. * use type Writable as a wrapper to work across multiple RpcEngine kinds.
*/ */
private static class RpcRequestWrapper implements RpcWrapper { private static abstract class RpcMessageWithHeader<T extends GeneratedMessage>
RequestHeaderProto requestHeader; implements RpcWrapper {
T requestHeader;
Message theRequest; // for clientSide, the request is here Message theRequest; // for clientSide, the request is here
byte[] theRequestRead; // for server side, the request is here byte[] theRequestRead; // for server side, the request is here
@SuppressWarnings("unused") public RpcMessageWithHeader() {
public RpcRequestWrapper() {
} }
RpcRequestWrapper(RequestHeaderProto requestHeader, Message theRequest) { public RpcMessageWithHeader(T requestHeader, Message theRequest) {
this.requestHeader = requestHeader; this.requestHeader = requestHeader;
this.theRequest = theRequest; this.theRequest = theRequest;
} }
@ -303,21 +305,31 @@ public void write(DataOutput out) throws IOException {
@Override @Override
public void readFields(DataInput in) throws IOException { public void readFields(DataInput in) throws IOException {
int length = ProtoUtil.readRawVarint32(in); requestHeader = parseHeaderFrom(readVarintBytes(in));
byte[] bytes = new byte[length]; theRequestRead = readMessageRequest(in);
in.readFully(bytes);
requestHeader = RequestHeaderProto.parseFrom(bytes);
length = ProtoUtil.readRawVarint32(in);
theRequestRead = new byte[length];
in.readFully(theRequestRead);
}
@Override
public String toString() {
return requestHeader.getDeclaringClassProtocolName() + "." +
requestHeader.getMethodName();
} }
abstract T parseHeaderFrom(byte[] bytes) throws IOException;
byte[] readMessageRequest(DataInput in) throws IOException {
return readVarintBytes(in);
}
private static byte[] readVarintBytes(DataInput in) throws IOException {
final int length = ProtoUtil.readRawVarint32(in);
final byte[] bytes = new byte[length];
in.readFully(bytes);
return bytes;
}
public T getMessageHeader() {
return requestHeader;
}
public byte[] getMessageBytes() {
return theRequestRead;
}
@Override @Override
public int getLength() { public int getLength() {
int headerLen = requestHeader.getSerializedSize(); int headerLen = requestHeader.getSerializedSize();
@ -328,12 +340,78 @@ public int getLength() {
reqLen = theRequestRead.length; reqLen = theRequestRead.length;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"getLenght on uninilialized RpcWrapper"); "getLength on uninitialized RpcWrapper");
} }
return CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen return CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen
+ CodedOutputStream.computeRawVarint32Size(reqLen) + reqLen; + CodedOutputStream.computeRawVarint32Size(reqLen) + reqLen;
} }
} }
private static class RpcRequestWrapper
extends RpcMessageWithHeader<RequestHeaderProto> {
@SuppressWarnings("unused")
public RpcRequestWrapper() {}
public RpcRequestWrapper(
RequestHeaderProto requestHeader, Message theRequest) {
super(requestHeader, theRequest);
}
@Override
RequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RequestHeaderProto.parseFrom(bytes);
}
@Override
public String toString() {
return requestHeader.getDeclaringClassProtocolName() + "." +
requestHeader.getMethodName();
}
}
@InterfaceAudience.LimitedPrivate({"RPC"})
public static class RpcRequestMessageWrapper
extends RpcMessageWithHeader<RpcRequestHeaderProto> {
public RpcRequestMessageWrapper() {}
public RpcRequestMessageWrapper(
RpcRequestHeaderProto requestHeader, Message theRequest) {
super(requestHeader, theRequest);
}
@Override
RpcRequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RpcRequestHeaderProto.parseFrom(bytes);
}
}
@InterfaceAudience.LimitedPrivate({"RPC"})
public static class RpcResponseMessageWrapper
extends RpcMessageWithHeader<RpcResponseHeaderProto> {
public RpcResponseMessageWrapper() {}
public RpcResponseMessageWrapper(
RpcResponseHeaderProto responseHeader, Message theRequest) {
super(responseHeader, theRequest);
}
@Override
byte[] readMessageRequest(DataInput in) throws IOException {
// error message contain no message body
switch (requestHeader.getStatus()) {
case ERROR:
case FATAL:
return null;
default:
return super.readMessageRequest(in);
}
}
@Override
RpcResponseHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
return RpcResponseHeaderProto.parseFrom(bytes);
}
}
/** /**
* Wrapper for Protocol Buffer Responses * Wrapper for Protocol Buffer Responses
@ -342,11 +420,11 @@ public int getLength() {
* Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC}
* use type Writable as a wrapper to work across multiple RpcEngine kinds. * use type Writable as a wrapper to work across multiple RpcEngine kinds.
*/ */
private static class RpcResponseWrapper implements RpcWrapper { @InterfaceAudience.LimitedPrivate({"RPC"}) // temporarily exposed
public static class RpcResponseWrapper implements RpcWrapper {
Message theResponse; // for senderSide, the response is here Message theResponse; // for senderSide, the response is here
byte[] theResponseRead; // for receiver side, the response is here byte[] theResponseRead; // for receiver side, the response is here
@SuppressWarnings("unused")
public RpcResponseWrapper() { public RpcResponseWrapper() {
} }
@ -376,7 +454,7 @@ public int getLength() {
resLen = theResponseRead.length; resLen = theResponseRead.length;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"getLenght on uninilialized RpcWrapper"); "getLength on uninitialized RpcWrapper");
} }
return CodedOutputStream.computeRawVarint32Size(resLen) + resLen; return CodedOutputStream.computeRawVarint32Size(resLen) + resLen;
} }

View File

@ -21,7 +21,6 @@
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.DataInputStream; import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException; import java.lang.reflect.UndeclaredThrowableException;
@ -46,7 +45,6 @@
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
@ -59,7 +57,6 @@
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl; import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer; import javax.security.sasl.SaslServer;
@ -72,11 +69,11 @@
import org.apache.hadoop.conf.Configuration.IntegerRanges; import org.apache.hadoop.conf.Configuration.IntegerRanges;
import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RPC.VersionMismatch; import org.apache.hadoop.ipc.RPC.VersionMismatch;
import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics; import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics;
@ -84,18 +81,15 @@
import org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto; import org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.*;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.*; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.*;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.SaslRpcServer; import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler;
import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler;
import org.apache.hadoop.security.SaslRpcServer.SaslStatus;
import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.PolicyProvider; import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.security.authorize.ProxyUsers; import org.apache.hadoop.security.authorize.ProxyUsers;
@ -109,7 +103,9 @@
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedOutputStream; import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Message;
/** An abstract IPC service. IPC calls take a single {@link Writable} as a /** An abstract IPC service. IPC calls take a single {@link Writable} as a
* parameter, and return a {@link Writable} as their value. A service runs on * parameter, and return a {@link Writable} as their value. A service runs on
@ -121,7 +117,8 @@
@InterfaceStability.Evolving @InterfaceStability.Evolving
public abstract class Server { public abstract class Server {
private final boolean authorize; private final boolean authorize;
private EnumSet<AuthMethod> enabledAuthMethods; private List<AuthMethod> enabledAuthMethods;
private RpcSaslProto negotiateResponse;
private ExceptionsHandler exceptionsHandler = new ExceptionsHandler(); private ExceptionsHandler exceptionsHandler = new ExceptionsHandler();
public void addTerseExceptions(Class<?>... exceptionClass) { public void addTerseExceptions(Class<?>... exceptionClass) {
@ -1065,6 +1062,26 @@ private synchronized void waitPending() throws InterruptedException {
} }
} }
@InterfaceAudience.Private
public static enum AuthProtocol {
NONE(0),
SASL(-33);
public final int callId;
AuthProtocol(int callId) {
this.callId = callId;
}
static AuthProtocol valueOf(int callId) {
for (AuthProtocol authType : AuthProtocol.values()) {
if (authType.callId == callId) {
return authType;
}
}
return null;
}
};
/** Reads calls from a connection and queues them for handling. */ /** Reads calls from a connection and queues them for handling. */
public class Connection { public class Connection {
private boolean connectionHeaderRead = false; // connection header is read? private boolean connectionHeaderRead = false; // connection header is read?
@ -1089,6 +1106,7 @@ public class Connection {
String protocolName; String protocolName;
SaslServer saslServer; SaslServer saslServer;
private AuthMethod authMethod; private AuthMethod authMethod;
private AuthProtocol authProtocol;
private boolean saslContextEstablished; private boolean saslContextEstablished;
private boolean skipInitialSaslHandshake; private boolean skipInitialSaslHandshake;
private ByteBuffer connectionHeaderBuf = null; private ByteBuffer connectionHeaderBuf = null;
@ -1104,12 +1122,11 @@ public class Connection {
private final Call authFailedCall = private final Call authFailedCall =
new Call(AUTHORIZATION_FAILED_CALLID, null, this); new Call(AUTHORIZATION_FAILED_CALLID, null, this);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
// Fake 'call' for SASL context setup
private static final int SASL_CALLID = -33;
private final Call saslCall = new Call(SASL_CALLID, null, this); private final Call saslCall = new Call(AuthProtocol.SASL.callId, null, this);
private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream(); private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
private boolean sentNegotiate = false;
private boolean useWrap = false; private boolean useWrap = false;
public Connection(SelectionKey key, SocketChannel channel, public Connection(SelectionKey key, SocketChannel channel,
@ -1183,7 +1200,7 @@ private boolean timedOut(long currentTime) {
private UserGroupInformation getAuthorizedUgi(String authorizedId) private UserGroupInformation getAuthorizedUgi(String authorizedId)
throws IOException { throws IOException {
if (authMethod == SaslRpcServer.AuthMethod.DIGEST) { if (authMethod == AuthMethod.TOKEN) {
TokenIdentifier tokenId = SaslRpcServer.getIdentifier(authorizedId, TokenIdentifier tokenId = SaslRpcServer.getIdentifier(authorizedId,
secretManager); secretManager);
UserGroupInformation ugi = tokenId.getUser(); UserGroupInformation ugi = tokenId.getUser();
@ -1201,12 +1218,9 @@ private UserGroupInformation getAuthorizedUgi(String authorizedId)
private void saslReadAndProcess(byte[] saslToken) throws IOException, private void saslReadAndProcess(byte[] saslToken) throws IOException,
InterruptedException { InterruptedException {
if (!saslContextEstablished) { if (!saslContextEstablished) {
byte[] replyToken = null; RpcSaslProto saslResponse;
try { try {
if (LOG.isDebugEnabled()) saslResponse = processSaslMessage(saslToken);
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.evaluateResponse()");
replyToken = saslServer.evaluateResponse(saslToken);
} catch (IOException e) { } catch (IOException e) {
IOException sendToClient = e; IOException sendToClient = e;
Throwable cause = e; Throwable cause = e;
@ -1217,27 +1231,17 @@ private void saslReadAndProcess(byte[] saslToken) throws IOException,
} }
cause = cause.getCause(); cause = cause.getCause();
} }
doSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
rpcMetrics.incrAuthenticationFailures(); rpcMetrics.incrAuthenticationFailures();
String clientIP = this.toString(); String clientIP = this.toString();
// attempting user could be null // attempting user could be null
AUDITLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser + AUDITLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser +
" (" + e.getLocalizedMessage() + ")"); " (" + e.getLocalizedMessage() + ")");
// wait to send response until failure is logged
doSaslReply(sendToClient);
throw e; throw e;
} }
if (saslServer.isComplete() && replyToken == null) {
// send final response for success if (saslServer != null && saslServer.isComplete()) {
replyToken = new byte[0];
}
if (replyToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + replyToken.length
+ " from saslServer.");
doSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null,
null);
}
if (saslServer.isComplete()) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("SASL server context established. Negotiated QoP is " LOG.debug("SASL server context established. Negotiated QoP is "
+ saslServer.getNegotiatedProperty(Sasl.QOP)); + saslServer.getNegotiatedProperty(Sasl.QOP));
@ -1252,6 +1256,9 @@ private void saslReadAndProcess(byte[] saslToken) throws IOException,
AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user); AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user);
saslContextEstablished = true; saslContextEstablished = true;
} }
// send reply here to avoid a successful auth being logged as a
// failure if response can't be sent
doSaslReply(saslResponse);
} else { } else {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Have read input token of size " + saslToken.length LOG.debug("Have read input token of size " + saslToken.length
@ -1267,21 +1274,101 @@ private void saslReadAndProcess(byte[] saslToken) throws IOException,
} }
} }
private void doSaslReply(SaslStatus status, Writable rv, private RpcSaslProto processSaslMessage(byte[] buf)
String errorClass, String error) throws IOException { throws IOException, InterruptedException {
saslResponse.reset(); final DataInputStream dis =
DataOutputStream out = new DataOutputStream(saslResponse); new DataInputStream(new ByteArrayInputStream(buf));
out.writeInt(status.state); // write status RpcRequestMessageWrapper requestWrapper = new RpcRequestMessageWrapper();
if (status == SaslStatus.SUCCESS) { requestWrapper.readFields(dis);
rv.write(out);
} else { final RpcRequestHeaderProto rpcHeader = requestWrapper.requestHeader;
WritableUtils.writeString(out, errorClass); if (rpcHeader.getCallId() != AuthProtocol.SASL.callId) {
WritableUtils.writeString(out, error); throw new SaslException("Client sent non-SASL request");
}
final RpcSaslProto saslMessage =
RpcSaslProto.parseFrom(requestWrapper.theRequestRead);
RpcSaslProto saslResponse = null;
final SaslState state = saslMessage.getState(); // required
switch (state) {
case NEGOTIATE: {
if (sentNegotiate) {
throw new AccessControlException(
"Client already attempted negotiation");
}
saslResponse = buildSaslNegotiateResponse();
break;
}
case INITIATE: {
if (saslMessage.getAuthsCount() != 1) {
throw new SaslException("Client mechanism is malformed");
}
String authMethodName = saslMessage.getAuths(0).getMethod();
authMethod = createSaslServer(authMethodName);
if (authMethod == null) { // the auth method is not supported
if (sentNegotiate) {
throw new AccessControlException(
authMethodName + " authentication is not enabled."
+ " Available:" + enabledAuthMethods);
}
saslResponse = buildSaslNegotiateResponse();
break;
}
// fallthru to process sasl token
}
case RESPONSE: {
if (!saslMessage.hasToken()) {
throw new SaslException("Client did not send a token");
}
byte[] saslToken = saslMessage.getToken().toByteArray();
if (LOG.isDebugEnabled()) {
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.evaluateResponse()");
}
saslToken = saslServer.evaluateResponse(saslToken);
saslResponse = buildSaslResponse(
saslServer.isComplete() ? SaslState.SUCCESS : SaslState.CHALLENGE,
saslToken);
break;
}
default:
throw new SaslException("Client sent unsupported state " + state);
} }
saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray())); return saslResponse;
}
private RpcSaslProto buildSaslResponse(SaslState state, byte[] replyToken)
throws IOException {
if (LOG.isDebugEnabled()) {
LOG.debug("Will send " + state + " token of size "
+ ((replyToken != null) ? replyToken.length : null)
+ " from saslServer.");
}
RpcSaslProto.Builder response = RpcSaslProto.newBuilder();
response.setState(state);
if (replyToken != null) {
response.setToken(ByteString.copyFrom(replyToken));
}
return response.build();
}
private void doSaslReply(Message message)
throws IOException {
if (LOG.isDebugEnabled()) {
LOG.debug("Sending sasl message "+message);
}
setupResponse(saslResponse, saslCall,
RpcStatusProto.SUCCESS, null,
new RpcResponseWrapper(message), null, null);
responder.doRespond(saslCall); responder.doRespond(saslCall);
} }
private void doSaslReply(Exception ioe) throws IOException {
setupResponse(authFailedResponse, authFailedCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
null, ioe.getClass().getName(), ioe.getLocalizedMessage());
responder.doRespond(authFailedCall);
}
private void disposeSasl() { private void disposeSasl() {
if (saslServer != null) { if (saslServer != null) {
try { try {
@ -1315,10 +1402,6 @@ public int readAndProcess() throws IOException, InterruptedException {
int version = connectionHeaderBuf.get(0); int version = connectionHeaderBuf.get(0);
// TODO we should add handler for service class later // TODO we should add handler for service class later
this.setServiceClass(connectionHeaderBuf.get(1)); this.setServiceClass(connectionHeaderBuf.get(1));
byte[] method = new byte[] {connectionHeaderBuf.get(2)};
authMethod = AuthMethod.read(new DataInputStream(
new ByteArrayInputStream(method)));
dataLengthBuffer.flip(); dataLengthBuffer.flip();
// Check if it looks like the user is hitting an IPC port // Check if it looks like the user is hitting an IPC port
@ -1339,14 +1422,10 @@ public int readAndProcess() throws IOException, InterruptedException {
return -1; return -1;
} }
dataLengthBuffer.clear(); // this may switch us into SIMPLE
if (authMethod == null) { authProtocol = initializeAuthContext(connectionHeaderBuf.get(2));
throw new IOException("Unable to read authentication method");
}
// this may create a SASL server, or switch us into SIMPLE
authMethod = initializeAuthContext(authMethod);
dataLengthBuffer.clear();
connectionHeaderBuf = null; connectionHeaderBuf = null;
connectionHeaderRead = true; connectionHeaderRead = true;
continue; continue;
@ -1373,14 +1452,14 @@ public int readAndProcess() throws IOException, InterruptedException {
if (data.remaining() == 0) { if (data.remaining() == 0) {
dataLengthBuffer.clear(); dataLengthBuffer.clear();
data.flip(); data.flip();
if (skipInitialSaslHandshake) {
data = null;
skipInitialSaslHandshake = false;
continue;
}
boolean isHeaderRead = connectionContextRead; boolean isHeaderRead = connectionContextRead;
if (saslServer != null) { if (authProtocol == AuthProtocol.SASL) {
saslReadAndProcess(data.array()); // switch to simple must ignore next negotiate or initiate
if (skipInitialSaslHandshake) {
authProtocol = AuthProtocol.NONE;
} else {
saslReadAndProcess(data.array());
}
} else { } else {
processOneRpc(data.array()); processOneRpc(data.array());
} }
@ -1393,102 +1472,79 @@ public int readAndProcess() throws IOException, InterruptedException {
} }
} }
private AuthMethod initializeAuthContext(AuthMethod authMethod) private AuthProtocol initializeAuthContext(int authType)
throws IOException, InterruptedException { throws IOException, InterruptedException {
AuthProtocol authProtocol = AuthProtocol.valueOf(authType);
if (authProtocol == null) {
IOException ioe = new IpcException("Unknown auth protocol:" + authType);
doSaslReply(ioe);
throw ioe;
}
boolean isSimpleEnabled = enabledAuthMethods.contains(AuthMethod.SIMPLE);
switch (authProtocol) {
case NONE: {
// don't reply if client is simple and server is insecure
if (!isSimpleEnabled) {
IOException ioe = new AccessControlException(
"SIMPLE authentication is not enabled."
+ " Available:" + enabledAuthMethods);
doSaslReply(ioe);
throw ioe;
}
break;
}
case SASL: {
if (isSimpleEnabled) { // switch to simple hack
skipInitialSaslHandshake = true;
doSaslReply(buildSaslResponse(SaslState.SUCCESS, null));
}
// else wait for a negotiate or initiate
break;
}
}
return authProtocol;
}
private RpcSaslProto buildSaslNegotiateResponse()
throws IOException, InterruptedException {
RpcSaslProto negotiateMessage = negotiateResponse;
// accelerate token negotiation by sending initial challenge
// in the negotiation response
if (enabledAuthMethods.contains(AuthMethod.TOKEN)) {
saslServer = createSaslServer(AuthMethod.TOKEN);
byte[] challenge = saslServer.evaluateResponse(new byte[0]);
RpcSaslProto.Builder negotiateBuilder =
RpcSaslProto.newBuilder(negotiateResponse);
negotiateBuilder.getAuthsBuilder(0) // TOKEN is always first
.setChallenge(ByteString.copyFrom(challenge));
negotiateMessage = negotiateBuilder.build();
}
sentNegotiate = true;
return negotiateMessage;
}
private AuthMethod createSaslServer(String authMethodName)
throws IOException, InterruptedException {
AuthMethod authMethod;
try { try {
if (enabledAuthMethods.contains(authMethod)) { authMethod = AuthMethod.valueOf(authMethodName);
saslServer = createSaslServer(authMethod); if (!enabledAuthMethods.contains(authMethod)) {
} else if (enabledAuthMethods.contains(AuthMethod.SIMPLE)) { authMethod = null;
doSaslReply(SaslStatus.SUCCESS, new IntWritable(
SaslRpcServer.SWITCH_TO_SIMPLE_AUTH), null, null);
authMethod = AuthMethod.SIMPLE;
// client has already sent the initial Sasl message and we
// should ignore it. Both client and server should fall back
// to simple auth from now on.
skipInitialSaslHandshake = true;
} else {
throw new AccessControlException(
authMethod + " authentication is not enabled."
+ " Available:" + enabledAuthMethods);
} }
} catch (IOException ioe) { } catch (IllegalArgumentException iae) {
final String ioeClass = ioe.getClass().getName(); authMethod = null;
final String ioeMessage = ioe.getLocalizedMessage(); }
if (authMethod == AuthMethod.SIMPLE) { if (authMethod != null &&
setupResponse(authFailedResponse, authFailedCall, // sasl server for tokens may already be instantiated
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED, (saslServer == null || authMethod != AuthMethod.TOKEN)) {
null, ioeClass, ioeMessage); saslServer = createSaslServer(authMethod);
responder.doRespond(authFailedCall);
} else {
doSaslReply(SaslStatus.ERROR, null, ioeClass, ioeMessage);
}
throw ioe;
} }
return authMethod; return authMethod;
} }
private SaslServer createSaslServer(AuthMethod authMethod) private SaslServer createSaslServer(AuthMethod authMethod)
throws IOException, InterruptedException { throws IOException, InterruptedException {
String hostname = null; return new SaslRpcServer(authMethod).create(this, secretManager);
String saslProtocol = null;
CallbackHandler saslCallback = null;
switch (authMethod) {
case SIMPLE: {
return null; // no sasl for simple
}
case DIGEST: {
secretManager.checkAvailableForRead();
hostname = SaslRpcServer.SASL_DEFAULT_REALM;
saslCallback = new SaslDigestCallbackHandler(secretManager, this);
break;
}
case KERBEROS: {
String fullName = UserGroupInformation.getCurrentUser().getUserName();
if (LOG.isDebugEnabled())
LOG.debug("Kerberos principal name is " + fullName);
KerberosName krbName = new KerberosName(fullName);
hostname = krbName.getHostName();
if (hostname == null) {
throw new AccessControlException(
"Kerberos principal name does NOT have the expected "
+ "hostname part: " + fullName);
}
saslProtocol = krbName.getServiceName();
saslCallback = new SaslGssCallbackHandler();
break;
}
default:
// we should never be able to get here
throw new AccessControlException(
"Server does not support SASL " + authMethod);
}
return createSaslServer(authMethod.getMechanismName(), saslProtocol,
hostname, saslCallback);
}
private SaslServer createSaslServer(final String mechanism,
final String protocol,
final String hostname,
final CallbackHandler callback
) throws IOException, InterruptedException {
SaslServer saslServer = UserGroupInformation.getCurrentUser().doAs(
new PrivilegedExceptionAction<SaslServer>() {
@Override
public SaslServer run() throws SaslException {
return Sasl.createSaslServer(mechanism, protocol, hostname,
SaslRpcServer.SASL_PROPS, callback);
}
});
if (saslServer == null) {
throw new AccessControlException(
"Unable to find SASL server implementation for " + mechanism);
}
if (LOG.isDebugEnabled()) {
LOG.debug("Created SASL server with mechanism = " + mechanism);
}
return saslServer;
} }
/** /**
@ -1557,7 +1613,7 @@ private void processConnectionContext(byte[] buf) throws IOException {
//this is not allowed if user authenticated with DIGEST. //this is not allowed if user authenticated with DIGEST.
if ((protocolUser != null) if ((protocolUser != null)
&& (!protocolUser.getUserName().equals(user.getUserName()))) { && (!protocolUser.getUserName().equals(user.getUserName()))) {
if (authMethod == AuthMethod.DIGEST) { if (authMethod == AuthMethod.TOKEN) {
// Not allowed to doAs if token authentication is used // Not allowed to doAs if token authentication is used
throw new AccessControlException("Authenticated user (" + user throw new AccessControlException("Authenticated user (" + user
+ ") doesn't match what the client claims to be (" + ") doesn't match what the client claims to be ("
@ -1713,7 +1769,7 @@ private boolean authorizeConnection() throws IOException {
// authorize real user. doAs is allowed only for simple or kerberos // authorize real user. doAs is allowed only for simple or kerberos
// authentication // authentication
if (user != null && user.getRealUser() != null if (user != null && user.getRealUser() != null
&& (authMethod != AuthMethod.DIGEST)) { && (authMethod != AuthMethod.TOKEN)) {
ProxyUsers.authorize(user, this.getHostAddress(), conf); ProxyUsers.authorize(user, this.getHostAddress(), conf);
} }
authorize(user, protocolName, getHostInetAddress()); authorize(user, protocolName, getHostInetAddress());
@ -1954,6 +2010,7 @@ protected Server(String bindAddress, int port,
// configure supported authentications // configure supported authentications
this.enabledAuthMethods = getAuthMethods(secretManager, conf); this.enabledAuthMethods = getAuthMethods(secretManager, conf);
this.negotiateResponse = buildNegotiateResponse(enabledAuthMethods);
// Start the listener here and let it bind to the port // Start the listener here and let it bind to the port
listener = new Listener(); listener = new Listener();
@ -1973,17 +2030,33 @@ protected Server(String bindAddress, int port,
this.exceptionsHandler.addTerseExceptions(StandbyException.class); this.exceptionsHandler.addTerseExceptions(StandbyException.class);
} }
private RpcSaslProto buildNegotiateResponse(List<AuthMethod> authMethods)
throws IOException {
RpcSaslProto.Builder negotiateBuilder = RpcSaslProto.newBuilder();
negotiateBuilder.setState(SaslState.NEGOTIATE);
for (AuthMethod authMethod : authMethods) {
if (authMethod == AuthMethod.SIMPLE) { // not a SASL method
continue;
}
SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod);
negotiateBuilder.addAuthsBuilder()
.setMethod(authMethod.toString())
.setMechanism(saslRpcServer.mechanism)
.setProtocol(saslRpcServer.protocol)
.setServerId(saslRpcServer.serverId);
}
return negotiateBuilder.build();
}
// get the security type from the conf. implicitly include token support // get the security type from the conf. implicitly include token support
// if a secret manager is provided, or fail if token is the conf value but // if a secret manager is provided, or fail if token is the conf value but
// there is no secret manager // there is no secret manager
private EnumSet<AuthMethod> getAuthMethods(SecretManager<?> secretManager, private List<AuthMethod> getAuthMethods(SecretManager<?> secretManager,
Configuration conf) { Configuration conf) {
AuthenticationMethod confAuthenticationMethod = AuthenticationMethod confAuthenticationMethod =
SecurityUtil.getAuthenticationMethod(conf); SecurityUtil.getAuthenticationMethod(conf);
EnumSet<AuthMethod> authMethods = List<AuthMethod> authMethods = new ArrayList<AuthMethod>();
EnumSet.of(confAuthenticationMethod.getAuthMethod());
if (confAuthenticationMethod == AuthenticationMethod.TOKEN) { if (confAuthenticationMethod == AuthenticationMethod.TOKEN) {
if (secretManager == null) { if (secretManager == null) {
throw new IllegalArgumentException(AuthenticationMethod.TOKEN + throw new IllegalArgumentException(AuthenticationMethod.TOKEN +
@ -1992,8 +2065,10 @@ private EnumSet<AuthMethod> getAuthMethods(SecretManager<?> secretManager,
} else if (secretManager != null) { } else if (secretManager != null) {
LOG.debug(AuthenticationMethod.TOKEN + LOG.debug(AuthenticationMethod.TOKEN +
" authentication enabled for secret manager"); " authentication enabled for secret manager");
// most preferred, go to the front of the line!
authMethods.add(AuthenticationMethod.TOKEN.getAuthMethod()); authMethods.add(AuthenticationMethod.TOKEN.getAuthMethod());
} }
authMethods.add(confAuthenticationMethod.getAuthMethod());
LOG.debug("Server accepts auth methods:" + authMethods); LOG.debug("Server accepts auth methods:" + authMethods);
return authMethods; return authMethods;

View File

@ -42,14 +42,24 @@
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.Server.AuthProtocol;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslAuth;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslState;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.SaslStatus;
import org.apache.hadoop.security.authentication.util.KerberosName; import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.ProtoUtil;
import com.google.protobuf.ByteString;
/** /**
* A utility class that encapsulates SASL logic for RPC client * A utility class that encapsulates SASL logic for RPC client
*/ */
@ -58,9 +68,15 @@
public class SaslRpcClient { public class SaslRpcClient {
public static final Log LOG = LogFactory.getLog(SaslRpcClient.class); public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
private final AuthMethod authMethod;
private final SaslClient saslClient; private final SaslClient saslClient;
private final boolean fallbackAllowed; private final boolean fallbackAllowed;
private static final RpcRequestHeaderProto saslHeader =
ProtoUtil.makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
OperationProto.RPC_FINAL_PACKET, AuthProtocol.SASL.callId);
private static final RpcSaslProto negotiateRequest =
RpcSaslProto.newBuilder().setState(SaslState.NEGOTIATE).build();
/** /**
* Create a SaslRpcClient for an authentication method * Create a SaslRpcClient for an authentication method
* *
@ -73,6 +89,7 @@ public SaslRpcClient(AuthMethod method,
Token<? extends TokenIdentifier> token, String serverPrincipal, Token<? extends TokenIdentifier> token, String serverPrincipal,
boolean fallbackAllowed) boolean fallbackAllowed)
throws IOException { throws IOException {
this.authMethod = method;
this.fallbackAllowed = fallbackAllowed; this.fallbackAllowed = fallbackAllowed;
String saslUser = null; String saslUser = null;
String saslProtocol = null; String saslProtocol = null;
@ -81,7 +98,8 @@ public SaslRpcClient(AuthMethod method,
CallbackHandler saslCallback = null; CallbackHandler saslCallback = null;
switch (method) { switch (method) {
case DIGEST: { case TOKEN: {
saslProtocol = "";
saslServerName = SaslRpcServer.SASL_DEFAULT_REALM; saslServerName = SaslRpcServer.SASL_DEFAULT_REALM;
saslCallback = new SaslClientCallbackHandler(token); saslCallback = new SaslClientCallbackHandler(token);
break; break;
@ -107,7 +125,7 @@ public SaslRpcClient(AuthMethod method,
String mechanism = method.getMechanismName(); String mechanism = method.getMechanismName();
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Creating SASL " + mechanism LOG.debug("Creating SASL " + mechanism + "(" + authMethod + ") "
+ " client to authenticate to service at " + saslServerName); + " client to authenticate to service at " + saslServerName);
} }
saslClient = Sasl.createSaslClient( saslClient = Sasl.createSaslClient(
@ -118,14 +136,6 @@ public SaslRpcClient(AuthMethod method,
} }
} }
private static void readStatus(DataInputStream inStream) throws IOException {
int status = inStream.readInt(); // read status
if (status != SaslStatus.SUCCESS.state) {
throw new RemoteException(WritableUtils.readString(inStream),
WritableUtils.readString(inStream));
}
}
/** /**
* Do client side SASL authentication with server via the given InputStream * Do client side SASL authentication with server via the given InputStream
* and OutputStream * and OutputStream
@ -143,56 +153,142 @@ public boolean saslConnect(InputStream inS, OutputStream outS)
DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS)); DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream( DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(
outS)); outS));
try { // track if SASL ever started, or server switched us to simple
byte[] saslToken = new byte[0]; boolean inSasl = false;
if (saslClient.hasInitialResponse()) sendSaslMessage(outStream, negotiateRequest);
saslToken = saslClient.evaluateChallenge(saslToken);
while (saslToken != null) { // loop until sasl is complete or a rpc error occurs
outStream.writeInt(saslToken.length); boolean done = false;
outStream.write(saslToken, 0, saslToken.length); do {
outStream.flush(); int totalLen = inStream.readInt();
if (LOG.isDebugEnabled()) RpcResponseMessageWrapper responseWrapper =
LOG.debug("Have sent token of size " + saslToken.length new RpcResponseMessageWrapper();
+ " from initSASLContext."); responseWrapper.readFields(inStream);
readStatus(inStream); RpcResponseHeaderProto header = responseWrapper.getMessageHeader();
int len = inStream.readInt(); switch (header.getStatus()) {
if (len == SaslRpcServer.SWITCH_TO_SIMPLE_AUTH) { case ERROR: // might get a RPC error during
if (!fallbackAllowed) { case FATAL:
throw new IOException("Server asks us to fall back to SIMPLE " + throw new RemoteException(header.getExceptionClassName(),
"auth, but this client is configured to only allow secure " + header.getErrorMsg());
"connections."); default: break;
}
if (totalLen != responseWrapper.getLength()) {
throw new SaslException("Received malformed response length");
}
if (header.getCallId() != AuthProtocol.SASL.callId) {
throw new SaslException("Non-SASL response during negotiation");
}
RpcSaslProto saslMessage =
RpcSaslProto.parseFrom(responseWrapper.getMessageBytes());
if (LOG.isDebugEnabled()) {
LOG.debug("Received SASL message "+saslMessage);
}
// handle sasl negotiation process
RpcSaslProto.Builder response = null;
switch (saslMessage.getState()) {
case NEGOTIATE: {
inSasl = true;
// TODO: should instantiate sasl client based on advertisement
// but just blindly use the pre-instantiated sasl client for now
String clientAuthMethod = authMethod.toString();
SaslAuth saslAuthType = null;
for (SaslAuth authType : saslMessage.getAuthsList()) {
if (clientAuthMethod.equals(authType.getMethod())) {
saslAuthType = authType;
break;
}
} }
if (LOG.isDebugEnabled()) if (saslAuthType == null) {
LOG.debug("Server asks us to fall back to simple auth."); saslAuthType = SaslAuth.newBuilder()
saslClient.dispose(); .setMethod(clientAuthMethod)
return false; .setMechanism(saslClient.getMechanismName())
} else if ((len == 0) && saslClient.isComplete()) { .build();
}
byte[] challengeToken = null;
if (saslAuthType != null && saslAuthType.hasChallenge()) {
// server provided the first challenge
challengeToken = saslAuthType.getChallenge().toByteArray();
saslAuthType =
SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
} else if (saslClient.hasInitialResponse()) {
challengeToken = new byte[0];
}
byte[] responseToken = (challengeToken != null)
? saslClient.evaluateChallenge(challengeToken)
: new byte[0];
response = createSaslReply(SaslState.INITIATE, responseToken);
response.addAuths(saslAuthType);
break; break;
} }
saslToken = new byte[len]; case CHALLENGE: {
if (LOG.isDebugEnabled()) inSasl = true;
LOG.debug("Will read input token of size " + saslToken.length byte[] responseToken = saslEvaluateToken(saslMessage, false);
+ " for processing by initSASLContext"); response = createSaslReply(SaslState.RESPONSE, responseToken);
inStream.readFully(saslToken); break;
saslToken = saslClient.evaluateChallenge(saslToken); }
case SUCCESS: {
if (inSasl && saslEvaluateToken(saslMessage, true) != null) {
throw new SaslException("SASL client generated spurious token");
}
done = true;
break;
}
default: {
throw new SaslException(
"RPC client doesn't support SASL " + saslMessage.getState());
}
} }
if (!saslClient.isComplete()) { // shouldn't happen if (response != null) {
throw new SaslException("Internal negotiation error"); sendSaslMessage(outStream, response.build());
} }
if (LOG.isDebugEnabled()) { } while (!done);
LOG.debug("SASL client context established. Negotiated QoP: " if (!inSasl && !fallbackAllowed) {
+ saslClient.getNegotiatedProperty(Sasl.QOP)); throw new IOException("Server asks us to fall back to SIMPLE " +
} "auth, but this client is configured to only allow secure " +
return true; "connections.");
} catch (IOException e) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
// ignore further exceptions during cleanup
}
throw e;
} }
return inSasl;
}
private void sendSaslMessage(DataOutputStream out, RpcSaslProto message)
throws IOException {
if (LOG.isDebugEnabled()) {
LOG.debug("Sending sasl message "+message);
}
RpcRequestMessageWrapper request =
new RpcRequestMessageWrapper(saslHeader, message);
out.writeInt(request.getLength());
request.write(out);
out.flush();
}
private byte[] saslEvaluateToken(RpcSaslProto saslResponse,
boolean done) throws SaslException {
byte[] saslToken = null;
if (saslResponse.hasToken()) {
saslToken = saslResponse.getToken().toByteArray();
saslToken = saslClient.evaluateChallenge(saslToken);
} else if (!done) {
throw new SaslException("Challenge contains no token");
}
if (done && !saslClient.isComplete()) {
throw new SaslException("Client is out of sync with server");
}
return saslToken;
}
private RpcSaslProto.Builder createSaslReply(SaslState state,
byte[] responseToken) {
RpcSaslProto.Builder response = RpcSaslProto.newBuilder();
response.setState(state);
if (responseToken != null) {
response.setToken(ByteString.copyFrom(responseToken));
}
return response;
} }
/** /**

View File

@ -23,6 +23,7 @@
import java.io.DataInputStream; import java.io.DataInputStream;
import java.io.DataOutput; import java.io.DataOutput;
import java.io.IOException; import java.io.IOException;
import java.security.PrivilegedExceptionAction;
import java.security.Security; import java.security.Security;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
@ -35,6 +36,8 @@
import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback; import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl; import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -43,6 +46,8 @@
import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.Server; import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.ipc.Server.Connection;
import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.security.token.SecretManager.InvalidToken;
@ -58,8 +63,6 @@ public class SaslRpcServer {
public static final Map<String, String> SASL_PROPS = public static final Map<String, String> SASL_PROPS =
new TreeMap<String, String>(); new TreeMap<String, String>();
public static final int SWITCH_TO_SIMPLE_AUTH = -88;
public static enum QualityOfProtection { public static enum QualityOfProtection {
AUTHENTICATION("auth"), AUTHENTICATION("auth"),
INTEGRITY("auth-int"), INTEGRITY("auth-int"),
@ -75,7 +78,93 @@ public String getSaslQop() {
return saslQop; return saslQop;
} }
} }
@InterfaceAudience.Private
@InterfaceStability.Unstable
public AuthMethod authMethod;
public String mechanism;
public String protocol;
public String serverId;
@InterfaceAudience.Private
@InterfaceStability.Unstable
public SaslRpcServer(AuthMethod authMethod) throws IOException {
this.authMethod = authMethod;
mechanism = authMethod.getMechanismName();
switch (authMethod) {
case SIMPLE: {
return; // no sasl for simple
}
case TOKEN: {
protocol = "";
serverId = SaslRpcServer.SASL_DEFAULT_REALM;
break;
}
case KERBEROS: {
String fullName = UserGroupInformation.getCurrentUser().getUserName();
if (LOG.isDebugEnabled())
LOG.debug("Kerberos principal name is " + fullName);
KerberosName krbName = new KerberosName(fullName);
serverId = krbName.getHostName();
if (serverId == null) {
serverId = "";
}
protocol = krbName.getServiceName();
break;
}
default:
// we should never be able to get here
throw new AccessControlException(
"Server does not support SASL " + authMethod);
}
}
@InterfaceAudience.Private
@InterfaceStability.Unstable
public SaslServer create(Connection connection,
SecretManager<TokenIdentifier> secretManager
) throws IOException, InterruptedException {
UserGroupInformation ugi = UserGroupInformation.getCurrentUser();
final CallbackHandler callback;
switch (authMethod) {
case TOKEN: {
secretManager.checkAvailableForRead();
callback = new SaslDigestCallbackHandler(secretManager, connection);
break;
}
case KERBEROS: {
if (serverId.isEmpty()) {
throw new AccessControlException(
"Kerberos principal name does NOT have the expected "
+ "hostname part: " + ugi.getUserName());
}
callback = new SaslGssCallbackHandler();
break;
}
default:
// we should never be able to get here
throw new AccessControlException(
"Server does not support SASL " + authMethod);
}
SaslServer saslServer = ugi.doAs(
new PrivilegedExceptionAction<SaslServer>() {
@Override
public SaslServer run() throws SaslException {
return Sasl.createSaslServer(mechanism, protocol, serverId,
SaslRpcServer.SASL_PROPS, callback);
}
});
if (saslServer == null) {
throw new AccessControlException(
"Unable to find SASL server implementation for " + mechanism);
}
if (LOG.isDebugEnabled()) {
LOG.debug("Created SASL server with mechanism = " + mechanism);
}
return saslServer;
}
public static void init(Configuration conf) { public static void init(Configuration conf) {
QualityOfProtection saslQOP = QualityOfProtection.AUTHENTICATION; QualityOfProtection saslQOP = QualityOfProtection.AUTHENTICATION;
String rpcProtection = conf.get("hadoop.rpc.protection", String rpcProtection = conf.get("hadoop.rpc.protection",
@ -124,23 +213,14 @@ public static String[] splitKerberosName(String fullName) {
return fullName.split("[/@]"); return fullName.split("[/@]");
} }
@InterfaceStability.Evolving
public enum SaslStatus {
SUCCESS (0),
ERROR (1);
public final int state;
private SaslStatus(int state) {
this.state = state;
}
}
/** Authentication method */ /** Authentication method */
@InterfaceStability.Evolving @InterfaceStability.Evolving
public static enum AuthMethod { public static enum AuthMethod {
SIMPLE((byte) 80, ""), SIMPLE((byte) 80, ""),
KERBEROS((byte) 81, "GSSAPI"), KERBEROS((byte) 81, "GSSAPI"),
@Deprecated
DIGEST((byte) 82, "DIGEST-MD5"), DIGEST((byte) 82, "DIGEST-MD5"),
TOKEN((byte) 82, "DIGEST-MD5"),
PLAIN((byte) 83, "PLAIN"); PLAIN((byte) 83, "PLAIN");
/** The code for this method. */ /** The code for this method. */

View File

@ -1076,7 +1076,7 @@ public static enum AuthenticationMethod {
HadoopConfiguration.SIMPLE_CONFIG_NAME), HadoopConfiguration.SIMPLE_CONFIG_NAME),
KERBEROS(AuthMethod.KERBEROS, KERBEROS(AuthMethod.KERBEROS,
HadoopConfiguration.USER_KERBEROS_CONFIG_NAME), HadoopConfiguration.USER_KERBEROS_CONFIG_NAME),
TOKEN(AuthMethod.DIGEST), TOKEN(AuthMethod.TOKEN),
CERTIFICATE(null), CERTIFICATE(null),
KERBEROS_SSL(null), KERBEROS_SSL(null),
PROXY(null); PROXY(null);

View File

@ -94,7 +94,7 @@ public static IpcConnectionContextProto makeIpcConnectionContext(
// Real user was established as part of the connection. // Real user was established as part of the connection.
// Send effective user only. // Send effective user only.
ugiProto.setEffectiveUser(ugi.getUserName()); ugiProto.setEffectiveUser(ugi.getUserName());
} else if (authMethod == AuthMethod.DIGEST) { } else if (authMethod == AuthMethod.TOKEN) {
// With token, the connection itself establishes // With token, the connection itself establishes
// both real and effective user. Hence send none in header. // both real and effective user. Hence send none in header.
} else { // Simple authentication } else { // Simple authentication

View File

@ -127,3 +127,26 @@ message RpcResponseHeaderProto {
optional string errorMsg = 5; // if request fails, often contains strack trace optional string errorMsg = 5; // if request fails, often contains strack trace
optional RpcErrorCodeProto errorDetail = 6; // in case of error optional RpcErrorCodeProto errorDetail = 6; // in case of error
} }
message RpcSaslProto {
enum SaslState {
SUCCESS = 0;
NEGOTIATE = 1;
INITIATE = 2;
CHALLENGE = 3;
RESPONSE = 4;
}
message SaslAuth {
required string method = 1;
required string mechanism = 2;
optional string protocol = 3;
optional string serverId = 4;
optional bytes challenge = 5;
}
optional uint32 version = 1;
required SaslState state = 2;
optional bytes token = 3;
repeated SaslAuth auths = 4;
}

View File

@ -674,6 +674,7 @@ private String getAuthMethod(
try { try {
return internalGetAuthMethod(clientAuth, serverAuth, false, false); return internalGetAuthMethod(clientAuth, serverAuth, false, false);
} catch (Exception e) { } catch (Exception e) {
LOG.warn("Auth method failure", e);
return e.toString(); return e.toString();
} }
} }
@ -685,6 +686,7 @@ private String getAuthMethod(
try { try {
return internalGetAuthMethod(clientAuth, serverAuth, true, useValidToken); return internalGetAuthMethod(clientAuth, serverAuth, true, useValidToken);
} catch (Exception e) { } catch (Exception e) {
LOG.warn("Auth method failure", e);
return e.toString(); return e.toString();
} }
} }
@ -702,7 +704,7 @@ private String internalGetAuthMethod(
UserGroupInformation.setConfiguration(serverConf); UserGroupInformation.setConfiguration(serverConf);
final UserGroupInformation serverUgi = final UserGroupInformation serverUgi =
UserGroupInformation.createRemoteUser(currentUser + "-SERVER"); UserGroupInformation.createRemoteUser(currentUser + "-SERVER/localhost@NONE");
serverUgi.setAuthenticationMethod(serverAuth); serverUgi.setAuthenticationMethod(serverAuth);
final TestTokenSecretManager sm = new TestTokenSecretManager(); final TestTokenSecretManager sm = new TestTokenSecretManager();