HADOOP-6907. Rpc client doesn't use the per-connection conf to figure out server's Kerberos principal. Contributed by Kan Zhang.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@991780 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Hairong Kuang 2010-09-02 00:35:30 +00:00
parent 5c8d9aecf7
commit 1c75bcc76b
5 changed files with 293 additions and 88 deletions

View File

@ -226,6 +226,9 @@ Trunk (unreleased changes)
HADOOP-6913. Circular initialization between UserGroupInformation and
KerberosName (Kan Zhang via boryas)
HADOOP-6907. Rpc client doesn't use the per-connection conf to figure
out server's Kerberos principal (Kan Zhang via hairong)
Release 0.21.0 - Unreleased
INCOMPATIBLE CHANGES

View File

@ -37,6 +37,7 @@
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
@ -45,6 +46,8 @@
import org.apache.commons.logging.*;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
@ -80,12 +83,6 @@ public class Client {
private int counter; // counter for call ids
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
final private Configuration conf;
final private int maxIdleTime; //connections will be culled if it was idle for
//maxIdleTime msecs
final private int maxRetries; //the max. no. of retries for socket connections
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
private int pingInterval; // how often sends ping to the server in msecs
final private boolean doPing; //do we need to send ping message
private SocketFactory socketFactory; // how to create sockets
private int refCount = 1;
@ -220,6 +217,12 @@ private class Connection extends Thread {
private DataInputStream in;
private DataOutputStream out;
private int rpcTimeout;
private int maxIdleTime; //connections will be culled if it was idle for
//maxIdleTime msecs
private int maxRetries; //the max. no. of retries for socket connections
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
private boolean doPing; //do we need to send ping message
private int pingInterval; // how often sends ping to the server in msecs
// currently active calls
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@ -235,6 +238,15 @@ public Connection(ConnectionId remoteId) throws IOException {
remoteId.getAddress().getHostName());
}
this.rpcTimeout = remoteId.getRpcTimeout();
this.maxIdleTime = remoteId.getMaxIdleTime();
this.maxRetries = remoteId.getMaxRetries();
this.tcpNoDelay = remoteId.getTcpNoDelay();
this.doPing = remoteId.getDoPing();
this.pingInterval = remoteId.getPingInterval();
if (LOG.isDebugEnabled()) {
LOG.debug("The ping interval is" + this.pingInterval + "ms.");
}
UserGroupInformation ticket = remoteId.getTicket();
Class<?> protocol = remoteId.getProtocol();
this.useSasl = UserGroupInformation.isSecurityEnabled();
@ -256,15 +268,9 @@ public Connection(ConnectionId remoteId) throws IOException {
}
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
if (krbInfo != null) {
String serverKey = krbInfo.serverPrincipal();
if (serverKey == null) {
throw new IOException(
"Can't obtain server Kerberos config key from KerberosInfo");
}
serverPrincipal = SecurityUtil.getServerPrincipal(
conf.get(serverKey), server.getAddress().getCanonicalHostName());
serverPrincipal = remoteId.getServerPrincipal();
if (LOG.isDebugEnabled()) {
LOG.debug("RPC Server Kerberos principal name for protocol="
LOG.debug("RPC Server's Kerberos principal name for protocol="
+ protocol.getCanonicalName() + " is " + serverPrincipal);
}
}
@ -882,15 +888,6 @@ public synchronized void callComplete(ParallelCall call) {
public Client(Class<? extends Writable> valueClass, Configuration conf,
SocketFactory factory) {
this.valueClass = valueClass;
this.maxIdleTime =
conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
this.doPing = conf.getBoolean("ipc.client.ping", true);
this.pingInterval = getPingInterval(conf);
if (LOG.isDebugEnabled()) {
LOG.debug("The ping interval is" + this.pingInterval + "ms.");
}
this.conf = conf;
this.socketFactory = factory;
}
@ -942,7 +939,7 @@ public void stop() {
/** Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code>, returning the value. Throws exceptions if there are
* network problems or if the remote code threw an exception.
* @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
*/
@Deprecated
public Writable call(Writable param, InetSocketAddress address)
@ -955,27 +952,60 @@ public Writable call(Writable param, InetSocketAddress address)
* the value.
* Throws exceptions if there are network problems or if the remote code
* threw an exception.
* @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
*/
@Deprecated
public Writable call(Writable param, InetSocketAddress addr,
UserGroupInformation ticket)
throws InterruptedException, IOException {
return call(param, addr, null, ticket, 0);
ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0,
conf);
return call(param, remoteId);
}
/** Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol,
* with the <code>ticket</code> credentials, returning the value.
* with the <code>ticket</code> credentials and <code>rpcTimeout</code> as
* timeout, returning the value.
* Throws exceptions if there are network problems or if the remote code
* threw an exception. */
* threw an exception.
* @deprecated Use {@link #call(Writable, ConnectionId)} instead
*/
@Deprecated
public Writable call(Writable param, InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket,
int rpcTimeout)
throws InterruptedException, IOException {
ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
ticket, rpcTimeout, conf);
return call(param, remoteId);
}
/**
* Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol,
* with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
* timeout and <code>conf</code> as conf for this connection, returning the
* value. Throws exceptions if there are network problems or if the remote
* code threw an exception.
*/
public Writable call(Writable param, InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket,
int rpcTimeout, Configuration conf)
throws InterruptedException, IOException {
ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
ticket, rpcTimeout, conf);
return call(param, remoteId);
}
/** Make a call, passing <code>param</code>, to the IPC server defined by
* <code>remoteId</code>, returning the value.
* Throws exceptions if there are network problems or if the remote code
* threw an exception. */
public Writable call(Writable param, ConnectionId remoteId)
throws InterruptedException, IOException {
Call call = new Call(param);
Connection connection = getConnection(
addr, protocol, ticket, rpcTimeout, call);
Connection connection = getConnection(remoteId, call);
connection.sendParam(call); // send the parameter
boolean interrupted = false;
synchronized (call) {
@ -998,7 +1028,7 @@ public Writable call(Writable param, InetSocketAddress addr,
call.error.fillInStackTrace();
throw call.error;
} else { // local exception
throw wrapException(addr, call.error);
throw wrapException(remoteId.getAddress(), call.error);
}
} else {
return call.value;
@ -1038,25 +1068,34 @@ private IOException wrapException(InetSocketAddress addr,
}
/**
* Makes a set of calls in parallel. Each parameter is sent to the
* corresponding address. When all values are available, or have timed out
* or errored, the collected results are returned in an array. The array
* contains nulls for calls that timed out or errored.
* @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead
* @deprecated Use {@link #call(Writable[], InetSocketAddress[],
* Class, UserGroupInformation, Configuration)} instead
*/
@Deprecated
public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
throws IOException, InterruptedException {
return call(params, addresses, null, null);
return call(params, addresses, null, null, conf);
}
/**
* @deprecated Use {@link #call(Writable[], InetSocketAddress[],
* Class, UserGroupInformation, Configuration)} instead
*/
@Deprecated
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket)
throws IOException, InterruptedException {
return call(params, addresses, protocol, ticket, conf);
}
/** Makes a set of calls in parallel. Each parameter is sent to the
* corresponding address. When all values are available, or have timed out
* or errored, the collected results are returned in an array. The array
* contains nulls for calls that timed out or errored. */
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket)
throws IOException, InterruptedException {
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket, Configuration conf)
throws IOException, InterruptedException {
if (addresses.length == 0) return new Writable[0];
ParallelResults results = new ParallelResults(params.length);
@ -1064,8 +1103,9 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
for (int i = 0; i < params.length; i++) {
ParallelCall call = new ParallelCall(params[i], results, i);
try {
Connection connection =
getConnection(addresses[i], protocol, ticket, 0, call);
ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
protocol, ticket, 0, conf);
Connection connection = getConnection(remoteId, call);
connection.sendParam(call); // send each parameter
} catch (IOException e) {
// log errors
@ -1084,12 +1124,18 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
}
}
// for unit testing only
@InterfaceAudience.Private
@InterfaceStability.Unstable
Set<ConnectionId> getConnectionIds() {
synchronized (connections) {
return connections.keySet();
}
}
/** Get a connection from the pool, or create a new one and add it to the
* pool. Connections to a given host/port are reused. */
private Connection getConnection(InetSocketAddress addr,
Class<?> protocol,
UserGroupInformation ticket,
int rpcTimeout,
* pool. Connections to a given ConnectionId are reused. */
private Connection getConnection(ConnectionId remoteId,
Call call)
throws IOException, InterruptedException {
if (!running.get()) {
@ -1101,8 +1147,6 @@ private Connection getConnection(InetSocketAddress addr,
* connectionsId object and with set() method. We need to manage the
* refs for keys in HashMap properly. For now its ok.
*/
ConnectionId remoteId = new ConnectionId(
addr, protocol, ticket, rpcTimeout);
do {
synchronized (connections) {
connection = connections.get(remoteId);
@ -1120,24 +1164,40 @@ private Connection getConnection(InetSocketAddress addr,
connection.setupIOstreams();
return connection;
}
/**
* This class holds the address and the user ticket. The client connections
* to servers are uniquely identified by <remoteAddress, protocol, ticket>
*/
private static class ConnectionId {
static class ConnectionId {
InetSocketAddress address;
UserGroupInformation ticket;
Class<?> protocol;
private static final int PRIME = 16777619;
private int rpcTimeout;
private String serverPrincipal;
private int maxIdleTime; //connections will be culled if it was idle for
//maxIdleTime msecs
private int maxRetries; //the max. no. of retries for socket connections
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
private boolean doPing; //do we need to send ping message
private int pingInterval; // how often sends ping to the server in msecs
ConnectionId(InetSocketAddress address, Class<?> protocol,
UserGroupInformation ticket, int rpcTimeout) {
UserGroupInformation ticket, int rpcTimeout,
String serverPrincipal, int maxIdleTime,
int maxRetries, boolean tcpNoDelay,
boolean doPing, int pingInterval) {
this.protocol = protocol;
this.address = address;
this.ticket = ticket;
this.rpcTimeout = rpcTimeout;
this.serverPrincipal = serverPrincipal;
this.maxIdleTime = maxIdleTime;
this.maxRetries = maxRetries;
this.tcpNoDelay = tcpNoDelay;
this.doPing = doPing;
this.pingInterval = pingInterval;
}
InetSocketAddress getAddress() {
@ -1156,25 +1216,102 @@ private int getRpcTimeout() {
return rpcTimeout;
}
@Override
public boolean equals(Object obj) {
if (obj instanceof ConnectionId) {
ConnectionId id = (ConnectionId) obj;
return address.equals(id.address) && protocol == id.protocol &&
((ticket != null && ticket.equals(id.ticket)) ||
(ticket == id.ticket)) && rpcTimeout == id.rpcTimeout;
}
return false;
String getServerPrincipal() {
return serverPrincipal;
}
@Override // simply use the default Object#hashcode() ?
int getMaxIdleTime() {
return maxIdleTime;
}
int getMaxRetries() {
return maxRetries;
}
boolean getTcpNoDelay() {
return tcpNoDelay;
}
boolean getDoPing() {
return doPing;
}
int getPingInterval() {
return pingInterval;
}
static ConnectionId getConnectionId(InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
Configuration conf) throws IOException {
String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
return new ConnectionId(addr, protocol, ticket,
rpcTimeout, remotePrincipal,
conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
conf.getInt("ipc.client.connect.max.retries", 10),
conf.getBoolean("ipc.client.tcpnodelay", false),
conf.getBoolean("ipc.client.ping", true),
Client.getPingInterval(conf));
}
private static String getRemotePrincipal(Configuration conf,
InetSocketAddress address, Class<?> protocol) throws IOException {
if (protocol == null) {
return null;
}
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
if (krbInfo != null) {
String serverKey = krbInfo.serverPrincipal();
if (serverKey == null) {
throw new IOException(
"Can't obtain server Kerberos config key from protocol="
+ protocol.getCanonicalName());
}
return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
.getAddress().getCanonicalHostName());
}
return null;
}
static boolean isEqual(Object a, Object b) {
return a == null ? b == null : a.equals(b);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj instanceof ConnectionId) {
ConnectionId that = (ConnectionId) obj;
return isEqual(this.address, that.address)
&& this.doPing == that.doPing
&& this.maxIdleTime == that.maxIdleTime
&& this.maxRetries == that.maxRetries
&& this.pingInterval == that.pingInterval
&& isEqual(this.protocol, that.protocol)
&& this.rpcTimeout == that.rpcTimeout
&& isEqual(this.serverPrincipal, that.serverPrincipal)
&& this.tcpNoDelay == that.tcpNoDelay
&& isEqual(this.ticket, that.ticket);
}
return false;
}
@Override
public int hashCode() {
return (address.hashCode() + PRIME * (
PRIME * (
PRIME * System.identityHashCode(protocol) ^
System.identityHashCode(ticket)
) ^ System.identityHashCode(rpcTimeout)
));
int result = 1;
result = PRIME * result + ((address == null) ? 0 : address.hashCode());
result = PRIME * result + (doPing ? 1231 : 1237);
result = PRIME * result + maxIdleTime;
result = PRIME * result + maxRetries;
result = PRIME * result + pingInterval;
result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
result = PRIME * result + rpcTimeout;
result = PRIME * result
+ ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
return result;
}
}
}

View File

@ -38,6 +38,8 @@
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
@ -172,21 +174,16 @@ private void stopClient(Client client) {
private static ClientCache CLIENTS=new ClientCache();
private static class Invoker implements InvocationHandler {
private Class<?> protocol;
private InetSocketAddress address;
private UserGroupInformation ticket;
private int rpcTimeout;
private Client.ConnectionId remoteId;
private Client client;
private boolean isClosed = false;
public Invoker(Class<?> protocol,
InetSocketAddress address, UserGroupInformation ticket,
Configuration conf, SocketFactory factory,
int rpcTimeout) {
this.protocol = protocol;
this.address = address;
this.ticket = ticket;
this.rpcTimeout = rpcTimeout;
int rpcTimeout) throws IOException {
this.remoteId = Client.ConnectionId.getConnectionId(address, protocol,
ticket, rpcTimeout, conf);
this.client = CLIENTS.getClient(conf, factory);
}
@ -198,8 +195,7 @@ public Object invoke(Object proxy, Method method, Object[] args)
}
ObjectWritable value = (ObjectWritable)
client.call(new Invocation(method, args), address,
protocol, ticket, rpcTimeout);
client.call(new Invocation(method, args), remoteId);
if (LOG.isDebugEnabled()) {
long callTime = System.currentTimeMillis() - startTime;
LOG.debug("Call: " + method.getName() + " " + callTime);
@ -216,6 +212,13 @@ synchronized private void close() {
}
}
// for unit testing only
@InterfaceAudience.Private
@InterfaceStability.Unstable
static Client getClient(Configuration conf) {
return CLIENTS.getClient(conf);
}
/** Construct a client-side proxy object that implements the named protocol,
* talking to a server at the named address. */
public Object getProxy(Class<?> protocol, long clientVersion,
@ -259,7 +262,7 @@ public Object[] call(Method method, Object[][] params,
Client client = CLIENTS.getClient(conf);
try {
Writable[] wrappedValues =
client.call(invocations, addrs, method.getDeclaringClass(), ticket);
client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf);
if (method.getReturnType() == Void.TYPE) {
return null;

View File

@ -94,7 +94,7 @@ public void run() {
try {
LongWritable param = new LongWritable(RANDOM.nextLong());
LongWritable value =
(LongWritable)client.call(param, server, null, null, 0);
(LongWritable)client.call(param, server, null, null, 0, conf);
if (!param.equals(value)) {
LOG.fatal("Call failed!");
failed = true;
@ -127,7 +127,7 @@ public void run() {
Writable[] params = new Writable[addresses.length];
for (int j = 0; j < addresses.length; j++)
params[j] = new LongWritable(RANDOM.nextLong());
Writable[] values = client.call(params, addresses, null, null);
Writable[] values = client.call(params, addresses, null, null, conf);
for (int j = 0; j < addresses.length; j++) {
if (!params[j].equals(values[j])) {
LOG.fatal("Call failed!");
@ -223,7 +223,7 @@ public void testStandAloneClient() throws Exception {
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
try {
client.call(new LongWritable(RANDOM.nextLong()),
address, null, null, 0);
address, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
String message = e.getMessage();
@ -280,7 +280,7 @@ public void testErrorClient() throws Exception {
Client client = new Client(LongErrorWritable.class, conf);
try {
client.call(new LongErrorWritable(RANDOM.nextLong()),
addr, null, null, 0);
addr, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
// check error
@ -300,7 +300,7 @@ public void testRuntimeExceptionWritable() throws Exception {
Client client = new Client(LongRTEWritable.class, conf);
try {
client.call(new LongRTEWritable(RANDOM.nextLong()),
addr, null, null, 0);
addr, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
// check error
@ -326,7 +326,7 @@ public void testSocketFactoryException() throws Exception {
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
try {
client.call(new LongWritable(RANDOM.nextLong()),
address, null, null, 0);
address, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
assertTrue(e.getMessage().contains("Injected fault"));
@ -344,14 +344,14 @@ public void testIpcTimeout() throws Exception {
// set timeout to be less than MIN_SLEEP_TIME
try {
client.call(new LongWritable(RANDOM.nextLong()),
addr, null, null, MIN_SLEEP_TIME/2);
addr, null, null, MIN_SLEEP_TIME/2, conf);
fail("Expected an exception to have been thrown");
} catch (SocketTimeoutException e) {
LOG.info("Get a SocketTimeoutException ", e);
}
// set timeout to be bigger than 3*ping interval
client.call(new LongWritable(RANDOM.nextLong()),
addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME);
addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf);
}
public static void main(String[] args) throws Exception {

View File

@ -27,6 +27,7 @@
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.util.Collection;
import java.util.Set;
import javax.security.sasl.Sasl;
@ -37,6 +38,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.SecretManager;
@ -66,6 +68,9 @@ public class TestSaslRPC {
static final String ERROR_MESSAGE = "Token is invalid";
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
private static Configuration conf;
static {
conf = new Configuration();
@ -249,6 +254,63 @@ private void doDigestRpc(Server server, TestTokenSecretManager sm)
}
}
@Test
public void testPerConnectionConf() throws Exception {
TestTokenSecretManager sm = new TestTokenSecretManager();
final Server server = RPC.getServer(TestSaslProtocol.class,
new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
server.start();
final UserGroupInformation current = UserGroupInformation.getCurrentUser();
final InetSocketAddress addr = NetUtils.getConnectAddress(server);
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
.getUserName()));
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
sm);
Text host = new Text(addr.getAddress().getHostAddress() + ":"
+ addr.getPort());
token.setService(host);
LOG.info("Service IP address for token is " + host);
current.addToken(token);
Configuration newConf = new Configuration(conf);
newConf.set("hadoop.rpc.socket.factory.class.default", "");
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
TestSaslProtocol proxy1 = null;
TestSaslProtocol proxy2 = null;
TestSaslProtocol proxy3 = null;
try {
proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
Client client = WritableRpcEngine.getClient(conf);
Set<ConnectionId> conns = client.getConnectionIds();
assertEquals("number of connections in cache is wrong", 1, conns.size());
// same conf, connection should be re-used
proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
assertEquals("number of connections in cache is wrong", 1, conns.size());
// different conf, new connection should be set up
newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
assertEquals("number of connections in cache is wrong", 2,
connsArray.length);
String p1 = connsArray[0].getServerPrincipal();
String p2 = connsArray[1].getServerPrincipal();
assertFalse("should have different principals", p1.equals(p2));
assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
|| p1.equals(SERVER_PRINCIPAL_2));
assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
|| p2.equals(SERVER_PRINCIPAL_2));
} finally {
server.stop();
RPC.stopProxy(proxy1);
RPC.stopProxy(proxy2);
RPC.stopProxy(proxy3);
}
}
static void testKerberosRpc(String principal, String keytab) throws Exception {
final Configuration newConf = new Configuration(conf);
newConf.set(SERVER_PRINCIPAL_KEY, principal);