diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index 534824e204..20fc9efe57 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -419,7 +419,7 @@ public synchronized Writable getRpcResponse() { * socket: responses may be delivered out of order. */ private class Connection extends Thread { private InetSocketAddress server; // server ip:port - private final ConnectionId remoteId; // connection id + private final ConnectionId remoteId; // connection id private AuthMethod authMethod; // authentication method private AuthProtocol authProtocol; private int serviceClass; @@ -645,6 +645,9 @@ private synchronized boolean updateAddress() throws IOException { LOG.warn("Address change detected. Old: " + server.toString() + " New: " + currentAddr.toString()); server = currentAddr; + // Update the remote address so that reconnections are with the updated address. + // This avoids thrashing. + remoteId.setAddress(currentAddr); UserGroupInformation ticket = remoteId.getTicket(); this.setName("IPC Client (" + socketFactory.hashCode() + ") connection to " + server.toString() + " from " @@ -1700,9 +1703,9 @@ private Connection getConnection(ConnectionId remoteId, @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) @InterfaceStability.Evolving public static class ConnectionId { - InetSocketAddress address; - UserGroupInformation ticket; - final Class protocol; + private InetSocketAddress address; + private final UserGroupInformation ticket; + private final Class protocol; private static final int PRIME = 16777619; private final int rpcTimeout; private final int maxIdleTime; //connections will be culled if it was idle for @@ -1717,7 +1720,7 @@ public static class ConnectionId { private final int pingInterval; // how often sends ping to the server in msecs private String saslQop; // here for testing private final Configuration conf; // used to get the expected kerberos principal name - + public ConnectionId(InetSocketAddress address, Class protocol, UserGroupInformation ticket, int rpcTimeout, RetryPolicy connectionRetryPolicy, Configuration conf) { @@ -1753,7 +1756,28 @@ public ConnectionId(InetSocketAddress address, Class protocol, InetSocketAddress getAddress() { return address; } - + + /** + * This is used to update the remote address when an address change is detected. This method + * ensures that the {@link #hashCode()} won't change. + * + * @param address the updated address + * @throws IllegalArgumentException if the hostname or port doesn't match + * @see Connection#updateAddress() + */ + void setAddress(InetSocketAddress address) { + if (!Objects.equals(this.address.getHostName(), address.getHostName())) { + throw new IllegalArgumentException("Hostname must match: " + this.address + " vs " + + address); + } + if (this.address.getPort() != address.getPort()) { + throw new IllegalArgumentException("Port must match: " + this.address + " vs " + address); + } + + this.address = address; + } + + Class getProtocol() { return protocol; } @@ -1864,7 +1888,11 @@ && isEqual(this.protocol, that.protocol) @Override public int hashCode() { int result = connectionRetryPolicy.hashCode(); - result = PRIME * result + ((address == null) ? 0 : address.hashCode()); + // We calculate based on the host name and port without the IP address, since the hashCode + // must be stable even if the IP address is updated. + result = PRIME * result + ((address == null || address.getHostName() == null) ? 0 : + address.getHostName().hashCode()); + result = PRIME * result + ((address == null) ? 0 : address.getPort()); result = PRIME * result + (doPing ? 1231 : 1237); result = PRIME * result + maxIdleTime; result = PRIME * result + pingInterval; diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java index 2a19ad29a6..3e4ee707d4 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java @@ -323,7 +323,7 @@ public ProtocolProxy getProxy(Class protocol, long clientVersion, Client.ConnectionId connId, Configuration conf, SocketFactory factory) throws IOException { return getProxy(protocol, clientVersion, connId.getAddress(), - connId.ticket, conf, factory, connId.getRpcTimeout(), + connId.getTicket(), conf, factory, connId.getRpcTimeout(), connId.getRetryPolicy(), null, null); } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java index 95ff302103..1e780793a6 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java @@ -18,6 +18,7 @@ package org.apache.hadoop.ipc; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -93,6 +94,7 @@ import org.apache.hadoop.test.LambdaTestUtils; import org.apache.hadoop.test.Whitebox; import org.apache.hadoop.util.StringUtils; +import org.assertj.core.api.Condition; import org.junit.Assert; import org.junit.Assume; import org.junit.Before; @@ -815,6 +817,81 @@ public Void call() throws IOException { } } + /** + * The {@link ConnectionId#hashCode} has to be stable despite updates that occur as the the + * address evolves over time. The {@link ConnectionId} is used as a primary key in maps, so + * its hashCode can't change. + * + * @throws IOException if there is a client or server failure + */ + @Test + public void testStableHashCode() throws IOException { + Server server = new TestServer(5, false); + try { + server.start(); + + // Leave host unresolved to start. Use "localhost" as opposed + // to local IP from NetUtils.getConnectAddress(server) to force + // resolution later + InetSocketAddress unresolvedAddr = InetSocketAddress.createUnresolved( + "localhost", NetUtils.getConnectAddress(server).getPort()); + + // Setup: Create a ConnectionID using an unresolved address, and get it's hashCode to serve + // as a point of comparison. + int rpcTimeout = MIN_SLEEP_TIME * 2; + final ConnectionId remoteId = getConnectionId(unresolvedAddr, rpcTimeout, conf); + int expected = remoteId.hashCode(); + + // Start client + Client.setConnectTimeout(conf, 100); + Client client = new Client(LongWritable.class, conf); + try { + // Test: Call should re-resolve host and succeed + LongWritable param = new LongWritable(RANDOM.nextLong()); + client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, + RPC.RPC_SERVICE_CLASS_DEFAULT, null); + int actual = remoteId.hashCode(); + + // Verify: The hashCode should match, although the InetAddress is different since it has + // now been resolved + assertThat(remoteId.getAddress()).isNotEqualTo(unresolvedAddr); + assertThat(remoteId.getAddress().getHostName()).isEqualTo(unresolvedAddr.getHostName()); + assertThat(remoteId.hashCode()).isEqualTo(expected); + + // Test: Call should succeed without having to re-resolve + InetSocketAddress expectedSocketAddress = remoteId.getAddress(); + param = new LongWritable(RANDOM.nextLong()); + client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, + RPC.RPC_SERVICE_CLASS_DEFAULT, null); + + // Verify: The same instance of the InetSocketAddress has been used to make the second + // call + assertThat(remoteId.getAddress()).isSameAs(expectedSocketAddress); + + // Verify: The hashCode is protected against updates to the host name + String hostName = InetAddress.getLocalHost().getHostName(); + InetSocketAddress mismatchedHostName = NetUtils.createSocketAddr( + InetAddress.getLocalHost().getHostName(), + remoteId.getAddress().getPort()); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> remoteId.setAddress(mismatchedHostName)) + .withMessageStartingWith("Hostname must match"); + + // Verify: The hashCode is protected against updates to the port + InetSocketAddress mismatchedPort = NetUtils.createSocketAddr( + remoteId.getAddress().getHostName(), + remoteId.getAddress().getPort() + 1); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> remoteId.setAddress(mismatchedPort)) + .withMessageStartingWith("Port must match"); + } finally { + client.stop(); + } + } finally { + server.stop(); + } + } + @Test(timeout=60000) public void testIpcFlakyHostResolution() throws IOException { // start server