HADOOP-6907. Rpc client doesn't use the per-connection conf to figure out server's Kerberos principal. Contributed by Kan Zhang.
git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@991780 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
5c8d9aecf7
commit
1c75bcc76b
@ -226,6 +226,9 @@ Trunk (unreleased changes)
|
||||
HADOOP-6913. Circular initialization between UserGroupInformation and
|
||||
KerberosName (Kan Zhang via boryas)
|
||||
|
||||
HADOOP-6907. Rpc client doesn't use the per-connection conf to figure
|
||||
out server's Kerberos principal (Kan Zhang via hairong)
|
||||
|
||||
Release 0.21.0 - Unreleased
|
||||
|
||||
INCOMPATIBLE CHANGES
|
||||
|
@ -37,6 +37,7 @@
|
||||
import java.util.Hashtable;
|
||||
import java.util.Iterator;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
@ -45,6 +46,8 @@
|
||||
|
||||
import org.apache.commons.logging.*;
|
||||
|
||||
import org.apache.hadoop.classification.InterfaceAudience;
|
||||
import org.apache.hadoop.classification.InterfaceStability;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.io.IOUtils;
|
||||
import org.apache.hadoop.io.Text;
|
||||
@ -80,12 +83,6 @@ public class Client {
|
||||
private int counter; // counter for call ids
|
||||
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
|
||||
final private Configuration conf;
|
||||
final private int maxIdleTime; //connections will be culled if it was idle for
|
||||
//maxIdleTime msecs
|
||||
final private int maxRetries; //the max. no. of retries for socket connections
|
||||
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
||||
private int pingInterval; // how often sends ping to the server in msecs
|
||||
final private boolean doPing; //do we need to send ping message
|
||||
|
||||
private SocketFactory socketFactory; // how to create sockets
|
||||
private int refCount = 1;
|
||||
@ -220,6 +217,12 @@ private class Connection extends Thread {
|
||||
private DataInputStream in;
|
||||
private DataOutputStream out;
|
||||
private int rpcTimeout;
|
||||
private int maxIdleTime; //connections will be culled if it was idle for
|
||||
//maxIdleTime msecs
|
||||
private int maxRetries; //the max. no. of retries for socket connections
|
||||
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
||||
private boolean doPing; //do we need to send ping message
|
||||
private int pingInterval; // how often sends ping to the server in msecs
|
||||
|
||||
// currently active calls
|
||||
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
|
||||
@ -235,6 +238,15 @@ public Connection(ConnectionId remoteId) throws IOException {
|
||||
remoteId.getAddress().getHostName());
|
||||
}
|
||||
this.rpcTimeout = remoteId.getRpcTimeout();
|
||||
this.maxIdleTime = remoteId.getMaxIdleTime();
|
||||
this.maxRetries = remoteId.getMaxRetries();
|
||||
this.tcpNoDelay = remoteId.getTcpNoDelay();
|
||||
this.doPing = remoteId.getDoPing();
|
||||
this.pingInterval = remoteId.getPingInterval();
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug("The ping interval is" + this.pingInterval + "ms.");
|
||||
}
|
||||
|
||||
UserGroupInformation ticket = remoteId.getTicket();
|
||||
Class<?> protocol = remoteId.getProtocol();
|
||||
this.useSasl = UserGroupInformation.isSecurityEnabled();
|
||||
@ -256,15 +268,9 @@ public Connection(ConnectionId remoteId) throws IOException {
|
||||
}
|
||||
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
|
||||
if (krbInfo != null) {
|
||||
String serverKey = krbInfo.serverPrincipal();
|
||||
if (serverKey == null) {
|
||||
throw new IOException(
|
||||
"Can't obtain server Kerberos config key from KerberosInfo");
|
||||
}
|
||||
serverPrincipal = SecurityUtil.getServerPrincipal(
|
||||
conf.get(serverKey), server.getAddress().getCanonicalHostName());
|
||||
serverPrincipal = remoteId.getServerPrincipal();
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug("RPC Server Kerberos principal name for protocol="
|
||||
LOG.debug("RPC Server's Kerberos principal name for protocol="
|
||||
+ protocol.getCanonicalName() + " is " + serverPrincipal);
|
||||
}
|
||||
}
|
||||
@ -882,15 +888,6 @@ public synchronized void callComplete(ParallelCall call) {
|
||||
public Client(Class<? extends Writable> valueClass, Configuration conf,
|
||||
SocketFactory factory) {
|
||||
this.valueClass = valueClass;
|
||||
this.maxIdleTime =
|
||||
conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
|
||||
this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
|
||||
this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
|
||||
this.doPing = conf.getBoolean("ipc.client.ping", true);
|
||||
this.pingInterval = getPingInterval(conf);
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug("The ping interval is" + this.pingInterval + "ms.");
|
||||
}
|
||||
this.conf = conf;
|
||||
this.socketFactory = factory;
|
||||
}
|
||||
@ -942,7 +939,7 @@ public void stop() {
|
||||
/** Make a call, passing <code>param</code>, to the IPC server running at
|
||||
* <code>address</code>, returning the value. Throws exceptions if there are
|
||||
* network problems or if the remote code threw an exception.
|
||||
* @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
|
||||
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Writable call(Writable param, InetSocketAddress address)
|
||||
@ -955,27 +952,60 @@ public Writable call(Writable param, InetSocketAddress address)
|
||||
* the value.
|
||||
* Throws exceptions if there are network problems or if the remote code
|
||||
* threw an exception.
|
||||
* @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
|
||||
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Writable call(Writable param, InetSocketAddress addr,
|
||||
UserGroupInformation ticket)
|
||||
throws InterruptedException, IOException {
|
||||
return call(param, addr, null, ticket, 0);
|
||||
ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0,
|
||||
conf);
|
||||
return call(param, remoteId);
|
||||
}
|
||||
|
||||
/** Make a call, passing <code>param</code>, to the IPC server running at
|
||||
* <code>address</code> which is servicing the <code>protocol</code> protocol,
|
||||
* with the <code>ticket</code> credentials, returning the value.
|
||||
* with the <code>ticket</code> credentials and <code>rpcTimeout</code> as
|
||||
* timeout, returning the value.
|
||||
* Throws exceptions if there are network problems or if the remote code
|
||||
* threw an exception. */
|
||||
* threw an exception.
|
||||
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Writable call(Writable param, InetSocketAddress addr,
|
||||
Class<?> protocol, UserGroupInformation ticket,
|
||||
int rpcTimeout)
|
||||
throws InterruptedException, IOException {
|
||||
ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
|
||||
ticket, rpcTimeout, conf);
|
||||
return call(param, remoteId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a call, passing <code>param</code>, to the IPC server running at
|
||||
* <code>address</code> which is servicing the <code>protocol</code> protocol,
|
||||
* with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
|
||||
* timeout and <code>conf</code> as conf for this connection, returning the
|
||||
* value. Throws exceptions if there are network problems or if the remote
|
||||
* code threw an exception.
|
||||
*/
|
||||
public Writable call(Writable param, InetSocketAddress addr,
|
||||
Class<?> protocol, UserGroupInformation ticket,
|
||||
int rpcTimeout, Configuration conf)
|
||||
throws InterruptedException, IOException {
|
||||
ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
|
||||
ticket, rpcTimeout, conf);
|
||||
return call(param, remoteId);
|
||||
}
|
||||
|
||||
/** Make a call, passing <code>param</code>, to the IPC server defined by
|
||||
* <code>remoteId</code>, returning the value.
|
||||
* Throws exceptions if there are network problems or if the remote code
|
||||
* threw an exception. */
|
||||
public Writable call(Writable param, ConnectionId remoteId)
|
||||
throws InterruptedException, IOException {
|
||||
Call call = new Call(param);
|
||||
Connection connection = getConnection(
|
||||
addr, protocol, ticket, rpcTimeout, call);
|
||||
Connection connection = getConnection(remoteId, call);
|
||||
connection.sendParam(call); // send the parameter
|
||||
boolean interrupted = false;
|
||||
synchronized (call) {
|
||||
@ -998,7 +1028,7 @@ public Writable call(Writable param, InetSocketAddress addr,
|
||||
call.error.fillInStackTrace();
|
||||
throw call.error;
|
||||
} else { // local exception
|
||||
throw wrapException(addr, call.error);
|
||||
throw wrapException(remoteId.getAddress(), call.error);
|
||||
}
|
||||
} else {
|
||||
return call.value;
|
||||
@ -1038,25 +1068,34 @@ private IOException wrapException(InetSocketAddress addr,
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes a set of calls in parallel. Each parameter is sent to the
|
||||
* corresponding address. When all values are available, or have timed out
|
||||
* or errored, the collected results are returned in an array. The array
|
||||
* contains nulls for calls that timed out or errored.
|
||||
* @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead
|
||||
* @deprecated Use {@link #call(Writable[], InetSocketAddress[],
|
||||
* Class, UserGroupInformation, Configuration)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
|
||||
throws IOException, InterruptedException {
|
||||
return call(params, addresses, null, null);
|
||||
return call(params, addresses, null, null, conf);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #call(Writable[], InetSocketAddress[],
|
||||
* Class, UserGroupInformation, Configuration)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
||||
Class<?> protocol, UserGroupInformation ticket)
|
||||
throws IOException, InterruptedException {
|
||||
return call(params, addresses, protocol, ticket, conf);
|
||||
}
|
||||
|
||||
|
||||
/** Makes a set of calls in parallel. Each parameter is sent to the
|
||||
* corresponding address. When all values are available, or have timed out
|
||||
* or errored, the collected results are returned in an array. The array
|
||||
* contains nulls for calls that timed out or errored. */
|
||||
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
||||
Class<?> protocol, UserGroupInformation ticket)
|
||||
throws IOException, InterruptedException {
|
||||
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
||||
Class<?> protocol, UserGroupInformation ticket, Configuration conf)
|
||||
throws IOException, InterruptedException {
|
||||
if (addresses.length == 0) return new Writable[0];
|
||||
|
||||
ParallelResults results = new ParallelResults(params.length);
|
||||
@ -1064,8 +1103,9 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
||||
for (int i = 0; i < params.length; i++) {
|
||||
ParallelCall call = new ParallelCall(params[i], results, i);
|
||||
try {
|
||||
Connection connection =
|
||||
getConnection(addresses[i], protocol, ticket, 0, call);
|
||||
ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
|
||||
protocol, ticket, 0, conf);
|
||||
Connection connection = getConnection(remoteId, call);
|
||||
connection.sendParam(call); // send each parameter
|
||||
} catch (IOException e) {
|
||||
// log errors
|
||||
@ -1084,12 +1124,18 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
|
||||
}
|
||||
}
|
||||
|
||||
// for unit testing only
|
||||
@InterfaceAudience.Private
|
||||
@InterfaceStability.Unstable
|
||||
Set<ConnectionId> getConnectionIds() {
|
||||
synchronized (connections) {
|
||||
return connections.keySet();
|
||||
}
|
||||
}
|
||||
|
||||
/** Get a connection from the pool, or create a new one and add it to the
|
||||
* pool. Connections to a given host/port are reused. */
|
||||
private Connection getConnection(InetSocketAddress addr,
|
||||
Class<?> protocol,
|
||||
UserGroupInformation ticket,
|
||||
int rpcTimeout,
|
||||
* pool. Connections to a given ConnectionId are reused. */
|
||||
private Connection getConnection(ConnectionId remoteId,
|
||||
Call call)
|
||||
throws IOException, InterruptedException {
|
||||
if (!running.get()) {
|
||||
@ -1101,8 +1147,6 @@ private Connection getConnection(InetSocketAddress addr,
|
||||
* connectionsId object and with set() method. We need to manage the
|
||||
* refs for keys in HashMap properly. For now its ok.
|
||||
*/
|
||||
ConnectionId remoteId = new ConnectionId(
|
||||
addr, protocol, ticket, rpcTimeout);
|
||||
do {
|
||||
synchronized (connections) {
|
||||
connection = connections.get(remoteId);
|
||||
@ -1120,24 +1164,40 @@ private Connection getConnection(InetSocketAddress addr,
|
||||
connection.setupIOstreams();
|
||||
return connection;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This class holds the address and the user ticket. The client connections
|
||||
* to servers are uniquely identified by <remoteAddress, protocol, ticket>
|
||||
*/
|
||||
private static class ConnectionId {
|
||||
static class ConnectionId {
|
||||
InetSocketAddress address;
|
||||
UserGroupInformation ticket;
|
||||
Class<?> protocol;
|
||||
private static final int PRIME = 16777619;
|
||||
private int rpcTimeout;
|
||||
private String serverPrincipal;
|
||||
private int maxIdleTime; //connections will be culled if it was idle for
|
||||
//maxIdleTime msecs
|
||||
private int maxRetries; //the max. no. of retries for socket connections
|
||||
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
|
||||
private boolean doPing; //do we need to send ping message
|
||||
private int pingInterval; // how often sends ping to the server in msecs
|
||||
|
||||
ConnectionId(InetSocketAddress address, Class<?> protocol,
|
||||
UserGroupInformation ticket, int rpcTimeout) {
|
||||
UserGroupInformation ticket, int rpcTimeout,
|
||||
String serverPrincipal, int maxIdleTime,
|
||||
int maxRetries, boolean tcpNoDelay,
|
||||
boolean doPing, int pingInterval) {
|
||||
this.protocol = protocol;
|
||||
this.address = address;
|
||||
this.ticket = ticket;
|
||||
this.rpcTimeout = rpcTimeout;
|
||||
this.serverPrincipal = serverPrincipal;
|
||||
this.maxIdleTime = maxIdleTime;
|
||||
this.maxRetries = maxRetries;
|
||||
this.tcpNoDelay = tcpNoDelay;
|
||||
this.doPing = doPing;
|
||||
this.pingInterval = pingInterval;
|
||||
}
|
||||
|
||||
InetSocketAddress getAddress() {
|
||||
@ -1156,25 +1216,102 @@ private int getRpcTimeout() {
|
||||
return rpcTimeout;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj instanceof ConnectionId) {
|
||||
ConnectionId id = (ConnectionId) obj;
|
||||
return address.equals(id.address) && protocol == id.protocol &&
|
||||
((ticket != null && ticket.equals(id.ticket)) ||
|
||||
(ticket == id.ticket)) && rpcTimeout == id.rpcTimeout;
|
||||
}
|
||||
return false;
|
||||
String getServerPrincipal() {
|
||||
return serverPrincipal;
|
||||
}
|
||||
|
||||
@Override // simply use the default Object#hashcode() ?
|
||||
int getMaxIdleTime() {
|
||||
return maxIdleTime;
|
||||
}
|
||||
|
||||
int getMaxRetries() {
|
||||
return maxRetries;
|
||||
}
|
||||
|
||||
boolean getTcpNoDelay() {
|
||||
return tcpNoDelay;
|
||||
}
|
||||
|
||||
boolean getDoPing() {
|
||||
return doPing;
|
||||
}
|
||||
|
||||
int getPingInterval() {
|
||||
return pingInterval;
|
||||
}
|
||||
|
||||
static ConnectionId getConnectionId(InetSocketAddress addr,
|
||||
Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
|
||||
Configuration conf) throws IOException {
|
||||
String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
|
||||
return new ConnectionId(addr, protocol, ticket,
|
||||
rpcTimeout, remotePrincipal,
|
||||
conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
|
||||
conf.getInt("ipc.client.connect.max.retries", 10),
|
||||
conf.getBoolean("ipc.client.tcpnodelay", false),
|
||||
conf.getBoolean("ipc.client.ping", true),
|
||||
Client.getPingInterval(conf));
|
||||
}
|
||||
|
||||
private static String getRemotePrincipal(Configuration conf,
|
||||
InetSocketAddress address, Class<?> protocol) throws IOException {
|
||||
if (protocol == null) {
|
||||
return null;
|
||||
}
|
||||
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
|
||||
if (krbInfo != null) {
|
||||
String serverKey = krbInfo.serverPrincipal();
|
||||
if (serverKey == null) {
|
||||
throw new IOException(
|
||||
"Can't obtain server Kerberos config key from protocol="
|
||||
+ protocol.getCanonicalName());
|
||||
}
|
||||
return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
|
||||
.getAddress().getCanonicalHostName());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
static boolean isEqual(Object a, Object b) {
|
||||
return a == null ? b == null : a.equals(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == this) {
|
||||
return true;
|
||||
}
|
||||
if (obj instanceof ConnectionId) {
|
||||
ConnectionId that = (ConnectionId) obj;
|
||||
return isEqual(this.address, that.address)
|
||||
&& this.doPing == that.doPing
|
||||
&& this.maxIdleTime == that.maxIdleTime
|
||||
&& this.maxRetries == that.maxRetries
|
||||
&& this.pingInterval == that.pingInterval
|
||||
&& isEqual(this.protocol, that.protocol)
|
||||
&& this.rpcTimeout == that.rpcTimeout
|
||||
&& isEqual(this.serverPrincipal, that.serverPrincipal)
|
||||
&& this.tcpNoDelay == that.tcpNoDelay
|
||||
&& isEqual(this.ticket, that.ticket);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return (address.hashCode() + PRIME * (
|
||||
PRIME * (
|
||||
PRIME * System.identityHashCode(protocol) ^
|
||||
System.identityHashCode(ticket)
|
||||
) ^ System.identityHashCode(rpcTimeout)
|
||||
));
|
||||
int result = 1;
|
||||
result = PRIME * result + ((address == null) ? 0 : address.hashCode());
|
||||
result = PRIME * result + (doPing ? 1231 : 1237);
|
||||
result = PRIME * result + maxIdleTime;
|
||||
result = PRIME * result + maxRetries;
|
||||
result = PRIME * result + pingInterval;
|
||||
result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
|
||||
result = PRIME * result + rpcTimeout;
|
||||
result = PRIME * result
|
||||
+ ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
|
||||
result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
|
||||
result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -38,6 +38,8 @@
|
||||
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
|
||||
import org.apache.hadoop.security.token.SecretManager;
|
||||
import org.apache.hadoop.security.token.TokenIdentifier;
|
||||
import org.apache.hadoop.classification.InterfaceAudience;
|
||||
import org.apache.hadoop.classification.InterfaceStability;
|
||||
import org.apache.hadoop.conf.*;
|
||||
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
|
||||
|
||||
@ -172,21 +174,16 @@ private void stopClient(Client client) {
|
||||
private static ClientCache CLIENTS=new ClientCache();
|
||||
|
||||
private static class Invoker implements InvocationHandler {
|
||||
private Class<?> protocol;
|
||||
private InetSocketAddress address;
|
||||
private UserGroupInformation ticket;
|
||||
private int rpcTimeout;
|
||||
private Client.ConnectionId remoteId;
|
||||
private Client client;
|
||||
private boolean isClosed = false;
|
||||
|
||||
public Invoker(Class<?> protocol,
|
||||
InetSocketAddress address, UserGroupInformation ticket,
|
||||
Configuration conf, SocketFactory factory,
|
||||
int rpcTimeout) {
|
||||
this.protocol = protocol;
|
||||
this.address = address;
|
||||
this.ticket = ticket;
|
||||
this.rpcTimeout = rpcTimeout;
|
||||
int rpcTimeout) throws IOException {
|
||||
this.remoteId = Client.ConnectionId.getConnectionId(address, protocol,
|
||||
ticket, rpcTimeout, conf);
|
||||
this.client = CLIENTS.getClient(conf, factory);
|
||||
}
|
||||
|
||||
@ -198,8 +195,7 @@ public Object invoke(Object proxy, Method method, Object[] args)
|
||||
}
|
||||
|
||||
ObjectWritable value = (ObjectWritable)
|
||||
client.call(new Invocation(method, args), address,
|
||||
protocol, ticket, rpcTimeout);
|
||||
client.call(new Invocation(method, args), remoteId);
|
||||
if (LOG.isDebugEnabled()) {
|
||||
long callTime = System.currentTimeMillis() - startTime;
|
||||
LOG.debug("Call: " + method.getName() + " " + callTime);
|
||||
@ -216,6 +212,13 @@ synchronized private void close() {
|
||||
}
|
||||
}
|
||||
|
||||
// for unit testing only
|
||||
@InterfaceAudience.Private
|
||||
@InterfaceStability.Unstable
|
||||
static Client getClient(Configuration conf) {
|
||||
return CLIENTS.getClient(conf);
|
||||
}
|
||||
|
||||
/** Construct a client-side proxy object that implements the named protocol,
|
||||
* talking to a server at the named address. */
|
||||
public Object getProxy(Class<?> protocol, long clientVersion,
|
||||
@ -259,7 +262,7 @@ public Object[] call(Method method, Object[][] params,
|
||||
Client client = CLIENTS.getClient(conf);
|
||||
try {
|
||||
Writable[] wrappedValues =
|
||||
client.call(invocations, addrs, method.getDeclaringClass(), ticket);
|
||||
client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf);
|
||||
|
||||
if (method.getReturnType() == Void.TYPE) {
|
||||
return null;
|
||||
|
@ -94,7 +94,7 @@ public void run() {
|
||||
try {
|
||||
LongWritable param = new LongWritable(RANDOM.nextLong());
|
||||
LongWritable value =
|
||||
(LongWritable)client.call(param, server, null, null, 0);
|
||||
(LongWritable)client.call(param, server, null, null, 0, conf);
|
||||
if (!param.equals(value)) {
|
||||
LOG.fatal("Call failed!");
|
||||
failed = true;
|
||||
@ -127,7 +127,7 @@ public void run() {
|
||||
Writable[] params = new Writable[addresses.length];
|
||||
for (int j = 0; j < addresses.length; j++)
|
||||
params[j] = new LongWritable(RANDOM.nextLong());
|
||||
Writable[] values = client.call(params, addresses, null, null);
|
||||
Writable[] values = client.call(params, addresses, null, null, conf);
|
||||
for (int j = 0; j < addresses.length; j++) {
|
||||
if (!params[j].equals(values[j])) {
|
||||
LOG.fatal("Call failed!");
|
||||
@ -223,7 +223,7 @@ public void testStandAloneClient() throws Exception {
|
||||
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
|
||||
try {
|
||||
client.call(new LongWritable(RANDOM.nextLong()),
|
||||
address, null, null, 0);
|
||||
address, null, null, 0, conf);
|
||||
fail("Expected an exception to have been thrown");
|
||||
} catch (IOException e) {
|
||||
String message = e.getMessage();
|
||||
@ -280,7 +280,7 @@ public void testErrorClient() throws Exception {
|
||||
Client client = new Client(LongErrorWritable.class, conf);
|
||||
try {
|
||||
client.call(new LongErrorWritable(RANDOM.nextLong()),
|
||||
addr, null, null, 0);
|
||||
addr, null, null, 0, conf);
|
||||
fail("Expected an exception to have been thrown");
|
||||
} catch (IOException e) {
|
||||
// check error
|
||||
@ -300,7 +300,7 @@ public void testRuntimeExceptionWritable() throws Exception {
|
||||
Client client = new Client(LongRTEWritable.class, conf);
|
||||
try {
|
||||
client.call(new LongRTEWritable(RANDOM.nextLong()),
|
||||
addr, null, null, 0);
|
||||
addr, null, null, 0, conf);
|
||||
fail("Expected an exception to have been thrown");
|
||||
} catch (IOException e) {
|
||||
// check error
|
||||
@ -326,7 +326,7 @@ public void testSocketFactoryException() throws Exception {
|
||||
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
|
||||
try {
|
||||
client.call(new LongWritable(RANDOM.nextLong()),
|
||||
address, null, null, 0);
|
||||
address, null, null, 0, conf);
|
||||
fail("Expected an exception to have been thrown");
|
||||
} catch (IOException e) {
|
||||
assertTrue(e.getMessage().contains("Injected fault"));
|
||||
@ -344,14 +344,14 @@ public void testIpcTimeout() throws Exception {
|
||||
// set timeout to be less than MIN_SLEEP_TIME
|
||||
try {
|
||||
client.call(new LongWritable(RANDOM.nextLong()),
|
||||
addr, null, null, MIN_SLEEP_TIME/2);
|
||||
addr, null, null, MIN_SLEEP_TIME/2, conf);
|
||||
fail("Expected an exception to have been thrown");
|
||||
} catch (SocketTimeoutException e) {
|
||||
LOG.info("Get a SocketTimeoutException ", e);
|
||||
}
|
||||
// set timeout to be bigger than 3*ping interval
|
||||
client.call(new LongWritable(RANDOM.nextLong()),
|
||||
addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME);
|
||||
addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf);
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
@ -27,6 +27,7 @@
|
||||
import java.net.InetSocketAddress;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.util.Collection;
|
||||
import java.util.Set;
|
||||
|
||||
import javax.security.sasl.Sasl;
|
||||
|
||||
@ -37,6 +38,7 @@
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.io.Text;
|
||||
import org.apache.hadoop.ipc.Client.ConnectionId;
|
||||
import org.apache.hadoop.net.NetUtils;
|
||||
import org.apache.hadoop.security.KerberosInfo;
|
||||
import org.apache.hadoop.security.token.SecretManager;
|
||||
@ -66,6 +68,9 @@ public class TestSaslRPC {
|
||||
static final String ERROR_MESSAGE = "Token is invalid";
|
||||
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
|
||||
static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
|
||||
static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
|
||||
static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
|
||||
|
||||
private static Configuration conf;
|
||||
static {
|
||||
conf = new Configuration();
|
||||
@ -249,6 +254,63 @@ private void doDigestRpc(Server server, TestTokenSecretManager sm)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPerConnectionConf() throws Exception {
|
||||
TestTokenSecretManager sm = new TestTokenSecretManager();
|
||||
final Server server = RPC.getServer(TestSaslProtocol.class,
|
||||
new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
|
||||
server.start();
|
||||
final UserGroupInformation current = UserGroupInformation.getCurrentUser();
|
||||
final InetSocketAddress addr = NetUtils.getConnectAddress(server);
|
||||
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
|
||||
.getUserName()));
|
||||
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
|
||||
sm);
|
||||
Text host = new Text(addr.getAddress().getHostAddress() + ":"
|
||||
+ addr.getPort());
|
||||
token.setService(host);
|
||||
LOG.info("Service IP address for token is " + host);
|
||||
current.addToken(token);
|
||||
|
||||
Configuration newConf = new Configuration(conf);
|
||||
newConf.set("hadoop.rpc.socket.factory.class.default", "");
|
||||
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
|
||||
|
||||
TestSaslProtocol proxy1 = null;
|
||||
TestSaslProtocol proxy2 = null;
|
||||
TestSaslProtocol proxy3 = null;
|
||||
try {
|
||||
proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
Client client = WritableRpcEngine.getClient(conf);
|
||||
Set<ConnectionId> conns = client.getConnectionIds();
|
||||
assertEquals("number of connections in cache is wrong", 1, conns.size());
|
||||
// same conf, connection should be re-used
|
||||
proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
assertEquals("number of connections in cache is wrong", 1, conns.size());
|
||||
// different conf, new connection should be set up
|
||||
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
|
||||
proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
|
||||
TestSaslProtocol.versionID, addr, newConf);
|
||||
ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
|
||||
assertEquals("number of connections in cache is wrong", 2,
|
||||
connsArray.length);
|
||||
String p1 = connsArray[0].getServerPrincipal();
|
||||
String p2 = connsArray[1].getServerPrincipal();
|
||||
assertFalse("should have different principals", p1.equals(p2));
|
||||
assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
|
||||
|| p1.equals(SERVER_PRINCIPAL_2));
|
||||
assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
|
||||
|| p2.equals(SERVER_PRINCIPAL_2));
|
||||
} finally {
|
||||
server.stop();
|
||||
RPC.stopProxy(proxy1);
|
||||
RPC.stopProxy(proxy2);
|
||||
RPC.stopProxy(proxy3);
|
||||
}
|
||||
}
|
||||
|
||||
static void testKerberosRpc(String principal, String keytab) throws Exception {
|
||||
final Configuration newConf = new Configuration(conf);
|
||||
newConf.set(SERVER_PRINCIPAL_KEY, principal);
|
||||
|
Loading…
Reference in New Issue
Block a user