diff --git a/hadoop-client-modules/hadoop-client-runtime/pom.xml b/hadoop-client-modules/hadoop-client-runtime/pom.xml index b2bd7a4fc4..ddafdab9b1 100644 --- a/hadoop-client-modules/hadoop-client-runtime/pom.xml +++ b/hadoop-client-modules/hadoop-client-runtime/pom.xml @@ -148,6 +148,7 @@ com.google.code.findbugs:jsr305 + io.netty:* io.dropwizard.metrics:metrics-core org.eclipse.jetty:jetty-servlet org.eclipse.jetty:jetty-security diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java index 1da5b2f5d3..e013d017b1 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java @@ -53,14 +53,15 @@ import org.apache.hadoop.classification.VisibleForTesting; -class Fetcher extends Thread { +@VisibleForTesting +public class Fetcher extends Thread { private static final Logger LOG = LoggerFactory.getLogger(Fetcher.class); - /** Number of ms before timing out a copy */ + /** Number of ms before timing out a copy. */ private static final int DEFAULT_STALLED_COPY_TIMEOUT = 3 * 60 * 1000; - /** Basic/unit connection timeout (in milliseconds) */ + /** Basic/unit connection timeout (in milliseconds). */ private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000; /* Default read timeout (in milliseconds) */ @@ -72,10 +73,12 @@ class Fetcher extends Thread { private static final String FETCH_RETRY_AFTER_HEADER = "Retry-After"; protected final Reporter reporter; - private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, + @VisibleForTesting + public enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, CONNECTION, WRONG_REDUCE} - - private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors"; + + @VisibleForTesting + public final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors"; private final JobConf jobConf; private final Counters.Counter connectionErrs; private final Counters.Counter ioErrs; @@ -83,8 +86,8 @@ private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, private final Counters.Counter badIdErrs; private final Counters.Counter wrongMapErrs; private final Counters.Counter wrongReduceErrs; - protected final MergeManager merger; - protected final ShuffleSchedulerImpl scheduler; + protected final MergeManager merger; + protected final ShuffleSchedulerImpl scheduler; protected final ShuffleClientMetrics metrics; protected final ExceptionReporter exceptionReporter; protected final int id; @@ -111,7 +114,7 @@ private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, private static SSLFactory sslFactory; public Fetcher(JobConf job, TaskAttemptID reduceId, - ShuffleSchedulerImpl scheduler, MergeManager merger, + ShuffleSchedulerImpl scheduler, MergeManager merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey shuffleKey) { this(job, reduceId, scheduler, merger, reporter, metrics, @@ -120,7 +123,7 @@ public Fetcher(JobConf job, TaskAttemptID reduceId, @VisibleForTesting Fetcher(JobConf job, TaskAttemptID reduceId, - ShuffleSchedulerImpl scheduler, MergeManager merger, + ShuffleSchedulerImpl scheduler, MergeManager merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey shuffleKey, int id) { @@ -315,9 +318,8 @@ protected void copyFromHost(MapHost host) throws IOException { return; } - if(LOG.isDebugEnabled()) { - LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " - + maps); + if (LOG.isDebugEnabled()) { + LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " + maps); } // List of maps to be fetched yet @@ -411,8 +413,8 @@ private void openConnectionWithRetry(URL url) throws IOException { shouldWait = false; } catch (IOException e) { if (!fetchRetryEnabled) { - // throw exception directly if fetch's retry is not enabled - throw e; + // throw exception directly if fetch's retry is not enabled + throw e; } if ((Time.monotonicNow() - startTime) >= this.fetchRetryTimeout) { LOG.warn("Failed to connect to host: " + url + "after " @@ -489,7 +491,7 @@ private TaskAttemptID[] copyMapOutput(MapHost host, DataInputStream input, Set remaining, boolean canRetry) throws IOException { - MapOutput mapOutput = null; + MapOutput mapOutput = null; TaskAttemptID mapId = null; long decompressedLength = -1; long compressedLength = -1; @@ -611,7 +613,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe) // First time to retry. long currentTime = Time.monotonicNow(); if (retryStartTime == 0) { - retryStartTime = currentTime; + retryStartTime = currentTime; } // Retry is not timeout, let's do retry with throwing an exception. @@ -628,7 +630,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe) } /** - * Do some basic verification on the input received -- Being defensive + * Do some basic verification on the input received -- Being defensive. * @param compressedLength * @param decompressedLength * @param forReduce @@ -695,8 +697,7 @@ private URL getMapOutputURL(MapHost host, Collection maps * only on the last failure. Instead of connecting with a timeout of * X, we try connecting with a timeout of x < X but multiple times. */ - private void connect(URLConnection connection, int connectionTimeout) - throws IOException { + private void connect(URLConnection connection, int connectionTimeout) throws IOException { int unit = 0; if (connectionTimeout < 0) { throw new IOException("Invalid timeout " diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapred/TestReduceFetchFromPartialMem.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapred/TestReduceFetchFromPartialMem.java index 9b04f64ac6..1b99ce0c0a 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapred/TestReduceFetchFromPartialMem.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapred/TestReduceFetchFromPartialMem.java @@ -26,6 +26,7 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.mapreduce.TaskCounter; +import org.apache.hadoop.mapreduce.task.reduce.Fetcher; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -37,6 +38,7 @@ import java.util.Formatter; import java.util.Iterator; +import static org.apache.hadoop.mapreduce.task.reduce.Fetcher.SHUFFLE_ERR_GRP_NAME; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -87,6 +89,9 @@ public void testReduceFromPartialMem() throws Exception { final long spill = c.findCounter(TaskCounter.SPILLED_RECORDS).getCounter(); assertTrue("Expected some records not spilled during reduce" + spill + ")", spill < 2 * out); // spilled map records, some records at the reduce + long shuffleIoErrors = + c.getGroup(SHUFFLE_ERR_GRP_NAME).getCounter(Fetcher.ShuffleErrors.IO_ERROR.toString()); + assertEquals(0, shuffleIoErrors); } /** diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedChunkedFile.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedChunkedFile.java index 99d4a4cb42..1f009a4919 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedChunkedFile.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedChunkedFile.java @@ -23,6 +23,9 @@ import java.io.RandomAccessFile; import org.apache.hadoop.classification.VisibleForTesting; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.stream.ChunkedFile; import org.apache.hadoop.io.ReadaheadPool; import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; import org.apache.hadoop.io.nativeio.NativeIO; @@ -31,8 +34,6 @@ import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED; -import org.jboss.netty.handler.stream.ChunkedFile; - public class FadvisedChunkedFile extends ChunkedFile { private static final Logger LOG = @@ -64,16 +65,16 @@ FileDescriptor getFd() { } @Override - public Object nextChunk() throws Exception { + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { synchronized (closeLock) { if (fd.valid()) { if (manageOsCache && readaheadPool != null) { readaheadRequest = readaheadPool .readaheadStream( - identifier, fd, getCurrentOffset(), readaheadLength, - getEndOffset(), readaheadRequest); + identifier, fd, currentOffset(), readaheadLength, + endOffset(), readaheadRequest); } - return super.nextChunk(); + return super.readChunk(allocator); } else { return null; } @@ -88,12 +89,12 @@ public void close() throws Exception { readaheadRequest = null; } if (fd.valid() && - manageOsCache && getEndOffset() - getStartOffset() > 0) { + manageOsCache && endOffset() - startOffset() > 0) { try { NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible( identifier, fd, - getStartOffset(), getEndOffset() - getStartOffset(), + startOffset(), endOffset() - startOffset(), POSIX_FADV_DONTNEED); } catch (Throwable t) { LOG.warn("Failed to manage OS cache for " + identifier + diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedFileRegion.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedFileRegion.java index 1d3f162c90..9290a282e3 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedFileRegion.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/FadvisedFileRegion.java @@ -25,6 +25,7 @@ import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; +import io.netty.channel.DefaultFileRegion; import org.apache.hadoop.io.ReadaheadPool; import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; import org.apache.hadoop.io.nativeio.NativeIO; @@ -33,8 +34,6 @@ import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED; -import org.jboss.netty.channel.DefaultFileRegion; - import org.apache.hadoop.classification.VisibleForTesting; public class FadvisedFileRegion extends DefaultFileRegion { @@ -77,8 +76,8 @@ public long transferTo(WritableByteChannel target, long position) throws IOException { if (readaheadPool != null && readaheadLength > 0) { readaheadRequest = readaheadPool.readaheadStream(identifier, fd, - getPosition() + position, readaheadLength, - getPosition() + getCount(), readaheadRequest); + position() + position, readaheadLength, + position() + count(), readaheadRequest); } if(this.shuffleTransferToAllowed) { @@ -147,11 +146,11 @@ long customShuffleTransfer(WritableByteChannel target, long position) @Override - public void releaseExternalResources() { + protected void deallocate() { if (readaheadRequest != null) { readaheadRequest.cancel(); } - super.releaseExternalResources(); + super.deallocate(); } /** @@ -159,10 +158,10 @@ public void releaseExternalResources() { * we don't need the region to be cached anymore. */ public void transferSuccessful() { - if (manageOsCache && getCount() > 0) { + if (manageOsCache && count() > 0) { try { NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier, - fd, getPosition(), getCount(), POSIX_FADV_DONTNEED); + fd, position(), count(), POSIX_FADV_DONTNEED); } catch (Throwable t) { LOG.warn("Failed to manage OS cache for " + identifier, t); } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/LoggingHttpResponseEncoder.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/LoggingHttpResponseEncoder.java new file mode 100644 index 0000000000..c7b98ce166 --- /dev/null +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/LoggingHttpResponseEncoder.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapred; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +class LoggingHttpResponseEncoder extends HttpResponseEncoder { + private static final Logger LOG = LoggerFactory.getLogger(LoggingHttpResponseEncoder.class); + private final boolean logStacktraceOfEncodingMethods; + + LoggingHttpResponseEncoder(boolean logStacktraceOfEncodingMethods) { + this.logStacktraceOfEncodingMethods = logStacktraceOfEncodingMethods; + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + printExecutingMethod(); + LOG.info("OUTBOUND MESSAGE: " + msg); + return super.acceptOutboundMessage(msg); + } + + @Override + protected void encodeInitialLine(ByteBuf buf, HttpResponse response) throws Exception { + LOG.debug("Executing method: {}, response: {}", + getExecutingMethodName(), response); + logStacktraceIfRequired(); + super.encodeInitialLine(buf, response); + } + + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, + List out) throws Exception { + LOG.debug("Encoding to channel {}: {}", ctx.channel(), msg); + printExecutingMethod(); + logStacktraceIfRequired(); + super.encode(ctx, msg, out); + } + + @Override + protected void encodeHeaders(HttpHeaders headers, ByteBuf buf) { + printExecutingMethod(); + super.encodeHeaders(headers, buf); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise + promise) throws Exception { + LOG.debug("Writing to channel {}: {}", ctx.channel(), msg); + printExecutingMethod(); + super.write(ctx, msg, promise); + } + + private void logStacktraceIfRequired() { + if (logStacktraceOfEncodingMethods) { + LOG.debug("Stacktrace: ", new Throwable()); + } + } + + private void printExecutingMethod() { + String methodName = getExecutingMethodName(1); + LOG.debug("Executing method: {}", methodName); + } + + private String getExecutingMethodName() { + return getExecutingMethodName(0); + } + + private String getExecutingMethodName(int additionalSkipFrames) { + try { + StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); + // Array items (indices): + // 0: java.lang.Thread.getStackTrace(...) + // 1: TestShuffleHandler$LoggingHttpResponseEncoder.getExecutingMethodName(...) + int skipFrames = 2 + additionalSkipFrames; + String methodName = stackTrace[skipFrames].getMethodName(); + String className = this.getClass().getSimpleName(); + return className + "#" + methodName; + } catch (Throwable t) { + LOG.error("Error while getting execution method name", t); + return "unknown"; + } + } +} diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java index 448082f7fe..e4a43f85b9 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java @@ -18,19 +18,20 @@ package org.apache.hadoop.mapred; +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; +import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; +import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.apache.hadoop.mapred.ShuffleHandler.NettyChannelHelper.*; import static org.fusesource.leveldbjni.JniDBFactory.asString; import static org.fusesource.leveldbjni.JniDBFactory.bytes; -import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; -import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; -import static org.jboss.netty.handler.codec.http.HttpMethod.GET; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED; -import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; import java.io.File; import java.io.FileNotFoundException; @@ -54,6 +55,44 @@ import javax.crypto.SecretKey; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.stream.ChunkedWriteHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.DefaultEventExecutorGroup; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DataInputByteBuffer; @@ -79,7 +118,6 @@ import org.apache.hadoop.security.token.Token; import org.apache.hadoop.util.DiskChecker; import org.apache.hadoop.util.Shell; -import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto; import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; @@ -94,42 +132,6 @@ import org.iq80.leveldb.DB; import org.iq80.leveldb.DBException; import org.iq80.leveldb.Options; -import org.jboss.netty.bootstrap.ServerBootstrap; -import org.jboss.netty.buffer.ChannelBuffers; -import org.jboss.netty.channel.Channel; -import org.jboss.netty.channel.ChannelFactory; -import org.jboss.netty.channel.ChannelFuture; -import org.jboss.netty.channel.ChannelFutureListener; -import org.jboss.netty.channel.ChannelHandler; -import org.jboss.netty.channel.ChannelHandlerContext; -import org.jboss.netty.channel.ChannelPipeline; -import org.jboss.netty.channel.ChannelPipelineFactory; -import org.jboss.netty.channel.ChannelStateEvent; -import org.jboss.netty.channel.Channels; -import org.jboss.netty.channel.ExceptionEvent; -import org.jboss.netty.channel.MessageEvent; -import org.jboss.netty.channel.SimpleChannelUpstreamHandler; -import org.jboss.netty.channel.group.ChannelGroup; -import org.jboss.netty.channel.group.DefaultChannelGroup; -import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; -import org.jboss.netty.handler.codec.frame.TooLongFrameException; -import org.jboss.netty.handler.codec.http.DefaultHttpResponse; -import org.jboss.netty.handler.codec.http.HttpChunkAggregator; -import org.jboss.netty.handler.codec.http.HttpRequest; -import org.jboss.netty.handler.codec.http.HttpRequestDecoder; -import org.jboss.netty.handler.codec.http.HttpResponse; -import org.jboss.netty.handler.codec.http.HttpResponseEncoder; -import org.jboss.netty.handler.codec.http.HttpResponseStatus; -import org.jboss.netty.handler.codec.http.QueryStringDecoder; -import org.jboss.netty.handler.ssl.SslHandler; -import org.jboss.netty.handler.stream.ChunkedWriteHandler; -import org.jboss.netty.handler.timeout.IdleState; -import org.jboss.netty.handler.timeout.IdleStateAwareChannelHandler; -import org.jboss.netty.handler.timeout.IdleStateEvent; -import org.jboss.netty.handler.timeout.IdleStateHandler; -import org.jboss.netty.util.CharsetUtil; -import org.jboss.netty.util.HashedWheelTimer; -import org.jboss.netty.util.Timer; import org.eclipse.jetty.http.HttpHeader; import org.slf4j.LoggerFactory; @@ -182,19 +184,29 @@ public class ShuffleHandler extends AuxiliaryService { public static final HttpResponseStatus TOO_MANY_REQ_STATUS = new HttpResponseStatus(429, "TOO MANY REQUESTS"); - // This should kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT + // This should be kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT public static final long FETCH_RETRY_DELAY = 1000L; public static final String RETRY_AFTER_HEADER = "Retry-After"; + static final String ENCODER_HANDLER_NAME = "encoder"; private int port; - private ChannelFactory selector; - private final ChannelGroup accepted = new DefaultChannelGroup(); + private EventLoopGroup bossGroup; + private EventLoopGroup workerGroup; + private ServerBootstrap bootstrap; + private Channel ch; + private final ChannelGroup accepted = + new DefaultChannelGroup(new DefaultEventExecutorGroup(5).next()); + private final AtomicInteger activeConnections = new AtomicInteger(); protected HttpPipelineFactory pipelineFact; private int sslFileBufferSize; + + //TODO snemeth add a config option for these later, this is temporarily disabled for now. + private boolean useOutboundExceptionHandler = false; + private boolean useOutboundLogger = false; /** * Should the shuffle use posix_fadvise calls to manage the OS cache during - * sendfile + * sendfile. */ private boolean manageOsCache; private int readaheadLength; @@ -204,7 +216,7 @@ public class ShuffleHandler extends AuxiliaryService { private int maxSessionOpenFiles; private ReadaheadPool readaheadPool = ReadaheadPool.getInstance(); - private Map userRsrc; + private Map userRsrc; private JobTokenSecretManager secretManager; private DB stateDb = null; @@ -235,7 +247,7 @@ public class ShuffleHandler extends AuxiliaryService { public static final String CONNECTION_CLOSE = "close"; public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY = - "mapreduce.shuffle.ssl.file.buffer.size"; + "mapreduce.shuffle.ssl.file.buffer.size"; public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024; @@ -255,7 +267,7 @@ public class ShuffleHandler extends AuxiliaryService { public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true; public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = false; - private static final String TIMEOUT_HANDLER = "timeout"; + static final String TIMEOUT_HANDLER = "timeout"; /* the maximum number of files a single GET request can open simultaneously during shuffle @@ -267,7 +279,6 @@ public class ShuffleHandler extends AuxiliaryService { boolean connectionKeepAliveEnabled = false; private int connectionKeepAliveTimeOut; private int mapOutputMetaInfoCacheSize; - private Timer timer; @Metrics(about="Shuffle output metrics", context="mapred") static class ShuffleMetrics implements ChannelFutureListener { @@ -291,6 +302,49 @@ public void operationComplete(ChannelFuture future) throws Exception { } } + static class NettyChannelHelper { + static ChannelFuture writeToChannel(Channel ch, Object obj) { + LOG.debug("Writing {} to channel: {}", obj.getClass().getSimpleName(), ch.id()); + return ch.writeAndFlush(obj); + } + + static ChannelFuture writeToChannelAndClose(Channel ch, Object obj) { + return writeToChannel(ch, obj).addListener(ChannelFutureListener.CLOSE); + } + + static ChannelFuture writeToChannelAndAddLastHttpContent(Channel ch, HttpResponse obj) { + writeToChannel(ch, obj); + return writeLastHttpContentToChannel(ch); + } + + static ChannelFuture writeLastHttpContentToChannel(Channel ch) { + LOG.debug("Writing LastHttpContent, channel id: {}", ch.id()); + return ch.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); + } + + static ChannelFuture closeChannel(Channel ch) { + LOG.debug("Closing channel, channel id: {}", ch.id()); + return ch.close(); + } + + static void closeChannels(ChannelGroup channelGroup) { + channelGroup.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + } + + public static ChannelFuture closeAsIdle(Channel ch, int timeout) { + LOG.debug("Closing channel as writer was idle for {} seconds", timeout); + return closeChannel(ch); + } + + public static void channelActive(Channel ch) { + LOG.debug("Executing channelActive, channel id: {}", ch.id()); + } + + public static void channelInactive(Channel ch) { + LOG.debug("Executing channelInactive, channel id: {}", ch.id()); + } + } + private final MetricsSystem ms; final ShuffleMetrics metrics; @@ -298,29 +352,36 @@ class ReduceMapFileCount implements ChannelFutureListener { private ReduceContext reduceContext; - public ReduceMapFileCount(ReduceContext rc) { + ReduceMapFileCount(ReduceContext rc) { this.reduceContext = rc; } @Override public void operationComplete(ChannelFuture future) throws Exception { + LOG.trace("operationComplete"); if (!future.isSuccess()) { - future.getChannel().close(); + LOG.error("Future is unsuccessful. Cause: ", future.cause()); + closeChannel(future.channel()); return; } int waitCount = this.reduceContext.getMapsToWait().decrementAndGet(); if (waitCount == 0) { + LOG.trace("Finished with all map outputs"); + //HADOOP-15327: Need to send an instance of LastHttpContent to define HTTP + //message boundaries. See details in jira. + writeLastHttpContentToChannel(future.channel()); metrics.operationComplete(future); // Let the idle timer handler close keep-alive connections if (reduceContext.getKeepAlive()) { - ChannelPipeline pipeline = future.getChannel().getPipeline(); + ChannelPipeline pipeline = future.channel().pipeline(); TimeoutHandler timeoutHandler = (TimeoutHandler)pipeline.get(TIMEOUT_HANDLER); timeoutHandler.setEnabledTimeout(true); } else { - future.getChannel().close(); + closeChannel(future.channel()); } } else { + LOG.trace("operationComplete, waitCount > 0, invoking sendMap with reduceContext"); pipelineFact.getSHUFFLE().sendMap(reduceContext); } } @@ -331,7 +392,6 @@ public void operationComplete(ChannelFuture future) throws Exception { * Allows sendMapOutput calls from operationComplete() */ private static class ReduceContext { - private List mapIds; private AtomicInteger mapsToWait; private AtomicInteger mapsToSend; @@ -342,7 +402,7 @@ private static class ReduceContext { private String jobId; private final boolean keepAlive; - public ReduceContext(List mapIds, int rId, + ReduceContext(List mapIds, int rId, ChannelHandlerContext context, String usr, Map mapOutputInfoMap, String jobId, boolean keepAlive) { @@ -448,7 +508,8 @@ public static int deserializeMetaData(ByteBuffer meta) throws IOException { * shuffle data requests. * @return the serialized version of the jobToken. */ - public static ByteBuffer serializeServiceData(Token jobToken) throws IOException { + public static ByteBuffer serializeServiceData(Token jobToken) + throws IOException { //TODO these bytes should be versioned DataOutputBuffer jobToken_dob = new DataOutputBuffer(); jobToken.write(jobToken_dob); @@ -505,6 +566,11 @@ protected void serviceInit(Configuration conf) throws Exception { DEFAULT_MAX_SHUFFLE_CONNECTIONS); int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS, DEFAULT_MAX_SHUFFLE_THREADS); + // Since Netty 4.x, the value of 0 threads would default to: + // io.netty.channel.MultithreadEventLoopGroup.DEFAULT_EVENT_LOOP_THREADS + // by simply passing 0 value to NioEventLoopGroup constructor below. + // However, this logic to determinte thread count + // was in place so we can keep it for now. if (maxShuffleThreads == 0) { maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors(); } @@ -520,16 +586,14 @@ protected void serviceInit(Configuration conf) throws Exception { DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES); ThreadFactory bossFactory = new ThreadFactoryBuilder() - .setNameFormat("ShuffleHandler Netty Boss #%d") - .build(); + .setNameFormat("ShuffleHandler Netty Boss #%d") + .build(); ThreadFactory workerFactory = new ThreadFactoryBuilder() - .setNameFormat("ShuffleHandler Netty Worker #%d") - .build(); + .setNameFormat("ShuffleHandler Netty Worker #%d") + .build(); - selector = new NioServerSocketChannelFactory( - HadoopExecutors.newCachedThreadPool(bossFactory), - HadoopExecutors.newCachedThreadPool(workerFactory), - maxShuffleThreads); + bossGroup = new NioEventLoopGroup(maxShuffleThreads, bossFactory); + workerGroup = new NioEventLoopGroup(maxShuffleThreads, workerFactory); super.serviceInit(new Configuration(conf)); } @@ -540,22 +604,24 @@ protected void serviceStart() throws Exception { userRsrc = new ConcurrentHashMap(); secretManager = new JobTokenSecretManager(); recoverState(conf); - ServerBootstrap bootstrap = new ServerBootstrap(selector); - // Timer is shared across entire factory and must be released separately - timer = new HashedWheelTimer(); try { - pipelineFact = new HttpPipelineFactory(conf, timer); + pipelineFact = new HttpPipelineFactory(conf); } catch (Exception ex) { throw new RuntimeException(ex); } - bootstrap.setOption("backlog", conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE, - DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE)); - bootstrap.setOption("child.keepAlive", true); - bootstrap.setPipelineFactory(pipelineFact); + + bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, + conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE, + DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE)) + .childOption(ChannelOption.SO_KEEPALIVE, true) + .childHandler(pipelineFact); port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); - Channel ch = bootstrap.bind(new InetSocketAddress(port)); + ch = bootstrap.bind(new InetSocketAddress(port)).sync().channel(); accepted.add(ch); - port = ((InetSocketAddress)ch.getLocalAddress()).getPort(); + port = ((InetSocketAddress)ch.localAddress()).getPort(); conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port)); pipelineFact.SHUFFLE.setPort(port); LOG.info(getName() + " listening on port " + port); @@ -576,18 +642,12 @@ protected void serviceStart() throws Exception { @Override protected void serviceStop() throws Exception { - accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS); - if (selector != null) { - ServerBootstrap bootstrap = new ServerBootstrap(selector); - bootstrap.releaseExternalResources(); - } + closeChannels(accepted); + if (pipelineFact != null) { pipelineFact.destroy(); } - if (timer != null) { - // Release this shared timer resource - timer.stop(); - } + if (stateDb != null) { stateDb.close(); } @@ -744,7 +804,7 @@ private void recoverJobShuffleInfo(String jobIdStr, byte[] data) JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data); String user = proto.getUser(); TokenProto tokenProto = proto.getJobToken(); - Token jobToken = new Token( + Token jobToken = new Token<>( tokenProto.getIdentifier().toByteArray(), tokenProto.getPassword().toByteArray(), new Text(tokenProto.getKind()), new Text(tokenProto.getService())); @@ -785,29 +845,47 @@ private void removeJobShuffleInfo(JobID jobId) throws IOException { } } - static class TimeoutHandler extends IdleStateAwareChannelHandler { + @VisibleForTesting + public void setUseOutboundExceptionHandler(boolean useHandler) { + this.useOutboundExceptionHandler = useHandler; + } + static class TimeoutHandler extends IdleStateHandler { + private final int connectionKeepAliveTimeOut; private boolean enabledTimeout; + TimeoutHandler(int connectionKeepAliveTimeOut) { + //disable reader timeout + //set writer timeout to configured timeout value + //disable all idle timeout + super(0, connectionKeepAliveTimeOut, 0, TimeUnit.SECONDS); + this.connectionKeepAliveTimeOut = connectionKeepAliveTimeOut; + } + + @VisibleForTesting + public int getConnectionKeepAliveTimeOut() { + return connectionKeepAliveTimeOut; + } + void setEnabledTimeout(boolean enabledTimeout) { this.enabledTimeout = enabledTimeout; } @Override public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) { - if (e.getState() == IdleState.WRITER_IDLE && enabledTimeout) { - e.getChannel().close(); + if (e.state() == IdleState.WRITER_IDLE && enabledTimeout) { + closeAsIdle(ctx.channel(), connectionKeepAliveTimeOut); } } } - class HttpPipelineFactory implements ChannelPipelineFactory { + class HttpPipelineFactory extends ChannelInitializer { + private static final int MAX_CONTENT_LENGTH = 1 << 16; final Shuffle SHUFFLE; private SSLFactory sslFactory; - private final ChannelHandler idleStateHandler; - public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception { + HttpPipelineFactory(Configuration conf) throws Exception { SHUFFLE = getShuffle(conf); if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY, MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) { @@ -815,7 +893,6 @@ public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception { sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf); sslFactory.init(); } - this.idleStateHandler = new IdleStateHandler(timer, 0, connectionKeepAliveTimeOut, 0); } public Shuffle getSHUFFLE() { @@ -828,30 +905,39 @@ public void destroy() { } } - @Override - public ChannelPipeline getPipeline() throws Exception { - ChannelPipeline pipeline = Channels.pipeline(); + @Override protected void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); if (sslFactory != null) { pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine())); } pipeline.addLast("decoder", new HttpRequestDecoder()); - pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16)); - pipeline.addLast("encoder", new HttpResponseEncoder()); + pipeline.addLast("aggregator", new HttpObjectAggregator(MAX_CONTENT_LENGTH)); + pipeline.addLast(ENCODER_HANDLER_NAME, useOutboundLogger ? + new LoggingHttpResponseEncoder(false) : new HttpResponseEncoder()); pipeline.addLast("chunking", new ChunkedWriteHandler()); pipeline.addLast("shuffle", SHUFFLE); - pipeline.addLast("idle", idleStateHandler); - pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler()); - return pipeline; + if (useOutboundExceptionHandler) { + //https://stackoverflow.com/questions/50612403/catch-all-exception-handling-for-outbound-channelhandler + pipeline.addLast("outboundExceptionHandler", new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, + ChannelPromise promise) throws Exception { + promise.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + super.write(ctx, msg, promise); + } + }); + } + pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler(connectionKeepAliveTimeOut)); // TODO factor security manager into pipeline // TODO factor out encode/decode to permit binary shuffle // TODO factor out decode of index to permit alt. models } } - class Shuffle extends SimpleChannelUpstreamHandler { + @ChannelHandler.Sharable + class Shuffle extends ChannelInboundHandlerAdapter { private final IndexCache indexCache; - private final - LoadingCache pathCache; + private final LoadingCache pathCache; private int port; @@ -904,65 +990,84 @@ private List splitMaps(List mapq) { } @Override - public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt) + public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelOpen(ctx, evt); - - if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) { + NettyChannelHelper.channelActive(ctx.channel()); + int numConnections = activeConnections.incrementAndGet(); + if ((maxShuffleConnections > 0) && (numConnections > maxShuffleConnections)) { LOG.info(String.format("Current number of shuffle connections (%d) is " + - "greater than or equal to the max allowed shuffle connections (%d)", + "greater than the max allowed shuffle connections (%d)", accepted.size(), maxShuffleConnections)); - Map headers = new HashMap(1); + Map headers = new HashMap<>(1); // notify fetchers to backoff for a while before closing the connection // if the shuffle connection limit is hit. Fetchers are expected to // handle this notification gracefully, that is, not treating this as a // fetch failure. headers.put(RETRY_AFTER_HEADER, String.valueOf(FETCH_RETRY_DELAY)); sendError(ctx, "", TOO_MANY_REQ_STATUS, headers); - return; + } else { + super.channelActive(ctx); + accepted.add(ctx.channel()); + LOG.debug("Added channel: {}, channel id: {}. Accepted number of connections={}", + ctx.channel(), ctx.channel().id(), activeConnections.get()); } - accepted.add(evt.getChannel()); } @Override - public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + NettyChannelHelper.channelInactive(ctx.channel()); + super.channelInactive(ctx); + int noOfConnections = activeConnections.decrementAndGet(); + LOG.debug("New value of Accepted number of connections={}", noOfConnections); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - HttpRequest request = (HttpRequest) evt.getMessage(); - if (request.getMethod() != GET) { - sendError(ctx, METHOD_NOT_ALLOWED); - return; + Channel channel = ctx.channel(); + LOG.trace("Executing channelRead, channel id: {}", channel.id()); + HttpRequest request = (HttpRequest) msg; + LOG.debug("Received HTTP request: {}, channel id: {}", request, channel.id()); + if (request.method() != GET) { + sendError(ctx, METHOD_NOT_ALLOWED); + return; } // Check whether the shuffle version is compatible - if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals( - request.headers() != null ? - request.headers().get(ShuffleHeader.HTTP_HEADER_NAME) : null) - || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals( - request.headers() != null ? - request.headers() - .get(ShuffleHeader.HTTP_HEADER_VERSION) : null)) { + String shuffleVersion = ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION; + String httpHeaderName = ShuffleHeader.DEFAULT_HTTP_HEADER_NAME; + if (request.headers() != null) { + shuffleVersion = request.headers().get(ShuffleHeader.HTTP_HEADER_VERSION); + httpHeaderName = request.headers().get(ShuffleHeader.HTTP_HEADER_NAME); + LOG.debug("Received from request header: ShuffleVersion={} header name={}, channel id: {}", + shuffleVersion, httpHeaderName, channel.id()); + } + if (request.headers() == null || + !ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(httpHeaderName) || + !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(shuffleVersion)) { sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST); } - final Map> q = - new QueryStringDecoder(request.getUri()).getParameters(); + final Map> q = + new QueryStringDecoder(request.uri()).parameters(); final List keepAliveList = q.get("keepAlive"); boolean keepAliveParam = false; if (keepAliveList != null && keepAliveList.size() == 1) { keepAliveParam = Boolean.valueOf(keepAliveList.get(0)); if (LOG.isDebugEnabled()) { - LOG.debug("KeepAliveParam : " + keepAliveList - + " : " + keepAliveParam); + LOG.debug("KeepAliveParam: {} : {}, channel id: {}", + keepAliveList, keepAliveParam, channel.id()); } } final List mapIds = splitMaps(q.get("map")); final List reduceQ = q.get("reduce"); final List jobQ = q.get("job"); if (LOG.isDebugEnabled()) { - LOG.debug("RECV: " + request.getUri() + + LOG.debug("RECV: " + request.uri() + "\n mapId: " + mapIds + "\n reduceId: " + reduceQ + "\n jobId: " + jobQ + - "\n keepAlive: " + keepAliveParam); + "\n keepAlive: " + keepAliveParam + + "\n channel id: " + channel.id()); } if (mapIds == null || reduceQ == null || jobQ == null) { @@ -986,7 +1091,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) sendError(ctx, "Bad job parameter", BAD_REQUEST); return; } - final String reqUri = request.getUri(); + final String reqUri = request.uri(); if (null == reqUri) { // TODO? add upstream? sendError(ctx, FORBIDDEN); @@ -1004,8 +1109,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) Map mapOutputInfoMap = new HashMap(); - Channel ch = evt.getChannel(); - ChannelPipeline pipeline = ch.getPipeline(); + ChannelPipeline pipeline = channel.pipeline(); TimeoutHandler timeoutHandler = (TimeoutHandler)pipeline.get(TIMEOUT_HANDLER); timeoutHandler.setEnabledTimeout(false); @@ -1013,16 +1117,29 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) try { populateHeaders(mapIds, jobId, user, reduceId, request, - response, keepAliveParam, mapOutputInfoMap); + response, keepAliveParam, mapOutputInfoMap); } catch(IOException e) { - ch.write(response); - LOG.error("Shuffle error in populating headers :", e); - String errorMessage = getErrorMessage(e); - sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR); + //HADOOP-15327 + // Need to send an instance of LastHttpContent to define HTTP + // message boundaries. + //Sending a HTTP 200 OK + HTTP 500 later (sendError) + // is quite a non-standard way of crafting HTTP responses, + // but we need to keep backward compatibility. + // See more details in jira. + writeToChannelAndAddLastHttpContent(channel, response); + LOG.error("Shuffle error while populating headers. Channel id: " + channel.id(), e); + sendError(ctx, getErrorMessage(e), INTERNAL_SERVER_ERROR); return; } - ch.write(response); - //Initialize one ReduceContext object per messageReceived call + writeToChannel(channel, response).addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + LOG.debug("Written HTTP response object successfully. Channel id: {}", channel.id()); + } else { + LOG.error("Error while writing HTTP response object: {}. " + + "Cause: {}, channel id: {}", response, future.cause(), channel.id()); + } + }); + //Initialize one ReduceContext object per channelRead call boolean keepAlive = keepAliveParam || connectionKeepAliveEnabled; ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx, user, mapOutputInfoMap, jobId, keepAlive); @@ -1044,9 +1161,8 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) * @param reduceContext used to call sendMapOutput with correct params. * @return the ChannelFuture of the sendMapOutput, can be null. */ - public ChannelFuture sendMap(ReduceContext reduceContext) - throws Exception { - + public ChannelFuture sendMap(ReduceContext reduceContext) { + LOG.trace("Executing sendMap"); ChannelFuture nextMap = null; if (reduceContext.getMapsToSend().get() < reduceContext.getMapIds().size()) { @@ -1059,21 +1175,24 @@ public ChannelFuture sendMap(ReduceContext reduceContext) info = getMapOutputInfo(mapId, reduceContext.getReduceId(), reduceContext.getJobId(), reduceContext.getUser()); } + LOG.trace("Calling sendMapOutput"); nextMap = sendMapOutput( reduceContext.getCtx(), - reduceContext.getCtx().getChannel(), + reduceContext.getCtx().channel(), reduceContext.getUser(), mapId, reduceContext.getReduceId(), info); - if (null == nextMap) { + if (nextMap == null) { + //This can only happen if spill file was not found sendError(reduceContext.getCtx(), NOT_FOUND); + LOG.trace("Returning nextMap: null"); return null; } nextMap.addListener(new ReduceMapFileCount(reduceContext)); } catch (IOException e) { if (e instanceof DiskChecker.DiskErrorException) { - LOG.error("Shuffle error :" + e); + LOG.error("Shuffle error: " + e); } else { - LOG.error("Shuffle error :", e); + LOG.error("Shuffle error: ", e); } String errorMessage = getErrorMessage(e); sendError(reduceContext.getCtx(), errorMessage, @@ -1125,8 +1244,7 @@ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, } } - IndexRecord info = - indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user); + IndexRecord info = indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user); if (LOG.isDebugEnabled()) { LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId + @@ -1155,7 +1273,6 @@ protected void populateHeaders(List mapIds, String jobId, outputInfo.indexRecord.rawLength, reduce); DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - contentLength += outputInfo.indexRecord.partLength; contentLength += dob.getLength(); } @@ -1183,14 +1300,10 @@ protected void populateHeaders(List mapIds, String jobId, protected void setResponseHeaders(HttpResponse response, boolean keepAliveParam, long contentLength) { if (!connectionKeepAliveEnabled && !keepAliveParam) { - if (LOG.isDebugEnabled()) { - LOG.debug("Setting connection close header..."); - } - response.headers().set(HttpHeader.CONNECTION.asString(), - CONNECTION_CLOSE); + response.headers().set(HttpHeader.CONNECTION.asString(), CONNECTION_CLOSE); } else { response.headers().set(HttpHeader.CONTENT_LENGTH.asString(), - String.valueOf(contentLength)); + String.valueOf(contentLength)); response.headers().set(HttpHeader.CONNECTION.asString(), HttpHeader.KEEP_ALIVE.asString()); response.headers().set(HttpHeader.KEEP_ALIVE.asString(), @@ -1214,29 +1327,29 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx, throws IOException { SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid); if (null == tokenSecret) { - LOG.info("Request for unknown token " + appid); - throw new IOException("could not find jobid"); + LOG.info("Request for unknown token {}, channel id: {}", appid, ctx.channel().id()); + throw new IOException("Could not find jobid"); } - // string to encrypt - String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri); + // encrypting URL + String encryptedURL = SecureShuffleUtils.buildMsgFrom(requestUri); // hash from the fetcher String urlHashStr = request.headers().get(SecureShuffleUtils.HTTP_HEADER_URL_HASH); if (urlHashStr == null) { - LOG.info("Missing header hash for " + appid); + LOG.info("Missing header hash for {}, channel id: {}", appid, ctx.channel().id()); throw new IOException("fetcher cannot be authenticated"); } if (LOG.isDebugEnabled()) { int len = urlHashStr.length(); - LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." + - urlHashStr.substring(len-len/2, len-1)); + LOG.debug("Verifying request. encryptedURL:{}, hash:{}, channel id: " + + "{}", encryptedURL, + urlHashStr.substring(len - len / 2, len - 1), ctx.channel().id()); } // verify - throws exception - SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret); + SecureShuffleUtils.verifyReply(urlHashStr, encryptedURL, tokenSecret); // verification passed - encode the reply - String reply = - SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8), - tokenSecret); + String reply = SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8), + tokenSecret); response.headers().set( SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply); // Put shuffle version into http header @@ -1246,8 +1359,10 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); if (LOG.isDebugEnabled()) { int len = reply.length(); - LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" + - reply.substring(len-len/2, len-1)); + LOG.debug("Fetcher request verified. " + + "encryptedURL: {}, reply: {}, channel id: {}", + encryptedURL, reply.substring(len - len / 2, len - 1), + ctx.channel().id()); } } @@ -1255,27 +1370,27 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, String user, String mapId, int reduce, MapOutputInfo mapOutputInfo) throws IOException { final IndexRecord info = mapOutputInfo.indexRecord; - final ShuffleHeader header = - new ShuffleHeader(mapId, info.partLength, info.rawLength, reduce); + final ShuffleHeader header = new ShuffleHeader(mapId, info.partLength, info.rawLength, + reduce); final DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + writeToChannel(ch, wrappedBuffer(dob.getData(), 0, dob.getLength())); final File spillfile = new File(mapOutputInfo.mapOutputFileName.toString()); RandomAccessFile spill; try { spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null); } catch (FileNotFoundException e) { - LOG.info(spillfile + " not found"); + LOG.info("{} not found. Channel id: {}", spillfile, ctx.channel().id()); return null; } ChannelFuture writeFuture; - if (ch.getPipeline().get(SslHandler.class) == null) { + if (ch.pipeline().get(SslHandler.class) == null) { final FadvisedFileRegion partition = new FadvisedFileRegion(spill, info.startOffset, info.partLength, manageOsCache, readaheadLength, readaheadPool, spillfile.getAbsolutePath(), shuffleBufferSize, shuffleTransferToAllowed); - writeFuture = ch.write(partition); + writeFuture = writeToChannel(ch, partition); writeFuture.addListener(new ChannelFutureListener() { // TODO error handling; distinguish IO/connection failures, // attribute to appropriate spill output @@ -1284,7 +1399,7 @@ public void operationComplete(ChannelFuture future) { if (future.isSuccess()) { partition.transferSuccessful(); } - partition.releaseExternalResources(); + partition.deallocate(); } }); } else { @@ -1293,7 +1408,7 @@ public void operationComplete(ChannelFuture future) { info.startOffset, info.partLength, sslFileBufferSize, manageOsCache, readaheadLength, readaheadPool, spillfile.getAbsolutePath()); - writeFuture = ch.write(chunk); + writeFuture = writeToChannel(ch, chunk); } metrics.shuffleConnections.incr(); metrics.shuffleOutputBytes.incr(info.partLength); // optimistic @@ -1307,12 +1422,13 @@ protected void sendError(ChannelHandlerContext ctx, protected void sendError(ChannelHandlerContext ctx, String message, HttpResponseStatus status) { - sendError(ctx, message, status, Collections.emptyMap()); + sendError(ctx, message, status, Collections.emptyMap()); } protected void sendError(ChannelHandlerContext ctx, String msg, HttpResponseStatus status, Map headers) { - HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); + FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, status, + Unpooled.copiedBuffer(msg, CharsetUtil.UTF_8)); response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8"); // Put shuffle version into http header response.headers().set(ShuffleHeader.HTTP_HEADER_NAME, @@ -1322,48 +1438,45 @@ protected void sendError(ChannelHandlerContext ctx, String msg, for (Map.Entry header : headers.entrySet()) { response.headers().set(header.getKey(), header.getValue()); } - response.setContent( - ChannelBuffers.copiedBuffer(msg, CharsetUtil.UTF_8)); // Close the connection as soon as the error message is sent. - ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); + writeToChannelAndClose(ctx.channel(), response); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - Channel ch = e.getChannel(); - Throwable cause = e.getCause(); + Channel ch = ctx.channel(); if (cause instanceof TooLongFrameException) { + LOG.trace("TooLongFrameException, channel id: {}", ch.id()); sendError(ctx, BAD_REQUEST); return; } else if (cause instanceof IOException) { if (cause instanceof ClosedChannelException) { - LOG.debug("Ignoring closed channel error", cause); + LOG.debug("Ignoring closed channel error, channel id: " + ch.id(), cause); return; } String message = String.valueOf(cause.getMessage()); if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { - LOG.debug("Ignoring client socket close", cause); + LOG.debug("Ignoring client socket close, channel id: " + ch.id(), cause); return; } } - LOG.error("Shuffle error: ", cause); - if (ch.isConnected()) { - LOG.error("Shuffle error " + e); + LOG.error("Shuffle error. Channel id: " + ch.id(), cause); + if (ch.isActive()) { sendError(ctx, INTERNAL_SERVER_ERROR); } } } - + static class AttemptPathInfo { // TODO Change this over to just store local dir indices, instead of the // entire path. Far more efficient. private final Path indexPath; private final Path dataPath; - public AttemptPathInfo(Path indexPath, Path dataPath) { + AttemptPathInfo(Path indexPath, Path dataPath) { this.indexPath = indexPath; this.dataPath = dataPath; } @@ -1374,7 +1487,7 @@ static class AttemptPathIdentifier { private final String user; private final String attemptId; - public AttemptPathIdentifier(String jobId, String user, String attemptId) { + AttemptPathIdentifier(String jobId, String user, String attemptId) { this.jobId = jobId; this.user = user; this.attemptId = attemptId; diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestFadvisedFileRegion.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestFadvisedFileRegion.java index 242382e06a..ce0c0d6aea 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestFadvisedFileRegion.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestFadvisedFileRegion.java @@ -104,7 +104,7 @@ public void testCustomShuffleTransfer() throws IOException { Assert.assertEquals(count, targetFile.length()); } finally { if (fileRegion != null) { - fileRegion.releaseExternalResources(); + fileRegion.deallocate(); } IOUtils.cleanupWithLogger(LOG, target); IOUtils.cleanupWithLogger(LOG, targetFile); diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java index af3cb87760..38500032ef 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java @@ -17,34 +17,65 @@ */ package org.apache.hadoop.mapred; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.DefaultFileRegion; +import org.apache.hadoop.thirdparty.com.google.common.collect.Maps; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.timeout.IdleStateEvent; import org.apache.hadoop.test.GenericTestUtils; + +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static java.util.stream.Collectors.toList; import static org.apache.hadoop.test.MetricsAsserts.assertCounter; import static org.apache.hadoop.test.MetricsAsserts.assertGauge; import static org.apache.hadoop.test.MetricsAsserts.getMetrics; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; -import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; -import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; -import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.EOFException; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; import java.net.HttpURLConnection; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.Socket; import java.net.URL; import java.net.SocketAddress; +import java.net.URLConnection; import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import java.util.zip.CheckedOutputStream; import java.util.zip.Checksum; @@ -71,6 +102,7 @@ import org.apache.hadoop.service.ServiceStateException; import org.apache.hadoop.util.DiskChecker; import org.apache.hadoop.util.PureJavaCrc32; +import org.apache.hadoop.util.Sets; import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.conf.YarnConfiguration; @@ -79,22 +111,13 @@ import org.apache.hadoop.yarn.server.api.AuxiliaryLocalPathHandler; import org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer.ContainerLocalizer; import org.apache.hadoop.yarn.server.records.Version; -import org.jboss.netty.channel.Channel; -import org.jboss.netty.channel.ChannelFuture; -import org.jboss.netty.channel.ChannelHandlerContext; -import org.jboss.netty.channel.ChannelPipeline; -import org.jboss.netty.channel.socket.SocketChannel; -import org.jboss.netty.channel.MessageEvent; -import org.jboss.netty.channel.AbstractChannel; -import org.jboss.netty.handler.codec.http.DefaultHttpResponse; -import org.jboss.netty.handler.codec.http.HttpRequest; -import org.jboss.netty.handler.codec.http.HttpResponse; -import org.jboss.netty.handler.codec.http.HttpResponseStatus; -import org.jboss.netty.handler.codec.http.HttpMethod; +import org.hamcrest.CoreMatchers; +import org.junit.After; import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; +import org.junit.rules.TestName; import org.mockito.Mockito; import org.eclipse.jetty.http.HttpHeader; import org.slf4j.Logger; @@ -106,10 +129,583 @@ public class TestShuffleHandler { LoggerFactory.getLogger(TestShuffleHandler.class); private static final File ABS_LOG_DIR = GenericTestUtils.getTestDir( TestShuffleHandler.class.getSimpleName() + "LocDir"); + private static final long ATTEMPT_ID = 12345L; + private static final long ATTEMPT_ID_2 = 12346L; + private static final HttpResponseStatus OK_STATUS = new HttpResponseStatus(200, "OK"); + + + //Control test execution properties with these flags + private static final boolean DEBUG_MODE = false; + //WARNING: If this is set to true and proxy server is not running, tests will fail! + private static final boolean USE_PROXY = false; + private static final int HEADER_WRITE_COUNT = 100000; + private static final int ARBITRARY_NEGATIVE_TIMEOUT_SECONDS = -100; + private static TestExecution TEST_EXECUTION; + + private static class TestExecution { + private static final int DEFAULT_KEEP_ALIVE_TIMEOUT_SECONDS = 1; + private static final int DEBUG_KEEP_ALIVE_SECONDS = 1000; + private static final int DEFAULT_PORT = 0; //random port + private static final int FIXED_PORT = 8088; + private static final String PROXY_HOST = "127.0.0.1"; + private static final int PROXY_PORT = 8888; + private static final int CONNECTION_DEBUG_TIMEOUT = 1000000; + private final boolean debugMode; + private final boolean useProxy; + + TestExecution(boolean debugMode, boolean useProxy) { + this.debugMode = debugMode; + this.useProxy = useProxy; + } + + int getKeepAliveTimeout() { + if (debugMode) { + return DEBUG_KEEP_ALIVE_SECONDS; + } + return DEFAULT_KEEP_ALIVE_TIMEOUT_SECONDS; + } + + HttpURLConnection openConnection(URL url) throws IOException { + HttpURLConnection conn; + if (useProxy) { + Proxy proxy + = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(PROXY_HOST, PROXY_PORT)); + conn = (HttpURLConnection) url.openConnection(proxy); + } else { + conn = (HttpURLConnection) url.openConnection(); + } + return conn; + } + + int shuffleHandlerPort() { + if (debugMode) { + return FIXED_PORT; + } else { + return DEFAULT_PORT; + } + } + + void parameterizeConnection(URLConnection conn) { + if (DEBUG_MODE) { + conn.setReadTimeout(CONNECTION_DEBUG_TIMEOUT); + conn.setConnectTimeout(CONNECTION_DEBUG_TIMEOUT); + } + } + } + + private static class ResponseConfig { + private final int headerWriteCount; + private final int mapOutputCount; + private final int contentLengthOfOneMapOutput; + private long headerSize; + public long contentLengthOfResponse; + + ResponseConfig(int headerWriteCount, int mapOutputCount, + int contentLengthOfOneMapOutput) { + if (mapOutputCount <= 0 && contentLengthOfOneMapOutput > 0) { + throw new IllegalStateException("mapOutputCount should be at least 1"); + } + this.headerWriteCount = headerWriteCount; + this.mapOutputCount = mapOutputCount; + this.contentLengthOfOneMapOutput = contentLengthOfOneMapOutput; + } + + private void setHeaderSize(long headerSize) { + this.headerSize = headerSize; + long contentLengthOfAllHeaders = headerWriteCount * headerSize; + this.contentLengthOfResponse = computeContentLengthOfResponse(contentLengthOfAllHeaders); + LOG.debug("Content-length of all headers: {}", contentLengthOfAllHeaders); + LOG.debug("Content-length of one MapOutput: {}", contentLengthOfOneMapOutput); + LOG.debug("Content-length of final HTTP response: {}", contentLengthOfResponse); + } + + private long computeContentLengthOfResponse(long contentLengthOfAllHeaders) { + int mapOutputCountMultiplier = mapOutputCount; + if (mapOutputCount == 0) { + mapOutputCountMultiplier = 1; + } + return (contentLengthOfAllHeaders + contentLengthOfOneMapOutput) * mapOutputCountMultiplier; + } + } + + private enum ShuffleUrlType { + SIMPLE, WITH_KEEPALIVE, WITH_KEEPALIVE_MULTIPLE_MAP_IDS, WITH_KEEPALIVE_NO_MAP_IDS + } + + private static class InputStreamReadResult { + final String asString; + int totalBytesRead; + + InputStreamReadResult(byte[] bytes, int totalBytesRead) { + this.asString = new String(bytes, StandardCharsets.UTF_8); + this.totalBytesRead = totalBytesRead; + } + } + + private static abstract class AdditionalMapOutputSenderOperations { + public abstract ChannelFuture perform(ChannelHandlerContext ctx, Channel ch) throws IOException; + } + + private class ShuffleHandlerForKeepAliveTests extends ShuffleHandler { + final LastSocketAddress lastSocketAddress = new LastSocketAddress(); + final ArrayList failures = new ArrayList<>(); + final ShuffleHeaderProvider shuffleHeaderProvider; + final HeaderPopulator headerPopulator; + MapOutputSender mapOutputSender; + private Consumer channelIdleCallback; + private CustomTimeoutHandler customTimeoutHandler; + private boolean failImmediatelyOnErrors = false; + private boolean closeChannelOnError = true; + private ResponseConfig responseConfig; + + ShuffleHandlerForKeepAliveTests(long attemptId, ResponseConfig responseConfig, + Consumer channelIdleCallback) throws IOException { + this(attemptId, responseConfig); + this.channelIdleCallback = channelIdleCallback; + } + + ShuffleHandlerForKeepAliveTests(long attemptId, ResponseConfig responseConfig) + throws IOException { + this.responseConfig = responseConfig; + this.shuffleHeaderProvider = new ShuffleHeaderProvider(attemptId); + this.responseConfig.setHeaderSize(shuffleHeaderProvider.getShuffleHeaderSize()); + this.headerPopulator = new HeaderPopulator(this, responseConfig, shuffleHeaderProvider, true); + this.mapOutputSender = new MapOutputSender(responseConfig, lastSocketAddress, + shuffleHeaderProvider); + setUseOutboundExceptionHandler(true); + } + + public void setFailImmediatelyOnErrors(boolean failImmediatelyOnErrors) { + this.failImmediatelyOnErrors = failImmediatelyOnErrors; + } + + public void setCloseChannelOnError(boolean closeChannelOnError) { + this.closeChannelOnError = closeChannelOnError; + } + + @Override + protected Shuffle getShuffle(final Configuration conf) { + // replace the shuffle handler with one stubbed for testing + return new Shuffle(conf) { + @Override + protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, + String jobId, String user) { + return null; + } + @Override + protected void verifyRequest(String appid, ChannelHandlerContext ctx, + HttpRequest request, HttpResponse response, URL requestUri) { + } + + @Override + protected void populateHeaders(List mapIds, String jobId, + String user, int reduce, HttpRequest request, + HttpResponse response, boolean keepAliveParam, + Map infoMap) throws IOException { + long contentLength = headerPopulator.populateHeaders( + keepAliveParam); + super.setResponseHeaders(response, keepAliveParam, contentLength); + } + + @Override + protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, + Channel ch, String user, String mapId, int reduce, + MapOutputInfo info) throws IOException { + return mapOutputSender.send(ctx, ch); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().replace(HttpResponseEncoder.class, ENCODER_HANDLER_NAME, + new LoggingHttpResponseEncoder(false)); + replaceTimeoutHandlerWithCustom(ctx); + LOG.debug("Modified pipeline: {}", ctx.pipeline()); + super.channelActive(ctx); + } + + private void replaceTimeoutHandlerWithCustom(ChannelHandlerContext ctx) { + TimeoutHandler oldTimeoutHandler = + (TimeoutHandler)ctx.pipeline().get(TIMEOUT_HANDLER); + int timeoutValue = + oldTimeoutHandler.getConnectionKeepAliveTimeOut(); + customTimeoutHandler = new CustomTimeoutHandler(timeoutValue, channelIdleCallback); + ctx.pipeline().replace(TIMEOUT_HANDLER, TIMEOUT_HANDLER, customTimeoutHandler); + } + + @Override + protected void sendError(ChannelHandlerContext ctx, + HttpResponseStatus status) { + String message = "Error while processing request. Status: " + status; + handleError(ctx, message); + if (failImmediatelyOnErrors) { + stop(); + } + } + + @Override + protected void sendError(ChannelHandlerContext ctx, String message, + HttpResponseStatus status) { + String errMessage = String.format("Error while processing request. " + + "Status: " + + "%s, message: %s", status, message); + handleError(ctx, errMessage); + if (failImmediatelyOnErrors) { + stop(); + } + } + }; + } + + private void handleError(ChannelHandlerContext ctx, String message) { + LOG.error(message); + failures.add(new Error(message)); + if (closeChannelOnError) { + LOG.warn("sendError: Closing channel"); + ctx.channel().close(); + } + } + + private class CustomTimeoutHandler extends TimeoutHandler { + private boolean channelIdle = false; + private final Consumer channelIdleCallback; + + CustomTimeoutHandler(int connectionKeepAliveTimeOut, + Consumer channelIdleCallback) { + super(connectionKeepAliveTimeOut); + this.channelIdleCallback = channelIdleCallback; + } + + @Override + public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) { + LOG.debug("Channel idle"); + this.channelIdle = true; + if (channelIdleCallback != null) { + LOG.debug("Calling channel idle callback.."); + channelIdleCallback.accept(e); + } + super.channelIdle(ctx, e); + } + } + } + + private static class MapOutputSender { + private final ResponseConfig responseConfig; + private final LastSocketAddress lastSocketAddress; + private final ShuffleHeaderProvider shuffleHeaderProvider; + private AdditionalMapOutputSenderOperations additionalMapOutputSenderOperations; + + MapOutputSender(ResponseConfig responseConfig, LastSocketAddress lastSocketAddress, + ShuffleHeaderProvider shuffleHeaderProvider) { + this.responseConfig = responseConfig; + this.lastSocketAddress = lastSocketAddress; + this.shuffleHeaderProvider = shuffleHeaderProvider; + } + + public ChannelFuture send(ChannelHandlerContext ctx, Channel ch) throws IOException { + LOG.debug("In MapOutputSender#send"); + lastSocketAddress.setAddress(ch.remoteAddress()); + ShuffleHeader header = shuffleHeaderProvider.createNewShuffleHeader(); + ChannelFuture future = writeHeaderNTimes(ch, header, responseConfig.headerWriteCount); + // This is the last operation + // It's safe to increment ShuffleHeader counter for better identification + shuffleHeaderProvider.incrementCounter(); + if (additionalMapOutputSenderOperations != null) { + return additionalMapOutputSenderOperations.perform(ctx, ch); + } + return future; + } + + private ChannelFuture writeHeaderNTimes(Channel ch, ShuffleHeader header, int iterations) + throws IOException { + DataOutputBuffer dob = new DataOutputBuffer(); + for (int i = 0; i < iterations; ++i) { + header.write(dob); + } + LOG.debug("MapOutputSender#writeHeaderNTimes WriteAndFlush big chunk of data, " + + "outputBufferSize: " + dob.size()); + return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); + } + } + + private static class ShuffleHeaderProvider { + private final long attemptId; + private int attemptCounter = 0; + private int cachedSize = Integer.MIN_VALUE; + + ShuffleHeaderProvider(long attemptId) { + this.attemptId = attemptId; + } + + ShuffleHeader createNewShuffleHeader() { + return new ShuffleHeader(String.format("attempt_%s_1_m_1_0%s", attemptId, attemptCounter), + 5678, 5678, 1); + } + + void incrementCounter() { + attemptCounter++; + } + + private int getShuffleHeaderSize() throws IOException { + if (cachedSize != Integer.MIN_VALUE) { + return cachedSize; + } + DataOutputBuffer dob = new DataOutputBuffer(); + ShuffleHeader header = createNewShuffleHeader(); + header.write(dob); + cachedSize = dob.size(); + return cachedSize; + } + } + + private static class HeaderPopulator { + private final ShuffleHandler shuffleHandler; + private final boolean disableKeepAliveConfig; + private final ShuffleHeaderProvider shuffleHeaderProvider; + private final ResponseConfig responseConfig; + + HeaderPopulator(ShuffleHandler shuffleHandler, + ResponseConfig responseConfig, + ShuffleHeaderProvider shuffleHeaderProvider, + boolean disableKeepAliveConfig) { + this.shuffleHandler = shuffleHandler; + this.responseConfig = responseConfig; + this.disableKeepAliveConfig = disableKeepAliveConfig; + this.shuffleHeaderProvider = shuffleHeaderProvider; + } + + public long populateHeaders(boolean keepAliveParam) throws IOException { + // Send some dummy data (populate content length details) + DataOutputBuffer dob = new DataOutputBuffer(); + for (int i = 0; i < responseConfig.headerWriteCount; ++i) { + ShuffleHeader header = + shuffleHeaderProvider.createNewShuffleHeader(); + header.write(dob); + } + // for testing purpose; + // disable connectionKeepAliveEnabled if keepAliveParam is available + if (keepAliveParam && disableKeepAliveConfig) { + shuffleHandler.connectionKeepAliveEnabled = false; + } + return responseConfig.contentLengthOfResponse; + } + } + + private static final class HttpConnectionData { + private final Map> headers; + private HttpURLConnection conn; + private final int payloadLength; + private final SocketAddress socket; + private int responseCode = -1; + + private HttpConnectionData(HttpURLConnection conn, int payloadLength, + SocketAddress socket) { + this.headers = conn.getHeaderFields(); + this.conn = conn; + this.payloadLength = payloadLength; + this.socket = socket; + try { + this.responseCode = conn.getResponseCode(); + } catch (IOException e) { + fail("Failed to read response code from connection: " + conn); + } + } + + static HttpConnectionData create(HttpURLConnection conn, int payloadLength, + SocketAddress socket) { + return new HttpConnectionData(conn, payloadLength, socket); + } + } + + private static final class HttpConnectionAssert { + private final HttpConnectionData connData; + + private HttpConnectionAssert(HttpConnectionData connData) { + this.connData = connData; + } + + static HttpConnectionAssert create(HttpConnectionData connData) { + return new HttpConnectionAssert(connData); + } + + public static void assertKeepAliveConnectionsAreSame( + HttpConnectionHelper httpConnectionHelper) { + assertTrue("At least two connection data " + + "is required to perform this assertion", + httpConnectionHelper.connectionData.size() >= 2); + SocketAddress firstAddress = httpConnectionHelper.getConnectionData(0).socket; + SocketAddress secondAddress = httpConnectionHelper.getConnectionData(1).socket; + Assert.assertNotNull("Initial shuffle address should not be null", + firstAddress); + Assert.assertNotNull("Keep-Alive shuffle address should not be null", + secondAddress); + assertEquals("Initial shuffle address and keep-alive shuffle " + + "address should be the same", firstAddress, secondAddress); + } + + public HttpConnectionAssert expectKeepAliveWithTimeout(long timeout) { + assertEquals(HttpURLConnection.HTTP_OK, connData.responseCode); + assertHeaderValue(HttpHeader.CONNECTION, HttpHeader.KEEP_ALIVE.asString()); + assertHeaderValue(HttpHeader.KEEP_ALIVE, "timeout=" + timeout); + return this; + } + + public HttpConnectionAssert expectBadRequest(long timeout) { + assertEquals(HttpURLConnection.HTTP_BAD_REQUEST, connData.responseCode); + assertHeaderValue(HttpHeader.CONNECTION, HttpHeader.KEEP_ALIVE.asString()); + assertHeaderValue(HttpHeader.KEEP_ALIVE, "timeout=" + timeout); + return this; + } + + public HttpConnectionAssert expectResponseContentLength(long size) { + assertEquals(size, connData.payloadLength); + return this; + } + + private void assertHeaderValue(HttpHeader header, String expectedValue) { + List headerList = connData.headers.get(header.asString()); + Assert.assertNotNull("Got null header value for header: " + header, headerList); + Assert.assertFalse("Got empty header value for header: " + header, headerList.isEmpty()); + assertEquals("Unexpected size of header list for header: " + header, 1, + headerList.size()); + assertEquals(expectedValue, headerList.get(0)); + } + } + + private static class HttpConnectionHelper { + private final LastSocketAddress lastSocketAddress; + List connectionData = new ArrayList<>(); + + HttpConnectionHelper(LastSocketAddress lastSocketAddress) { + this.lastSocketAddress = lastSocketAddress; + } + + public void connectToUrls(String[] urls, ResponseConfig responseConfig) throws IOException { + connectToUrlsInternal(urls, responseConfig, HttpURLConnection.HTTP_OK); + } + + public void connectToUrls(String[] urls, ResponseConfig responseConfig, int expectedHttpStatus) + throws IOException { + connectToUrlsInternal(urls, responseConfig, expectedHttpStatus); + } + + private void connectToUrlsInternal(String[] urls, ResponseConfig responseConfig, + int expectedHttpStatus) throws IOException { + int requests = urls.length; + int expectedConnections = urls.length; + LOG.debug("Will connect to URLs: {}", Arrays.toString(urls)); + for (int reqIdx = 0; reqIdx < requests; reqIdx++) { + String urlString = urls[reqIdx]; + LOG.debug("Connecting to URL: {}", urlString); + URL url = new URL(urlString); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); + conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + TEST_EXECUTION.parameterizeConnection(conn); + conn.connect(); + if (expectedHttpStatus == HttpURLConnection.HTTP_BAD_REQUEST) { + //Catch exception as error are caught with overridden sendError method + //Caught errors will be validated later. + try { + DataInputStream input = new DataInputStream(conn.getInputStream()); + } catch (Exception e) { + expectedConnections--; + continue; + } + } + DataInputStream input = new DataInputStream(conn.getInputStream()); + LOG.debug("Opened DataInputStream for connection: {}/{}", (reqIdx + 1), requests); + ShuffleHeader header = new ShuffleHeader(); + header.readFields(input); + InputStreamReadResult result = readDataFromInputStream(input); + result.totalBytesRead += responseConfig.headerSize; + int expectedContentLength = + Integer.parseInt(conn.getHeaderField(HttpHeader.CONTENT_LENGTH.asString())); + + if (result.totalBytesRead != expectedContentLength) { + throw new IOException(String.format("Premature EOF InputStream. " + + "Expected content-length: %s, " + + "Actual content-length: %s", expectedContentLength, result.totalBytesRead)); + } + connectionData.add(HttpConnectionData + .create(conn, result.totalBytesRead, lastSocketAddress.getSocketAddres())); + input.close(); + LOG.debug("Finished all interactions with URL: {}. Progress: {}/{}", url, (reqIdx + 1), + requests); + } + assertEquals(expectedConnections, connectionData.size()); + } + + void validate(Consumer connDataValidator) { + for (int i = 0; i < connectionData.size(); i++) { + LOG.debug("Validating connection data #{}", (i + 1)); + HttpConnectionData connData = connectionData.get(i); + connDataValidator.accept(connData); + } + } + + HttpConnectionData getConnectionData(int i) { + return connectionData.get(i); + } + + private static InputStreamReadResult readDataFromInputStream( + InputStream input) throws IOException { + ByteArrayOutputStream dataStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int bytesRead; + int totalBytesRead = 0; + while ((bytesRead = input.read(buffer)) != -1) { + dataStream.write(buffer, 0, bytesRead); + totalBytesRead += bytesRead; + } + LOG.debug("Read total bytes: " + totalBytesRead); + dataStream.flush(); + return new InputStreamReadResult(dataStream.toByteArray(), totalBytesRead); + } + } + + class ShuffleHandlerForTests extends ShuffleHandler { + public final ArrayList failures = new ArrayList<>(); + + ShuffleHandlerForTests() { + setUseOutboundExceptionHandler(true); + } + + ShuffleHandlerForTests(MetricsSystem ms) { + super(ms); + setUseOutboundExceptionHandler(true); + } + + @Override + protected Shuffle getShuffle(final Configuration conf) { + return new Shuffle(conf) { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + LOG.debug("ExceptionCaught"); + failures.add(cause); + super.exceptionCaught(ctx, cause); + } + }; + } + } class MockShuffleHandler extends org.apache.hadoop.mapred.ShuffleHandler { - private AuxiliaryLocalPathHandler pathHandler = + final ArrayList failures = new ArrayList<>(); + + private final AuxiliaryLocalPathHandler pathHandler = new TestAuxiliaryLocalPathHandler(); + + MockShuffleHandler() { + setUseOutboundExceptionHandler(true); + } + + MockShuffleHandler(MetricsSystem ms) { + super(ms); + setUseOutboundExceptionHandler(true); + } + @Override protected Shuffle getShuffle(final Configuration conf) { return new Shuffle(conf) { @@ -120,7 +716,7 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx, } @Override protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, - String jobId, String user) throws IOException { + String jobId, String user) { // Do nothing. return null; } @@ -128,7 +724,7 @@ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, protected void populateHeaders(List mapIds, String jobId, String user, int reduce, HttpRequest request, HttpResponse response, boolean keepAliveParam, - Map infoMap) throws IOException { + Map infoMap) { // Do nothing. } @Override @@ -140,12 +736,20 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1); DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); dob = new DataOutputBuffer(); for (int i = 0; i < 100; ++i) { header.write(dob); } - return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + LOG.debug("ExceptionCaught"); + failures.add(cause); + super.exceptionCaught(ctx, cause); } }; } @@ -159,24 +763,22 @@ public AuxiliaryLocalPathHandler getAuxiliaryLocalPathHandler() { private class TestAuxiliaryLocalPathHandler implements AuxiliaryLocalPathHandler { @Override - public Path getLocalPathForRead(String path) throws IOException { + public Path getLocalPathForRead(String path) { return new Path(ABS_LOG_DIR.getAbsolutePath(), path); } @Override - public Path getLocalPathForWrite(String path) throws IOException { + public Path getLocalPathForWrite(String path) { return new Path(ABS_LOG_DIR.getAbsolutePath()); } @Override - public Path getLocalPathForWrite(String path, long size) - throws IOException { + public Path getLocalPathForWrite(String path, long size) { return new Path(ABS_LOG_DIR.getAbsolutePath()); } @Override - public Iterable getAllLocalPathsForRead(String path) - throws IOException { + public Iterable getAllLocalPathsForRead(String path) { ArrayList paths = new ArrayList<>(); paths.add(new Path(ABS_LOG_DIR.getAbsolutePath())); return paths; @@ -185,16 +787,34 @@ public Iterable getAllLocalPathsForRead(String path) private static class MockShuffleHandler2 extends org.apache.hadoop.mapred.ShuffleHandler { + final ArrayList failures = new ArrayList<>(1); boolean socketKeepAlive = false; + + MockShuffleHandler2() { + setUseOutboundExceptionHandler(true); + } + + MockShuffleHandler2(MetricsSystem ms) { + super(ms); + setUseOutboundExceptionHandler(true); + } + @Override protected Shuffle getShuffle(final Configuration conf) { return new Shuffle(conf) { @Override protected void verifyRequest(String appid, ChannelHandlerContext ctx, - HttpRequest request, HttpResponse response, URL requestUri) - throws IOException { - SocketChannel channel = (SocketChannel)(ctx.getChannel()); - socketKeepAlive = channel.getConfig().isKeepAlive(); + HttpRequest request, HttpResponse response, URL requestUri) { + SocketChannel channel = (SocketChannel)(ctx.channel()); + socketKeepAlive = channel.config().isKeepAlive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + LOG.debug("ExceptionCaught"); + failures.add(cause); + super.exceptionCaught(ctx, cause); } }; } @@ -204,6 +824,38 @@ protected boolean isSocketKeepAlive() { } } + @Rule + public TestName name = new TestName(); + + @Before + public void setup() { + TEST_EXECUTION = new TestExecution(DEBUG_MODE, USE_PROXY); + } + + @After + public void tearDown() { + int port = TEST_EXECUTION.shuffleHandlerPort(); + if (isPortUsed(port)) { + String msg = String.format("Port is being used: %d. " + + "Current testcase name: %s", + port, name.getMethodName()); + throw new IllegalStateException(msg); + } + } + + private static boolean isPortUsed(int port) { + if (port == 0) { + //Don't check if port is 0 + return false; + } + try (Socket ignored = new Socket("localhost", port)) { + return true; + } catch (IOException e) { + LOG.error("Port: {}, port check result: {}", port, e.getMessage()); + return false; + } + } + /** * Test the validation of ShuffleHandler's meta-data's serialization and * de-serialization. @@ -228,21 +880,23 @@ public void testSerializeMeta() throws Exception { @Test (timeout = 10000) public void testShuffleMetrics() throws Exception { MetricsSystem ms = new MetricsSystemImpl(); - ShuffleHandler sh = new ShuffleHandler(ms); + ShuffleHandler sh = new ShuffleHandlerForTests(ms); ChannelFuture cf = mock(ChannelFuture.class); when(cf.isSuccess()).thenReturn(true).thenReturn(false); sh.metrics.shuffleConnections.incr(); - sh.metrics.shuffleOutputBytes.incr(1*MiB); + sh.metrics.shuffleOutputBytes.incr(MiB); sh.metrics.shuffleConnections.incr(); sh.metrics.shuffleOutputBytes.incr(2*MiB); - checkShuffleMetrics(ms, 3*MiB, 0 , 0, 2); + checkShuffleMetrics(ms, 3*MiB, 0, 0, 2); sh.metrics.operationComplete(cf); sh.metrics.operationComplete(cf); checkShuffleMetrics(ms, 3*MiB, 1, 1, 0); + + sh.stop(); } static void checkShuffleMetrics(MetricsSystem ms, long bytes, int failed, @@ -262,57 +916,54 @@ static void checkShuffleMetrics(MetricsSystem ms, long bytes, int failed, */ @Test (timeout = 10000) public void testClientClosesConnection() throws Exception { - final ArrayList failures = new ArrayList(1); Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); - ShuffleHandler shuffleHandler = new ShuffleHandler() { + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + ShuffleHandlerForTests shuffleHandler = new ShuffleHandlerForTests() { + @Override protected Shuffle getShuffle(Configuration conf) { // replace the shuffle handler with one stubbed for testing return new Shuffle(conf) { @Override protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, - String jobId, String user) throws IOException { + String jobId, String user) { return null; } @Override protected void populateHeaders(List mapIds, String jobId, String user, int reduce, HttpRequest request, HttpResponse response, boolean keepAliveParam, - Map infoMap) throws IOException { + Map infoMap) { // Only set response headers and skip everything else // send some dummy value for content-length super.setResponseHeaders(response, keepAliveParam, 100); } @Override protected void verifyRequest(String appid, ChannelHandlerContext ctx, - HttpRequest request, HttpResponse response, URL requestUri) - throws IOException { + HttpRequest request, HttpResponse response, URL requestUri) { } @Override protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, String user, String mapId, int reduce, MapOutputInfo info) throws IOException { - // send a shuffle header and a lot of data down the channel - // to trigger a broken pipe ShuffleHeader header = new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1); DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); dob = new DataOutputBuffer(); for (int i = 0; i < 100000; ++i) { header.write(dob); } - return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); } @Override protected void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) { if (failures.size() == 0) { failures.add(new Error()); - ctx.getChannel().close(); + ctx.channel().close(); } } @Override @@ -320,7 +971,7 @@ protected void sendError(ChannelHandlerContext ctx, String message, HttpResponseStatus status) { if (failures.size() == 0) { failures.add(new Error()); - ctx.getChannel().close(); + ctx.channel().close(); } } }; @@ -332,26 +983,30 @@ protected void sendError(ChannelHandlerContext ctx, String message, // simulate a reducer that closes early by reading a single shuffle header // then closing the connection URL url = new URL("http://127.0.0.1:" - + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) - + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0"); - HttpURLConnection conn = (HttpURLConnection)url.openConnection(); + + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0"); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); conn.connect(); DataInputStream input = new DataInputStream(conn.getInputStream()); - Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); - Assert.assertEquals("close", + assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); + assertEquals("close", conn.getHeaderField(HttpHeader.CONNECTION.asString())); ShuffleHeader header = new ShuffleHeader(); header.readFields(input); input.close(); + assertEquals("sendError called when client closed connection", 0, + shuffleHandler.failures.size()); + assertEquals("Should have no caught exceptions", Collections.emptyList(), + shuffleHandler.failures); + shuffleHandler.stop(); - Assert.assertTrue("sendError called when client closed connection", - failures.size() == 0); } + static class LastSocketAddress { SocketAddress lastAddress; void setAddress(SocketAddress lastAddress) { @@ -363,161 +1018,180 @@ SocketAddress getSocketAddres() { } @Test(timeout = 10000) - public void testKeepAlive() throws Exception { - final ArrayList failures = new ArrayList(1); + public void testKeepAliveInitiallyEnabled() throws Exception { Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); - // try setting to -ve keep alive timeout. - conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, -100); - final LastSocketAddress lastSocketAddress = new LastSocketAddress(); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + TEST_EXECUTION.getKeepAliveTimeout()); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, 0, 0); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig); + testKeepAliveWithHttpOk(conf, shuffleHandler, ShuffleUrlType.SIMPLE, + ShuffleUrlType.WITH_KEEPALIVE); + } - ShuffleHandler shuffleHandler = new ShuffleHandler() { - @Override - protected Shuffle getShuffle(final Configuration conf) { - // replace the shuffle handler with one stubbed for testing - return new Shuffle(conf) { + @Test(timeout = 1000000) + public void testKeepAliveInitiallyEnabledTwoKeepAliveUrls() throws Exception { + Configuration conf = new Configuration(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + TEST_EXECUTION.getKeepAliveTimeout()); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, 0, 0); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig); + testKeepAliveWithHttpOk(conf, shuffleHandler, ShuffleUrlType.WITH_KEEPALIVE, + ShuffleUrlType.WITH_KEEPALIVE); + } + + //TODO snemeth implement keepalive test that used properly mocked ShuffleHandler + @Test(timeout = 10000) + public void testKeepAliveInitiallyDisabled() throws Exception { + Configuration conf = new Configuration(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, false); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + TEST_EXECUTION.getKeepAliveTimeout()); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, 0, 0); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig); + testKeepAliveWithHttpOk(conf, shuffleHandler, ShuffleUrlType.WITH_KEEPALIVE, + ShuffleUrlType.WITH_KEEPALIVE); + } + + @Test(timeout = 10000) + public void testKeepAliveMultipleMapAttemptIds() throws Exception { + final int mapOutputContentLength = 11; + final int mapOutputCount = 2; + + Configuration conf = new Configuration(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + TEST_EXECUTION.getKeepAliveTimeout()); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, + mapOutputCount, mapOutputContentLength); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig); + shuffleHandler.mapOutputSender.additionalMapOutputSenderOperations = + new AdditionalMapOutputSenderOperations() { @Override - protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, - String jobId, String user) throws IOException { - return null; - } - @Override - protected void verifyRequest(String appid, ChannelHandlerContext ctx, - HttpRequest request, HttpResponse response, URL requestUri) - throws IOException { - } - - @Override - protected void populateHeaders(List mapIds, String jobId, - String user, int reduce, HttpRequest request, - HttpResponse response, boolean keepAliveParam, - Map infoMap) throws IOException { - // Send some dummy data (populate content length details) - ShuffleHeader header = - new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1); - DataOutputBuffer dob = new DataOutputBuffer(); - header.write(dob); - dob = new DataOutputBuffer(); - for (int i = 0; i < 100000; ++i) { - header.write(dob); - } - - long contentLength = dob.getLength(); - // for testing purpose; - // disable connectinKeepAliveEnabled if keepAliveParam is available - if (keepAliveParam) { - connectionKeepAliveEnabled = false; - } - - super.setResponseHeaders(response, keepAliveParam, contentLength); - } - - @Override - protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, - Channel ch, String user, String mapId, int reduce, - MapOutputInfo info) throws IOException { - lastSocketAddress.setAddress(ch.getRemoteAddress()); - HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK); - - // send a shuffle header and a lot of data down the channel - // to trigger a broken pipe - ShuffleHeader header = - new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1); - DataOutputBuffer dob = new DataOutputBuffer(); - header.write(dob); - ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); - dob = new DataOutputBuffer(); - for (int i = 0; i < 100000; ++i) { - header.write(dob); - } - return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); - } - - @Override - protected void sendError(ChannelHandlerContext ctx, - HttpResponseStatus status) { - if (failures.size() == 0) { - failures.add(new Error()); - ctx.getChannel().close(); - } - } - - @Override - protected void sendError(ChannelHandlerContext ctx, String message, - HttpResponseStatus status) { - if (failures.size() == 0) { - failures.add(new Error()); - ctx.getChannel().close(); - } + public ChannelFuture perform(ChannelHandlerContext ctx, Channel ch) throws IOException { + File tmpFile = File.createTempFile("test", ".tmp"); + Files.write(tmpFile.toPath(), + "dummytestcontent123456".getBytes(StandardCharsets.UTF_8)); + final DefaultFileRegion partition = new DefaultFileRegion(tmpFile, 0, + mapOutputContentLength); + LOG.debug("Writing response partition: {}, channel: {}", + partition, ch.id()); + return ch.writeAndFlush(partition) + .addListener((ChannelFutureListener) future -> + LOG.debug("Finished Writing response partition: {}, channel: " + + "{}", partition, ch.id())); } }; - } - }; + testKeepAliveWithHttpOk(conf, shuffleHandler, + ShuffleUrlType.WITH_KEEPALIVE_MULTIPLE_MAP_IDS, + ShuffleUrlType.WITH_KEEPALIVE_MULTIPLE_MAP_IDS); + } + + @Test(timeout = 10000) + public void testKeepAliveWithoutMapAttemptIds() throws Exception { + Configuration conf = new Configuration(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + TEST_EXECUTION.getKeepAliveTimeout()); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, 0, 0); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig); + shuffleHandler.setFailImmediatelyOnErrors(true); + //Closing channels caused Netty to open another channel + // so 1 request was handled with 2 separate channels, + // ultimately generating 2 * HTTP 400 errors. + // We'd like to avoid this so disabling closing the channel here. + shuffleHandler.setCloseChannelOnError(false); + testKeepAliveWithHttpBadRequest(conf, shuffleHandler, ShuffleUrlType.WITH_KEEPALIVE_NO_MAP_IDS); + } + + private void testKeepAliveWithHttpOk( + Configuration conf, + ShuffleHandlerForKeepAliveTests shuffleHandler, + ShuffleUrlType... shuffleUrlTypes) throws IOException { + testKeepAliveWithHttpStatus(conf, shuffleHandler, shuffleUrlTypes, HttpURLConnection.HTTP_OK); + } + + private void testKeepAliveWithHttpBadRequest( + Configuration conf, + ShuffleHandlerForKeepAliveTests shuffleHandler, + ShuffleUrlType... shuffleUrlTypes) throws IOException { + testKeepAliveWithHttpStatus(conf, shuffleHandler, shuffleUrlTypes, + HttpURLConnection.HTTP_BAD_REQUEST); + } + + private void testKeepAliveWithHttpStatus(Configuration conf, + ShuffleHandlerForKeepAliveTests shuffleHandler, + ShuffleUrlType[] shuffleUrlTypes, + int expectedHttpStatus) throws IOException { + if (expectedHttpStatus != HttpURLConnection.HTTP_BAD_REQUEST) { + assertTrue("Expected at least two shuffle URL types ", + shuffleUrlTypes.length >= 2); + } shuffleHandler.init(conf); shuffleHandler.start(); - String shuffleBaseURL = "http://127.0.0.1:" - + shuffleHandler.getConfig().get( - ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY); - URL url = - new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&" - + "map=attempt_12345_1_m_1_0"); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, - ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); - conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, - ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); - conn.connect(); - DataInputStream input = new DataInputStream(conn.getInputStream()); - Assert.assertEquals(HttpHeader.KEEP_ALIVE.asString(), - conn.getHeaderField(HttpHeader.CONNECTION.asString())); - Assert.assertEquals("timeout=1", - conn.getHeaderField(HttpHeader.KEEP_ALIVE.asString())); - Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); - ShuffleHeader header = new ShuffleHeader(); - header.readFields(input); - byte[] buffer = new byte[1024]; - while (input.read(buffer) != -1) {} - SocketAddress firstAddress = lastSocketAddress.getSocketAddres(); - input.close(); + String[] urls = new String[shuffleUrlTypes.length]; + for (int i = 0; i < shuffleUrlTypes.length; i++) { + ShuffleUrlType url = shuffleUrlTypes[i]; + if (url == ShuffleUrlType.SIMPLE) { + urls[i] = getShuffleUrl(shuffleHandler, ATTEMPT_ID, ATTEMPT_ID); + } else if (url == ShuffleUrlType.WITH_KEEPALIVE) { + urls[i] = getShuffleUrlWithKeepAlive(shuffleHandler, ATTEMPT_ID, ATTEMPT_ID); + } else if (url == ShuffleUrlType.WITH_KEEPALIVE_MULTIPLE_MAP_IDS) { + urls[i] = getShuffleUrlWithKeepAlive(shuffleHandler, ATTEMPT_ID, ATTEMPT_ID, ATTEMPT_ID_2); + } else if (url == ShuffleUrlType.WITH_KEEPALIVE_NO_MAP_IDS) { + urls[i] = getShuffleUrlWithKeepAlive(shuffleHandler, ATTEMPT_ID); + } + } + HttpConnectionHelper connHelper; + try { + connHelper = new HttpConnectionHelper(shuffleHandler.lastSocketAddress); + connHelper.connectToUrls(urls, shuffleHandler.responseConfig, expectedHttpStatus); + if (expectedHttpStatus == HttpURLConnection.HTTP_BAD_REQUEST) { + assertEquals(1, shuffleHandler.failures.size()); + assertThat(shuffleHandler.failures.get(0).getMessage(), + CoreMatchers.containsString("Status: 400 Bad Request, " + + "message: Required param job, map and reduce")); + } + } finally { + shuffleHandler.stop(); + } - // For keepAlive via URL - url = - new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&" - + "map=attempt_12345_1_m_1_0&keepAlive=true"); - conn = (HttpURLConnection) url.openConnection(); - conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, - ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); - conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, - ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); - conn.connect(); - input = new DataInputStream(conn.getInputStream()); - Assert.assertEquals(HttpHeader.KEEP_ALIVE.asString(), - conn.getHeaderField(HttpHeader.CONNECTION.asString())); - Assert.assertEquals("timeout=1", - conn.getHeaderField(HttpHeader.KEEP_ALIVE.asString())); - Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); - header = new ShuffleHeader(); - header.readFields(input); - input.close(); - SocketAddress secondAddress = lastSocketAddress.getSocketAddres(); - Assert.assertNotNull("Initial shuffle address should not be null", - firstAddress); - Assert.assertNotNull("Keep-Alive shuffle address should not be null", - secondAddress); - Assert.assertEquals("Initial shuffle address and keep-alive shuffle " - + "address should be the same", firstAddress, secondAddress); + //Verify expectations + int configuredTimeout = TEST_EXECUTION.getKeepAliveTimeout(); + int expectedTimeout = configuredTimeout < 0 ? 1 : configuredTimeout; + connHelper.validate(connData -> { + HttpConnectionAssert.create(connData) + .expectKeepAliveWithTimeout(expectedTimeout) + .expectResponseContentLength(shuffleHandler.responseConfig.contentLengthOfResponse); + }); + if (expectedHttpStatus == HttpURLConnection.HTTP_OK) { + HttpConnectionAssert.assertKeepAliveConnectionsAreSame(connHelper); + assertEquals("Unexpected ShuffleHandler failure", Collections.emptyList(), + shuffleHandler.failures); + } } @Test(timeout = 10000) public void testSocketKeepAlive() throws Exception { Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); - // try setting to -ve keep alive timeout. - conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, -100); + // try setting to negative keep alive timeout. + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, + ARBITRARY_NEGATIVE_TIMEOUT_SECONDS); HttpURLConnection conn = null; MockShuffleHandler2 shuffleHandler = new MockShuffleHandler2(); AuxiliaryLocalPathHandler pathHandler = @@ -535,14 +1209,16 @@ public void testSocketKeepAlive() throws Exception { URL url = new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&" + "map=attempt_12345_1_m_1_0"); - conn = (HttpURLConnection) url.openConnection(); + conn = TEST_EXECUTION.openConnection(url); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); conn.connect(); + int rc = conn.getResponseCode(); conn.getInputStream(); - Assert.assertTrue("socket should be set KEEP_ALIVE", + assertEquals(HttpURLConnection.HTTP_OK, rc); + assertTrue("socket should be set KEEP_ALIVE", shuffleHandler.isSocketKeepAlive()); } finally { if (conn != null) { @@ -550,11 +1226,13 @@ public void testSocketKeepAlive() throws Exception { } shuffleHandler.stop(); } + assertEquals("Should have no caught exceptions", + Collections.emptyList(), shuffleHandler.failures); } /** - * simulate a reducer that sends an invalid shuffle-header - sometimes a wrong - * header_name and sometimes a wrong version + * Simulate a reducer that sends an invalid shuffle-header - sometimes a wrong + * header_name and sometimes a wrong version. * * @throws Exception exception */ @@ -562,24 +1240,24 @@ public void testSocketKeepAlive() throws Exception { public void testIncompatibleShuffleVersion() throws Exception { final int failureNum = 3; Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); - ShuffleHandler shuffleHandler = new ShuffleHandler(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + ShuffleHandler shuffleHandler = new ShuffleHandlerForTests(); shuffleHandler.init(conf); shuffleHandler.start(); // simulate a reducer that closes early by reading a single shuffle header // then closing the connection URL url = new URL("http://127.0.0.1:" - + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) - + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0"); + + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0"); for (int i = 0; i < failureNum; ++i) { - HttpURLConnection conn = (HttpURLConnection)url.openConnection(); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, i == 0 ? "mapreduce" : "other"); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, i == 1 ? "1.0.0" : "1.0.1"); conn.connect(); - Assert.assertEquals( + assertEquals( HttpURLConnection.HTTP_BAD_REQUEST, conn.getResponseCode()); } @@ -594,10 +1272,14 @@ public void testIncompatibleShuffleVersion() throws Exception { */ @Test (timeout = 10000) public void testMaxConnections() throws Exception { + final ArrayList failures = new ArrayList<>(); + final int maxAllowedConnections = 3; + final int notAcceptedConnections = 1; + final int connAttempts = maxAllowedConnections + notAcceptedConnections; Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); - conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, maxAllowedConnections); ShuffleHandler shuffleHandler = new ShuffleHandler() { @Override protected Shuffle getShuffle(Configuration conf) { @@ -605,7 +1287,7 @@ protected Shuffle getShuffle(Configuration conf) { return new Shuffle(conf) { @Override protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, - String jobId, String user) throws IOException { + String jobId, String user) { // Do nothing. return null; } @@ -613,13 +1295,12 @@ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce, protected void populateHeaders(List mapIds, String jobId, String user, int reduce, HttpRequest request, HttpResponse response, boolean keepAliveParam, - Map infoMap) throws IOException { + Map infoMap) { // Do nothing. } @Override protected void verifyRequest(String appid, ChannelHandlerContext ctx, - HttpRequest request, HttpResponse response, URL requestUri) - throws IOException { + HttpRequest request, HttpResponse response, URL requestUri) { // Do nothing. } @Override @@ -633,30 +1314,38 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, new ShuffleHeader("dummy_header", 5678, 5678, 1); DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); dob = new DataOutputBuffer(); for (int i=0; i<100000; ++i) { header.write(dob); } - return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + LOG.debug("ExceptionCaught"); + failures.add(cause); + super.exceptionCaught(ctx, cause); } }; } }; + shuffleHandler.setUseOutboundExceptionHandler(true); shuffleHandler.init(conf); shuffleHandler.start(); // setup connections - int connAttempts = 3; - HttpURLConnection conns[] = new HttpURLConnection[connAttempts]; + HttpURLConnection[] conns = new HttpURLConnection[connAttempts]; for (int i = 0; i < connAttempts; i++) { - String URLstring = "http://127.0.0.1:" + String urlString = "http://127.0.0.1:" + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_" + i + "_0"; - URL url = new URL(URLstring); - conns[i] = (HttpURLConnection)url.openConnection(); + URL url = new URL(urlString); + conns[i] = TEST_EXECUTION.openConnection(url); conns[i].setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); conns[i].setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, @@ -668,34 +1357,61 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, conns[i].connect(); } - //Ensure first connections are okay - conns[0].getInputStream(); - int rc = conns[0].getResponseCode(); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); - - conns[1].getInputStream(); - rc = conns[1].getResponseCode(); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); - - // This connection should be closed because it to above the limit - try { - rc = conns[2].getResponseCode(); - Assert.assertEquals("Expected a too-many-requests response code", - ShuffleHandler.TOO_MANY_REQ_STATUS.getCode(), rc); - long backoff = Long.valueOf( - conns[2].getHeaderField(ShuffleHandler.RETRY_AFTER_HEADER)); - Assert.assertTrue("The backoff value cannot be negative.", backoff > 0); - conns[2].getInputStream(); - Assert.fail("Expected an IOException"); - } catch (IOException ioe) { - LOG.info("Expected - connection should not be open"); - } catch (NumberFormatException ne) { - Assert.fail("Expected a numerical value for RETRY_AFTER header field"); - } catch (Exception e) { - Assert.fail("Expected a IOException"); + Map> mapOfConnections = Maps.newHashMap(); + for (HttpURLConnection conn : conns) { + try { + conn.getInputStream(); + } catch (IOException ioe) { + LOG.info("Expected - connection should not be open"); + } catch (NumberFormatException ne) { + fail("Expected a numerical value for RETRY_AFTER header field"); + } catch (Exception e) { + fail("Expected a IOException"); + } + int statusCode = conn.getResponseCode(); + LOG.debug("Connection status code: {}", statusCode); + mapOfConnections.putIfAbsent(statusCode, new ArrayList<>()); + List connectionList = mapOfConnections.get(statusCode); + connectionList.add(conn); } + + assertEquals(String.format("Expected only %s and %s response", + OK_STATUS, ShuffleHandler.TOO_MANY_REQ_STATUS), + Sets.newHashSet( + HttpURLConnection.HTTP_OK, + ShuffleHandler.TOO_MANY_REQ_STATUS.code()), + mapOfConnections.keySet()); - shuffleHandler.stop(); + List successfulConnections = + mapOfConnections.get(HttpURLConnection.HTTP_OK); + assertEquals(String.format("Expected exactly %d requests " + + "with %s response", maxAllowedConnections, OK_STATUS), + maxAllowedConnections, successfulConnections.size()); + + //Ensure exactly one connection is HTTP 429 (TOO MANY REQUESTS) + List closedConnections = + mapOfConnections.get(ShuffleHandler.TOO_MANY_REQ_STATUS.code()); + assertEquals(String.format("Expected exactly %d %s response", + notAcceptedConnections, ShuffleHandler.TOO_MANY_REQ_STATUS), + notAcceptedConnections, closedConnections.size()); + + // This connection should be closed because it is above the maximum limit + HttpURLConnection conn = closedConnections.get(0); + assertEquals(String.format("Expected a %s response", + ShuffleHandler.TOO_MANY_REQ_STATUS), + ShuffleHandler.TOO_MANY_REQ_STATUS.code(), conn.getResponseCode()); + long backoff = Long.parseLong( + conn.getHeaderField(ShuffleHandler.RETRY_AFTER_HEADER)); + assertTrue("The backoff value cannot be negative.", backoff > 0); + + shuffleHandler.stop(); + + //It's okay to get a ClosedChannelException. + //All other kinds of exceptions means something went wrong + assertEquals("Should have no caught exceptions", + Collections.emptyList(), failures.stream() + .filter(f -> !(f instanceof ClosedChannelException)) + .collect(toList())); } /** @@ -706,10 +1422,11 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, */ @Test(timeout = 100000) public void testMapFileAccess() throws IOException { + final ArrayList failures = new ArrayList<>(); // This will run only in NativeIO is enabled as SecureIOUtils need it assumeTrue(NativeIO.isAvailable()); Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3); conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION, "kerberos"); @@ -720,7 +1437,7 @@ public void testMapFileAccess() throws IOException { String appAttemptId = "attempt_12345_1_m_1_0"; String user = "randomUser"; String reducerId = "0"; - List fileMap = new ArrayList(); + List fileMap = new ArrayList<>(); createShuffleHandlerFiles(ABS_LOG_DIR, user, appId.toString(), appAttemptId, conf, fileMap); ShuffleHandler shuffleHandler = new ShuffleHandler() { @@ -731,15 +1448,31 @@ protected Shuffle getShuffle(Configuration conf) { @Override protected void verifyRequest(String appid, ChannelHandlerContext ctx, - HttpRequest request, HttpResponse response, URL requestUri) - throws IOException { + HttpRequest request, HttpResponse response, URL requestUri) { // Do nothing. } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + LOG.debug("ExceptionCaught"); + failures.add(cause); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().replace(HttpResponseEncoder.class, + "loggingResponseEncoder", + new LoggingHttpResponseEncoder(false)); + LOG.debug("Modified pipeline: {}", ctx.pipeline()); + super.channelActive(ctx); + } }; } }; AuxiliaryLocalPathHandler pathHandler = new TestAuxiliaryLocalPathHandler(); + shuffleHandler.setUseOutboundExceptionHandler(true); shuffleHandler.setAuxiliaryLocalPathHandler(pathHandler); shuffleHandler.init(conf); try { @@ -747,13 +1480,13 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx, DataOutputBuffer outputBuffer = new DataOutputBuffer(); outputBuffer.reset(); Token jt = - new Token("identifier".getBytes(), + new Token<>("identifier".getBytes(), "password".getBytes(), new Text(user), new Text("shuffleService")); jt.write(outputBuffer); shuffleHandler - .initializeApplication(new ApplicationInitializationContext(user, - appId, ByteBuffer.wrap(outputBuffer.getData(), 0, - outputBuffer.getLength()))); + .initializeApplication(new ApplicationInitializationContext(user, + appId, ByteBuffer.wrap(outputBuffer.getData(), 0, + outputBuffer.getLength()))); URL url = new URL( "http://127.0.0.1:" @@ -761,32 +1494,37 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx, ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + "/mapOutput?job=job_12345_0001&reduce=" + reducerId + "&map=attempt_12345_1_m_1_0"); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); conn.connect(); - byte[] byteArr = new byte[10000]; - try { - DataInputStream is = new DataInputStream(conn.getInputStream()); - is.readFully(byteArr); - } catch (EOFException e) { - // ignore - } - // Retrieve file owner name - FileInputStream is = new FileInputStream(fileMap.get(0)); - String owner = NativeIO.POSIX.getFstat(is.getFD()).getOwner(); - is.close(); + DataInputStream is = new DataInputStream(conn.getInputStream()); + InputStreamReadResult result = HttpConnectionHelper.readDataFromInputStream(is); + String receivedString = result.asString; + + //Retrieve file owner name + FileInputStream fis = new FileInputStream(fileMap.get(0)); + String owner = NativeIO.POSIX.getFstat(fis.getFD()).getOwner(); + fis.close(); String message = "Owner '" + owner + "' for path " + fileMap.get(0).getAbsolutePath() + " did not match expected owner '" + user + "'"; - Assert.assertTrue((new String(byteArr)).contains(message)); + assertTrue(String.format("Received string '%s' should contain " + + "message '%s'", receivedString, message), + receivedString.contains(message)); + assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); + LOG.info("received: " + receivedString); + assertNotEquals("", receivedString); } finally { shuffleHandler.stop(); FileUtil.fullyDelete(ABS_LOG_DIR); } + + assertEquals("Should have no caught exceptions", + Collections.emptyList(), failures); } private static void createShuffleHandlerFiles(File logDir, String user, @@ -794,7 +1532,7 @@ private static void createShuffleHandlerFiles(File logDir, String user, List fileMap) throws IOException { String attemptDir = StringUtils.join(Path.SEPARATOR, - new String[] { logDir.getAbsolutePath(), + new String[] {logDir.getAbsolutePath(), ContainerLocalizer.USERCACHE, user, ContainerLocalizer.APPCACHE, appId, "output", appAttemptId }); File appAttemptDir = new File(attemptDir); @@ -808,8 +1546,7 @@ private static void createShuffleHandlerFiles(File logDir, String user, createMapOutputFile(mapOutputFile, conf); } - private static void - createMapOutputFile(File mapOutputFile, Configuration conf) + private static void createMapOutputFile(File mapOutputFile, Configuration conf) throws IOException { FileOutputStream out = new FileOutputStream(mapOutputFile); out.write("Creating new dummy map output file. Used only for testing" @@ -846,11 +1583,11 @@ public void testRecovery() throws IOException { final File tmpDir = new File(System.getProperty("test.build.data", System.getProperty("java.io.tmpdir")), TestShuffleHandler.class.getName()); - ShuffleHandler shuffle = new ShuffleHandler(); + ShuffleHandler shuffle = new ShuffleHandlerForTests(); AuxiliaryLocalPathHandler pathHandler = new TestAuxiliaryLocalPathHandler(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3); conf.set(YarnConfiguration.NM_LOCAL_DIRS, ABS_LOG_DIR.getAbsolutePath()); @@ -861,10 +1598,10 @@ public void testRecovery() throws IOException { shuffle.init(conf); shuffle.start(); - // setup a shuffle token for an application + // set up a shuffle token for an application DataOutputBuffer outputBuffer = new DataOutputBuffer(); outputBuffer.reset(); - Token jt = new Token( + Token jt = new Token<>( "identifier".getBytes(), "password".getBytes(), new Text(user), new Text("shuffleService")); jt.write(outputBuffer); @@ -874,11 +1611,11 @@ public void testRecovery() throws IOException { // verify we are authorized to shuffle int rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); + assertEquals(HttpURLConnection.HTTP_OK, rc); // emulate shuffle handler restart shuffle.close(); - shuffle = new ShuffleHandler(); + shuffle = new ShuffleHandlerForTests(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); shuffle.setRecoveryPath(new Path(tmpDir.toString())); shuffle.init(conf); @@ -886,23 +1623,23 @@ public void testRecovery() throws IOException { // verify we are still authorized to shuffle to the old application rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); + assertEquals(HttpURLConnection.HTTP_OK, rc); // shutdown app and verify access is lost shuffle.stopApplication(new ApplicationTerminationContext(appId)); rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc); + assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc); // emulate shuffle handler restart shuffle.close(); - shuffle = new ShuffleHandler(); + shuffle = new ShuffleHandlerForTests(); shuffle.setRecoveryPath(new Path(tmpDir.toString())); shuffle.init(conf); shuffle.start(); // verify we still don't have access rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc); + assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc); } finally { if (shuffle != null) { shuffle.close(); @@ -919,9 +1656,9 @@ public void testRecoveryFromOtherVersions() throws IOException { System.getProperty("java.io.tmpdir")), TestShuffleHandler.class.getName()); Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3); - ShuffleHandler shuffle = new ShuffleHandler(); + ShuffleHandler shuffle = new ShuffleHandlerForTests(); AuxiliaryLocalPathHandler pathHandler = new TestAuxiliaryLocalPathHandler(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); conf.set(YarnConfiguration.NM_LOCAL_DIRS, ABS_LOG_DIR.getAbsolutePath()); @@ -932,10 +1669,10 @@ public void testRecoveryFromOtherVersions() throws IOException { shuffle.init(conf); shuffle.start(); - // setup a shuffle token for an application + // set up a shuffle token for an application DataOutputBuffer outputBuffer = new DataOutputBuffer(); outputBuffer.reset(); - Token jt = new Token( + Token jt = new Token<>( "identifier".getBytes(), "password".getBytes(), new Text(user), new Text("shuffleService")); jt.write(outputBuffer); @@ -945,11 +1682,11 @@ public void testRecoveryFromOtherVersions() throws IOException { // verify we are authorized to shuffle int rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); + assertEquals(HttpURLConnection.HTTP_OK, rc); // emulate shuffle handler restart shuffle.close(); - shuffle = new ShuffleHandler(); + shuffle = new ShuffleHandlerForTests(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); shuffle.setRecoveryPath(new Path(tmpDir.toString())); shuffle.init(conf); @@ -957,44 +1694,44 @@ public void testRecoveryFromOtherVersions() throws IOException { // verify we are still authorized to shuffle to the old application rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); + assertEquals(HttpURLConnection.HTTP_OK, rc); Version version = Version.newInstance(1, 0); - Assert.assertEquals(version, shuffle.getCurrentVersion()); + assertEquals(version, shuffle.getCurrentVersion()); // emulate shuffle handler restart with compatible version Version version11 = Version.newInstance(1, 1); // update version info before close shuffle shuffle.storeVersion(version11); - Assert.assertEquals(version11, shuffle.loadVersion()); + assertEquals(version11, shuffle.loadVersion()); shuffle.close(); - shuffle = new ShuffleHandler(); + shuffle = new ShuffleHandlerForTests(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); shuffle.setRecoveryPath(new Path(tmpDir.toString())); shuffle.init(conf); shuffle.start(); // shuffle version will be override by CURRENT_VERSION_INFO after restart // successfully. - Assert.assertEquals(version, shuffle.loadVersion()); + assertEquals(version, shuffle.loadVersion()); // verify we are still authorized to shuffle to the old application rc = getShuffleResponseCode(shuffle, jt); - Assert.assertEquals(HttpURLConnection.HTTP_OK, rc); + assertEquals(HttpURLConnection.HTTP_OK, rc); // emulate shuffle handler restart with incompatible version Version version21 = Version.newInstance(2, 1); shuffle.storeVersion(version21); - Assert.assertEquals(version21, shuffle.loadVersion()); + assertEquals(version21, shuffle.loadVersion()); shuffle.close(); - shuffle = new ShuffleHandler(); + shuffle = new ShuffleHandlerForTests(); shuffle.setAuxiliaryLocalPathHandler(pathHandler); shuffle.setRecoveryPath(new Path(tmpDir.toString())); shuffle.init(conf); try { shuffle.start(); - Assert.fail("Incompatible version, should expect fail here."); + fail("Incompatible version, should expect fail here."); } catch (ServiceStateException e) { - Assert.assertTrue("Exception message mismatch", - e.getMessage().contains("Incompatible version for state DB schema:")); + assertTrue("Exception message mismatch", + e.getMessage().contains("Incompatible version for state DB schema:")); } } finally { @@ -1010,7 +1747,7 @@ private static int getShuffleResponseCode(ShuffleHandler shuffle, URL url = new URL("http://127.0.0.1:" + shuffle.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + "/mapOutput?job=job_12345_0001&reduce=0&map=attempt_12345_1_m_1_0"); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); String encHash = SecureShuffleUtils.hashFromString( SecureShuffleUtils.buildMsgFrom(url), JobTokenSecretManager.createSecretKey(jt.getPassword())); @@ -1028,9 +1765,9 @@ private static int getShuffleResponseCode(ShuffleHandler shuffle, @Test(timeout = 100000) public void testGetMapOutputInfo() throws Exception { - final ArrayList failures = new ArrayList(1); + final ArrayList failures = new ArrayList<>(1); Configuration conf = new Configuration(); - conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3); conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION, "simple"); @@ -1040,7 +1777,7 @@ public void testGetMapOutputInfo() throws Exception { String appAttemptId = "attempt_12345_1_m_1_0"; String user = "randomUser"; String reducerId = "0"; - List fileMap = new ArrayList(); + List fileMap = new ArrayList<>(); createShuffleHandlerFiles(ABS_LOG_DIR, user, appId.toString(), appAttemptId, conf, fileMap); AuxiliaryLocalPathHandler pathHandler = new TestAuxiliaryLocalPathHandler(); @@ -1062,7 +1799,7 @@ protected void populateHeaders(List mapIds, @Override protected void verifyRequest(String appid, ChannelHandlerContext ctx, HttpRequest request, - HttpResponse response, URL requestUri) throws IOException { + HttpResponse response, URL requestUri) { // Do nothing. } @Override @@ -1070,7 +1807,7 @@ protected void sendError(ChannelHandlerContext ctx, String message, HttpResponseStatus status) { if (failures.size() == 0) { failures.add(new Error(message)); - ctx.getChannel().close(); + ctx.channel().close(); } } @Override @@ -1082,11 +1819,12 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1); DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); - return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); + return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength())); } }; } }; + shuffleHandler.setUseOutboundExceptionHandler(true); shuffleHandler.setAuxiliaryLocalPathHandler(pathHandler); shuffleHandler.init(conf); try { @@ -1094,8 +1832,8 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, DataOutputBuffer outputBuffer = new DataOutputBuffer(); outputBuffer.reset(); Token jt = - new Token("identifier".getBytes(), - "password".getBytes(), new Text(user), new Text("shuffleService")); + new Token<>("identifier".getBytes(), + "password".getBytes(), new Text(user), new Text("shuffleService")); jt.write(outputBuffer); shuffleHandler .initializeApplication(new ApplicationInitializationContext(user, @@ -1108,7 +1846,7 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY) + "/mapOutput?job=job_12345_0001&reduce=" + reducerId + "&map=attempt_12345_1_m_1_0"); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + HttpURLConnection conn = TEST_EXECUTION.openConnection(url); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, @@ -1122,7 +1860,7 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, } catch (EOFException e) { // ignore } - Assert.assertEquals("sendError called due to shuffle error", + assertEquals("sendError called due to shuffle error", 0, failures.size()); } finally { shuffleHandler.stop(); @@ -1133,11 +1871,10 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, @Test(timeout = 4000) public void testSendMapCount() throws Exception { final List listenerList = - new ArrayList(); - + new ArrayList<>(); + int connectionKeepAliveTimeOut = 5; //arbitrary value final ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); - final MessageEvent mockEvt = mock(MessageEvent.class); final Channel mockCh = mock(AbstractChannel.class); final ChannelPipeline mockPipeline = mock(ChannelPipeline.class); @@ -1146,29 +1883,23 @@ public void testSendMapCount() throws Exception { final ChannelFuture mockFuture = createMockChannelFuture(mockCh, listenerList); final ShuffleHandler.TimeoutHandler timerHandler = - new ShuffleHandler.TimeoutHandler(); + new ShuffleHandler.TimeoutHandler(connectionKeepAliveTimeOut); // Mock Netty Channel Context and Channel behavior - Mockito.doReturn(mockCh).when(mockCtx).getChannel(); - when(mockCh.getPipeline()).thenReturn(mockPipeline); + Mockito.doReturn(mockCh).when(mockCtx).channel(); + when(mockCh.pipeline()).thenReturn(mockPipeline); when(mockPipeline.get( Mockito.any(String.class))).thenReturn(timerHandler); - when(mockCtx.getChannel()).thenReturn(mockCh); - Mockito.doReturn(mockFuture).when(mockCh).write(Mockito.any(Object.class)); - when(mockCh.write(Object.class)).thenReturn(mockFuture); + when(mockCtx.channel()).thenReturn(mockCh); + Mockito.doReturn(mockFuture).when(mockCh).writeAndFlush(Mockito.any(Object.class)); - //Mock MessageEvent behavior - Mockito.doReturn(mockCh).when(mockEvt).getChannel(); - when(mockEvt.getChannel()).thenReturn(mockCh); - Mockito.doReturn(mockHttpRequest).when(mockEvt).getMessage(); - - final ShuffleHandler sh = new MockShuffleHandler(); + final MockShuffleHandler sh = new MockShuffleHandler(); Configuration conf = new Configuration(); sh.init(conf); sh.start(); int maxOpenFiles =conf.getInt(ShuffleHandler.SHUFFLE_MAX_SESSION_OPEN_FILES, ShuffleHandler.DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES); - sh.getShuffle(conf).messageReceived(mockCtx, mockEvt); + sh.getShuffle(conf).channelRead(mockCtx, mockHttpRequest); assertTrue("Number of Open files should not exceed the configured " + "value!-Not Expected", listenerList.size() <= maxOpenFiles); @@ -1179,23 +1910,97 @@ public void testSendMapCount() throws Exception { listenerList.size() <= maxOpenFiles); } sh.close(); + sh.stop(); + + assertEquals("Should have no caught exceptions", + Collections.emptyList(), sh.failures); + } + + @Test(timeout = 10000) + public void testIdleStateHandlingSpecifiedTimeout() throws Exception { + int timeoutSeconds = 4; + int expectedTimeoutSeconds = timeoutSeconds; + testHandlingIdleState(timeoutSeconds, expectedTimeoutSeconds); + } + + @Test(timeout = 10000) + public void testIdleStateHandlingNegativeTimeoutDefaultsTo1Second() throws Exception { + int expectedTimeoutSeconds = 1; //expected by production code + testHandlingIdleState(ARBITRARY_NEGATIVE_TIMEOUT_SECONDS, expectedTimeoutSeconds); + } + + private String getShuffleUrlWithKeepAlive(ShuffleHandler shuffleHandler, long jobId, + long... attemptIds) { + String url = getShuffleUrl(shuffleHandler, jobId, attemptIds); + return url + "&keepAlive=true"; + } + + private String getShuffleUrl(ShuffleHandler shuffleHandler, long jobId, long... attemptIds) { + String port = shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY); + String shuffleBaseURL = "http://127.0.0.1:" + port; + + StringBuilder mapAttemptIds = new StringBuilder(); + for (int i = 0; i < attemptIds.length; i++) { + if (i == 0) { + mapAttemptIds.append("&map="); + } else { + mapAttemptIds.append(","); + } + mapAttemptIds.append(String.format("attempt_%s_1_m_1_0", attemptIds[i])); + } + + String location = String.format("/mapOutput" + + "?job=job_%s_1" + + "&reduce=1" + + "%s", jobId, mapAttemptIds); + return shuffleBaseURL + location; + } + + private void testHandlingIdleState(int configuredTimeoutSeconds, int expectedTimeoutSeconds) + throws IOException, + InterruptedException { + Configuration conf = new Configuration(); + conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, TEST_EXECUTION.shuffleHandlerPort()); + conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true); + conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, configuredTimeoutSeconds); + + final CountDownLatch countdownLatch = new CountDownLatch(1); + ResponseConfig responseConfig = new ResponseConfig(HEADER_WRITE_COUNT, 0, 0); + ShuffleHandlerForKeepAliveTests shuffleHandler = new ShuffleHandlerForKeepAliveTests( + ATTEMPT_ID, responseConfig, + event -> countdownLatch.countDown()); + shuffleHandler.init(conf); + shuffleHandler.start(); + + String shuffleUrl = getShuffleUrl(shuffleHandler, ATTEMPT_ID, ATTEMPT_ID); + String[] urls = new String[] {shuffleUrl}; + HttpConnectionHelper httpConnectionHelper = new HttpConnectionHelper( + shuffleHandler.lastSocketAddress); + long beforeConnectionTimestamp = System.currentTimeMillis(); + httpConnectionHelper.connectToUrls(urls, shuffleHandler.responseConfig); + countdownLatch.await(); + long channelClosedTimestamp = System.currentTimeMillis(); + long secondsPassed = + TimeUnit.SECONDS.convert(channelClosedTimestamp - beforeConnectionTimestamp, + TimeUnit.MILLISECONDS); + assertTrue(String.format("Expected at least %s seconds of timeout. " + + "Actual timeout seconds: %s", expectedTimeoutSeconds, secondsPassed), + secondsPassed >= expectedTimeoutSeconds); + shuffleHandler.stop(); } public ChannelFuture createMockChannelFuture(Channel mockCh, final List listenerList) { final ChannelFuture mockFuture = mock(ChannelFuture.class); - when(mockFuture.getChannel()).thenReturn(mockCh); + when(mockFuture.channel()).thenReturn(mockCh); Mockito.doReturn(true).when(mockFuture).isSuccess(); - Mockito.doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - //Add ReduceMapFileCount listener to a list - if (invocation.getArguments()[0].getClass() == - ShuffleHandler.ReduceMapFileCount.class) - listenerList.add((ShuffleHandler.ReduceMapFileCount) - invocation.getArguments()[0]); - return null; + Mockito.doAnswer(invocation -> { + //Add ReduceMapFileCount listener to a list + if (invocation.getArguments()[0].getClass() == ShuffleHandler.ReduceMapFileCount.class) { + listenerList.add((ShuffleHandler.ReduceMapFileCount) + invocation.getArguments()[0]); } + return null; }).when(mockFuture).addListener(Mockito.any( ShuffleHandler.ReduceMapFileCount.class)); return mockFuture; @@ -1203,16 +2008,14 @@ public Object answer(InvocationOnMock invocation) throws Throwable { public HttpRequest createMockHttpRequest() { HttpRequest mockHttpRequest = mock(HttpRequest.class); - Mockito.doReturn(HttpMethod.GET).when(mockHttpRequest).getMethod(); - Mockito.doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - String uri = "/mapOutput?job=job_12345_1&reduce=1"; - for (int i = 0; i < 100; i++) - uri = uri.concat("&map=attempt_12345_1_m_" + i + "_0"); - return uri; + Mockito.doReturn(HttpMethod.GET).when(mockHttpRequest).method(); + Mockito.doAnswer(invocation -> { + String uri = "/mapOutput?job=job_12345_1&reduce=1"; + for (int i = 0; i < 100; i++) { + uri = uri.concat("&map=attempt_12345_1_m_" + i + "_0"); } - }).when(mockHttpRequest).getUri(); + return uri; + }).when(mockHttpRequest).uri(); return mockHttpRequest; } } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/resources/log4j.properties b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/resources/log4j.properties index 81a3f6ad5d..b7d8ad36ef 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/resources/log4j.properties +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/resources/log4j.properties @@ -17,3 +17,5 @@ log4j.threshold=ALL log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout.layout=org.apache.log4j.PatternLayout log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %c{2} (%F:%M(%L)) - %m%n +log4j.logger.io.netty=INFO +log4j.logger.org.apache.hadoop.mapred=INFO \ No newline at end of file diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/pom.xml b/hadoop-mapreduce-project/hadoop-mapreduce-client/pom.xml index b394fe5be1..fdcab2f2ff 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/pom.xml +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/pom.xml @@ -130,7 +130,7 @@ io.netty - netty + netty-all commons-logging