diff --git a/CHANGES.txt b/CHANGES.txt index 6455cfc591..0ed77a8566 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -226,6 +226,9 @@ Trunk (unreleased changes) HADOOP-6913. Circular initialization between UserGroupInformation and KerberosName (Kan Zhang via boryas) + HADOOP-6907. Rpc client doesn't use the per-connection conf to figure + out server's Kerberos principal (Kan Zhang via hairong) + Release 0.21.0 - Unreleased INCOMPATIBLE CHANGES diff --git a/src/java/org/apache/hadoop/ipc/Client.java b/src/java/org/apache/hadoop/ipc/Client.java index e65b2e812e..d6a3e17e79 100644 --- a/src/java/org/apache/hadoop/ipc/Client.java +++ b/src/java/org/apache/hadoop/ipc/Client.java @@ -37,6 +37,7 @@ import java.util.Hashtable; import java.util.Iterator; import java.util.Random; +import java.util.Set; import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -45,6 +46,8 @@ import org.apache.commons.logging.*; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.Text; @@ -80,12 +83,6 @@ public class Client { private int counter; // counter for call ids private AtomicBoolean running = new AtomicBoolean(true); // if client runs final private Configuration conf; - final private int maxIdleTime; //connections will be culled if it was idle for - //maxIdleTime msecs - final private int maxRetries; //the max. no. of retries for socket connections - private boolean tcpNoDelay; // if T then disable Nagle's Algorithm - private int pingInterval; // how often sends ping to the server in msecs - final private boolean doPing; //do we need to send ping message private SocketFactory socketFactory; // how to create sockets private int refCount = 1; @@ -220,6 +217,12 @@ private class Connection extends Thread { private DataInputStream in; private DataOutputStream out; private int rpcTimeout; + private int maxIdleTime; //connections will be culled if it was idle for + //maxIdleTime msecs + private int maxRetries; //the max. no. of retries for socket connections + private boolean tcpNoDelay; // if T then disable Nagle's Algorithm + private boolean doPing; //do we need to send ping message + private int pingInterval; // how often sends ping to the server in msecs // currently active calls private Hashtable calls = new Hashtable(); @@ -235,6 +238,15 @@ public Connection(ConnectionId remoteId) throws IOException { remoteId.getAddress().getHostName()); } this.rpcTimeout = remoteId.getRpcTimeout(); + this.maxIdleTime = remoteId.getMaxIdleTime(); + this.maxRetries = remoteId.getMaxRetries(); + this.tcpNoDelay = remoteId.getTcpNoDelay(); + this.doPing = remoteId.getDoPing(); + this.pingInterval = remoteId.getPingInterval(); + if (LOG.isDebugEnabled()) { + LOG.debug("The ping interval is" + this.pingInterval + "ms."); + } + UserGroupInformation ticket = remoteId.getTicket(); Class protocol = remoteId.getProtocol(); this.useSasl = UserGroupInformation.isSecurityEnabled(); @@ -256,15 +268,9 @@ public Connection(ConnectionId remoteId) throws IOException { } KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class); if (krbInfo != null) { - String serverKey = krbInfo.serverPrincipal(); - if (serverKey == null) { - throw new IOException( - "Can't obtain server Kerberos config key from KerberosInfo"); - } - serverPrincipal = SecurityUtil.getServerPrincipal( - conf.get(serverKey), server.getAddress().getCanonicalHostName()); + serverPrincipal = remoteId.getServerPrincipal(); if (LOG.isDebugEnabled()) { - LOG.debug("RPC Server Kerberos principal name for protocol=" + LOG.debug("RPC Server's Kerberos principal name for protocol=" + protocol.getCanonicalName() + " is " + serverPrincipal); } } @@ -882,15 +888,6 @@ public synchronized void callComplete(ParallelCall call) { public Client(Class valueClass, Configuration conf, SocketFactory factory) { this.valueClass = valueClass; - this.maxIdleTime = - conf.getInt("ipc.client.connection.maxidletime", 10000); //10s - this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10); - this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false); - this.doPing = conf.getBoolean("ipc.client.ping", true); - this.pingInterval = getPingInterval(conf); - if (LOG.isDebugEnabled()) { - LOG.debug("The ping interval is" + this.pingInterval + "ms."); - } this.conf = conf; this.socketFactory = factory; } @@ -942,7 +939,7 @@ public void stop() { /** Make a call, passing param, to the IPC server running at * address, returning the value. Throws exceptions if there are * network problems or if the remote code threw an exception. - * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead + * @deprecated Use {@link #call(Writable, ConnectionId)} instead */ @Deprecated public Writable call(Writable param, InetSocketAddress address) @@ -955,27 +952,60 @@ public Writable call(Writable param, InetSocketAddress address) * the value. * Throws exceptions if there are network problems or if the remote code * threw an exception. - * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead + * @deprecated Use {@link #call(Writable, ConnectionId)} instead */ @Deprecated public Writable call(Writable param, InetSocketAddress addr, UserGroupInformation ticket) throws InterruptedException, IOException { - return call(param, addr, null, ticket, 0); + ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0, + conf); + return call(param, remoteId); } /** Make a call, passing param, to the IPC server running at * address which is servicing the protocol protocol, - * with the ticket credentials, returning the value. + * with the ticket credentials and rpcTimeout as + * timeout, returning the value. * Throws exceptions if there are network problems or if the remote code - * threw an exception. */ + * threw an exception. + * @deprecated Use {@link #call(Writable, ConnectionId)} instead + */ + @Deprecated public Writable call(Writable param, InetSocketAddress addr, Class protocol, UserGroupInformation ticket, int rpcTimeout) throws InterruptedException, IOException { + ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol, + ticket, rpcTimeout, conf); + return call(param, remoteId); + } + + /** + * Make a call, passing param, to the IPC server running at + * address which is servicing the protocol protocol, + * with the ticket credentials, rpcTimeout as + * timeout and conf as conf for this connection, returning the + * value. Throws exceptions if there are network problems or if the remote + * code threw an exception. + */ + public Writable call(Writable param, InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, + int rpcTimeout, Configuration conf) + throws InterruptedException, IOException { + ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol, + ticket, rpcTimeout, conf); + return call(param, remoteId); + } + + /** Make a call, passing param, to the IPC server defined by + * remoteId, returning the value. + * Throws exceptions if there are network problems or if the remote code + * threw an exception. */ + public Writable call(Writable param, ConnectionId remoteId) + throws InterruptedException, IOException { Call call = new Call(param); - Connection connection = getConnection( - addr, protocol, ticket, rpcTimeout, call); + Connection connection = getConnection(remoteId, call); connection.sendParam(call); // send the parameter boolean interrupted = false; synchronized (call) { @@ -998,7 +1028,7 @@ public Writable call(Writable param, InetSocketAddress addr, call.error.fillInStackTrace(); throw call.error; } else { // local exception - throw wrapException(addr, call.error); + throw wrapException(remoteId.getAddress(), call.error); } } else { return call.value; @@ -1038,25 +1068,34 @@ private IOException wrapException(InetSocketAddress addr, } /** - * Makes a set of calls in parallel. Each parameter is sent to the - * corresponding address. When all values are available, or have timed out - * or errored, the collected results are returned in an array. The array - * contains nulls for calls that timed out or errored. - * @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead + * @deprecated Use {@link #call(Writable[], InetSocketAddress[], + * Class, UserGroupInformation, Configuration)} instead */ @Deprecated public Writable[] call(Writable[] params, InetSocketAddress[] addresses) throws IOException, InterruptedException { - return call(params, addresses, null, null); + return call(params, addresses, null, null, conf); } + /** + * @deprecated Use {@link #call(Writable[], InetSocketAddress[], + * Class, UserGroupInformation, Configuration)} instead + */ + @Deprecated + public Writable[] call(Writable[] params, InetSocketAddress[] addresses, + Class protocol, UserGroupInformation ticket) + throws IOException, InterruptedException { + return call(params, addresses, protocol, ticket, conf); + } + + /** Makes a set of calls in parallel. Each parameter is sent to the * corresponding address. When all values are available, or have timed out * or errored, the collected results are returned in an array. The array * contains nulls for calls that timed out or errored. */ - public Writable[] call(Writable[] params, InetSocketAddress[] addresses, - Class protocol, UserGroupInformation ticket) - throws IOException, InterruptedException { + public Writable[] call(Writable[] params, InetSocketAddress[] addresses, + Class protocol, UserGroupInformation ticket, Configuration conf) + throws IOException, InterruptedException { if (addresses.length == 0) return new Writable[0]; ParallelResults results = new ParallelResults(params.length); @@ -1064,8 +1103,9 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses, for (int i = 0; i < params.length; i++) { ParallelCall call = new ParallelCall(params[i], results, i); try { - Connection connection = - getConnection(addresses[i], protocol, ticket, 0, call); + ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i], + protocol, ticket, 0, conf); + Connection connection = getConnection(remoteId, call); connection.sendParam(call); // send each parameter } catch (IOException e) { // log errors @@ -1084,12 +1124,18 @@ public Writable[] call(Writable[] params, InetSocketAddress[] addresses, } } + // for unit testing only + @InterfaceAudience.Private + @InterfaceStability.Unstable + Set getConnectionIds() { + synchronized (connections) { + return connections.keySet(); + } + } + /** Get a connection from the pool, or create a new one and add it to the - * pool. Connections to a given host/port are reused. */ - private Connection getConnection(InetSocketAddress addr, - Class protocol, - UserGroupInformation ticket, - int rpcTimeout, + * pool. Connections to a given ConnectionId are reused. */ + private Connection getConnection(ConnectionId remoteId, Call call) throws IOException, InterruptedException { if (!running.get()) { @@ -1101,8 +1147,6 @@ private Connection getConnection(InetSocketAddress addr, * connectionsId object and with set() method. We need to manage the * refs for keys in HashMap properly. For now its ok. */ - ConnectionId remoteId = new ConnectionId( - addr, protocol, ticket, rpcTimeout); do { synchronized (connections) { connection = connections.get(remoteId); @@ -1120,24 +1164,40 @@ private Connection getConnection(InetSocketAddress addr, connection.setupIOstreams(); return connection; } - + /** * This class holds the address and the user ticket. The client connections * to servers are uniquely identified by */ - private static class ConnectionId { + static class ConnectionId { InetSocketAddress address; UserGroupInformation ticket; Class protocol; private static final int PRIME = 16777619; private int rpcTimeout; + private String serverPrincipal; + private int maxIdleTime; //connections will be culled if it was idle for + //maxIdleTime msecs + private int maxRetries; //the max. no. of retries for socket connections + private boolean tcpNoDelay; // if T then disable Nagle's Algorithm + private boolean doPing; //do we need to send ping message + private int pingInterval; // how often sends ping to the server in msecs ConnectionId(InetSocketAddress address, Class protocol, - UserGroupInformation ticket, int rpcTimeout) { + UserGroupInformation ticket, int rpcTimeout, + String serverPrincipal, int maxIdleTime, + int maxRetries, boolean tcpNoDelay, + boolean doPing, int pingInterval) { this.protocol = protocol; this.address = address; this.ticket = ticket; this.rpcTimeout = rpcTimeout; + this.serverPrincipal = serverPrincipal; + this.maxIdleTime = maxIdleTime; + this.maxRetries = maxRetries; + this.tcpNoDelay = tcpNoDelay; + this.doPing = doPing; + this.pingInterval = pingInterval; } InetSocketAddress getAddress() { @@ -1156,25 +1216,102 @@ private int getRpcTimeout() { return rpcTimeout; } - @Override - public boolean equals(Object obj) { - if (obj instanceof ConnectionId) { - ConnectionId id = (ConnectionId) obj; - return address.equals(id.address) && protocol == id.protocol && - ((ticket != null && ticket.equals(id.ticket)) || - (ticket == id.ticket)) && rpcTimeout == id.rpcTimeout; - } - return false; + String getServerPrincipal() { + return serverPrincipal; } - @Override // simply use the default Object#hashcode() ? + int getMaxIdleTime() { + return maxIdleTime; + } + + int getMaxRetries() { + return maxRetries; + } + + boolean getTcpNoDelay() { + return tcpNoDelay; + } + + boolean getDoPing() { + return doPing; + } + + int getPingInterval() { + return pingInterval; + } + + static ConnectionId getConnectionId(InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, int rpcTimeout, + Configuration conf) throws IOException { + String remotePrincipal = getRemotePrincipal(conf, addr, protocol); + return new ConnectionId(addr, protocol, ticket, + rpcTimeout, remotePrincipal, + conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s + conf.getInt("ipc.client.connect.max.retries", 10), + conf.getBoolean("ipc.client.tcpnodelay", false), + conf.getBoolean("ipc.client.ping", true), + Client.getPingInterval(conf)); + } + + private static String getRemotePrincipal(Configuration conf, + InetSocketAddress address, Class protocol) throws IOException { + if (protocol == null) { + return null; + } + KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class); + if (krbInfo != null) { + String serverKey = krbInfo.serverPrincipal(); + if (serverKey == null) { + throw new IOException( + "Can't obtain server Kerberos config key from protocol=" + + protocol.getCanonicalName()); + } + return SecurityUtil.getServerPrincipal(conf.get(serverKey), address + .getAddress().getCanonicalHostName()); + } + return null; + } + + static boolean isEqual(Object a, Object b) { + return a == null ? b == null : a.equals(b); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof ConnectionId) { + ConnectionId that = (ConnectionId) obj; + return isEqual(this.address, that.address) + && this.doPing == that.doPing + && this.maxIdleTime == that.maxIdleTime + && this.maxRetries == that.maxRetries + && this.pingInterval == that.pingInterval + && isEqual(this.protocol, that.protocol) + && this.rpcTimeout == that.rpcTimeout + && isEqual(this.serverPrincipal, that.serverPrincipal) + && this.tcpNoDelay == that.tcpNoDelay + && isEqual(this.ticket, that.ticket); + } + return false; + } + + @Override public int hashCode() { - return (address.hashCode() + PRIME * ( - PRIME * ( - PRIME * System.identityHashCode(protocol) ^ - System.identityHashCode(ticket) - ) ^ System.identityHashCode(rpcTimeout) - )); + int result = 1; + result = PRIME * result + ((address == null) ? 0 : address.hashCode()); + result = PRIME * result + (doPing ? 1231 : 1237); + result = PRIME * result + maxIdleTime; + result = PRIME * result + maxRetries; + result = PRIME * result + pingInterval; + result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode()); + result = PRIME * result + rpcTimeout; + result = PRIME * result + + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode()); + result = PRIME * result + (tcpNoDelay ? 1231 : 1237); + result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode()); + return result; } } } diff --git a/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java b/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java index 48b23acbb6..0b86bc9540 100644 --- a/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java +++ b/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java @@ -38,6 +38,8 @@ import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.conf.*; import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate; @@ -172,21 +174,16 @@ private void stopClient(Client client) { private static ClientCache CLIENTS=new ClientCache(); private static class Invoker implements InvocationHandler { - private Class protocol; - private InetSocketAddress address; - private UserGroupInformation ticket; - private int rpcTimeout; + private Client.ConnectionId remoteId; private Client client; private boolean isClosed = false; public Invoker(Class protocol, InetSocketAddress address, UserGroupInformation ticket, Configuration conf, SocketFactory factory, - int rpcTimeout) { - this.protocol = protocol; - this.address = address; - this.ticket = ticket; - this.rpcTimeout = rpcTimeout; + int rpcTimeout) throws IOException { + this.remoteId = Client.ConnectionId.getConnectionId(address, protocol, + ticket, rpcTimeout, conf); this.client = CLIENTS.getClient(conf, factory); } @@ -198,8 +195,7 @@ public Object invoke(Object proxy, Method method, Object[] args) } ObjectWritable value = (ObjectWritable) - client.call(new Invocation(method, args), address, - protocol, ticket, rpcTimeout); + client.call(new Invocation(method, args), remoteId); if (LOG.isDebugEnabled()) { long callTime = System.currentTimeMillis() - startTime; LOG.debug("Call: " + method.getName() + " " + callTime); @@ -216,6 +212,13 @@ synchronized private void close() { } } + // for unit testing only + @InterfaceAudience.Private + @InterfaceStability.Unstable + static Client getClient(Configuration conf) { + return CLIENTS.getClient(conf); + } + /** Construct a client-side proxy object that implements the named protocol, * talking to a server at the named address. */ public Object getProxy(Class protocol, long clientVersion, @@ -259,7 +262,7 @@ public Object[] call(Method method, Object[][] params, Client client = CLIENTS.getClient(conf); try { Writable[] wrappedValues = - client.call(invocations, addrs, method.getDeclaringClass(), ticket); + client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf); if (method.getReturnType() == Void.TYPE) { return null; diff --git a/src/test/core/org/apache/hadoop/ipc/TestIPC.java b/src/test/core/org/apache/hadoop/ipc/TestIPC.java index 4ef3f204a0..9844595cde 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestIPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestIPC.java @@ -94,7 +94,7 @@ public void run() { try { LongWritable param = new LongWritable(RANDOM.nextLong()); LongWritable value = - (LongWritable)client.call(param, server, null, null, 0); + (LongWritable)client.call(param, server, null, null, 0, conf); if (!param.equals(value)) { LOG.fatal("Call failed!"); failed = true; @@ -127,7 +127,7 @@ public void run() { Writable[] params = new Writable[addresses.length]; for (int j = 0; j < addresses.length; j++) params[j] = new LongWritable(RANDOM.nextLong()); - Writable[] values = client.call(params, addresses, null, null); + Writable[] values = client.call(params, addresses, null, null, conf); for (int j = 0; j < addresses.length; j++) { if (!params[j].equals(values[j])) { LOG.fatal("Call failed!"); @@ -223,7 +223,7 @@ public void testStandAloneClient() throws Exception { InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10); try { client.call(new LongWritable(RANDOM.nextLong()), - address, null, null, 0); + address, null, null, 0, conf); fail("Expected an exception to have been thrown"); } catch (IOException e) { String message = e.getMessage(); @@ -280,7 +280,7 @@ public void testErrorClient() throws Exception { Client client = new Client(LongErrorWritable.class, conf); try { client.call(new LongErrorWritable(RANDOM.nextLong()), - addr, null, null, 0); + addr, null, null, 0, conf); fail("Expected an exception to have been thrown"); } catch (IOException e) { // check error @@ -300,7 +300,7 @@ public void testRuntimeExceptionWritable() throws Exception { Client client = new Client(LongRTEWritable.class, conf); try { client.call(new LongRTEWritable(RANDOM.nextLong()), - addr, null, null, 0); + addr, null, null, 0, conf); fail("Expected an exception to have been thrown"); } catch (IOException e) { // check error @@ -326,7 +326,7 @@ public void testSocketFactoryException() throws Exception { InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10); try { client.call(new LongWritable(RANDOM.nextLong()), - address, null, null, 0); + address, null, null, 0, conf); fail("Expected an exception to have been thrown"); } catch (IOException e) { assertTrue(e.getMessage().contains("Injected fault")); @@ -344,14 +344,14 @@ public void testIpcTimeout() throws Exception { // set timeout to be less than MIN_SLEEP_TIME try { client.call(new LongWritable(RANDOM.nextLong()), - addr, null, null, MIN_SLEEP_TIME/2); + addr, null, null, MIN_SLEEP_TIME/2, conf); fail("Expected an exception to have been thrown"); } catch (SocketTimeoutException e) { LOG.info("Get a SocketTimeoutException ", e); } // set timeout to be bigger than 3*ping interval client.call(new LongWritable(RANDOM.nextLong()), - addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME); + addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf); } public static void main(String[] args) throws Exception { diff --git a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java index bb34babc78..d2cf76f27e 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java @@ -27,6 +27,7 @@ import java.net.InetSocketAddress; import java.security.PrivilegedExceptionAction; import java.util.Collection; +import java.util.Set; import javax.security.sasl.Sasl; @@ -37,6 +38,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.Text; +import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.KerberosInfo; import org.apache.hadoop.security.token.SecretManager; @@ -66,6 +68,9 @@ public class TestSaslRPC { static final String ERROR_MESSAGE = "Token is invalid"; static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal"; static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab"; + static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR"; + static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR"; + private static Configuration conf; static { conf = new Configuration(); @@ -249,6 +254,63 @@ private void doDigestRpc(Server server, TestTokenSecretManager sm) } } + @Test + public void testPerConnectionConf() throws Exception { + TestTokenSecretManager sm = new TestTokenSecretManager(); + final Server server = RPC.getServer(TestSaslProtocol.class, + new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); + server.start(); + final UserGroupInformation current = UserGroupInformation.getCurrentUser(); + final InetSocketAddress addr = NetUtils.getConnectAddress(server); + TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current + .getUserName())); + Token token = new Token(tokenId, + sm); + Text host = new Text(addr.getAddress().getHostAddress() + ":" + + addr.getPort()); + token.setService(host); + LOG.info("Service IP address for token is " + host); + current.addToken(token); + + Configuration newConf = new Configuration(conf); + newConf.set("hadoop.rpc.socket.factory.class.default", ""); + newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); + + TestSaslProtocol proxy1 = null; + TestSaslProtocol proxy2 = null; + TestSaslProtocol proxy3 = null; + try { + proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, + TestSaslProtocol.versionID, addr, newConf); + Client client = WritableRpcEngine.getClient(conf); + Set conns = client.getConnectionIds(); + assertEquals("number of connections in cache is wrong", 1, conns.size()); + // same conf, connection should be re-used + proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, + TestSaslProtocol.versionID, addr, newConf); + assertEquals("number of connections in cache is wrong", 1, conns.size()); + // different conf, new connection should be set up + newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); + proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, + TestSaslProtocol.versionID, addr, newConf); + ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); + assertEquals("number of connections in cache is wrong", 2, + connsArray.length); + String p1 = connsArray[0].getServerPrincipal(); + String p2 = connsArray[1].getServerPrincipal(); + assertFalse("should have different principals", p1.equals(p2)); + assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1) + || p1.equals(SERVER_PRINCIPAL_2)); + assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1) + || p2.equals(SERVER_PRINCIPAL_2)); + } finally { + server.stop(); + RPC.stopProxy(proxy1); + RPC.stopProxy(proxy2); + RPC.stopProxy(proxy3); + } + } + static void testKerberosRpc(String principal, String keytab) throws Exception { final Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, principal);