From 1ea5db52dd930a3b9ba3dbe04f564a38bbc5bf68 Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Fri, 18 Nov 2022 16:24:45 +0000 Subject: [PATCH] HADOOP-18324. Interrupting RPC Client calls can lead to thread exhaustion. (#4527) * Exactly 1 sending thread per an RPC connection. * If the calling thread is interrupted before the socket write, it will be skipped instead of sending it anyways. * If the calling thread is interrupted during the socket write, the write will finish. * RPC requests will be written to the socket in the order received. * Sending thread is only started by the receiving thread. * The sending thread periodically checks the shouldCloseConnection flag. --- .../java/org/apache/hadoop/ipc/Client.java | 186 ++++++----------- .../java/org/apache/hadoop/ipc/TestIPC.java | 5 - .../java/org/apache/hadoop/ipc/TestRPC.java | 195 ++++++++++++++++++ 3 files changed, 255 insertions(+), 131 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index f0d4f8921a..c43f922477 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -18,10 +18,10 @@ package org.apache.hadoop.ipc; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.classification.VisibleForTesting; import org.apache.hadoop.util.Preconditions; -import org.apache.hadoop.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.classification.InterfaceStability; @@ -166,73 +166,6 @@ public static Object getExternalHandler() { private final int maxAsyncCalls; private final AtomicInteger asyncCallCounter = new AtomicInteger(0); - /** - * Executor on which IPC calls' parameters are sent. - * Deferring the sending of parameters to a separate - * thread isolates them from thread interruptions in the - * calling code. - */ - private final ExecutorService sendParamsExecutor; - private final static ClientExecutorServiceFactory clientExcecutorFactory = - new ClientExecutorServiceFactory(); - - private static class ClientExecutorServiceFactory { - private int executorRefCount = 0; - private ExecutorService clientExecutor = null; - - /** - * Get Executor on which IPC calls' parameters are sent. - * If the internal reference counter is zero, this method - * creates the instance of Executor. If not, this method - * just returns the reference of clientExecutor. - * - * @return An ExecutorService instance - */ - synchronized ExecutorService refAndGetInstance() { - if (executorRefCount == 0) { - clientExecutor = Executors.newCachedThreadPool( - new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat("IPC Parameter Sending Thread #%d") - .build()); - } - executorRefCount++; - - return clientExecutor; - } - - /** - * Cleanup Executor on which IPC calls' parameters are sent. - * If reference counter is zero, this method discards the - * instance of the Executor. If not, this method - * just decrements the internal reference counter. - * - * @return An ExecutorService instance if it exists. - * Null is returned if not. - */ - synchronized ExecutorService unrefAndCleanup() { - executorRefCount--; - assert(executorRefCount >= 0); - - if (executorRefCount == 0) { - clientExecutor.shutdown(); - try { - if (!clientExecutor.awaitTermination(1, TimeUnit.MINUTES)) { - clientExecutor.shutdownNow(); - } - } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for clientExecutor" + - " to stop"); - clientExecutor.shutdownNow(); - Thread.currentThread().interrupt(); - } - clientExecutor = null; - } - - return clientExecutor; - } - } - /** * set the ping interval value in configuration * @@ -301,11 +234,6 @@ public static final void setConnectTimeout(Configuration conf, int timeout) { conf.setInt(CommonConfigurationKeys.IPC_CLIENT_CONNECT_TIMEOUT_KEY, timeout); } - @VisibleForTesting - public static final ExecutorService getClientExecutor() { - return Client.clientExcecutorFactory.clientExecutor; - } - /** * Increment this client's reference count */ @@ -462,8 +390,10 @@ private class Connection extends Thread { private AtomicLong lastActivity = new AtomicLong();// last I/O activity time private AtomicBoolean shouldCloseConnection = new AtomicBoolean(); // indicate if the connection is closed private IOException closeException; // close reason - - private final Object sendRpcRequestLock = new Object(); + + private final Thread rpcRequestThread; + private final SynchronousQueue> rpcRequestQueue = + new SynchronousQueue<>(true); private AtomicReference connectingThread = new AtomicReference<>(); private final Consumer removeMethod; @@ -472,6 +402,9 @@ private class Connection extends Thread { Consumer removeMethod) { this.remoteId = remoteId; this.server = remoteId.getAddress(); + this.rpcRequestThread = new Thread(new RpcRequestSender(), + "IPC Parameter Sending Thread for " + remoteId); + this.rpcRequestThread.setDaemon(true); this.maxResponseLength = remoteId.conf.getInt( CommonConfigurationKeys.IPC_MAXIMUM_RESPONSE_LENGTH, @@ -1150,6 +1083,10 @@ private synchronized void sendPing() throws IOException { @Override public void run() { + // Don't start the ipc parameter sending thread until we start this + // thread, because the shutdown logic only gets triggered if this + // thread is started. + rpcRequestThread.start(); if (LOG.isDebugEnabled()) LOG.debug(getName() + ": starting, having connections " + connections.size()); @@ -1173,9 +1110,52 @@ public void run() { + connections.size()); } + /** + * A thread to write rpc requests to the socket. + */ + private class RpcRequestSender implements Runnable { + @Override + public void run() { + while (!shouldCloseConnection.get()) { + ResponseBuffer buf = null; + try { + Pair pair = + rpcRequestQueue.poll(maxIdleTime, TimeUnit.MILLISECONDS); + if (pair == null || shouldCloseConnection.get()) { + continue; + } + buf = pair.getRight(); + synchronized (ipcStreams.out) { + if (LOG.isDebugEnabled()) { + Call call = pair.getLeft(); + LOG.debug(getName() + "{} sending #{} {}", getName(), call.id, + call.rpcRequest); + } + // RpcRequestHeader + RpcRequest + ipcStreams.sendRequest(buf.toByteArray()); + ipcStreams.flush(); + } + } catch (InterruptedException ie) { + // stop this thread + return; + } catch (IOException e) { + // exception at this point would leave the connection in an + // unrecoverable state (eg half a call left on the wire). + // So, close the connection, killing any outstanding calls + markClosed(e); + } finally { + //the buffer is just an in-memory buffer, but it is still polite to + // close early + IOUtils.closeStream(buf); + } + } + } + } + /** Initiates a rpc call by sending the rpc request to the remote server. - * Note: this is not called from the Connection thread, but by other - * threads. + * Note: this is not called from the current thread, but by another + * thread, so that if the current thread is interrupted that the socket + * state isn't corrupted with a partially written message. * @param call - the rpc request */ public void sendRpcRequest(final Call call) @@ -1185,8 +1165,7 @@ public void sendRpcRequest(final Call call) } // Serialize the call to be sent. This is done from the actual - // caller thread, rather than the sendParamsExecutor thread, - + // caller thread, rather than the rpcRequestThread in the connection, // so that if the serialization throws an error, it is reported // properly. This also parallelizes the serialization. // @@ -1203,51 +1182,7 @@ public void sendRpcRequest(final Call call) final ResponseBuffer buf = new ResponseBuffer(); header.writeDelimitedTo(buf); RpcWritable.wrap(call.rpcRequest).writeTo(buf); - - synchronized (sendRpcRequestLock) { - Future senderFuture = sendParamsExecutor.submit(new Runnable() { - @Override - public void run() { - try { - synchronized (ipcStreams.out) { - if (shouldCloseConnection.get()) { - return; - } - if (LOG.isDebugEnabled()) { - LOG.debug(getName() + " sending #" + call.id - + " " + call.rpcRequest); - } - // RpcRequestHeader + RpcRequest - ipcStreams.sendRequest(buf.toByteArray()); - ipcStreams.flush(); - } - } catch (IOException e) { - // exception at this point would leave the connection in an - // unrecoverable state (eg half a call left on the wire). - // So, close the connection, killing any outstanding calls - markClosed(e); - } finally { - //the buffer is just an in-memory buffer, but it is still polite to - // close early - IOUtils.closeStream(buf); - } - } - }); - - try { - senderFuture.get(); - } catch (ExecutionException e) { - Throwable cause = e.getCause(); - - // cause should only be a RuntimeException as the Runnable above - // catches IOException - if (cause instanceof RuntimeException) { - throw (RuntimeException) cause; - } else { - throw new RuntimeException("unexpected checked exception", cause); - } - } - } + rpcRequestQueue.put(Pair.of(call, buf)); } /* Receive a response. @@ -1396,7 +1331,6 @@ public Client(Class valueClass, Configuration conf, CommonConfigurationKeys.IPC_CLIENT_BIND_WILDCARD_ADDR_DEFAULT); this.clientId = ClientId.getClientId(); - this.sendParamsExecutor = clientExcecutorFactory.refAndGetInstance(); this.maxAsyncCalls = conf.getInt( CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY, CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_DEFAULT); @@ -1440,6 +1374,7 @@ public void stop() { // wake up all connections for (Connection conn : connections.values()) { conn.interrupt(); + conn.rpcRequestThread.interrupt(); conn.interruptConnectingThread(); } @@ -1456,7 +1391,6 @@ public void stop() { } } } - clientExcecutorFactory.unrefAndCleanup(); } /** diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java index ffa17224b0..25c6976549 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java @@ -1216,11 +1216,6 @@ public void testSocketLeak() throws IOException { @Test(timeout=30000) public void testInterrupted() { Client client = new Client(LongWritable.class, conf); - Client.getClientExecutor().submit(new Runnable() { - public void run() { - while(true); - } - }); Thread.currentThread().interrupt(); client.stop(); try { diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java index 101750d72c..084a3dbd4a 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java @@ -55,6 +55,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.event.Level; @@ -62,13 +63,16 @@ import javax.net.SocketFactory; import java.io.Closeable; import java.io.IOException; +import java.io.InputStream; import java.io.InterruptedIOException; +import java.io.OutputStream; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.ConnectException; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.Socket; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.security.PrivilegedAction; @@ -89,6 +93,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import static org.assertj.core.api.Assertions.assertThat; import static org.apache.hadoop.test.MetricsAsserts.assertCounter; @@ -993,6 +998,196 @@ public void run() { } } + /** + * This tests the case where the server isn't receiving new data and + * multiple threads queue up to send rpc requests. Only one of the requests + * should be written and all of the calling threads should be interrupted. + * + * We use a mock SocketFactory so that we can control when the input and + * output streams are frozen. + */ + @Test(timeout=30000) + public void testSlowConnection() throws Exception { + SocketFactory mockFactory = Mockito.mock(SocketFactory.class); + Socket mockSocket = Mockito.mock(Socket.class); + Mockito.when(mockFactory.createSocket()).thenReturn(mockSocket); + Mockito.when(mockSocket.getPort()).thenReturn(1234); + Mockito.when(mockSocket.getLocalPort()).thenReturn(2345); + MockOutputStream mockOutputStream = new MockOutputStream(); + Mockito.when(mockSocket.getOutputStream()).thenReturn(mockOutputStream); + // Use an input stream that always blocks + Mockito.when(mockSocket.getInputStream()).thenReturn(new InputStream() { + @Override + public int read() throws IOException { + // wait forever + while (true) { + try { + Thread.sleep(TimeUnit.DAYS.toMillis(1)); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new InterruptedIOException("test"); + } + } + } + }); + Configuration clientConf = new Configuration(); + // disable ping & timeout to minimize traffic + clientConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false); + clientConf.setInt(CommonConfigurationKeys.IPC_CLIENT_RPC_TIMEOUT_KEY, 0); + RPC.setProtocolEngine(clientConf, TestRpcService.class, ProtobufRpcEngine.class); + // set async mode so that we don't need to implement the input stream + final boolean wasAsync = Client.isAsynchronousMode(); + TestRpcService client = null; + try { + Client.setAsynchronousMode(true); + client = RPC.getProtocolProxy( + TestRpcService.class, + 0, + new InetSocketAddress("localhost", 1234), + UserGroupInformation.getCurrentUser(), + clientConf, + mockFactory).getProxy(); + // The connection isn't actually made until the first call. + client.ping(null, newEmptyRequest()); + mockOutputStream.waitForFlush(1); + final long headerAndFirst = mockOutputStream.getBytesWritten(); + client.ping(null, newEmptyRequest()); + mockOutputStream.waitForFlush(2); + final long second = mockOutputStream.getBytesWritten() - headerAndFirst; + // pause the writer thread + mockOutputStream.pause(); + // create a set of threads to create calls that will back up + ExecutorService pool = Executors.newCachedThreadPool(); + Future[] futures = new Future[numThreads]; + final AtomicInteger doneThreads = new AtomicInteger(0); + for(int thread = 0; thread < numThreads; ++thread) { + final TestRpcService finalClient = client; + futures[thread] = pool.submit(new Callable() { + @Override + public Void call() throws Exception { + finalClient.ping(null, newEmptyRequest()); + doneThreads.incrementAndGet(); + return null; + } + }); + } + // wait until the threads have started writing + mockOutputStream.waitForWriters(); + // interrupt all the threads + for(int thread=0; thread < numThreads; ++thread) { + assertTrue("cancel thread " + thread, + futures[thread].cancel(true)); + } + // wait until all the writers are cancelled + pool.shutdown(); + pool.awaitTermination(10, TimeUnit.SECONDS); + mockOutputStream.resume(); + // wait for the in flight rpc request to be flushed + mockOutputStream.waitForFlush(3); + // All the threads should have been interrupted + assertEquals(0, doneThreads.get()); + // make sure that only one additional rpc request was sent + assertEquals(headerAndFirst + second * 2, + mockOutputStream.getBytesWritten()); + } finally { + Client.setAsynchronousMode(wasAsync); + if (client != null) { + RPC.stopProxy(client); + } + } + } + + private static final class MockOutputStream extends OutputStream { + private long bytesWritten = 0; + private AtomicInteger flushCount = new AtomicInteger(0); + private ReentrantLock lock = new ReentrantLock(true); + + @Override + public synchronized void write(int b) throws IOException { + lock.lock(); + bytesWritten += 1; + lock.unlock(); + } + + @Override + public void flush() { + flushCount.incrementAndGet(); + } + + public synchronized long getBytesWritten() { + return bytesWritten; + } + + public void pause() { + lock.lock(); + } + + public void resume() { + lock.unlock(); + } + + private static final int DELAY_MS = 250; + + /** + * Wait for the Nth flush, which we assume will happen exactly when the + * Nth RPC request is sent. + * @param flush the total flush count to wait for + * @throws InterruptedException + */ + public void waitForFlush(int flush) throws InterruptedException { + while (flushCount.get() < flush) { + Thread.sleep(DELAY_MS); + } + } + + public void waitForWriters() throws InterruptedException { + while (!lock.hasQueuedThreads()) { + Thread.sleep(DELAY_MS); + } + } + } + + /** + * This test causes an exception in the RPC connection setup to make + * sure that threads aren't leaked. + */ + @Test(timeout=30000) + public void testBadSetup() throws Exception { + SocketFactory mockFactory = Mockito.mock(SocketFactory.class); + Mockito.when(mockFactory.createSocket()) + .thenThrow(new IOException("can't connect")); + Configuration clientConf = new Configuration(); + // Set an illegal value to cause an exception in the constructor + clientConf.set(CommonConfigurationKeys.IPC_MAXIMUM_RESPONSE_LENGTH, + "xxx"); + RPC.setProtocolEngine(clientConf, TestRpcService.class, + ProtobufRpcEngine.class); + TestRpcService client = null; + int threadCount = Thread.getAllStackTraces().size(); + try { + try { + client = RPC.getProtocolProxy( + TestRpcService.class, + 0, + new InetSocketAddress("localhost", 1234), + UserGroupInformation.getCurrentUser(), + clientConf, + mockFactory).getProxy(); + client.ping(null, newEmptyRequest()); + assertTrue("Didn't throw exception!", false); + } catch (ServiceException nfe) { + // ensure no extra threads are running. + assertEquals(threadCount, Thread.getAllStackTraces().size()); + } catch (Throwable t) { + assertTrue("wrong exception: " + t, false); + } + } finally { + if (client != null) { + RPC.stopProxy(client); + } + } + } + @Test public void testConnectionPing() throws Exception { Server server;