HADOOP-15327. Upgrade MR ShuffleHandler to use Netty4 #3259. Contributed by Szilard Nemeth.

This commit is contained in:
Szilard Nemeth 2021-06-05 00:14:07 +02:00 committed by Benjamin Teke
parent 552ee44eba
commit 5bb11cecea
11 changed files with 1606 additions and 575 deletions

View File

@ -148,6 +148,7 @@
<!-- Leave javax APIs that are stable -->
<!-- the jdk ships part of the javax.annotation namespace, so if we want to relocate this we'll have to care it out by class :( -->
<exclude>com.google.code.findbugs:jsr305</exclude>
<exclude>io.netty:*</exclude>
<exclude>io.dropwizard.metrics:metrics-core</exclude>
<exclude>org.eclipse.jetty:jetty-servlet</exclude>
<exclude>org.eclipse.jetty:jetty-security</exclude>

View File

@ -53,14 +53,15 @@
import org.apache.hadoop.classification.VisibleForTesting;
class Fetcher<K,V> extends Thread {
@VisibleForTesting
public class Fetcher<K, V> extends Thread {
private static final Logger LOG = LoggerFactory.getLogger(Fetcher.class);
/** Number of ms before timing out a copy */
/** Number of ms before timing out a copy. */
private static final int DEFAULT_STALLED_COPY_TIMEOUT = 3 * 60 * 1000;
/** Basic/unit connection timeout (in milliseconds) */
/** Basic/unit connection timeout (in milliseconds). */
private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000;
/* Default read timeout (in milliseconds) */
@ -72,10 +73,12 @@ class Fetcher<K,V> extends Thread {
private static final String FETCH_RETRY_AFTER_HEADER = "Retry-After";
protected final Reporter reporter;
private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
@VisibleForTesting
public enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
CONNECTION, WRONG_REDUCE}
private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
@VisibleForTesting
public final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
private final JobConf jobConf;
private final Counters.Counter connectionErrs;
private final Counters.Counter ioErrs;
@ -83,8 +86,8 @@ private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
private final Counters.Counter badIdErrs;
private final Counters.Counter wrongMapErrs;
private final Counters.Counter wrongReduceErrs;
protected final MergeManager<K,V> merger;
protected final ShuffleSchedulerImpl<K,V> scheduler;
protected final MergeManager<K, V> merger;
protected final ShuffleSchedulerImpl<K, V> scheduler;
protected final ShuffleClientMetrics metrics;
protected final ExceptionReporter exceptionReporter;
protected final int id;
@ -111,7 +114,7 @@ private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
private static SSLFactory sslFactory;
public Fetcher(JobConf job, TaskAttemptID reduceId,
ShuffleSchedulerImpl<K,V> scheduler, MergeManager<K,V> merger,
ShuffleSchedulerImpl<K, V> scheduler, MergeManager<K, V> merger,
Reporter reporter, ShuffleClientMetrics metrics,
ExceptionReporter exceptionReporter, SecretKey shuffleKey) {
this(job, reduceId, scheduler, merger, reporter, metrics,
@ -120,7 +123,7 @@ public Fetcher(JobConf job, TaskAttemptID reduceId,
@VisibleForTesting
Fetcher(JobConf job, TaskAttemptID reduceId,
ShuffleSchedulerImpl<K,V> scheduler, MergeManager<K,V> merger,
ShuffleSchedulerImpl<K, V> scheduler, MergeManager<K, V> merger,
Reporter reporter, ShuffleClientMetrics metrics,
ExceptionReporter exceptionReporter, SecretKey shuffleKey,
int id) {
@ -315,9 +318,8 @@ protected void copyFromHost(MapHost host) throws IOException {
return;
}
if(LOG.isDebugEnabled()) {
LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: "
+ maps);
if (LOG.isDebugEnabled()) {
LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " + maps);
}
// List of maps to be fetched yet
@ -411,8 +413,8 @@ private void openConnectionWithRetry(URL url) throws IOException {
shouldWait = false;
} catch (IOException e) {
if (!fetchRetryEnabled) {
// throw exception directly if fetch's retry is not enabled
throw e;
// throw exception directly if fetch's retry is not enabled
throw e;
}
if ((Time.monotonicNow() - startTime) >= this.fetchRetryTimeout) {
LOG.warn("Failed to connect to host: " + url + "after "
@ -489,7 +491,7 @@ private TaskAttemptID[] copyMapOutput(MapHost host,
DataInputStream input,
Set<TaskAttemptID> remaining,
boolean canRetry) throws IOException {
MapOutput<K,V> mapOutput = null;
MapOutput<K, V> mapOutput = null;
TaskAttemptID mapId = null;
long decompressedLength = -1;
long compressedLength = -1;
@ -611,7 +613,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe)
// First time to retry.
long currentTime = Time.monotonicNow();
if (retryStartTime == 0) {
retryStartTime = currentTime;
retryStartTime = currentTime;
}
// Retry is not timeout, let's do retry with throwing an exception.
@ -628,7 +630,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe)
}
/**
* Do some basic verification on the input received -- Being defensive
* Do some basic verification on the input received -- Being defensive.
* @param compressedLength
* @param decompressedLength
* @param forReduce
@ -695,8 +697,7 @@ private URL getMapOutputURL(MapHost host, Collection<TaskAttemptID> maps
* only on the last failure. Instead of connecting with a timeout of
* X, we try connecting with a timeout of x < X but multiple times.
*/
private void connect(URLConnection connection, int connectionTimeout)
throws IOException {
private void connect(URLConnection connection, int connectionTimeout) throws IOException {
int unit = 0;
if (connectionTimeout < 0) {
throw new IOException("Invalid timeout "

View File

@ -26,6 +26,7 @@
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.TaskCounter;
import org.apache.hadoop.mapreduce.task.reduce.Fetcher;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -37,6 +38,7 @@
import java.util.Formatter;
import java.util.Iterator;
import static org.apache.hadoop.mapreduce.task.reduce.Fetcher.SHUFFLE_ERR_GRP_NAME;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -87,6 +89,9 @@ public void testReduceFromPartialMem() throws Exception {
final long spill = c.findCounter(TaskCounter.SPILLED_RECORDS).getCounter();
assertTrue("Expected some records not spilled during reduce" + spill + ")",
spill < 2 * out); // spilled map records, some records at the reduce
long shuffleIoErrors =
c.getGroup(SHUFFLE_ERR_GRP_NAME).getCounter(Fetcher.ShuffleErrors.IO_ERROR.toString());
assertEquals(0, shuffleIoErrors);
}
/**

View File

@ -23,6 +23,9 @@
import java.io.RandomAccessFile;
import org.apache.hadoop.classification.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.stream.ChunkedFile;
import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
import org.apache.hadoop.io.nativeio.NativeIO;
@ -31,8 +34,6 @@
import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;
import org.jboss.netty.handler.stream.ChunkedFile;
public class FadvisedChunkedFile extends ChunkedFile {
private static final Logger LOG =
@ -64,16 +65,16 @@ FileDescriptor getFd() {
}
@Override
public Object nextChunk() throws Exception {
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
synchronized (closeLock) {
if (fd.valid()) {
if (manageOsCache && readaheadPool != null) {
readaheadRequest = readaheadPool
.readaheadStream(
identifier, fd, getCurrentOffset(), readaheadLength,
getEndOffset(), readaheadRequest);
identifier, fd, currentOffset(), readaheadLength,
endOffset(), readaheadRequest);
}
return super.nextChunk();
return super.readChunk(allocator);
} else {
return null;
}
@ -88,12 +89,12 @@ public void close() throws Exception {
readaheadRequest = null;
}
if (fd.valid() &&
manageOsCache && getEndOffset() - getStartOffset() > 0) {
manageOsCache && endOffset() - startOffset() > 0) {
try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(
identifier,
fd,
getStartOffset(), getEndOffset() - getStartOffset(),
startOffset(), endOffset() - startOffset(),
POSIX_FADV_DONTNEED);
} catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier +

View File

@ -25,6 +25,7 @@
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
import io.netty.channel.DefaultFileRegion;
import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
import org.apache.hadoop.io.nativeio.NativeIO;
@ -33,8 +34,6 @@
import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;
import org.jboss.netty.channel.DefaultFileRegion;
import org.apache.hadoop.classification.VisibleForTesting;
public class FadvisedFileRegion extends DefaultFileRegion {
@ -77,8 +76,8 @@ public long transferTo(WritableByteChannel target, long position)
throws IOException {
if (readaheadPool != null && readaheadLength > 0) {
readaheadRequest = readaheadPool.readaheadStream(identifier, fd,
getPosition() + position, readaheadLength,
getPosition() + getCount(), readaheadRequest);
position() + position, readaheadLength,
position() + count(), readaheadRequest);
}
if(this.shuffleTransferToAllowed) {
@ -147,11 +146,11 @@ long customShuffleTransfer(WritableByteChannel target, long position)
@Override
public void releaseExternalResources() {
protected void deallocate() {
if (readaheadRequest != null) {
readaheadRequest.cancel();
}
super.releaseExternalResources();
super.deallocate();
}
/**
@ -159,10 +158,10 @@ public void releaseExternalResources() {
* we don't need the region to be cached anymore.
*/
public void transferSuccessful() {
if (manageOsCache && getCount() > 0) {
if (manageOsCache && count() > 0) {
try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier,
fd, getPosition(), getCount(), POSIX_FADV_DONTNEED);
fd, position(), count(), POSIX_FADV_DONTNEED);
} catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier, t);
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapred;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
class LoggingHttpResponseEncoder extends HttpResponseEncoder {
private static final Logger LOG = LoggerFactory.getLogger(LoggingHttpResponseEncoder.class);
private final boolean logStacktraceOfEncodingMethods;
LoggingHttpResponseEncoder(boolean logStacktraceOfEncodingMethods) {
this.logStacktraceOfEncodingMethods = logStacktraceOfEncodingMethods;
}
@Override
public boolean acceptOutboundMessage(Object msg) throws Exception {
printExecutingMethod();
LOG.info("OUTBOUND MESSAGE: " + msg);
return super.acceptOutboundMessage(msg);
}
@Override
protected void encodeInitialLine(ByteBuf buf, HttpResponse response) throws Exception {
LOG.debug("Executing method: {}, response: {}",
getExecutingMethodName(), response);
logStacktraceIfRequired();
super.encodeInitialLine(buf, response);
}
@Override
protected void encode(ChannelHandlerContext ctx, Object msg,
List<Object> out) throws Exception {
LOG.debug("Encoding to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
logStacktraceIfRequired();
super.encode(ctx, msg, out);
}
@Override
protected void encodeHeaders(HttpHeaders headers, ByteBuf buf) {
printExecutingMethod();
super.encodeHeaders(headers, buf);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) throws Exception {
LOG.debug("Writing to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
super.write(ctx, msg, promise);
}
private void logStacktraceIfRequired() {
if (logStacktraceOfEncodingMethods) {
LOG.debug("Stacktrace: ", new Throwable());
}
}
private void printExecutingMethod() {
String methodName = getExecutingMethodName(1);
LOG.debug("Executing method: {}", methodName);
}
private String getExecutingMethodName() {
return getExecutingMethodName(0);
}
private String getExecutingMethodName(int additionalSkipFrames) {
try {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
// Array items (indices):
// 0: java.lang.Thread.getStackTrace(...)
// 1: TestShuffleHandler$LoggingHttpResponseEncoder.getExecutingMethodName(...)
int skipFrames = 2 + additionalSkipFrames;
String methodName = stackTrace[skipFrames].getMethodName();
String className = this.getClass().getSimpleName();
return className + "#" + methodName;
} catch (Throwable t) {
LOG.error("Error while getting execution method name", t);
return "unknown";
}
}
}

View File

@ -18,19 +18,20 @@
package org.apache.hadoop.mapred;
import static io.netty.buffer.Unpooled.wrappedBuffer;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.apache.hadoop.mapred.ShuffleHandler.NettyChannelHelper.*;
import static org.fusesource.leveldbjni.JniDBFactory.asString;
import static org.fusesource.leveldbjni.JniDBFactory.bytes;
import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import java.io.File;
import java.io.FileNotFoundException;
@ -54,6 +55,44 @@
import javax.crypto.SecretKey;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataInputByteBuffer;
@ -79,7 +118,6 @@
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.DiskChecker;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.util.concurrent.HadoopExecutors;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
@ -94,42 +132,6 @@
import org.iq80.leveldb.DB;
import org.iq80.leveldb.DBException;
import org.iq80.leveldb.Options;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandler;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.jboss.netty.handler.codec.frame.TooLongFrameException;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpChunkAggregator;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestDecoder;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseEncoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.QueryStringDecoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.handler.stream.ChunkedWriteHandler;
import org.jboss.netty.handler.timeout.IdleState;
import org.jboss.netty.handler.timeout.IdleStateAwareChannelHandler;
import org.jboss.netty.handler.timeout.IdleStateEvent;
import org.jboss.netty.handler.timeout.IdleStateHandler;
import org.jboss.netty.util.CharsetUtil;
import org.jboss.netty.util.HashedWheelTimer;
import org.jboss.netty.util.Timer;
import org.eclipse.jetty.http.HttpHeader;
import org.slf4j.LoggerFactory;
@ -182,19 +184,29 @@ public class ShuffleHandler extends AuxiliaryService {
public static final HttpResponseStatus TOO_MANY_REQ_STATUS =
new HttpResponseStatus(429, "TOO MANY REQUESTS");
// This should kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT
// This should be kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT
public static final long FETCH_RETRY_DELAY = 1000L;
public static final String RETRY_AFTER_HEADER = "Retry-After";
static final String ENCODER_HANDLER_NAME = "encoder";
private int port;
private ChannelFactory selector;
private final ChannelGroup accepted = new DefaultChannelGroup();
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;
private ServerBootstrap bootstrap;
private Channel ch;
private final ChannelGroup accepted =
new DefaultChannelGroup(new DefaultEventExecutorGroup(5).next());
private final AtomicInteger activeConnections = new AtomicInteger();
protected HttpPipelineFactory pipelineFact;
private int sslFileBufferSize;
//TODO snemeth add a config option for these later, this is temporarily disabled for now.
private boolean useOutboundExceptionHandler = false;
private boolean useOutboundLogger = false;
/**
* Should the shuffle use posix_fadvise calls to manage the OS cache during
* sendfile
* sendfile.
*/
private boolean manageOsCache;
private int readaheadLength;
@ -204,7 +216,7 @@ public class ShuffleHandler extends AuxiliaryService {
private int maxSessionOpenFiles;
private ReadaheadPool readaheadPool = ReadaheadPool.getInstance();
private Map<String,String> userRsrc;
private Map<String, String> userRsrc;
private JobTokenSecretManager secretManager;
private DB stateDb = null;
@ -235,7 +247,7 @@ public class ShuffleHandler extends AuxiliaryService {
public static final String CONNECTION_CLOSE = "close";
public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY =
"mapreduce.shuffle.ssl.file.buffer.size";
"mapreduce.shuffle.ssl.file.buffer.size";
public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;
@ -255,7 +267,7 @@ public class ShuffleHandler extends AuxiliaryService {
public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true;
public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED =
false;
private static final String TIMEOUT_HANDLER = "timeout";
static final String TIMEOUT_HANDLER = "timeout";
/* the maximum number of files a single GET request can
open simultaneously during shuffle
@ -267,7 +279,6 @@ public class ShuffleHandler extends AuxiliaryService {
boolean connectionKeepAliveEnabled = false;
private int connectionKeepAliveTimeOut;
private int mapOutputMetaInfoCacheSize;
private Timer timer;
@Metrics(about="Shuffle output metrics", context="mapred")
static class ShuffleMetrics implements ChannelFutureListener {
@ -291,6 +302,49 @@ public void operationComplete(ChannelFuture future) throws Exception {
}
}
static class NettyChannelHelper {
static ChannelFuture writeToChannel(Channel ch, Object obj) {
LOG.debug("Writing {} to channel: {}", obj.getClass().getSimpleName(), ch.id());
return ch.writeAndFlush(obj);
}
static ChannelFuture writeToChannelAndClose(Channel ch, Object obj) {
return writeToChannel(ch, obj).addListener(ChannelFutureListener.CLOSE);
}
static ChannelFuture writeToChannelAndAddLastHttpContent(Channel ch, HttpResponse obj) {
writeToChannel(ch, obj);
return writeLastHttpContentToChannel(ch);
}
static ChannelFuture writeLastHttpContentToChannel(Channel ch) {
LOG.debug("Writing LastHttpContent, channel id: {}", ch.id());
return ch.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
}
static ChannelFuture closeChannel(Channel ch) {
LOG.debug("Closing channel, channel id: {}", ch.id());
return ch.close();
}
static void closeChannels(ChannelGroup channelGroup) {
channelGroup.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
public static ChannelFuture closeAsIdle(Channel ch, int timeout) {
LOG.debug("Closing channel as writer was idle for {} seconds", timeout);
return closeChannel(ch);
}
public static void channelActive(Channel ch) {
LOG.debug("Executing channelActive, channel id: {}", ch.id());
}
public static void channelInactive(Channel ch) {
LOG.debug("Executing channelInactive, channel id: {}", ch.id());
}
}
private final MetricsSystem ms;
final ShuffleMetrics metrics;
@ -298,29 +352,36 @@ class ReduceMapFileCount implements ChannelFutureListener {
private ReduceContext reduceContext;
public ReduceMapFileCount(ReduceContext rc) {
ReduceMapFileCount(ReduceContext rc) {
this.reduceContext = rc;
}
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOG.trace("operationComplete");
if (!future.isSuccess()) {
future.getChannel().close();
LOG.error("Future is unsuccessful. Cause: ", future.cause());
closeChannel(future.channel());
return;
}
int waitCount = this.reduceContext.getMapsToWait().decrementAndGet();
if (waitCount == 0) {
LOG.trace("Finished with all map outputs");
//HADOOP-15327: Need to send an instance of LastHttpContent to define HTTP
//message boundaries. See details in jira.
writeLastHttpContentToChannel(future.channel());
metrics.operationComplete(future);
// Let the idle timer handler close keep-alive connections
if (reduceContext.getKeepAlive()) {
ChannelPipeline pipeline = future.getChannel().getPipeline();
ChannelPipeline pipeline = future.channel().pipeline();
TimeoutHandler timeoutHandler =
(TimeoutHandler)pipeline.get(TIMEOUT_HANDLER);
timeoutHandler.setEnabledTimeout(true);
} else {
future.getChannel().close();
closeChannel(future.channel());
}
} else {
LOG.trace("operationComplete, waitCount > 0, invoking sendMap with reduceContext");
pipelineFact.getSHUFFLE().sendMap(reduceContext);
}
}
@ -331,7 +392,6 @@ public void operationComplete(ChannelFuture future) throws Exception {
* Allows sendMapOutput calls from operationComplete()
*/
private static class ReduceContext {
private List<String> mapIds;
private AtomicInteger mapsToWait;
private AtomicInteger mapsToSend;
@ -342,7 +402,7 @@ private static class ReduceContext {
private String jobId;
private final boolean keepAlive;
public ReduceContext(List<String> mapIds, int rId,
ReduceContext(List<String> mapIds, int rId,
ChannelHandlerContext context, String usr,
Map<String, Shuffle.MapOutputInfo> mapOutputInfoMap,
String jobId, boolean keepAlive) {
@ -448,7 +508,8 @@ public static int deserializeMetaData(ByteBuffer meta) throws IOException {
* shuffle data requests.
* @return the serialized version of the jobToken.
*/
public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException {
public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken)
throws IOException {
//TODO these bytes should be versioned
DataOutputBuffer jobToken_dob = new DataOutputBuffer();
jobToken.write(jobToken_dob);
@ -505,6 +566,11 @@ protected void serviceInit(Configuration conf) throws Exception {
DEFAULT_MAX_SHUFFLE_CONNECTIONS);
int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS,
DEFAULT_MAX_SHUFFLE_THREADS);
// Since Netty 4.x, the value of 0 threads would default to:
// io.netty.channel.MultithreadEventLoopGroup.DEFAULT_EVENT_LOOP_THREADS
// by simply passing 0 value to NioEventLoopGroup constructor below.
// However, this logic to determinte thread count
// was in place so we can keep it for now.
if (maxShuffleThreads == 0) {
maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors();
}
@ -520,16 +586,14 @@ protected void serviceInit(Configuration conf) throws Exception {
DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES);
ThreadFactory bossFactory = new ThreadFactoryBuilder()
.setNameFormat("ShuffleHandler Netty Boss #%d")
.build();
.setNameFormat("ShuffleHandler Netty Boss #%d")
.build();
ThreadFactory workerFactory = new ThreadFactoryBuilder()
.setNameFormat("ShuffleHandler Netty Worker #%d")
.build();
.setNameFormat("ShuffleHandler Netty Worker #%d")
.build();
selector = new NioServerSocketChannelFactory(
HadoopExecutors.newCachedThreadPool(bossFactory),
HadoopExecutors.newCachedThreadPool(workerFactory),
maxShuffleThreads);
bossGroup = new NioEventLoopGroup(maxShuffleThreads, bossFactory);
workerGroup = new NioEventLoopGroup(maxShuffleThreads, workerFactory);
super.serviceInit(new Configuration(conf));
}
@ -540,22 +604,24 @@ protected void serviceStart() throws Exception {
userRsrc = new ConcurrentHashMap<String,String>();
secretManager = new JobTokenSecretManager();
recoverState(conf);
ServerBootstrap bootstrap = new ServerBootstrap(selector);
// Timer is shared across entire factory and must be released separately
timer = new HashedWheelTimer();
try {
pipelineFact = new HttpPipelineFactory(conf, timer);
pipelineFact = new HttpPipelineFactory(conf);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
bootstrap.setOption("backlog", conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE,
DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE));
bootstrap.setOption("child.keepAlive", true);
bootstrap.setPipelineFactory(pipelineFact);
bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG,
conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE,
DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE))
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childHandler(pipelineFact);
port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
Channel ch = bootstrap.bind(new InetSocketAddress(port));
ch = bootstrap.bind(new InetSocketAddress(port)).sync().channel();
accepted.add(ch);
port = ((InetSocketAddress)ch.getLocalAddress()).getPort();
port = ((InetSocketAddress)ch.localAddress()).getPort();
conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
pipelineFact.SHUFFLE.setPort(port);
LOG.info(getName() + " listening on port " + port);
@ -576,18 +642,12 @@ protected void serviceStart() throws Exception {
@Override
protected void serviceStop() throws Exception {
accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
if (selector != null) {
ServerBootstrap bootstrap = new ServerBootstrap(selector);
bootstrap.releaseExternalResources();
}
closeChannels(accepted);
if (pipelineFact != null) {
pipelineFact.destroy();
}
if (timer != null) {
// Release this shared timer resource
timer.stop();
}
if (stateDb != null) {
stateDb.close();
}
@ -744,7 +804,7 @@ private void recoverJobShuffleInfo(String jobIdStr, byte[] data)
JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data);
String user = proto.getUser();
TokenProto tokenProto = proto.getJobToken();
Token<JobTokenIdentifier> jobToken = new Token<JobTokenIdentifier>(
Token<JobTokenIdentifier> jobToken = new Token<>(
tokenProto.getIdentifier().toByteArray(),
tokenProto.getPassword().toByteArray(),
new Text(tokenProto.getKind()), new Text(tokenProto.getService()));
@ -785,29 +845,47 @@ private void removeJobShuffleInfo(JobID jobId) throws IOException {
}
}
static class TimeoutHandler extends IdleStateAwareChannelHandler {
@VisibleForTesting
public void setUseOutboundExceptionHandler(boolean useHandler) {
this.useOutboundExceptionHandler = useHandler;
}
static class TimeoutHandler extends IdleStateHandler {
private final int connectionKeepAliveTimeOut;
private boolean enabledTimeout;
TimeoutHandler(int connectionKeepAliveTimeOut) {
//disable reader timeout
//set writer timeout to configured timeout value
//disable all idle timeout
super(0, connectionKeepAliveTimeOut, 0, TimeUnit.SECONDS);
this.connectionKeepAliveTimeOut = connectionKeepAliveTimeOut;
}
@VisibleForTesting
public int getConnectionKeepAliveTimeOut() {
return connectionKeepAliveTimeOut;
}
void setEnabledTimeout(boolean enabledTimeout) {
this.enabledTimeout = enabledTimeout;
}
@Override
public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) {
if (e.getState() == IdleState.WRITER_IDLE && enabledTimeout) {
e.getChannel().close();
if (e.state() == IdleState.WRITER_IDLE && enabledTimeout) {
closeAsIdle(ctx.channel(), connectionKeepAliveTimeOut);
}
}
}
class HttpPipelineFactory implements ChannelPipelineFactory {
class HttpPipelineFactory extends ChannelInitializer<SocketChannel> {
private static final int MAX_CONTENT_LENGTH = 1 << 16;
final Shuffle SHUFFLE;
private SSLFactory sslFactory;
private final ChannelHandler idleStateHandler;
public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception {
HttpPipelineFactory(Configuration conf) throws Exception {
SHUFFLE = getShuffle(conf);
if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY,
MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) {
@ -815,7 +893,6 @@ public HttpPipelineFactory(Configuration conf, Timer timer) throws Exception {
sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
sslFactory.init();
}
this.idleStateHandler = new IdleStateHandler(timer, 0, connectionKeepAliveTimeOut, 0);
}
public Shuffle getSHUFFLE() {
@ -828,30 +905,39 @@ public void destroy() {
}
}
@Override
public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = Channels.pipeline();
@Override protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
if (sslFactory != null) {
pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine()));
}
pipeline.addLast("decoder", new HttpRequestDecoder());
pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16));
pipeline.addLast("encoder", new HttpResponseEncoder());
pipeline.addLast("aggregator", new HttpObjectAggregator(MAX_CONTENT_LENGTH));
pipeline.addLast(ENCODER_HANDLER_NAME, useOutboundLogger ?
new LoggingHttpResponseEncoder(false) : new HttpResponseEncoder());
pipeline.addLast("chunking", new ChunkedWriteHandler());
pipeline.addLast("shuffle", SHUFFLE);
pipeline.addLast("idle", idleStateHandler);
pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler());
return pipeline;
if (useOutboundExceptionHandler) {
//https://stackoverflow.com/questions/50612403/catch-all-exception-handling-for-outbound-channelhandler
pipeline.addLast("outboundExceptionHandler", new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg,
ChannelPromise promise) throws Exception {
promise.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
super.write(ctx, msg, promise);
}
});
}
pipeline.addLast(TIMEOUT_HANDLER, new TimeoutHandler(connectionKeepAliveTimeOut));
// TODO factor security manager into pipeline
// TODO factor out encode/decode to permit binary shuffle
// TODO factor out decode of index to permit alt. models
}
}
class Shuffle extends SimpleChannelUpstreamHandler {
@ChannelHandler.Sharable
class Shuffle extends ChannelInboundHandlerAdapter {
private final IndexCache indexCache;
private final
LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache;
private final LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache;
private int port;
@ -904,65 +990,84 @@ private List<String> splitMaps(List<String> mapq) {
}
@Override
public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt)
public void channelActive(ChannelHandlerContext ctx)
throws Exception {
super.channelOpen(ctx, evt);
if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) {
NettyChannelHelper.channelActive(ctx.channel());
int numConnections = activeConnections.incrementAndGet();
if ((maxShuffleConnections > 0) && (numConnections > maxShuffleConnections)) {
LOG.info(String.format("Current number of shuffle connections (%d) is " +
"greater than or equal to the max allowed shuffle connections (%d)",
"greater than the max allowed shuffle connections (%d)",
accepted.size(), maxShuffleConnections));
Map<String, String> headers = new HashMap<String, String>(1);
Map<String, String> headers = new HashMap<>(1);
// notify fetchers to backoff for a while before closing the connection
// if the shuffle connection limit is hit. Fetchers are expected to
// handle this notification gracefully, that is, not treating this as a
// fetch failure.
headers.put(RETRY_AFTER_HEADER, String.valueOf(FETCH_RETRY_DELAY));
sendError(ctx, "", TOO_MANY_REQ_STATUS, headers);
return;
} else {
super.channelActive(ctx);
accepted.add(ctx.channel());
LOG.debug("Added channel: {}, channel id: {}. Accepted number of connections={}",
ctx.channel(), ctx.channel().id(), activeConnections.get());
}
accepted.add(evt.getChannel());
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
NettyChannelHelper.channelInactive(ctx.channel());
super.channelInactive(ctx);
int noOfConnections = activeConnections.decrementAndGet();
LOG.debug("New value of Accepted number of connections={}", noOfConnections);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
HttpRequest request = (HttpRequest) evt.getMessage();
if (request.getMethod() != GET) {
sendError(ctx, METHOD_NOT_ALLOWED);
return;
Channel channel = ctx.channel();
LOG.trace("Executing channelRead, channel id: {}", channel.id());
HttpRequest request = (HttpRequest) msg;
LOG.debug("Received HTTP request: {}, channel id: {}", request, channel.id());
if (request.method() != GET) {
sendError(ctx, METHOD_NOT_ALLOWED);
return;
}
// Check whether the shuffle version is compatible
if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
request.headers() != null ?
request.headers().get(ShuffleHeader.HTTP_HEADER_NAME) : null)
|| !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
request.headers() != null ?
request.headers()
.get(ShuffleHeader.HTTP_HEADER_VERSION) : null)) {
String shuffleVersion = ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION;
String httpHeaderName = ShuffleHeader.DEFAULT_HTTP_HEADER_NAME;
if (request.headers() != null) {
shuffleVersion = request.headers().get(ShuffleHeader.HTTP_HEADER_VERSION);
httpHeaderName = request.headers().get(ShuffleHeader.HTTP_HEADER_NAME);
LOG.debug("Received from request header: ShuffleVersion={} header name={}, channel id: {}",
shuffleVersion, httpHeaderName, channel.id());
}
if (request.headers() == null ||
!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(httpHeaderName) ||
!ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(shuffleVersion)) {
sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST);
}
final Map<String,List<String>> q =
new QueryStringDecoder(request.getUri()).getParameters();
final Map<String, List<String>> q =
new QueryStringDecoder(request.uri()).parameters();
final List<String> keepAliveList = q.get("keepAlive");
boolean keepAliveParam = false;
if (keepAliveList != null && keepAliveList.size() == 1) {
keepAliveParam = Boolean.valueOf(keepAliveList.get(0));
if (LOG.isDebugEnabled()) {
LOG.debug("KeepAliveParam : " + keepAliveList
+ " : " + keepAliveParam);
LOG.debug("KeepAliveParam: {} : {}, channel id: {}",
keepAliveList, keepAliveParam, channel.id());
}
}
final List<String> mapIds = splitMaps(q.get("map"));
final List<String> reduceQ = q.get("reduce");
final List<String> jobQ = q.get("job");
if (LOG.isDebugEnabled()) {
LOG.debug("RECV: " + request.getUri() +
LOG.debug("RECV: " + request.uri() +
"\n mapId: " + mapIds +
"\n reduceId: " + reduceQ +
"\n jobId: " + jobQ +
"\n keepAlive: " + keepAliveParam);
"\n keepAlive: " + keepAliveParam +
"\n channel id: " + channel.id());
}
if (mapIds == null || reduceQ == null || jobQ == null) {
@ -986,7 +1091,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
sendError(ctx, "Bad job parameter", BAD_REQUEST);
return;
}
final String reqUri = request.getUri();
final String reqUri = request.uri();
if (null == reqUri) {
// TODO? add upstream?
sendError(ctx, FORBIDDEN);
@ -1004,8 +1109,7 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
Map<String, MapOutputInfo> mapOutputInfoMap =
new HashMap<String, MapOutputInfo>();
Channel ch = evt.getChannel();
ChannelPipeline pipeline = ch.getPipeline();
ChannelPipeline pipeline = channel.pipeline();
TimeoutHandler timeoutHandler =
(TimeoutHandler)pipeline.get(TIMEOUT_HANDLER);
timeoutHandler.setEnabledTimeout(false);
@ -1013,16 +1117,29 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
try {
populateHeaders(mapIds, jobId, user, reduceId, request,
response, keepAliveParam, mapOutputInfoMap);
response, keepAliveParam, mapOutputInfoMap);
} catch(IOException e) {
ch.write(response);
LOG.error("Shuffle error in populating headers :", e);
String errorMessage = getErrorMessage(e);
sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR);
//HADOOP-15327
// Need to send an instance of LastHttpContent to define HTTP
// message boundaries.
//Sending a HTTP 200 OK + HTTP 500 later (sendError)
// is quite a non-standard way of crafting HTTP responses,
// but we need to keep backward compatibility.
// See more details in jira.
writeToChannelAndAddLastHttpContent(channel, response);
LOG.error("Shuffle error while populating headers. Channel id: " + channel.id(), e);
sendError(ctx, getErrorMessage(e), INTERNAL_SERVER_ERROR);
return;
}
ch.write(response);
//Initialize one ReduceContext object per messageReceived call
writeToChannel(channel, response).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
LOG.debug("Written HTTP response object successfully. Channel id: {}", channel.id());
} else {
LOG.error("Error while writing HTTP response object: {}. " +
"Cause: {}, channel id: {}", response, future.cause(), channel.id());
}
});
//Initialize one ReduceContext object per channelRead call
boolean keepAlive = keepAliveParam || connectionKeepAliveEnabled;
ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx,
user, mapOutputInfoMap, jobId, keepAlive);
@ -1044,9 +1161,8 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
* @param reduceContext used to call sendMapOutput with correct params.
* @return the ChannelFuture of the sendMapOutput, can be null.
*/
public ChannelFuture sendMap(ReduceContext reduceContext)
throws Exception {
public ChannelFuture sendMap(ReduceContext reduceContext) {
LOG.trace("Executing sendMap");
ChannelFuture nextMap = null;
if (reduceContext.getMapsToSend().get() <
reduceContext.getMapIds().size()) {
@ -1059,21 +1175,24 @@ public ChannelFuture sendMap(ReduceContext reduceContext)
info = getMapOutputInfo(mapId, reduceContext.getReduceId(),
reduceContext.getJobId(), reduceContext.getUser());
}
LOG.trace("Calling sendMapOutput");
nextMap = sendMapOutput(
reduceContext.getCtx(),
reduceContext.getCtx().getChannel(),
reduceContext.getCtx().channel(),
reduceContext.getUser(), mapId,
reduceContext.getReduceId(), info);
if (null == nextMap) {
if (nextMap == null) {
//This can only happen if spill file was not found
sendError(reduceContext.getCtx(), NOT_FOUND);
LOG.trace("Returning nextMap: null");
return null;
}
nextMap.addListener(new ReduceMapFileCount(reduceContext));
} catch (IOException e) {
if (e instanceof DiskChecker.DiskErrorException) {
LOG.error("Shuffle error :" + e);
LOG.error("Shuffle error: " + e);
} else {
LOG.error("Shuffle error :", e);
LOG.error("Shuffle error: ", e);
}
String errorMessage = getErrorMessage(e);
sendError(reduceContext.getCtx(), errorMessage,
@ -1125,8 +1244,7 @@ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
}
}
IndexRecord info =
indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user);
IndexRecord info = indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user);
if (LOG.isDebugEnabled()) {
LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId +
@ -1155,7 +1273,6 @@ protected void populateHeaders(List<String> mapIds, String jobId,
outputInfo.indexRecord.rawLength, reduce);
DataOutputBuffer dob = new DataOutputBuffer();
header.write(dob);
contentLength += outputInfo.indexRecord.partLength;
contentLength += dob.getLength();
}
@ -1183,14 +1300,10 @@ protected void populateHeaders(List<String> mapIds, String jobId,
protected void setResponseHeaders(HttpResponse response,
boolean keepAliveParam, long contentLength) {
if (!connectionKeepAliveEnabled && !keepAliveParam) {
if (LOG.isDebugEnabled()) {
LOG.debug("Setting connection close header...");
}
response.headers().set(HttpHeader.CONNECTION.asString(),
CONNECTION_CLOSE);
response.headers().set(HttpHeader.CONNECTION.asString(), CONNECTION_CLOSE);
} else {
response.headers().set(HttpHeader.CONTENT_LENGTH.asString(),
String.valueOf(contentLength));
String.valueOf(contentLength));
response.headers().set(HttpHeader.CONNECTION.asString(),
HttpHeader.KEEP_ALIVE.asString());
response.headers().set(HttpHeader.KEEP_ALIVE.asString(),
@ -1214,29 +1327,29 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
throws IOException {
SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid);
if (null == tokenSecret) {
LOG.info("Request for unknown token " + appid);
throw new IOException("could not find jobid");
LOG.info("Request for unknown token {}, channel id: {}", appid, ctx.channel().id());
throw new IOException("Could not find jobid");
}
// string to encrypt
String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri);
// encrypting URL
String encryptedURL = SecureShuffleUtils.buildMsgFrom(requestUri);
// hash from the fetcher
String urlHashStr =
request.headers().get(SecureShuffleUtils.HTTP_HEADER_URL_HASH);
if (urlHashStr == null) {
LOG.info("Missing header hash for " + appid);
LOG.info("Missing header hash for {}, channel id: {}", appid, ctx.channel().id());
throw new IOException("fetcher cannot be authenticated");
}
if (LOG.isDebugEnabled()) {
int len = urlHashStr.length();
LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." +
urlHashStr.substring(len-len/2, len-1));
LOG.debug("Verifying request. encryptedURL:{}, hash:{}, channel id: " +
"{}", encryptedURL,
urlHashStr.substring(len - len / 2, len - 1), ctx.channel().id());
}
// verify - throws exception
SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret);
SecureShuffleUtils.verifyReply(urlHashStr, encryptedURL, tokenSecret);
// verification passed - encode the reply
String reply =
SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8),
tokenSecret);
String reply = SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8),
tokenSecret);
response.headers().set(
SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply);
// Put shuffle version into http header
@ -1246,8 +1359,10 @@ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
if (LOG.isDebugEnabled()) {
int len = reply.length();
LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" +
reply.substring(len-len/2, len-1));
LOG.debug("Fetcher request verified. " +
"encryptedURL: {}, reply: {}, channel id: {}",
encryptedURL, reply.substring(len - len / 2, len - 1),
ctx.channel().id());
}
}
@ -1255,27 +1370,27 @@ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch,
String user, String mapId, int reduce, MapOutputInfo mapOutputInfo)
throws IOException {
final IndexRecord info = mapOutputInfo.indexRecord;
final ShuffleHeader header =
new ShuffleHeader(mapId, info.partLength, info.rawLength, reduce);
final ShuffleHeader header = new ShuffleHeader(mapId, info.partLength, info.rawLength,
reduce);
final DataOutputBuffer dob = new DataOutputBuffer();
header.write(dob);
ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
writeToChannel(ch, wrappedBuffer(dob.getData(), 0, dob.getLength()));
final File spillfile =
new File(mapOutputInfo.mapOutputFileName.toString());
RandomAccessFile spill;
try {
spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null);
} catch (FileNotFoundException e) {
LOG.info(spillfile + " not found");
LOG.info("{} not found. Channel id: {}", spillfile, ctx.channel().id());
return null;
}
ChannelFuture writeFuture;
if (ch.getPipeline().get(SslHandler.class) == null) {
if (ch.pipeline().get(SslHandler.class) == null) {
final FadvisedFileRegion partition = new FadvisedFileRegion(spill,
info.startOffset, info.partLength, manageOsCache, readaheadLength,
readaheadPool, spillfile.getAbsolutePath(),
shuffleBufferSize, shuffleTransferToAllowed);
writeFuture = ch.write(partition);
writeFuture = writeToChannel(ch, partition);
writeFuture.addListener(new ChannelFutureListener() {
// TODO error handling; distinguish IO/connection failures,
// attribute to appropriate spill output
@ -1284,7 +1399,7 @@ public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) {
partition.transferSuccessful();
}
partition.releaseExternalResources();
partition.deallocate();
}
});
} else {
@ -1293,7 +1408,7 @@ public void operationComplete(ChannelFuture future) {
info.startOffset, info.partLength, sslFileBufferSize,
manageOsCache, readaheadLength, readaheadPool,
spillfile.getAbsolutePath());
writeFuture = ch.write(chunk);
writeFuture = writeToChannel(ch, chunk);
}
metrics.shuffleConnections.incr();
metrics.shuffleOutputBytes.incr(info.partLength); // optimistic
@ -1307,12 +1422,13 @@ protected void sendError(ChannelHandlerContext ctx,
protected void sendError(ChannelHandlerContext ctx, String message,
HttpResponseStatus status) {
sendError(ctx, message, status, Collections.<String, String>emptyMap());
sendError(ctx, message, status, Collections.emptyMap());
}
protected void sendError(ChannelHandlerContext ctx, String msg,
HttpResponseStatus status, Map<String, String> headers) {
HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status);
FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, status,
Unpooled.copiedBuffer(msg, CharsetUtil.UTF_8));
response.headers().set(CONTENT_TYPE, "text/plain; charset=UTF-8");
// Put shuffle version into http header
response.headers().set(ShuffleHeader.HTTP_HEADER_NAME,
@ -1322,48 +1438,45 @@ protected void sendError(ChannelHandlerContext ctx, String msg,
for (Map.Entry<String, String> header : headers.entrySet()) {
response.headers().set(header.getKey(), header.getValue());
}
response.setContent(
ChannelBuffers.copiedBuffer(msg, CharsetUtil.UTF_8));
// Close the connection as soon as the error message is sent.
ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE);
writeToChannelAndClose(ctx.channel(), response);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
Channel ch = e.getChannel();
Throwable cause = e.getCause();
Channel ch = ctx.channel();
if (cause instanceof TooLongFrameException) {
LOG.trace("TooLongFrameException, channel id: {}", ch.id());
sendError(ctx, BAD_REQUEST);
return;
} else if (cause instanceof IOException) {
if (cause instanceof ClosedChannelException) {
LOG.debug("Ignoring closed channel error", cause);
LOG.debug("Ignoring closed channel error, channel id: " + ch.id(), cause);
return;
}
String message = String.valueOf(cause.getMessage());
if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) {
LOG.debug("Ignoring client socket close", cause);
LOG.debug("Ignoring client socket close, channel id: " + ch.id(), cause);
return;
}
}
LOG.error("Shuffle error: ", cause);
if (ch.isConnected()) {
LOG.error("Shuffle error " + e);
LOG.error("Shuffle error. Channel id: " + ch.id(), cause);
if (ch.isActive()) {
sendError(ctx, INTERNAL_SERVER_ERROR);
}
}
}
static class AttemptPathInfo {
// TODO Change this over to just store local dir indices, instead of the
// entire path. Far more efficient.
private final Path indexPath;
private final Path dataPath;
public AttemptPathInfo(Path indexPath, Path dataPath) {
AttemptPathInfo(Path indexPath, Path dataPath) {
this.indexPath = indexPath;
this.dataPath = dataPath;
}
@ -1374,7 +1487,7 @@ static class AttemptPathIdentifier {
private final String user;
private final String attemptId;
public AttemptPathIdentifier(String jobId, String user, String attemptId) {
AttemptPathIdentifier(String jobId, String user, String attemptId) {
this.jobId = jobId;
this.user = user;
this.attemptId = attemptId;

View File

@ -104,7 +104,7 @@ public void testCustomShuffleTransfer() throws IOException {
Assert.assertEquals(count, targetFile.length());
} finally {
if (fileRegion != null) {
fileRegion.releaseExternalResources();
fileRegion.deallocate();
}
IOUtils.cleanupWithLogger(LOG, target);
IOUtils.cleanupWithLogger(LOG, targetFile);

View File

@ -17,3 +17,5 @@ log4j.threshold=ALL
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %c{2} (%F:%M(%L)) - %m%n
log4j.logger.io.netty=INFO
log4j.logger.org.apache.hadoop.mapred=INFO

View File

@ -130,7 +130,7 @@
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty</artifactId>
<artifactId>netty-all</artifactId>
</dependency>
<dependency>
<groupId>commons-logging</groupId>