HADOOP-14578. Bind IPC connections to kerberos UPN host for proxy users. Contributed by Daryn Sharp.

This commit is contained in:
Kihwal Lee 2017-07-26 13:12:39 -05:00
parent a92bf39e23
commit 27a1a5fde9
2 changed files with 96 additions and 14 deletions

View File

@ -633,7 +633,8 @@ private synchronized boolean updateAddress() throws IOException {
return false; return false;
} }
private synchronized void setupConnection() throws IOException { private synchronized void setupConnection(
UserGroupInformation ticket) throws IOException {
short ioFailures = 0; short ioFailures = 0;
short timeoutFailures = 0; short timeoutFailures = 0;
while (true) { while (true) {
@ -661,24 +662,26 @@ private synchronized void setupConnection() throws IOException {
* client, to ensure Server matching address of the client connection * client, to ensure Server matching address of the client connection
* to host name in principal passed. * to host name in principal passed.
*/ */
UserGroupInformation ticket = remoteId.getTicket(); InetSocketAddress bindAddr = null;
if (ticket != null && ticket.hasKerberosCredentials()) { if (ticket != null && ticket.hasKerberosCredentials()) {
KerberosInfo krbInfo = KerberosInfo krbInfo =
remoteId.getProtocol().getAnnotation(KerberosInfo.class); remoteId.getProtocol().getAnnotation(KerberosInfo.class);
if (krbInfo != null && krbInfo.clientPrincipal() != null) { if (krbInfo != null) {
String host = String principal = ticket.getUserName();
SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName()); String host = SecurityUtil.getHostFromPrincipal(principal);
// If host name is a valid local address then bind socket to it // If host name is a valid local address then bind socket to it
InetAddress localAddr = NetUtils.getLocalInetAddress(host); InetAddress localAddr = NetUtils.getLocalInetAddress(host);
if (localAddr != null) { if (localAddr != null) {
this.socket.setReuseAddress(true); this.socket.setReuseAddress(true);
this.socket.bind(new InetSocketAddress(localAddr, 0)); if (LOG.isDebugEnabled()) {
LOG.debug("Binding " + principal + " to " + localAddr);
}
bindAddr = new InetSocketAddress(localAddr, 0);
} }
} }
} }
NetUtils.connect(this.socket, server, connectionTimeout); NetUtils.connect(this.socket, server, bindAddr, connectionTimeout);
this.socket.setSoTimeout(soTimeout); this.socket.setSoTimeout(soTimeout);
return; return;
} catch (ConnectTimeoutException toe) { } catch (ConnectTimeoutException toe) {
@ -763,6 +766,13 @@ private synchronized void setupIOstreams(
if (socket != null || shouldCloseConnection.get()) { if (socket != null || shouldCloseConnection.get()) {
return; return;
} }
UserGroupInformation ticket = remoteId.getTicket();
if (ticket != null) {
final UserGroupInformation realUser = ticket.getRealUser();
if (realUser != null) {
ticket = realUser;
}
}
try { try {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Connecting to "+server); LOG.debug("Connecting to "+server);
@ -774,14 +784,10 @@ private synchronized void setupIOstreams(
short numRetries = 0; short numRetries = 0;
Random rand = null; Random rand = null;
while (true) { while (true) {
setupConnection(); setupConnection(ticket);
ipcStreams = new IpcStreams(socket, maxResponseLength); ipcStreams = new IpcStreams(socket, maxResponseLength);
writeConnectionHeader(ipcStreams); writeConnectionHeader(ipcStreams);
if (authProtocol == AuthProtocol.SASL) { if (authProtocol == AuthProtocol.SASL) {
UserGroupInformation ticket = remoteId.getTicket();
if (ticket.getRealUser() != null) {
ticket = ticket.getRealUser();
}
try { try {
authMethod = ticket authMethod = ticket
.doAs(new PrivilegedExceptionAction<AuthMethod>() { .doAs(new PrivilegedExceptionAction<AuthMethod>() {

View File

@ -39,9 +39,12 @@
import java.io.OutputStream; import java.io.OutputStream;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -76,6 +79,7 @@
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.ConnectTimeoutException; import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
@ -1484,6 +1488,78 @@ public void testRpcResponseLimit() throws Throwable {
Assert.fail("didn't get limit exceeded"); Assert.fail("didn't get limit exceeded");
} }
@Test
public void testUserBinding() throws Exception {
checkUserBinding(false);
}
@Test
public void testProxyUserBinding() throws Exception {
checkUserBinding(true);
}
private void checkUserBinding(boolean asProxy) throws Exception {
Socket s;
// don't attempt bind with no service host.
s = checkConnect(null, asProxy);
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
// don't attempt bind with service host not belonging to this host.
s = checkConnect("1.2.3.4", asProxy);
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
// do attempt bind when service host is this host.
InetAddress addr = InetAddress.getLocalHost();
s = checkConnect(addr.getHostAddress(), asProxy);
Mockito.verify(s).bind(new InetSocketAddress(addr, 0));
}
// dummy protocol that claims to support kerberos.
@KerberosInfo(serverPrincipal = "server@REALM")
private static class TestBindingProtocol {
}
private Socket checkConnect(String addr, boolean asProxy) throws Exception {
// create a fake ugi that claims to have kerberos credentials.
StringBuilder principal = new StringBuilder();
principal.append("client");
if (addr != null) {
principal.append("/").append(addr);
}
principal.append("@REALM");
UserGroupInformation ugi =
spy(UserGroupInformation.createRemoteUser(principal.toString()));
Mockito.doReturn(true).when(ugi).hasKerberosCredentials();
if (asProxy) {
ugi = UserGroupInformation.createProxyUser("proxy", ugi);
}
// create a mock socket that throws on connect.
SocketException expectedConnectEx =
new SocketException("Expected connect failure");
Socket s = Mockito.mock(Socket.class);
SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
Mockito.doReturn(s).when(mockFactory).createSocket();
doThrow(expectedConnectEx).when(s).connect(
Mockito.any(SocketAddress.class), Mockito.anyInt());
// do a dummy call and expect it to throw an exception on connect.
// tests should verify if/how a bind occurred.
try (Client client = new Client(LongWritable.class, conf, mockFactory)) {
final InetSocketAddress sockAddr = new InetSocketAddress(0);
final LongWritable param = new LongWritable(RANDOM.nextLong());
final ConnectionId remoteId = new ConnectionId(
sockAddr, TestBindingProtocol.class, ugi, 0,
RetryPolicies.TRY_ONCE_THEN_FAIL, conf);
client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, null);
fail("call didn't throw connect exception");
} catch (SocketException se) {
// ipc layer re-wraps exceptions, so check the cause.
Assert.assertSame(expectedConnectEx, se.getCause());
}
return s;
}
private void doIpcVersionTest( private void doIpcVersionTest(
byte[] requestData, byte[] requestData,
byte[] expectedResponse) throws IOException { byte[] expectedResponse) throws IOException {