From a5db6831bc674a24a3251cf1b20f22a4fd4fac9f Mon Sep 17 00:00:00 2001 From: liangxs Date: Tue, 6 Jul 2021 09:11:03 +0800 Subject: [PATCH] HADOOP-17749. Remove lock contention in SelectorPool of SocketIOWithTimeout (#3080) --- .../hadoop/net/SocketIOWithTimeout.java | 103 ++++++++---------- .../hadoop/net/TestSocketIOWithTimeout.java | 79 ++++++++++++++ 2 files changed, 124 insertions(+), 58 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/net/SocketIOWithTimeout.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/net/SocketIOWithTimeout.java index 312a481f25..d117bb8a6b 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/net/SocketIOWithTimeout.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/net/SocketIOWithTimeout.java @@ -28,8 +28,9 @@ import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; -import java.util.Iterator; -import java.util.LinkedList; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.util.Time; import org.slf4j.Logger; @@ -48,8 +49,6 @@ abstract class SocketIOWithTimeout { private long timeout; private boolean closed = false; - private static SelectorPool selector = new SelectorPool(); - /* A timeout value of 0 implies wait for ever. * We should have a value of timeout that implies zero wait.. i.e. * read or write returns immediately. @@ -154,7 +153,7 @@ int doIO(ByteBuffer buf, int ops) throws IOException { //now wait for socket to be ready. int count = 0; try { - count = selector.select(channel, ops, timeout); + count = SelectorPool.select(channel, ops, timeout); } catch (IOException e) { //unexpected IOException. closed = true; throw e; @@ -200,7 +199,7 @@ static void connect(SocketChannel channel, // we might have to call finishConnect() more than once // for some channels (with user level protocols) - int ret = selector.select((SelectableChannel)channel, + int ret = SelectorPool.select(channel, SelectionKey.OP_CONNECT, timeoutLeft); if (ret > 0 && channel.finishConnect()) { @@ -242,7 +241,7 @@ static void connect(SocketChannel channel, */ void waitForIO(int ops) throws IOException { - if (selector.select(channel, ops, timeout) == 0) { + if (SelectorPool.select(channel, ops, timeout) == 0) { throw new SocketTimeoutException(timeoutExceptionString(channel, timeout, ops)); } @@ -280,12 +279,17 @@ private static String timeoutExceptionString(SelectableChannel channel, * This maintains a pool of selectors. These selectors are closed * once they are idle (unused) for a few seconds. */ - private static class SelectorPool { + private static final class SelectorPool { - private static class SelectorInfo { - Selector selector; - long lastActivityTime; - LinkedList queue; + private static final class SelectorInfo { + private final SelectorProvider provider; + private final Selector selector; + private long lastActivityTime; + + private SelectorInfo(SelectorProvider provider, Selector selector) { + this.provider = provider; + this.selector = selector; + } void close() { if (selector != null) { @@ -298,16 +302,11 @@ void close() { } } - private static class ProviderInfo { - SelectorProvider provider; - LinkedList queue; // lifo - ProviderInfo next; - } + private static ConcurrentHashMap> providerMap = new ConcurrentHashMap<>(); private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds. - private ProviderInfo providerList = null; - /** * Waits on the channel with the given timeout using one of the * cached selectors. It also removes any cached selectors that are @@ -319,7 +318,7 @@ private static class ProviderInfo { * @return * @throws IOException */ - int select(SelectableChannel channel, int ops, long timeout) + static int select(SelectableChannel channel, int ops, long timeout) throws IOException { SelectorInfo info = get(channel); @@ -385,35 +384,18 @@ int select(SelectableChannel channel, int ops, long timeout) * @return * @throws IOException */ - private synchronized SelectorInfo get(SelectableChannel channel) + private static SelectorInfo get(SelectableChannel channel) throws IOException { - SelectorInfo selInfo = null; - SelectorProvider provider = channel.provider(); - // pick the list : rarely there is more than one provider in use. - ProviderInfo pList = providerList; - while (pList != null && pList.provider != provider) { - pList = pList.next; - } - if (pList == null) { - //LOG.info("Creating new ProviderInfo : " + provider.toString()); - pList = new ProviderInfo(); - pList.provider = provider; - pList.queue = new LinkedList(); - pList.next = providerList; - providerList = pList; - } - - LinkedList queue = pList.queue; - - if (queue.isEmpty()) { + ConcurrentLinkedDeque infoQ = providerMap.computeIfAbsent( + provider, k -> new ConcurrentLinkedDeque<>()); + + SelectorInfo selInfo = infoQ.pollLast(); // last in first out + if (selInfo == null) { Selector selector = provider.openSelector(); - selInfo = new SelectorInfo(); - selInfo.selector = selector; - selInfo.queue = queue; - } else { - selInfo = queue.removeLast(); + // selInfo will be put into infoQ after `#release()` + selInfo = new SelectorInfo(provider, selector); } trimIdleSelectors(Time.now()); @@ -426,34 +408,39 @@ private synchronized SelectorInfo get(SelectableChannel channel) * * @param info */ - private synchronized void release(SelectorInfo info) { + private static void release(SelectorInfo info) { long now = Time.now(); trimIdleSelectors(now); info.lastActivityTime = now; - info.queue.addLast(info); + // SelectorInfos in queue are sorted by lastActivityTime + providerMap.get(info.provider).addLast(info); } + private static AtomicBoolean trimming = new AtomicBoolean(false); + /** * Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not * traverse the whole list, just over the one that have crossed * the timeout. */ - private void trimIdleSelectors(long now) { + private static void trimIdleSelectors(long now) { + if (!trimming.compareAndSet(false, true)) { + return; + } + long cutoff = now - IDLE_TIMEOUT; - - for(ProviderInfo pList=providerList; pList != null; pList=pList.next) { - if (pList.queue.isEmpty()) { - continue; - } - for(Iterator it = pList.queue.iterator(); it.hasNext();) { - SelectorInfo info = it.next(); - if (info.lastActivityTime > cutoff) { + for (ConcurrentLinkedDeque infoQ : providerMap.values()) { + SelectorInfo oldest; + while ((oldest = infoQ.peekFirst()) != null) { + if (oldest.lastActivityTime <= cutoff && infoQ.remove(oldest)) { + oldest.close(); + } else { break; } - it.remove(); - info.close(); } } + + trimming.set(false); } } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestSocketIOWithTimeout.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestSocketIOWithTimeout.java index 76c74a37a0..008d842937 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestSocketIOWithTimeout.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestSocketIOWithTimeout.java @@ -24,6 +24,11 @@ import java.net.SocketTimeoutException; import java.nio.channels.Pipe; import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import org.apache.hadoop.test.GenericTestUtils; import org.apache.hadoop.test.MultithreadedTestUtil; @@ -186,6 +191,46 @@ public void doWork() throws Exception { } } + @Test + public void testSocketIOWithTimeoutByMultiThread() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + Runnable ioTask = () -> { + try { + Pipe pipe = Pipe.open(); + try (Pipe.SourceChannel source = pipe.source(); + InputStream in = new SocketInputStream(source, TIMEOUT); + Pipe.SinkChannel sink = pipe.sink(); + OutputStream out = new SocketOutputStream(sink, TIMEOUT)) { + + byte[] writeBytes = TEST_STRING.getBytes(); + byte[] readBytes = new byte[writeBytes.length]; + latch.await(); + + out.write(writeBytes); + doIO(null, out, TIMEOUT); + + in.read(readBytes); + assertArrayEquals(writeBytes, readBytes); + doIO(in, null, TIMEOUT); + } + } catch (Exception e) { + fail(e.getMessage()); + } + }; + + int threadCnt = 64; + ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt); + for (int i = 0; i < threadCnt; ++i) { + threadPool.submit(ioTask); + } + + Thread.sleep(1000); + latch.countDown(); + + threadPool.shutdown(); + assertTrue(threadPool.awaitTermination(3, TimeUnit.SECONDS)); + } + @Test public void testSocketIOWithTimeoutInterrupted() throws Exception { Pipe pipe = Pipe.open(); @@ -223,4 +268,38 @@ public void doWork() throws Exception { ctx.stop(); } } + + @Test + public void testSocketIOWithTimeoutInterruptedByMultiThread() + throws Exception { + final int timeout = TIMEOUT * 10; + AtomicLong readCount = new AtomicLong(); + AtomicLong exceptionCount = new AtomicLong(); + Runnable ioTask = () -> { + try { + Pipe pipe = Pipe.open(); + try (Pipe.SourceChannel source = pipe.source(); + InputStream in = new SocketInputStream(source, timeout)) { + in.read(); + readCount.incrementAndGet(); + } catch (InterruptedIOException ste) { + exceptionCount.incrementAndGet(); + } + } catch (Exception e) { + fail(e.getMessage()); + } + }; + + int threadCnt = 64; + ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt); + for (int i = 0; i < threadCnt; ++i) { + threadPool.submit(ioTask); + } + Thread.sleep(1000); + threadPool.shutdownNow(); + threadPool.awaitTermination(1, TimeUnit.SECONDS); + + assertEquals(0, readCount.get()); + assertEquals(threadCnt, exceptionCount.get()); + } }