HADOOP-15327. Upgrade MR ShuffleHandler to use Netty4 #3259. Contributed by Szilard Nemeth.

This commit is contained in:
Szilard Nemeth 2021-06-05 00:14:07 +02:00 committed by Benjamin Teke
parent 552ee44eba
commit 5bb11cecea
11 changed files with 1606 additions and 575 deletions

View File

@ -148,6 +148,7 @@
<!-- Leave javax APIs that are stable --> <!-- Leave javax APIs that are stable -->
<!-- the jdk ships part of the javax.annotation namespace, so if we want to relocate this we'll have to care it out by class :( --> <!-- the jdk ships part of the javax.annotation namespace, so if we want to relocate this we'll have to care it out by class :( -->
<exclude>com.google.code.findbugs:jsr305</exclude> <exclude>com.google.code.findbugs:jsr305</exclude>
<exclude>io.netty:*</exclude>
<exclude>io.dropwizard.metrics:metrics-core</exclude> <exclude>io.dropwizard.metrics:metrics-core</exclude>
<exclude>org.eclipse.jetty:jetty-servlet</exclude> <exclude>org.eclipse.jetty:jetty-servlet</exclude>
<exclude>org.eclipse.jetty:jetty-security</exclude> <exclude>org.eclipse.jetty:jetty-security</exclude>

View File

@ -53,14 +53,15 @@
import org.apache.hadoop.classification.VisibleForTesting; import org.apache.hadoop.classification.VisibleForTesting;
class Fetcher<K,V> extends Thread { @VisibleForTesting
public class Fetcher<K, V> extends Thread {
private static final Logger LOG = LoggerFactory.getLogger(Fetcher.class); private static final Logger LOG = LoggerFactory.getLogger(Fetcher.class);
/** Number of ms before timing out a copy */ /** Number of ms before timing out a copy. */
private static final int DEFAULT_STALLED_COPY_TIMEOUT = 3 * 60 * 1000; private static final int DEFAULT_STALLED_COPY_TIMEOUT = 3 * 60 * 1000;
/** Basic/unit connection timeout (in milliseconds) */ /** Basic/unit connection timeout (in milliseconds). */
private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000; private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000;
/* Default read timeout (in milliseconds) */ /* Default read timeout (in milliseconds) */
@ -72,10 +73,12 @@ class Fetcher<K,V> extends Thread {
private static final String FETCH_RETRY_AFTER_HEADER = "Retry-After"; private static final String FETCH_RETRY_AFTER_HEADER = "Retry-After";
protected final Reporter reporter; protected final Reporter reporter;
private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, @VisibleForTesting
public enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
CONNECTION, WRONG_REDUCE} CONNECTION, WRONG_REDUCE}
private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors"; @VisibleForTesting
public final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
private final JobConf jobConf; private final JobConf jobConf;
private final Counters.Counter connectionErrs; private final Counters.Counter connectionErrs;
private final Counters.Counter ioErrs; private final Counters.Counter ioErrs;
@ -316,8 +319,7 @@ protected void copyFromHost(MapHost host) throws IOException {
} }
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " + maps);
+ maps);
} }
// List of maps to be fetched yet // List of maps to be fetched yet
@ -628,7 +630,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe)
} }
/** /**
* Do some basic verification on the input received -- Being defensive * Do some basic verification on the input received -- Being defensive.
* @param compressedLength * @param compressedLength
* @param decompressedLength * @param decompressedLength
* @param forReduce * @param forReduce
@ -695,8 +697,7 @@ private URL getMapOutputURL(MapHost host, Collection<TaskAttemptID> maps
* only on the last failure. Instead of connecting with a timeout of * only on the last failure. Instead of connecting with a timeout of
* X, we try connecting with a timeout of x < X but multiple times. * X, we try connecting with a timeout of x < X but multiple times.
*/ */
private void connect(URLConnection connection, int connectionTimeout) private void connect(URLConnection connection, int connectionTimeout) throws IOException {
throws IOException {
int unit = 0; int unit = 0;
if (connectionTimeout < 0) { if (connectionTimeout < 0) {
throw new IOException("Invalid timeout " throw new IOException("Invalid timeout "

View File

@ -26,6 +26,7 @@
import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.TaskCounter; import org.apache.hadoop.mapreduce.TaskCounter;
import org.apache.hadoop.mapreduce.task.reduce.Fetcher;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -37,6 +38,7 @@
import java.util.Formatter; import java.util.Formatter;
import java.util.Iterator; import java.util.Iterator;
import static org.apache.hadoop.mapreduce.task.reduce.Fetcher.SHUFFLE_ERR_GRP_NAME;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -87,6 +89,9 @@ public void testReduceFromPartialMem() throws Exception {
final long spill = c.findCounter(TaskCounter.SPILLED_RECORDS).getCounter(); final long spill = c.findCounter(TaskCounter.SPILLED_RECORDS).getCounter();
assertTrue("Expected some records not spilled during reduce" + spill + ")", assertTrue("Expected some records not spilled during reduce" + spill + ")",
spill < 2 * out); // spilled map records, some records at the reduce spill < 2 * out); // spilled map records, some records at the reduce
long shuffleIoErrors =
c.getGroup(SHUFFLE_ERR_GRP_NAME).getCounter(Fetcher.ShuffleErrors.IO_ERROR.toString());
assertEquals(0, shuffleIoErrors);
} }
/** /**

View File

@ -23,6 +23,9 @@
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import org.apache.hadoop.classification.VisibleForTesting; import org.apache.hadoop.classification.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.stream.ChunkedFile;
import org.apache.hadoop.io.ReadaheadPool; import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
import org.apache.hadoop.io.nativeio.NativeIO; import org.apache.hadoop.io.nativeio.NativeIO;
@ -31,8 +34,6 @@
import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED; import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;
import org.jboss.netty.handler.stream.ChunkedFile;
public class FadvisedChunkedFile extends ChunkedFile { public class FadvisedChunkedFile extends ChunkedFile {
private static final Logger LOG = private static final Logger LOG =
@ -64,16 +65,16 @@ FileDescriptor getFd() {
} }
@Override @Override
public Object nextChunk() throws Exception { public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
synchronized (closeLock) { synchronized (closeLock) {
if (fd.valid()) { if (fd.valid()) {
if (manageOsCache && readaheadPool != null) { if (manageOsCache && readaheadPool != null) {
readaheadRequest = readaheadPool readaheadRequest = readaheadPool
.readaheadStream( .readaheadStream(
identifier, fd, getCurrentOffset(), readaheadLength, identifier, fd, currentOffset(), readaheadLength,
getEndOffset(), readaheadRequest); endOffset(), readaheadRequest);
} }
return super.nextChunk(); return super.readChunk(allocator);
} else { } else {
return null; return null;
} }
@ -88,12 +89,12 @@ public void close() throws Exception {
readaheadRequest = null; readaheadRequest = null;
} }
if (fd.valid() && if (fd.valid() &&
manageOsCache && getEndOffset() - getStartOffset() > 0) { manageOsCache && endOffset() - startOffset() > 0) {
try { try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible( NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(
identifier, identifier,
fd, fd,
getStartOffset(), getEndOffset() - getStartOffset(), startOffset(), endOffset() - startOffset(),
POSIX_FADV_DONTNEED); POSIX_FADV_DONTNEED);
} catch (Throwable t) { } catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier + LOG.warn("Failed to manage OS cache for " + identifier +

View File

@ -25,6 +25,7 @@
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
import io.netty.channel.DefaultFileRegion;
import org.apache.hadoop.io.ReadaheadPool; import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest; import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
import org.apache.hadoop.io.nativeio.NativeIO; import org.apache.hadoop.io.nativeio.NativeIO;
@ -33,8 +34,6 @@
import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED; import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;
import org.jboss.netty.channel.DefaultFileRegion;
import org.apache.hadoop.classification.VisibleForTesting; import org.apache.hadoop.classification.VisibleForTesting;
public class FadvisedFileRegion extends DefaultFileRegion { public class FadvisedFileRegion extends DefaultFileRegion {
@ -77,8 +76,8 @@ public long transferTo(WritableByteChannel target, long position)
throws IOException { throws IOException {
if (readaheadPool != null && readaheadLength > 0) { if (readaheadPool != null && readaheadLength > 0) {
readaheadRequest = readaheadPool.readaheadStream(identifier, fd, readaheadRequest = readaheadPool.readaheadStream(identifier, fd,
getPosition() + position, readaheadLength, position() + position, readaheadLength,
getPosition() + getCount(), readaheadRequest); position() + count(), readaheadRequest);
} }
if(this.shuffleTransferToAllowed) { if(this.shuffleTransferToAllowed) {
@ -147,11 +146,11 @@ long customShuffleTransfer(WritableByteChannel target, long position)
@Override @Override
public void releaseExternalResources() { protected void deallocate() {
if (readaheadRequest != null) { if (readaheadRequest != null) {
readaheadRequest.cancel(); readaheadRequest.cancel();
} }
super.releaseExternalResources(); super.deallocate();
} }
/** /**
@ -159,10 +158,10 @@ public void releaseExternalResources() {
* we don't need the region to be cached anymore. * we don't need the region to be cached anymore.
*/ */
public void transferSuccessful() { public void transferSuccessful() {
if (manageOsCache && getCount() > 0) { if (manageOsCache && count() > 0) {
try { try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier, NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier,
fd, getPosition(), getCount(), POSIX_FADV_DONTNEED); fd, position(), count(), POSIX_FADV_DONTNEED);
} catch (Throwable t) { } catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier, t); LOG.warn("Failed to manage OS cache for " + identifier, t);
} }

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapred;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
class LoggingHttpResponseEncoder extends HttpResponseEncoder {
private static final Logger LOG = LoggerFactory.getLogger(LoggingHttpResponseEncoder.class);
private final boolean logStacktraceOfEncodingMethods;
LoggingHttpResponseEncoder(boolean logStacktraceOfEncodingMethods) {
this.logStacktraceOfEncodingMethods = logStacktraceOfEncodingMethods;
}
@Override
public boolean acceptOutboundMessage(Object msg) throws Exception {
printExecutingMethod();
LOG.info("OUTBOUND MESSAGE: " + msg);
return super.acceptOutboundMessage(msg);
}
@Override
protected void encodeInitialLine(ByteBuf buf, HttpResponse response) throws Exception {
LOG.debug("Executing method: {}, response: {}",
getExecutingMethodName(), response);
logStacktraceIfRequired();
super.encodeInitialLine(buf, response);
}
@Override
protected void encode(ChannelHandlerContext ctx, Object msg,
List<Object> out) throws Exception {
LOG.debug("Encoding to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
logStacktraceIfRequired();
super.encode(ctx, msg, out);
}
@Override
protected void encodeHeaders(HttpHeaders headers, ByteBuf buf) {
printExecutingMethod();
super.encodeHeaders(headers, buf);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) throws Exception {
LOG.debug("Writing to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
super.write(ctx, msg, promise);
}
private void logStacktraceIfRequired() {
if (logStacktraceOfEncodingMethods) {
LOG.debug("Stacktrace: ", new Throwable());
}
}
private void printExecutingMethod() {
String methodName = getExecutingMethodName(1);
LOG.debug("Executing method: {}", methodName);
}
private String getExecutingMethodName() {
return getExecutingMethodName(0);
}
private String getExecutingMethodName(int additionalSkipFrames) {
try {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
// Array items (indices):
// 0: java.lang.Thread.getStackTrace(...)
// 1: TestShuffleHandler$LoggingHttpResponseEncoder.getExecutingMethodName(...)
int skipFrames = 2 + additionalSkipFrames;
String methodName = stackTrace[skipFrames].getMethodName();
String className = this.getClass().getSimpleName();
return className + "#" + methodName;
} catch (Throwable t) {
LOG.error("Error while getting execution method name", t);
return "unknown";
}
}
}

View File

@ -18,19 +18,20 @@
package org.apache.hadoop.mapred; package org.apache.hadoop.mapred;
import static io.netty.buffer.Unpooled.wrappedBuffer;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.apache.hadoop.mapred.ShuffleHandler.NettyChannelHelper.*;
import static org.fusesource.leveldbjni.JniDBFactory.asString; import static org.fusesource.leveldbjni.JniDBFactory.asString;
import static org.fusesource.leveldbjni.JniDBFactory.bytes; import static org.fusesource.leveldbjni.JniDBFactory.bytes;
import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
@ -54,6 +55,44 @@
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataInputByteBuffer; import org.apache.hadoop.io.DataInputByteBuffer;
@ -79,7 +118,6 @@
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.DiskChecker; import org.apache.hadoop.util.DiskChecker;
import org.apache.hadoop.util.Shell; import org.apache.hadoop.util.Shell;
import org.apache.hadoop.util.concurrent.HadoopExecutors;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto; import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
@ -94,42 +132,6 @@
import org.iq80.leveldb.DB; import org.iq80.leveldb.DB;
import org.iq80.leveldb.DBException; import org.iq80.leveldb.DBException;
import org.iq80.leveldb.Options; import org.iq80.leveldb.Options;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandler;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.jboss.netty.handler.codec.frame.TooLongFrameException;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpChunkAggregator;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestDecoder;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseEncoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.QueryStringDecoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.handler.stream.ChunkedWriteHandler;
import org.jboss.netty.handler.timeout.IdleState;
import org.jboss.netty.handler.timeout.IdleStateAwareChannelHandler;
import org.jboss.netty.handler.timeout.IdleStateEvent;
import org.jboss.netty.handler.timeout.IdleStateHandler;
import org.jboss.netty.util.CharsetUtil;
import org.jboss.netty.util.HashedWheelTimer;
import org.jboss.netty.util.Timer;
import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpHeader;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -182,19 +184,29 @@ public class ShuffleHandler extends AuxiliaryService {
public static final HttpResponseStatus TOO_MANY_REQ_STATUS = public static final HttpResponseStatus TOO_MANY_REQ_STATUS =
new HttpResponseStatus(429, "TOO MANY REQUESTS"); new HttpResponseStatus(429, "TOO MANY REQUESTS");
// This should kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT // This should be kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT
public static final long FETCH_RETRY_DELAY = 1000L; public static final long FETCH_RETRY_DELAY = 1000L;
public static final String RETRY_AFTER_HEADER = "Retry-After"; public static final String RETRY_AFTER_HEADER = "Retry-After";
static final String ENCODER_HANDLER_NAME = "encoder";
private int port; private int port;
private ChannelFactory selector; private EventLoopGroup bossGroup;
private final ChannelGroup accepted = new DefaultChannelGroup(); private EventLoopGroup workerGroup;
private ServerBootstrap bootstrap;
private Channel ch;
private final ChannelGroup accepted =
new DefaultChannelGroup(new DefaultEventExecutorGroup(5).next());
private final AtomicInteger activeConnections = new AtomicInteger();
protected HttpPipelineFactory pipelineFact; protected HttpPipelineFactory pipelineFact;
private int sslFileBufferSize; private int sslFileBufferSize;
//TODO snemeth add a config option for these later, this is temporarily disabled for now.
private boolean useOutboundExceptionHandler = false;
private boolean useOutboundLogger = false;
/** /**
* Should the shuffle use posix_fadvise calls to manage the OS cache during * Should the shuffle use posix_fadvise calls to manage the OS cache during
* sendfile * sendfile.
*/ */
private boolean manageOsCache; private boolean manageOsCache;
private int readaheadLength; private int readaheadLength;
@ -255,7 +267,7 @@ public class ShuffleHandler extends AuxiliaryService {
public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true; public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true;
public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED =
false; false;
private static final String TIMEOUT_HANDLER = "timeout"; static final String TIMEOUT_HANDLER = "timeout";
/* the maximum number of files a single GET request can /* the maximum number of files a single GET request can
open simultaneously during shuffle open simultaneously during shuffle
@ -267,7 +279,6 @@ public class ShuffleHandler extends AuxiliaryService {
boolean connectionKeepAliveEnabled = false; boolean connectionKeepAliveEnabled = false;
private int connectionKeepAliveTimeOut; private int connectionKeepAliveTimeOut;
private int mapOutputMetaInfoCacheSize; private int mapOutputMetaInfoCacheSize;
private Timer timer;
@Metrics(about="Shuffle output metrics", context="mapred") @Metrics(about="Shuffle output metrics", context="mapred")
static class ShuffleMetrics implements ChannelFutureListener { static class ShuffleMetrics implements ChannelFutureListener {
@ -291,6 +302,49 @@ public void operationComplete(ChannelFuture future) throws Exception {
} }
} }
static class NettyChannelHelper {
static ChannelFuture writeToChannel(Channel ch, Object obj) {
LOG.debug("Writing {} to channel: {}", obj.getClass().getSimpleName(), ch.id());
return ch.writeAndFlush(obj);
}
static ChannelFuture writeToChannelAndClose(Channel ch, Object obj) {
return writeToChannel(ch, obj).addListener(ChannelFutureListener.CLOSE);
}
static ChannelFuture writeToChannelAndAddLastHttpContent(Channel ch, HttpResponse obj) {
writeToChannel(ch, obj);
return writeLastHttpContentToChannel(ch);
}
static ChannelFuture writeLastHttpContentToChannel(Channel ch) {
LOG.debug("Writing LastHttpContent, channel id: {}", ch.id());
return ch.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
}
static ChannelFuture closeChannel(Channel ch) {
LOG.debug("Closing channel, channel id: {}", ch.id());
return ch.close();
}
static void closeChannels(ChannelGroup channelGroup) {
channelGroup.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
public static ChannelFuture closeAsIdle(Channel ch, int timeout) {
LOG.debug("Closing channel as writer was idle for {} seconds", timeout);
return closeChannel(ch);
}
public static void channelActive(Channel ch) {
LOG.debug("Executing channelActive, channel id: {}", ch.id());
}
public static void channelInactive(Channel ch) {
LOG.debug("Executing channelInactive, channel id: {}", ch.id());
}
}
private final MetricsSystem ms; private final MetricsSystem ms;
final ShuffleMetrics metrics; final ShuffleMetrics metrics;
@ -298,29 +352,36 @@ class ReduceMapFileCount implements ChannelFutureListener {
private ReduceContext reduceContext; private ReduceContext reduceContext;
public ReduceMapFileCount(ReduceContext rc) { ReduceMapFileCount(ReduceContext rc) {
this.reduceContext = rc; this.reduceContext = rc;
} }
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
LOG.trace("operationComplete");
if (!future.isSuccess()) { if (!future.isSuccess()) {
future.getChannel().close(); LOG.error("Future is unsuccessful. Cause: ", future.cause());
closeChannel(future.channel());
return; return;
} }
int waitCount = this.reduceContext.getMapsToWait().decrementAndGet(); int waitCount = this.reduceContext.getMapsToWait().decrementAndGet();
if (waitCount == 0) { if (waitCount == 0) {
LOG.trace("Finished with all map outputs");
//HADOOP-15327: Need to send an instance of LastHttpContent to define HTTP
//message boundaries. See details in jira.
writeLastHttpContentToChannel(future.channel());
metrics.operationComplete(future); metrics.operationComplete(future);
// Let the idle timer handler close keep-alive connections // Let the idle timer handler close keep-alive connections
if (reduceContext.getKeepAlive()) { if (reduceContext.getKeepAlive()) {
ChannelPipeline pipeline = future.getChannel().getPipeline(); ChannelPipeline pipeline = future.channel().pipeline();
TimeoutHandler timeoutHandler = TimeoutHandler timeoutHandler =
(TimeoutHandler)pipeline.get(TIMEOUT_HANDLER); (TimeoutHandler)pipeline.get(TIMEOUT_HANDLER);
timeoutHandler.setEnabledTimeout(true); timeoutHandler.setEnabledTimeout(true);
} else { } else {
future.getChannel().close(); closeChannel(future.channel());
} }
} else { } else {
LOG.trace("operationComplete, waitCount > 0, invoking sendMap with reduceContext");
pipelineFact.getSHUFFLE().sendMap(reduceContext); pipelineFact.getSHUFFLE().sendMap(reduceContext);
} }
} }
@ -331,7 +392,6 @@ public void operationComplete(ChannelFuture future) throws Exception {
* Allows sendMapOutput calls from operationComplete() * Allows sendMapOutput calls from operationComplete()
*/ */
private static class ReduceContext { private static class ReduceContext {
private List<String> mapIds; private List<String> mapIds;
private AtomicInteger mapsToWait; private AtomicInteger mapsToWait;
private AtomicInteger mapsToSend; private AtomicInteger mapsToSend;
@ -342,7 +402,7 @@ private static class ReduceContext {
private String jobId; private String jobId;
private final boolean keepAlive; private final boolean keepAlive;
public ReduceContext(List<String> mapIds, int rId, ReduceContext(List<String> mapIds, int rId,
ChannelHandlerContext context, String usr, ChannelHandlerContext context, String usr,
Map<String, Shuffle.MapOutputInfo> mapOutputInfoMap, Map<String, Shuffle.MapOutputInfo> mapOutputInfoMap,
String jobId, boolean keepAlive) { String jobId, boolean keepAlive) {
@ -448,7 +508,8 @@ public static int deserializeMetaData(ByteBuffer meta) throws IOException {
* shuffle data requests. * shuffle data requests.
* @return the serialized version of the jobToken. * @return the serialized version of the jobToken.
*/ */
public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException { public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken)
throws IOException {
//TODO these bytes should be versioned //TODO these bytes should be versioned
DataOutputBuffer jobToken_dob = new DataOutputBuffer(); DataOutputBuffer jobToken_dob = new DataOutputBuffer();
jobToken.write(jobToken_dob); jobToken.write(jobToken_dob);
@ -505,6 +566,11 @@ protected void serviceInit(Configuration conf) throws Exception {
DEFAULT_MAX_SHUFFLE_CONNECTIONS); DEFAULT_MAX_SHUFFLE_CONNECTIONS);
int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS, int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS,
DEFAULT_MAX_SHUFFLE_THREADS); DEFAULT_MAX_SHUFFLE_THREADS);
// Since Netty 4.x, the value of 0 threads would default to:
// io.netty.channel.MultithreadEventLoopGroup.DEFAULT_EVENT_LOOP_THREADS
// by simply passing 0 value to NioEventLoopGroup constructor below.
// However, this logic to determinte thread count
// was in place so we can keep it for now.
if (maxShuffleThreads == 0) { if (maxShuffleThreads == 0) {
maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors(); maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors();
} }
@ -526,10 +592,8 @@ protected void serviceInit(Configuration conf) throws Exception {
.setNameFormat("ShuffleHandler Netty Worker #%d") .setNameFormat("ShuffleHandler Netty Worker #%d")
.build(); .build();
selector = new NioServerSocketChannelFactory( bossGroup = new NioEventLoopGroup(maxShuffleThreads, bossFactory);
HadoopExecutors.newCachedThreadPool(bossFactory), workerGroup = new NioEventLoopGroup(maxShuffleThreads, workerFactory);
HadoopExecutors.newCachedThreadPool(workerFactory),
maxShuffleThreads);
super.serviceInit(new Configuration(conf)); super.serviceInit(new Configuration(conf));
} }
@ -540,22 +604,24 @@ protected void serviceStart() throws Exception {
userRsrc = new ConcurrentHashMap<String,String>(); userRsrc = new ConcurrentHashMap<String,String>();
secretManager = new JobTokenSecretManager(); secretManager = new JobTokenSecretManager();
recoverState(conf); recoverState(conf);
ServerBootstrap bootstrap = new ServerBootstrap(selector);
// Timer is shared across entire factory and must be released separately
timer = new HashedWheelTimer();
try { try {
pipelineFact = new HttpPipelineFactory(conf, timer); pipelineFact = new HttpPipelineFactory(conf);
} catch (Exception ex) { } catch (Exception ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
bootstrap.setOption("backlog", conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE,
DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE)); bootstrap = new ServerBootstrap();
bootstrap.setOption("child.keepAlive", true); bootstrap.group(bossGroup, workerGroup)
bootstrap.setPipelineFactory(pipelineFact); .channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG,
conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE,
DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE))
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childHandler(pipelineFact);
port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
Channel ch = bootstrap.bind(new InetSocketAddress(port)); ch = bootstrap.bind(new InetSocketAddress(port)).sync().channel();
accepted.add(ch); accepted.add(ch);
port = ((InetSocketAddress)ch.getLocalAddress()).getPort(); port = ((InetSocketAddress)ch.localAddress()).getPort();
conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port)); conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
pipelineFact.SHUFFLE.setPort(port); pipelineFact.SHUFFLE.setPort(port);
LOG.info(getName() + " listening on port " + port); LOG.info(getName() + " listening on port " + port);
@ -576,18 +642,12 @@ protected void serviceStart() throws Exception {
@Override @Override
protected void serviceStop() throws Exception { protected void serviceStop() throws Exception {
accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS); closeChannels(accepted);
if (selector != null) {
ServerBootstrap bootstrap = new ServerBootstrap(selector);
bootstrap.releaseExternalResources();
}
if (pipelineFact != null) { if (pipelineFact != null) {
pipelineFact.destroy(); pipelineFact.destroy();
} }
if (timer != null) {
// Release this shared timer resource
timer.stop();
}
if (stateDb != null) { if (stateDb != null) {
stateDb.close(); stateDb.close();
} }
@ -744,7 +804,7 @@ private void recoverJobShuffleInfo(String jobIdStr, byte[] data)
JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data); JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data);
String user = proto.getUser(); String user = proto.getUser();
TokenProto tokenProto = proto.getJobToken(); TokenProto tokenProto = proto.getJobToken();
Token<JobTokenIdentifier> jobToken = new Token<JobTokenIdentifier>( Token<JobTokenIdentifier> jobToken = new Token<>(
tokenProto.getIdentifier().toByteArray(), tokenProto.getIdentifier().toByteArray(),
tokenProto.getPassword().toByteArray(), tokenProto.getPassword().toByteArray(),
new Text(tokenProto.getKind()), new Text(tokenProto.getService())); new Text(tokenProto.getKind()), new Text(tokenProto.getService()));
@ -785,29 +845,47 @@ private void removeJobShuffleInfo(JobID jobId) throws IOException {
} }
} }
static class TimeoutHandler extends IdleStateAwareChannelHandler { @VisibleForTesting
public void setUseOutboundExceptionHandler(boolean useHandler) {
this.useOutboundExceptionHandler = useHandler;
}
static class TimeoutHandler extends IdleStateHandler {
private final int connectionKeepAliveTimeOut;
private boolean enabledTimeout; private boolean enabledTimeout;
TimeoutHandler(int connectionKeepAliveTimeOut) {
//disable reader timeout
//set writer timeout to configured timeout value
//disable all idle timeout
super(0, connectionKeepAliveTimeOut, 0, TimeUnit.SECONDS);
this.connectionKeepAliveTimeOut = connectionKeepAliveTimeOut;
}
@VisibleForTesting
public int getConnectionKeepAliveTimeOut() {
return connectionKeepAliveTimeOut;
}
void setEnabledTimeout(boolean enabledTimeout) { void setEnabledTimeout(boolean enabledTimeout) {
this.enabledTimeout = enabledTimeout; this.enabledTimeout = enabledTimeout;
} }
@Override @Override
public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) { public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) {
if (e.getState() == IdleState.WRITER_IDLE && enabledTimeout) { if (e.state() == IdleState.WRITER_IDLE && enabledTimeout) {
e.getChannel().close(); closeAsIdle(ctx.channel(), connectionKeepAliveTimeOut);
} }
} }
} }
class HttpPipelineFactory implements ChannelPipelineFactory { class HttpPipelineFactory extends ChannelInitializer<SocketChannel> {
private static final int MAX_CONTENT_LENGTH = 1 << 16;
final Shuffle SHUFFLE; final Shuffle SHUFFLE;
private SSLFactory sslFactory; private SSLFactory sslFactory;
private final ChannelHandler idleStateHandler;
public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception { HttpPipelineFactory(Configuration conf) throws Exception {
SHUFFLE = getShuffle(conf); SHUFFLE = getShuffle(conf);
if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY, if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY,
MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) { MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) {
@ -815,7 +893,6 @@ public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception {
sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf); sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
sslFactory.init(); sslFactory.init();
} }
this.idleStateHandler = new IdleStateHandler(timer, 0, connectionKeepAliveTimeOut, 0);
} }
public Shuffle getSHUFFLE() { public Shuffle getSHUFFLE() {
@ -828,30 +905,39 @@ public void destroy() {
} }
} }
@Override @Override protected void initChannel(SocketChannel ch) throws Exception {
public ChannelPipeline getPipeline() throws Exception { ChannelPipeline pipeline = ch.pipeline();
ChannelPipeline pipeline = Channels.pipeline();
if (sslFactory != null) { if (sslFactory != null) {
pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine())); pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine()));
} }
pipeline.addLast("decoder", new HttpRequestDecoder()); pipeline.addLast("decoder", new HttpRequestDecoder());
pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16)); pipeline.addLast("aggregator", new HttpObjectAggregator(MAX_CONTENT_LENGTH));
pipeline.addLast("encoder", new HttpResponseEncoder()); pipeline.addLast(ENCODER_HANDLER_NAME, useOutboundLogger ?
new LoggingHttpResponseEncoder(false) : new HttpResponseEncoder());
pipeline.addLast("chunking", new ChunkedWriteHandler()); pipeline.addLast("chunking", new ChunkedWriteHandler());
pipeline.addLast("shuffle", SHUFFLE); pipeline.addLast("shuffle", SHUFFLE);
pipeline.addLast("idle", idleStateHandler); if (useOutboundExceptionHandler) {
pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler()); //https://stackoverflow.com/questions/50612403/catch-all-exception-handling-for-outbound-channelhandler
return pipeline; pipeline.addLast("outboundExceptionHandler", new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg,
ChannelPromise promise) throws Exception {
promise.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
super.write(ctx, msg, promise);
}
});
}
pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler(connectionKeepAliveTimeOut));
// TODO factor security manager into pipeline // TODO factor security manager into pipeline
// TODO factor out encode/decode to permit binary shuffle // TODO factor out encode/decode to permit binary shuffle
// TODO factor out decode of index to permit alt. models // TODO factor out decode of index to permit alt. models
} }
} }
class Shuffle extends SimpleChannelUpstreamHandler { @ChannelHandler.Sharable
class Shuffle extends ChannelInboundHandlerAdapter {
private final IndexCache indexCache; private final IndexCache indexCache;
private final private final LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache;
LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache;
private int port; private int port;
@ -904,65 +990,84 @@ private List<String> splitMaps(List<String> mapq) {
} }
@Override @Override
public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt) public void channelActive(ChannelHandlerContext ctx)
throws Exception { throws Exception {
super.channelOpen(ctx, evt); NettyChannelHelper.channelActive(ctx.channel());
int numConnections = activeConnections.incrementAndGet();
if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) { if ((maxShuffleConnections > 0) && (numConnections > maxShuffleConnections)) {
LOG.info(String.format("Current number of shuffle connections (%d) is " + LOG.info(String.format("Current number of shuffle connections (%d) is " +
"greater than or equal to the max allowed shuffle connections (%d)", "greater than the max allowed shuffle connections (%d)",
accepted.size(), maxShuffleConnections)); accepted.size(), maxShuffleConnections));
Map<String, String> headers = new HashMap<String, String>(1); Map<String, String> headers = new HashMap<>(1);
// notify fetchers to backoff for a while before closing the connection // notify fetchers to backoff for a while before closing the connection
// if the shuffle connection limit is hit. Fetchers are expected to // if the shuffle connection limit is hit. Fetchers are expected to
// handle this notification gracefully, that is, not treating this as a // handle this notification gracefully, that is, not treating this as a
// fetch failure. // fetch failure.
headers.put(RETRY_AFTER_HEADER, String.valueOf(FETCH_RETRY_DELAY)); headers.put(RETRY_AFTER_HEADER, String.valueOf(FETCH_RETRY_DELAY));
sendError(ctx, "", TOO_MANY_REQ_STATUS, headers); sendError(ctx, "", TOO_MANY_REQ_STATUS, headers);
return; } else {
super.channelActive(ctx);
accepted.add(ctx.channel());
LOG.debug("Added channel: {}, channel id: {}. Accepted number of connections={}",
ctx.channel(), ctx.channel().id(), activeConnections.get());
} }
accepted.add(evt.getChannel());
} }
@Override @Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) public void channelInactive(ChannelHandlerContext ctx) throws Exception {
NettyChannelHelper.channelInactive(ctx.channel());
super.channelInactive(ctx);
int noOfConnections = activeConnections.decrementAndGet();
LOG.debug("New value of Accepted number of connections={}", noOfConnections);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception { throws Exception {
HttpRequest request = (HttpRequest) evt.getMessage(); Channel channel = ctx.channel();
if (request.getMethod() != GET) { LOG.trace("Executing channelRead, channel id: {}", channel.id());
HttpRequest request = (HttpRequest) msg;
LOG.debug("Received HTTP request: {}, channel id: {}", request, channel.id());
if (request.method() != GET) {
sendError(ctx, METHOD_NOT_ALLOWED); sendError(ctx, METHOD_NOT_ALLOWED);
return; return;
} }
// Check whether the shuffle version is compatible // Check whether the shuffle version is compatible
if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals( String shuffleVersion = ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION;
request.headers() != null ? String httpHeaderName = ShuffleHeader.DEFAULT_HTTP_HEADER_NAME;
request.headers().get(ShuffleHeader.HTTP_HEADER_NAME) : null) if (request.headers() != null) {
|| !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals( shuffleVersion = request.headers().get(ShuffleHeader.HTTP_HEADER_VERSION);
request.headers() != null ? httpHeaderName = request.headers().get(ShuffleHeader.HTTP_HEADER_NAME);
request.headers() LOG.debug("Received from request header: ShuffleVersion={} header name={}, channel id: {}",
.get(ShuffleHeader.HTTP_HEADER_VERSION) : null)) { shuffleVersion, httpHeaderName, channel.id());
}
if (request.headers() == null ||
!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(httpHeaderName) ||
!ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(shuffleVersion)) {
sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST); sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST);
} }
final Map<String, List<String>> q = final Map<String, List<String>> q =
new QueryStringDecoder(request.getUri()).getParameters(); new QueryStringDecoder(request.uri()).parameters();
final List<String> keepAliveList = q.get("keepAlive"); final List<String> keepAliveList = q.get("keepAlive");
boolean keepAliveParam = false; boolean keepAliveParam = false;
if (keepAliveList != null && keepAliveList.size() == 1) { if (keepAliveList != null && keepAliveList.size() == 1) {
keepAliveParam = Boolean.valueOf(keepAliveList.get(0)); keepAliveParam = Boolean.valueOf(keepAliveList.get(0));
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("KeepAliveParam : " + keepAliveList LOG.debug("KeepAliveParam: {} : {}, channel id: {}",
+ " : " + keepAliveParam); keepAliveList, keepAliveParam, channel.id());
} }
} }
final List<String> mapIds = splitMaps(q.get("map")); final List<String> mapIds = splitMaps(q.get("map"));
final List<String> reduceQ = q.get("reduce"); final List<String> reduceQ = q.get("reduce");
final List<String> jobQ = q.get("job"); final List<String> jobQ = q.get("job");
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("RECV: " + request.getUri() + LOG.debug("RECV: " + request.uri() +
"\n mapId: " + mapIds + "\n mapId: " + mapIds +
"\n reduceId: " + reduceQ + "\n reduceId: " + reduceQ +
"\n jobId: " + jobQ + "\n jobId: " + jobQ +
"\n keepAlive: " + keepAliveParam); "\n keepAlive: " + keepAliveParam +
"\n channel id: " + channel.id());
} }
if (mapIds == null || reduceQ == null || jobQ == null) { if (mapIds == null || reduceQ == null || jobQ == null) {
@ -986,7 +1091,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
sendError(ctx, "Bad job parameter", BAD_REQUEST); sendError(ctx, "Bad job parameter", BAD_REQUEST);
return; return;
} }
final String reqUri = request.getUri(); final String reqUri = request.uri();
if (null == reqUri) { if (null == reqUri) {
// TODO? add upstream? // TODO? add upstream?
sendError(ctx, FORBIDDEN); sendError(ctx, FORBIDDEN);
@ -1004,8 +1109,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
Map<String, MapOutputInfo> mapOutputInfoMap = Map<String, MapOutputInfo> mapOutputInfoMap =
new HashMap<String, MapOutputInfo>(); new HashMap<String, MapOutputInfo>();
Channel ch = evt.getChannel(); ChannelPipeline pipeline = channel.pipeline();
ChannelPipeline pipeline = ch.getPipeline();
TimeoutHandler timeoutHandler = TimeoutHandler timeoutHandler =
(TimeoutHandler)pipeline.get(TIMEOUT_HANDLER); (TimeoutHandler)pipeline.get(TIMEOUT_HANDLER);
timeoutHandler.setEnabledTimeout(false); timeoutHandler.setEnabledTimeout(false);
@ -1015,14 +1119,27 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
populateHeaders(mapIds, jobId, user, reduceId, request, populateHeaders(mapIds, jobId, user, reduceId, request,
response, keepAliveParam, mapOutputInfoMap); response, keepAliveParam, mapOutputInfoMap);
} catch(IOException e) { } catch(IOException e) {
ch.write(response); //HADOOP-15327
LOG.error("Shuffle error in populating headers :", e); // Need to send an instance of LastHttpContent to define HTTP
String errorMessage = getErrorMessage(e); // message boundaries.
sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR); //Sending a HTTP 200 OK + HTTP 500 later (sendError)
// is quite a non-standard way of crafting HTTP responses,
// but we need to keep backward compatibility.
// See more details in jira.
writeToChannelAndAddLastHttpContent(channel, response);
LOG.error("Shuffle error while populating headers. Channel id: " + channel.id(), e);
sendError(ctx, getErrorMessage(e), INTERNAL_SERVER_ERROR);
return; return;
} }
ch.write(response); writeToChannel(channel, response).addListener((ChannelFutureListener) future -> {
//Initialize one ReduceContext object per messageReceived call if (future.isSuccess()) {
LOG.debug("Written HTTP response object successfully. Channel id: {}", channel.id());
} else {
LOG.error("Error while writing HTTP response object: {}. " +
"Cause: {}, channel id: {}", response, future.cause(), channel.id());
}
});
//Initialize one ReduceContext object per channelRead call
boolean keepAlive = keepAliveParam || connectionKeepAliveEnabled; boolean keepAlive = keepAliveParam || connectionKeepAliveEnabled;
ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx, ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx,
user, mapOutputInfoMap, jobId, keepAlive); user, mapOutputInfoMap, jobId, keepAlive);
@ -1044,9 +1161,8 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
* @param reduceContext used to call sendMapOutput with correct params. * @param reduceContext used to call sendMapOutput with correct params.
* @return the ChannelFuture of the sendMapOutput, can be null. * @return the ChannelFuture of the sendMapOutput, can be null.
*/ */
public ChannelFuture sendMap(ReduceContext reduceContext) public ChannelFuture sendMap(ReduceContext reduceContext) {
throws Exception { LOG.trace("Executing sendMap");
ChannelFuture nextMap = null; ChannelFuture nextMap = null;
if (reduceContext.getMapsToSend().get() < if (reduceContext.getMapsToSend().get() <
reduceContext.getMapIds().size()) { reduceContext.getMapIds().size()) {
@ -1059,13 +1175,16 @@ public ChannelFuture sendMap(ReduceContext reduceContext)
info = getMapOutputInfo(mapId, reduceContext.getReduceId(), info = getMapOutputInfo(mapId, reduceContext.getReduceId(),
reduceContext.getJobId(), reduceContext.getUser()); reduceContext.getJobId(), reduceContext.getUser());
} }
LOG.trace("Calling sendMapOutput");
nextMap = sendMapOutput( nextMap = sendMapOutput(
reduceContext.getCtx(), reduceContext.getCtx(),
reduceContext.getCtx().getChannel(), reduceContext.getCtx().channel(),
reduceContext.getUser(), mapId, reduceContext.getUser(), mapId,
reduceContext.getReduceId(), info); reduceContext.getReduceId(), info);
if (null == nextMap) { if (nextMap == null) {
//This can only happen if spill file was not found
sendError(reduceContext.getCtx(), NOT_FOUND); sendError(reduceContext.getCtx(), NOT_FOUND);
LOG.trace("Returning nextMap: null");
return null; return null;
} }
nextMap.addListener(new ReduceMapFileCount(reduceContext)); nextMap.addListener(new ReduceMapFileCount(reduceContext));
@ -1125,8 +1244,7 @@ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
} }
} }
IndexRecord info = IndexRecord info = indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user);
indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user);
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId + LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId +
@ -1155,7 +1273,6 @@ protected void populateHeaders(List<String> mapIds, String jobId,
outputInfo.indexRecord.rawLength, reduce); outputInfo.indexRecord.rawLength, reduce);
DataOutputBuffer dob = new DataOutputBuffer(); DataOutputBuffer dob = new DataOutputBuffer();
header.write(dob); header.write(dob);
contentLength += outputInfo.indexRecord.partLength; contentLength += outputInfo.indexRecord.partLength;
contentLength += dob.getLength(); contentLength += dob.getLength();
} }
@ -1183,11 +1300,7 @@ protected void populateHeaders(List<String> mapIds, String jobId,
protected void setResponseHeaders(HttpResponse response, protected void setResponseHeaders(HttpResponse response,
boolean keepAliveParam, long contentLength) { boolean keepAliveParam, long contentLength) {
if (!connectionKeepAliveEnabled && !keepAliveParam) { if (!connectionKeepAliveEnabled && !keepAliveParam) {
if (LOG.isDebugEnabled()) { response.headers().set(HttpHeader.CONNECTION.asString(), CONNECTION_CLOSE);
LOG.debug("Setting connection close header...");
}
response.headers().set(HttpHeader.CONNECTION.asString(),
CONNECTION_CLOSE);
} else { } else {
response.headers().set(HttpHeader.CONTENT_LENGTH.asString(), response.headers().set(HttpHeader.CONTENT_LENGTH.asString(),
String.valueOf(contentLength)); String.valueOf(contentLength));
@ -1214,28 +1327,28 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
throws IOException { throws IOException {
SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid); SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid);
if (null == tokenSecret) { if (null == tokenSecret) {
LOG.info("Request for unknown token " + appid); LOG.info("Request for unknown token {}, channel id: {}", appid, ctx.channel().id());
throw new IOException("could not find jobid"); throw new IOException("Could not find jobid");
} }
// string to encrypt // encrypting URL
String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri); String encryptedURL = SecureShuffleUtils.buildMsgFrom(requestUri);
// hash from the fetcher // hash from the fetcher
String urlHashStr = String urlHashStr =
request.headers().get(SecureShuffleUtils.HTTP_HEADER_URL_HASH); request.headers().get(SecureShuffleUtils.HTTP_HEADER_URL_HASH);
if (urlHashStr == null) { if (urlHashStr == null) {
LOG.info("Missing header hash for " + appid); LOG.info("Missing header hash for {}, channel id: {}", appid, ctx.channel().id());
throw new IOException("fetcher cannot be authenticated"); throw new IOException("fetcher cannot be authenticated");
} }
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
int len = urlHashStr.length(); int len = urlHashStr.length();
LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." + LOG.debug("Verifying request. encryptedURL:{}, hash:{}, channel id: " +
urlHashStr.substring(len-len/2, len-1)); "{}", encryptedURL,
urlHashStr.substring(len - len / 2, len - 1), ctx.channel().id());
} }
// verify - throws exception // verify - throws exception
SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret); SecureShuffleUtils.verifyReply(urlHashStr, encryptedURL, tokenSecret);
// verification passed - encode the reply // verification passed - encode the reply
String reply = String reply = SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8),
SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8),
tokenSecret); tokenSecret);
response.headers().set( response.headers().set(
SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply); SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply);
@ -1246,8 +1359,10 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
int len = reply.length(); int len = reply.length();
LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" + LOG.debug("Fetcher request verified. " +
reply.substring(len-len/2, len-1)); "encryptedURL: {}, reply: {}, channel id: {}",
encryptedURL, reply.substring(len - len / 2, len - 1),
ctx.channel().id());
} }
} }
@ -1255,27 +1370,27 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch,
String user, String mapId, int reduce, MapOutputInfo mapOutputInfo) String user, String mapId, int reduce, MapOutputInfo mapOutputInfo)
throws IOException { throws IOException {
final IndexRecord info = mapOutputInfo.indexRecord; final IndexRecord info = mapOutputInfo.indexRecord;
final ShuffleHeader header = final ShuffleHeader header = new ShuffleHeader(mapId, info.partLength, info.rawLength,
new ShuffleHeader(mapId, info.partLength, info.rawLength, reduce); reduce);
final DataOutputBuffer dob = new DataOutputBuffer(); final DataOutputBuffer dob = new DataOutputBuffer();
header.write(dob); header.write(dob);
ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); writeToChannel(ch, wrappedBuffer(dob.getData(), 0, dob.getLength()));
final File spillfile = final File spillfile =
new File(mapOutputInfo.mapOutputFileName.toString()); new File(mapOutputInfo.mapOutputFileName.toString());
RandomAccessFile spill; RandomAccessFile spill;
try { try {
spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null); spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null);
} catch (FileNotFoundException e) { } catch (FileNotFoundException e) {
LOG.info(spillfile + " not found"); LOG.info("{} not found. Channel id: {}", spillfile, ctx.channel().id());
return null; return null;
} }
ChannelFuture writeFuture; ChannelFuture writeFuture;
if (ch.getPipeline().get(SslHandler.class) == null) { if (ch.pipeline().get(SslHandler.class) == null) {
final FadvisedFileRegion partition = new FadvisedFileRegion(spill, final FadvisedFileRegion partition = new FadvisedFileRegion(spill,
info.startOffset, info.partLength, manageOsCache, readaheadLength, info.startOffset, info.partLength, manageOsCache, readaheadLength,
readaheadPool, spillfile.getAbsolutePath(), readaheadPool, spillfile.getAbsolutePath(),
shuffleBufferSize, shuffleTransferToAllowed); shuffleBufferSize, shuffleTransferToAllowed);
writeFuture = ch.write(partition); writeFuture = writeToChannel(ch, partition);
writeFuture.addListener(new ChannelFutureListener() { writeFuture.addListener(new ChannelFutureListener() {
// TODO error handling; distinguish IO/connection failures, // TODO error handling; distinguish IO/connection failures,
// attribute to appropriate spill output // attribute to appropriate spill output
@ -1284,7 +1399,7 @@ public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) { if (future.isSuccess()) {
partition.transferSuccessful(); partition.transferSuccessful();
} }
partition.releaseExternalResources(); partition.deallocate();
} }
}); });
} else { } else {
@ -1293,7 +1408,7 @@ public void operationComplete(ChannelFuture future) {
info.startOffset, info.partLength, sslFileBufferSize, info.startOffset, info.partLength, sslFileBufferSize,
manageOsCache, readaheadLength, readaheadPool, manageOsCache, readaheadLength, readaheadPool,
spillfile.getAbsolutePath()); spillfile.getAbsolutePath());
writeFuture = ch.write(chunk); writeFuture = writeToChannel(ch, chunk);
} }
metrics.shuffleConnections.incr(); metrics.shuffleConnections.incr();
metrics.shuffleOutputBytes.incr(info.partLength); // optimistic metrics.shuffleOutputBytes.incr(info.partLength); // optimistic
@ -1307,12 +1422,13 @@ protected void sendError(ChannelHandlerContext ctx,
protected void sendError(ChannelHandlerContext ctx, String message, protected void sendError(ChannelHandlerContext ctx, String message,
HttpResponseStatus status) { HttpResponseStatus status) {
sendError(ctx, message, status, Collections.<String, String>emptyMap()); sendError(ctx, message, status, Collections.emptyMap());
} }
protected void sendError(ChannelHandlerContext ctx, String msg, protected void sendError(ChannelHandlerContext ctx, String msg,
HttpResponseStatus status, Map<String, String> headers) { HttpResponseStatus status, Map<String, String> headers) {
HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, status,
Unpooled.copiedBuffer(msg, CharsetUtil.UTF_8));
response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8"); response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8");
// Put shuffle version into http header // Put shuffle version into http header
response.headers().set(ShuffleHeader.HTTP_HEADER_NAME, response.headers().set(ShuffleHeader.HTTP_HEADER_NAME,
@ -1322,36 +1438,33 @@ protected void sendError(ChannelHandlerContext ctx, String msg,
for (Map.Entry<String, String> header : headers.entrySet()) { for (Map.Entry<String, String> header : headers.entrySet()) {
response.headers().set(header.getKey(), header.getValue()); response.headers().set(header.getKey(), header.getValue());
} }
response.setContent(
ChannelBuffers.copiedBuffer(msg, CharsetUtil.UTF_8));
// Close the connection as soon as the error message is sent. // Close the connection as soon as the error message is sent.
ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); writeToChannelAndClose(ctx.channel(), response);
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception { throws Exception {
Channel ch = e.getChannel(); Channel ch = ctx.channel();
Throwable cause = e.getCause();
if (cause instanceof TooLongFrameException) { if (cause instanceof TooLongFrameException) {
LOG.trace("TooLongFrameException, channel id: {}", ch.id());
sendError(ctx, BAD_REQUEST); sendError(ctx, BAD_REQUEST);
return; return;
} else if (cause instanceof IOException) { } else if (cause instanceof IOException) {
if (cause instanceof ClosedChannelException) { if (cause instanceof ClosedChannelException) {
LOG.debug("Ignoring closed channel error", cause); LOG.debug("Ignoring closed channel error, channel id: " + ch.id(), cause);
return; return;
} }
String message = String.valueOf(cause.getMessage()); String message = String.valueOf(cause.getMessage());
if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) {
LOG.debug("Ignoring client socket close", cause); LOG.debug("Ignoring client socket close, channel id: " + ch.id(), cause);
return; return;
} }
} }
LOG.error("Shuffle error: ", cause); LOG.error("Shuffle error. Channel id: " + ch.id(), cause);
if (ch.isConnected()) { if (ch.isActive()) {
LOG.error("Shuffle error " + e);
sendError(ctx, INTERNAL_SERVER_ERROR); sendError(ctx, INTERNAL_SERVER_ERROR);
} }
} }
@ -1363,7 +1476,7 @@ static class AttemptPathInfo {
private final Path indexPath; private final Path indexPath;
private final Path dataPath; private final Path dataPath;
public AttemptPathInfo(Path indexPath, Path dataPath) { AttemptPathInfo(Path indexPath, Path dataPath) {
this.indexPath = indexPath; this.indexPath = indexPath;
this.dataPath = dataPath; this.dataPath = dataPath;
} }
@ -1374,7 +1487,7 @@ static class AttemptPathIdentifier {
private final String user; private final String user;
private final String attemptId; private final String attemptId;
public AttemptPathIdentifier(String jobId, String user, String attemptId) { AttemptPathIdentifier(String jobId, String user, String attemptId) {
this.jobId = jobId; this.jobId = jobId;
this.user = user; this.user = user;
this.attemptId = attemptId; this.attemptId = attemptId;

View File

@ -104,7 +104,7 @@ public void testCustomShuffleTransfer() throws IOException {
Assert.assertEquals(count, targetFile.length()); Assert.assertEquals(count, targetFile.length());
} finally { } finally {
if (fileRegion != null) { if (fileRegion != null) {
fileRegion.releaseExternalResources(); fileRegion.deallocate();
} }
IOUtils.cleanupWithLogger(LOG, target); IOUtils.cleanupWithLogger(LOG, target);
IOUtils.cleanupWithLogger(LOG, targetFile); IOUtils.cleanupWithLogger(LOG, targetFile);

View File

@ -17,3 +17,5 @@ log4j.threshold=ALL
log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %c{2} (%F:%M(%L)) - %m%n log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %c{2} (%F:%M(%L)) - %m%n
log4j.logger.io.netty=INFO
log4j.logger.org.apache.hadoop.mapred=INFO

View File

@ -130,7 +130,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>io.netty</groupId> <groupId>io.netty</groupId>
<artifactId>netty</artifactId> <artifactId>netty-all</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>commons-logging</groupId> <groupId>commons-logging</groupId>