HADOOP-14578. Bind IPC connections to kerberos UPN host for proxy users. Contributed by Daryn Sharp.
This commit is contained in:
parent
a92bf39e23
commit
27a1a5fde9
@ -633,7 +633,8 @@ private synchronized boolean updateAddress() throws IOException {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private synchronized void setupConnection() throws IOException {
|
private synchronized void setupConnection(
|
||||||
|
UserGroupInformation ticket) throws IOException {
|
||||||
short ioFailures = 0;
|
short ioFailures = 0;
|
||||||
short timeoutFailures = 0;
|
short timeoutFailures = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -661,24 +662,26 @@ private synchronized void setupConnection() throws IOException {
|
|||||||
* client, to ensure Server matching address of the client connection
|
* client, to ensure Server matching address of the client connection
|
||||||
* to host name in principal passed.
|
* to host name in principal passed.
|
||||||
*/
|
*/
|
||||||
UserGroupInformation ticket = remoteId.getTicket();
|
InetSocketAddress bindAddr = null;
|
||||||
if (ticket != null && ticket.hasKerberosCredentials()) {
|
if (ticket != null && ticket.hasKerberosCredentials()) {
|
||||||
KerberosInfo krbInfo =
|
KerberosInfo krbInfo =
|
||||||
remoteId.getProtocol().getAnnotation(KerberosInfo.class);
|
remoteId.getProtocol().getAnnotation(KerberosInfo.class);
|
||||||
if (krbInfo != null && krbInfo.clientPrincipal() != null) {
|
if (krbInfo != null) {
|
||||||
String host =
|
String principal = ticket.getUserName();
|
||||||
SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName());
|
String host = SecurityUtil.getHostFromPrincipal(principal);
|
||||||
|
|
||||||
// If host name is a valid local address then bind socket to it
|
// If host name is a valid local address then bind socket to it
|
||||||
InetAddress localAddr = NetUtils.getLocalInetAddress(host);
|
InetAddress localAddr = NetUtils.getLocalInetAddress(host);
|
||||||
if (localAddr != null) {
|
if (localAddr != null) {
|
||||||
this.socket.setReuseAddress(true);
|
this.socket.setReuseAddress(true);
|
||||||
this.socket.bind(new InetSocketAddress(localAddr, 0));
|
if (LOG.isDebugEnabled()) {
|
||||||
|
LOG.debug("Binding " + principal + " to " + localAddr);
|
||||||
|
}
|
||||||
|
bindAddr = new InetSocketAddress(localAddr, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NetUtils.connect(this.socket, server, connectionTimeout);
|
NetUtils.connect(this.socket, server, bindAddr, connectionTimeout);
|
||||||
this.socket.setSoTimeout(soTimeout);
|
this.socket.setSoTimeout(soTimeout);
|
||||||
return;
|
return;
|
||||||
} catch (ConnectTimeoutException toe) {
|
} catch (ConnectTimeoutException toe) {
|
||||||
@ -763,6 +766,13 @@ private synchronized void setupIOstreams(
|
|||||||
if (socket != null || shouldCloseConnection.get()) {
|
if (socket != null || shouldCloseConnection.get()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
UserGroupInformation ticket = remoteId.getTicket();
|
||||||
|
if (ticket != null) {
|
||||||
|
final UserGroupInformation realUser = ticket.getRealUser();
|
||||||
|
if (realUser != null) {
|
||||||
|
ticket = realUser;
|
||||||
|
}
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
if (LOG.isDebugEnabled()) {
|
if (LOG.isDebugEnabled()) {
|
||||||
LOG.debug("Connecting to "+server);
|
LOG.debug("Connecting to "+server);
|
||||||
@ -774,14 +784,10 @@ private synchronized void setupIOstreams(
|
|||||||
short numRetries = 0;
|
short numRetries = 0;
|
||||||
Random rand = null;
|
Random rand = null;
|
||||||
while (true) {
|
while (true) {
|
||||||
setupConnection();
|
setupConnection(ticket);
|
||||||
ipcStreams = new IpcStreams(socket, maxResponseLength);
|
ipcStreams = new IpcStreams(socket, maxResponseLength);
|
||||||
writeConnectionHeader(ipcStreams);
|
writeConnectionHeader(ipcStreams);
|
||||||
if (authProtocol == AuthProtocol.SASL) {
|
if (authProtocol == AuthProtocol.SASL) {
|
||||||
UserGroupInformation ticket = remoteId.getTicket();
|
|
||||||
if (ticket.getRealUser() != null) {
|
|
||||||
ticket = ticket.getRealUser();
|
|
||||||
}
|
|
||||||
try {
|
try {
|
||||||
authMethod = ticket
|
authMethod = ticket
|
||||||
.doAs(new PrivilegedExceptionAction<AuthMethod>() {
|
.doAs(new PrivilegedExceptionAction<AuthMethod>() {
|
||||||
|
@ -39,9 +39,12 @@
|
|||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.lang.reflect.Proxy;
|
import java.lang.reflect.Proxy;
|
||||||
|
import java.net.InetAddress;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.ServerSocket;
|
import java.net.ServerSocket;
|
||||||
import java.net.Socket;
|
import java.net.Socket;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.net.SocketException;
|
||||||
import java.net.SocketTimeoutException;
|
import java.net.SocketTimeoutException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@ -76,6 +79,7 @@
|
|||||||
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
|
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
|
||||||
import org.apache.hadoop.net.ConnectTimeoutException;
|
import org.apache.hadoop.net.ConnectTimeoutException;
|
||||||
import org.apache.hadoop.net.NetUtils;
|
import org.apache.hadoop.net.NetUtils;
|
||||||
|
import org.apache.hadoop.security.KerberosInfo;
|
||||||
import org.apache.hadoop.security.SecurityUtil;
|
import org.apache.hadoop.security.SecurityUtil;
|
||||||
import org.apache.hadoop.security.UserGroupInformation;
|
import org.apache.hadoop.security.UserGroupInformation;
|
||||||
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
|
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
|
||||||
@ -1484,6 +1488,78 @@ public void testRpcResponseLimit() throws Throwable {
|
|||||||
Assert.fail("didn't get limit exceeded");
|
Assert.fail("didn't get limit exceeded");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUserBinding() throws Exception {
|
||||||
|
checkUserBinding(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testProxyUserBinding() throws Exception {
|
||||||
|
checkUserBinding(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkUserBinding(boolean asProxy) throws Exception {
|
||||||
|
Socket s;
|
||||||
|
// don't attempt bind with no service host.
|
||||||
|
s = checkConnect(null, asProxy);
|
||||||
|
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
|
||||||
|
|
||||||
|
// don't attempt bind with service host not belonging to this host.
|
||||||
|
s = checkConnect("1.2.3.4", asProxy);
|
||||||
|
Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class));
|
||||||
|
|
||||||
|
// do attempt bind when service host is this host.
|
||||||
|
InetAddress addr = InetAddress.getLocalHost();
|
||||||
|
s = checkConnect(addr.getHostAddress(), asProxy);
|
||||||
|
Mockito.verify(s).bind(new InetSocketAddress(addr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// dummy protocol that claims to support kerberos.
|
||||||
|
@KerberosInfo(serverPrincipal = "server@REALM")
|
||||||
|
private static class TestBindingProtocol {
|
||||||
|
}
|
||||||
|
|
||||||
|
private Socket checkConnect(String addr, boolean asProxy) throws Exception {
|
||||||
|
// create a fake ugi that claims to have kerberos credentials.
|
||||||
|
StringBuilder principal = new StringBuilder();
|
||||||
|
principal.append("client");
|
||||||
|
if (addr != null) {
|
||||||
|
principal.append("/").append(addr);
|
||||||
|
}
|
||||||
|
principal.append("@REALM");
|
||||||
|
UserGroupInformation ugi =
|
||||||
|
spy(UserGroupInformation.createRemoteUser(principal.toString()));
|
||||||
|
Mockito.doReturn(true).when(ugi).hasKerberosCredentials();
|
||||||
|
if (asProxy) {
|
||||||
|
ugi = UserGroupInformation.createProxyUser("proxy", ugi);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a mock socket that throws on connect.
|
||||||
|
SocketException expectedConnectEx =
|
||||||
|
new SocketException("Expected connect failure");
|
||||||
|
Socket s = Mockito.mock(Socket.class);
|
||||||
|
SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
|
||||||
|
Mockito.doReturn(s).when(mockFactory).createSocket();
|
||||||
|
doThrow(expectedConnectEx).when(s).connect(
|
||||||
|
Mockito.any(SocketAddress.class), Mockito.anyInt());
|
||||||
|
|
||||||
|
// do a dummy call and expect it to throw an exception on connect.
|
||||||
|
// tests should verify if/how a bind occurred.
|
||||||
|
try (Client client = new Client(LongWritable.class, conf, mockFactory)) {
|
||||||
|
final InetSocketAddress sockAddr = new InetSocketAddress(0);
|
||||||
|
final LongWritable param = new LongWritable(RANDOM.nextLong());
|
||||||
|
final ConnectionId remoteId = new ConnectionId(
|
||||||
|
sockAddr, TestBindingProtocol.class, ugi, 0,
|
||||||
|
RetryPolicies.TRY_ONCE_THEN_FAIL, conf);
|
||||||
|
client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, null);
|
||||||
|
fail("call didn't throw connect exception");
|
||||||
|
} catch (SocketException se) {
|
||||||
|
// ipc layer re-wraps exceptions, so check the cause.
|
||||||
|
Assert.assertSame(expectedConnectEx, se.getCause());
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
private void doIpcVersionTest(
|
private void doIpcVersionTest(
|
||||||
byte[] requestData,
|
byte[] requestData,
|
||||||
byte[] expectedResponse) throws IOException {
|
byte[] expectedResponse) throws IOException {
|
||||||
|
Loading…
Reference in New Issue
Block a user