HADOOP-18324. Interrupting RPC Client calls can lead to thread exhaustion. (#4527)

* Exactly 1 sending thread per an RPC connection.
* If the calling thread is interrupted before the socket write, it will be skipped instead of sending it anyways.
* If the calling thread is interrupted during the socket write, the write will finish.
* RPC requests will be written to the socket in the order received.
* Sending thread is only started by the receiving thread.
* The sending thread periodically checks the shouldCloseConnection flag.
This commit is contained in:
Owen O'Malley 2022-11-18 16:24:45 +00:00 committed by GitHub
parent 7d39abd799
commit 1ea5db52dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 255 additions and 131 deletions

View File

@ -18,10 +18,10 @@
package org.apache.hadoop.ipc;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.util.Preconditions;
import org.apache.hadoop.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability;
@ -166,73 +166,6 @@ public static Object getExternalHandler() {
private final int maxAsyncCalls;
private final AtomicInteger asyncCallCounter = new AtomicInteger(0);
/**
* Executor on which IPC calls' parameters are sent.
* Deferring the sending of parameters to a separate
* thread isolates them from thread interruptions in the
* calling code.
*/
private final ExecutorService sendParamsExecutor;
private final static ClientExecutorServiceFactory clientExcecutorFactory =
new ClientExecutorServiceFactory();
private static class ClientExecutorServiceFactory {
private int executorRefCount = 0;
private ExecutorService clientExecutor = null;
/**
* Get Executor on which IPC calls' parameters are sent.
* If the internal reference counter is zero, this method
* creates the instance of Executor. If not, this method
* just returns the reference of clientExecutor.
*
* @return An ExecutorService instance
*/
synchronized ExecutorService refAndGetInstance() {
if (executorRefCount == 0) {
clientExecutor = Executors.newCachedThreadPool(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("IPC Parameter Sending Thread #%d")
.build());
}
executorRefCount++;
return clientExecutor;
}
/**
* Cleanup Executor on which IPC calls' parameters are sent.
* If reference counter is zero, this method discards the
* instance of the Executor. If not, this method
* just decrements the internal reference counter.
*
* @return An ExecutorService instance if it exists.
* Null is returned if not.
*/
synchronized ExecutorService unrefAndCleanup() {
executorRefCount--;
assert(executorRefCount >= 0);
if (executorRefCount == 0) {
clientExecutor.shutdown();
try {
if (!clientExecutor.awaitTermination(1, TimeUnit.MINUTES)) {
clientExecutor.shutdownNow();
}
} catch (InterruptedException e) {
LOG.warn("Interrupted while waiting for clientExecutor" +
" to stop");
clientExecutor.shutdownNow();
Thread.currentThread().interrupt();
}
clientExecutor = null;
}
return clientExecutor;
}
}
/**
* set the ping interval value in configuration
*
@ -301,11 +234,6 @@ public static final void setConnectTimeout(Configuration conf, int timeout) {
conf.setInt(CommonConfigurationKeys.IPC_CLIENT_CONNECT_TIMEOUT_KEY, timeout);
}
@VisibleForTesting
public static final ExecutorService getClientExecutor() {
return Client.clientExcecutorFactory.clientExecutor;
}
/**
* Increment this client's reference count
*/
@ -462,8 +390,10 @@ private class Connection extends Thread {
private AtomicLong lastActivity = new AtomicLong();// last I/O activity time
private AtomicBoolean shouldCloseConnection = new AtomicBoolean(); // indicate if the connection is closed
private IOException closeException; // close reason
private final Object sendRpcRequestLock = new Object();
private final Thread rpcRequestThread;
private final SynchronousQueue<Pair<Call, ResponseBuffer>> rpcRequestQueue =
new SynchronousQueue<>(true);
private AtomicReference<Thread> connectingThread = new AtomicReference<>();
private final Consumer<Connection> removeMethod;
@ -472,6 +402,9 @@ private class Connection extends Thread {
Consumer<Connection> removeMethod) {
this.remoteId = remoteId;
this.server = remoteId.getAddress();
this.rpcRequestThread = new Thread(new RpcRequestSender(),
"IPC Parameter Sending Thread for " + remoteId);
this.rpcRequestThread.setDaemon(true);
this.maxResponseLength = remoteId.conf.getInt(
CommonConfigurationKeys.IPC_MAXIMUM_RESPONSE_LENGTH,
@ -1150,6 +1083,10 @@ private synchronized void sendPing() throws IOException {
@Override
public void run() {
// Don't start the ipc parameter sending thread until we start this
// thread, because the shutdown logic only gets triggered if this
// thread is started.
rpcRequestThread.start();
if (LOG.isDebugEnabled())
LOG.debug(getName() + ": starting, having connections "
+ connections.size());
@ -1173,9 +1110,52 @@ public void run() {
+ connections.size());
}
/**
* A thread to write rpc requests to the socket.
*/
private class RpcRequestSender implements Runnable {
@Override
public void run() {
while (!shouldCloseConnection.get()) {
ResponseBuffer buf = null;
try {
Pair<Call, ResponseBuffer> pair =
rpcRequestQueue.poll(maxIdleTime, TimeUnit.MILLISECONDS);
if (pair == null || shouldCloseConnection.get()) {
continue;
}
buf = pair.getRight();
synchronized (ipcStreams.out) {
if (LOG.isDebugEnabled()) {
Call call = pair.getLeft();
LOG.debug(getName() + "{} sending #{} {}", getName(), call.id,
call.rpcRequest);
}
// RpcRequestHeader + RpcRequest
ipcStreams.sendRequest(buf.toByteArray());
ipcStreams.flush();
}
} catch (InterruptedException ie) {
// stop this thread
return;
} catch (IOException e) {
// exception at this point would leave the connection in an
// unrecoverable state (eg half a call left on the wire).
// So, close the connection, killing any outstanding calls
markClosed(e);
} finally {
//the buffer is just an in-memory buffer, but it is still polite to
// close early
IOUtils.closeStream(buf);
}
}
}
}
/** Initiates a rpc call by sending the rpc request to the remote server.
* Note: this is not called from the Connection thread, but by other
* threads.
* Note: this is not called from the current thread, but by another
* thread, so that if the current thread is interrupted that the socket
* state isn't corrupted with a partially written message.
* @param call - the rpc request
*/
public void sendRpcRequest(final Call call)
@ -1185,8 +1165,7 @@ public void sendRpcRequest(final Call call)
}
// Serialize the call to be sent. This is done from the actual
// caller thread, rather than the sendParamsExecutor thread,
// caller thread, rather than the rpcRequestThread in the connection,
// so that if the serialization throws an error, it is reported
// properly. This also parallelizes the serialization.
//
@ -1203,51 +1182,7 @@ public void sendRpcRequest(final Call call)
final ResponseBuffer buf = new ResponseBuffer();
header.writeDelimitedTo(buf);
RpcWritable.wrap(call.rpcRequest).writeTo(buf);
synchronized (sendRpcRequestLock) {
Future<?> senderFuture = sendParamsExecutor.submit(new Runnable() {
@Override
public void run() {
try {
synchronized (ipcStreams.out) {
if (shouldCloseConnection.get()) {
return;
}
if (LOG.isDebugEnabled()) {
LOG.debug(getName() + " sending #" + call.id
+ " " + call.rpcRequest);
}
// RpcRequestHeader + RpcRequest
ipcStreams.sendRequest(buf.toByteArray());
ipcStreams.flush();
}
} catch (IOException e) {
// exception at this point would leave the connection in an
// unrecoverable state (eg half a call left on the wire).
// So, close the connection, killing any outstanding calls
markClosed(e);
} finally {
//the buffer is just an in-memory buffer, but it is still polite to
// close early
IOUtils.closeStream(buf);
}
}
});
try {
senderFuture.get();
} catch (ExecutionException e) {
Throwable cause = e.getCause();
// cause should only be a RuntimeException as the Runnable above
// catches IOException
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
} else {
throw new RuntimeException("unexpected checked exception", cause);
}
}
}
rpcRequestQueue.put(Pair.of(call, buf));
}
/* Receive a response.
@ -1396,7 +1331,6 @@ public Client(Class<? extends Writable> valueClass, Configuration conf,
CommonConfigurationKeys.IPC_CLIENT_BIND_WILDCARD_ADDR_DEFAULT);
this.clientId = ClientId.getClientId();
this.sendParamsExecutor = clientExcecutorFactory.refAndGetInstance();
this.maxAsyncCalls = conf.getInt(
CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY,
CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_DEFAULT);
@ -1440,6 +1374,7 @@ public void stop() {
// wake up all connections
for (Connection conn : connections.values()) {
conn.interrupt();
conn.rpcRequestThread.interrupt();
conn.interruptConnectingThread();
}
@ -1456,7 +1391,6 @@ public void stop() {
}
}
}
clientExcecutorFactory.unrefAndCleanup();
}
/**

View File

@ -1216,11 +1216,6 @@ public void testSocketLeak() throws IOException {
@Test(timeout=30000)
public void testInterrupted() {
Client client = new Client(LongWritable.class, conf);
Client.getClientExecutor().submit(new Runnable() {
public void run() {
while(true);
}
});
Thread.currentThread().interrupt();
client.stop();
try {

View File

@ -55,6 +55,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.event.Level;
@ -62,13 +63,16 @@
import javax.net.SocketFactory;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.security.PrivilegedAction;
@ -89,6 +93,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import static org.assertj.core.api.Assertions.assertThat;
import static org.apache.hadoop.test.MetricsAsserts.assertCounter;
@ -993,6 +998,196 @@ public void run() {
}
}
/**
* This tests the case where the server isn't receiving new data and
* multiple threads queue up to send rpc requests. Only one of the requests
* should be written and all of the calling threads should be interrupted.
*
* We use a mock SocketFactory so that we can control when the input and
* output streams are frozen.
*/
@Test(timeout=30000)
public void testSlowConnection() throws Exception {
SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
Socket mockSocket = Mockito.mock(Socket.class);
Mockito.when(mockFactory.createSocket()).thenReturn(mockSocket);
Mockito.when(mockSocket.getPort()).thenReturn(1234);
Mockito.when(mockSocket.getLocalPort()).thenReturn(2345);
MockOutputStream mockOutputStream = new MockOutputStream();
Mockito.when(mockSocket.getOutputStream()).thenReturn(mockOutputStream);
// Use an input stream that always blocks
Mockito.when(mockSocket.getInputStream()).thenReturn(new InputStream() {
@Override
public int read() throws IOException {
// wait forever
while (true) {
try {
Thread.sleep(TimeUnit.DAYS.toMillis(1));
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new InterruptedIOException("test");
}
}
}
});
Configuration clientConf = new Configuration();
// disable ping & timeout to minimize traffic
clientConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false);
clientConf.setInt(CommonConfigurationKeys.IPC_CLIENT_RPC_TIMEOUT_KEY, 0);
RPC.setProtocolEngine(clientConf, TestRpcService.class, ProtobufRpcEngine.class);
// set async mode so that we don't need to implement the input stream
final boolean wasAsync = Client.isAsynchronousMode();
TestRpcService client = null;
try {
Client.setAsynchronousMode(true);
client = RPC.getProtocolProxy(
TestRpcService.class,
0,
new InetSocketAddress("localhost", 1234),
UserGroupInformation.getCurrentUser(),
clientConf,
mockFactory).getProxy();
// The connection isn't actually made until the first call.
client.ping(null, newEmptyRequest());
mockOutputStream.waitForFlush(1);
final long headerAndFirst = mockOutputStream.getBytesWritten();
client.ping(null, newEmptyRequest());
mockOutputStream.waitForFlush(2);
final long second = mockOutputStream.getBytesWritten() - headerAndFirst;
// pause the writer thread
mockOutputStream.pause();
// create a set of threads to create calls that will back up
ExecutorService pool = Executors.newCachedThreadPool();
Future[] futures = new Future[numThreads];
final AtomicInteger doneThreads = new AtomicInteger(0);
for(int thread = 0; thread < numThreads; ++thread) {
final TestRpcService finalClient = client;
futures[thread] = pool.submit(new Callable<Void>() {
@Override
public Void call() throws Exception {
finalClient.ping(null, newEmptyRequest());
doneThreads.incrementAndGet();
return null;
}
});
}
// wait until the threads have started writing
mockOutputStream.waitForWriters();
// interrupt all the threads
for(int thread=0; thread < numThreads; ++thread) {
assertTrue("cancel thread " + thread,
futures[thread].cancel(true));
}
// wait until all the writers are cancelled
pool.shutdown();
pool.awaitTermination(10, TimeUnit.SECONDS);
mockOutputStream.resume();
// wait for the in flight rpc request to be flushed
mockOutputStream.waitForFlush(3);
// All the threads should have been interrupted
assertEquals(0, doneThreads.get());
// make sure that only one additional rpc request was sent
assertEquals(headerAndFirst + second * 2,
mockOutputStream.getBytesWritten());
} finally {
Client.setAsynchronousMode(wasAsync);
if (client != null) {
RPC.stopProxy(client);
}
}
}
private static final class MockOutputStream extends OutputStream {
private long bytesWritten = 0;
private AtomicInteger flushCount = new AtomicInteger(0);
private ReentrantLock lock = new ReentrantLock(true);
@Override
public synchronized void write(int b) throws IOException {
lock.lock();
bytesWritten += 1;
lock.unlock();
}
@Override
public void flush() {
flushCount.incrementAndGet();
}
public synchronized long getBytesWritten() {
return bytesWritten;
}
public void pause() {
lock.lock();
}
public void resume() {
lock.unlock();
}
private static final int DELAY_MS = 250;
/**
* Wait for the Nth flush, which we assume will happen exactly when the
* Nth RPC request is sent.
* @param flush the total flush count to wait for
* @throws InterruptedException
*/
public void waitForFlush(int flush) throws InterruptedException {
while (flushCount.get() < flush) {
Thread.sleep(DELAY_MS);
}
}
public void waitForWriters() throws InterruptedException {
while (!lock.hasQueuedThreads()) {
Thread.sleep(DELAY_MS);
}
}
}
/**
* This test causes an exception in the RPC connection setup to make
* sure that threads aren't leaked.
*/
@Test(timeout=30000)
public void testBadSetup() throws Exception {
SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
Mockito.when(mockFactory.createSocket())
.thenThrow(new IOException("can't connect"));
Configuration clientConf = new Configuration();
// Set an illegal value to cause an exception in the constructor
clientConf.set(CommonConfigurationKeys.IPC_MAXIMUM_RESPONSE_LENGTH,
"xxx");
RPC.setProtocolEngine(clientConf, TestRpcService.class,
ProtobufRpcEngine.class);
TestRpcService client = null;
int threadCount = Thread.getAllStackTraces().size();
try {
try {
client = RPC.getProtocolProxy(
TestRpcService.class,
0,
new InetSocketAddress("localhost", 1234),
UserGroupInformation.getCurrentUser(),
clientConf,
mockFactory).getProxy();
client.ping(null, newEmptyRequest());
assertTrue("Didn't throw exception!", false);
} catch (ServiceException nfe) {
// ensure no extra threads are running.
assertEquals(threadCount, Thread.getAllStackTraces().size());
} catch (Throwable t) {
assertTrue("wrong exception: " + t, false);
}
} finally {
if (client != null) {
RPC.stopProxy(client);
}
}
}
@Test
public void testConnectionPing() throws Exception {
Server server;