HADOOP-14578. Bind IPC connections to kerberos UPN host for proxy users. Contributed by Daryn Sharp.

This commit is contained in:
Kihwal Lee 2017-07-26 13:12:39 -05:00
parent a92bf39e23
commit 27a1a5fde9
2 changed files with 96 additions and 14 deletions

View File

@ -633,7 +633,8 @@ private synchronized boolean updateAddress() throws IOException {
return false;
}
private synchronized void setupConnection() throws IOException {
private synchronized void setupConnection(
UserGroupInformation ticket) throws IOException {
short ioFailures = 0;
short timeoutFailures = 0;
while (true) {
@ -661,24 +662,26 @@ private synchronized void setupConnection() throws IOException {
* client, to ensure Server matching address of the client connection
* to host name in principal passed.
*/
UserGroupInformation ticket = remoteId.getTicket();
InetSocketAddress bindAddr = null;
if (ticket != null && ticket.hasKerberosCredentials()) {
KerberosInfo krbInfo =
remoteId.getProtocol().getAnnotation(KerberosInfo.class);
if (krbInfo != null && krbInfo.clientPrincipal() != null) {
String host =
SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName());
if (krbInfo != null) {
String principal = ticket.getUserName();
String host = SecurityUtil.getHostFromPrincipal(principal);
// If host name is a valid local address then bind socket to it
InetAddress localAddr = NetUtils.getLocalInetAddress(host);
if (localAddr != null) {
this.socket.setReuseAddress(true);
this.socket.bind(new InetSocketAddress(localAddr, 0));
if (LOG.isDebugEnabled()) {
LOG.debug("Binding " + principal + " to " + localAddr);
}
bindAddr = new InetSocketAddress(localAddr, 0);
}
}
}
NetUtils.connect(this.socket, server, connectionTimeout);
NetUtils.connect(this.socket, server, bindAddr, connectionTimeout);
this.socket.setSoTimeout(soTimeout);
return;
} catch (ConnectTimeoutException toe) {
@ -762,7 +765,14 @@ private synchronized void setupIOstreams(
AtomicBoolean fallbackToSimpleAuth) {
if (socket != null || shouldCloseConnection.get()) {
return;
}
}
UserGroupInformation ticket = remoteId.getTicket();
if (ticket != null) {
final UserGroupInformation realUser = ticket.getRealUser();
if (realUser != null) {
ticket = realUser;
}
}
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Connecting to "+server);
@ -774,14 +784,10 @@ private synchronized void setupIOstreams(
short numRetries = 0;
Random rand = null;
while (true) {
setupConnection();
setupConnection(ticket);
ipcStreams = new IpcStreams(socket, maxResponseLength);
writeConnectionHeader(ipcStreams);
if (authProtocol == AuthProtocol.SASL) {
UserGroupInformation ticket = remoteId.getTicket();
if (ticket.getRealUser() != null) {
ticket = ticket.getRealUser();
}
try {
authMethod = ticket
.doAs(new PrivilegedExceptionAction<AuthMethod>() {

View File

@ -39,9 +39,12 @@
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Collections;
@ -76,6 +79,7 @@
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
@ -1484,6 +1488,78 @@ public void testRpcResponseLimit() throws Throwable {
Assert.fail("didn't get limit exceeded");
}
@Test
public void testUserBinding() throws Exception {
checkUserBinding(false);
}
@Test
public void testProxyUserBinding() throws Exception {
checkUserBinding(true);
}
private void checkUserBinding(boolean asProxy) throws Exception {
Socket s;
// don't attempt bind with no service host.
s = checkConnect(null, asProxy);
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
// don't attempt bind with service host not belonging to this host.
s = checkConnect("1.2.3.4", asProxy);
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
// do attempt bind when service host is this host.
InetAddress addr = InetAddress.getLocalHost();
s = checkConnect(addr.getHostAddress(), asProxy);
Mockito.verify(s).bind(new InetSocketAddress(addr, 0));
}
// dummy protocol that claims to support kerberos.
@KerberosInfo(serverPrincipal = "server@REALM")
private static class TestBindingProtocol {
}
private Socket checkConnect(String addr, boolean asProxy) throws Exception {
// create a fake ugi that claims to have kerberos credentials.
StringBuilder principal = new StringBuilder();
principal.append("client");
if (addr != null) {
principal.append("/").append(addr);
}
principal.append("@REALM");
UserGroupInformation ugi =
spy(UserGroupInformation.createRemoteUser(principal.toString()));
Mockito.doReturn(true).when(ugi).hasKerberosCredentials();
if (asProxy) {
ugi = UserGroupInformation.createProxyUser("proxy", ugi);
}
// create a mock socket that throws on connect.
SocketException expectedConnectEx =
new SocketException("Expected connect failure");
Socket s = Mockito.mock(Socket.class);
SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
Mockito.doReturn(s).when(mockFactory).createSocket();
doThrow(expectedConnectEx).when(s).connect(
Mockito.any(SocketAddress.class), Mockito.anyInt());
// do a dummy call and expect it to throw an exception on connect.
// tests should verify if/how a bind occurred.
try (Client client = new Client(LongWritable.class, conf, mockFactory)) {
final InetSocketAddress sockAddr = new InetSocketAddress(0);
final LongWritable param = new LongWritable(RANDOM.nextLong());
final ConnectionId remoteId = new ConnectionId(
sockAddr, TestBindingProtocol.class, ugi, 0,
RetryPolicies.TRY_ONCE_THEN_FAIL, conf);
client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, null);
fail("call didn't throw connect exception");
} catch (SocketException se) {
// ipc layer re-wraps exceptions, so check the cause.
Assert.assertSame(expectedConnectEx, se.getCause());
}
return s;
}
private void doIpcVersionTest(
byte[] requestData,
byte[] expectedResponse) throws IOException {