HADOOP-13483. Optimize IPC server protobuf decoding. Contributed by Daryn Sharp.

This commit is contained in:
Kihwal Lee 2016-08-03 13:22:22 -05:00
parent 22ef5286bc
commit 580a833496
4 changed files with 362 additions and 63 deletions

View File

@ -32,6 +32,7 @@ import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RpcWritable;
import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
@ -68,7 +69,7 @@ public class ProtobufRpcEngine implements RpcEngine {
static { // Register the rpcRequest deserializer for WritableRpcEngine
org.apache.hadoop.ipc.Server.registerProtocolEngine(
RPC.RpcKind.RPC_PROTOCOL_BUFFER, RpcRequestWrapper.class,
RPC.RpcKind.RPC_PROTOCOL_BUFFER, RpcWritable.Buffer.class,
new Server.ProtoBufRpcInvoker());
}
@ -612,11 +613,11 @@ public class ProtobufRpcEngine implements RpcEngine {
*/
public Writable call(RPC.Server server, String connectionProtocolName,
Writable writableRequest, long receiveTime) throws Exception {
RpcRequestWrapper request = (RpcRequestWrapper) writableRequest;
RequestHeaderProto rpcRequest = request.requestHeader;
RpcWritable.Buffer request = (RpcWritable.Buffer) writableRequest;
RequestHeaderProto rpcRequest =
request.getValue(RequestHeaderProto.getDefaultInstance());
String methodName = rpcRequest.getMethodName();
/**
* RPCs for a particular interface (ie protocol) are done using a
* IPC connection that is setup using rpcProxy.
@ -652,8 +653,7 @@ public class ProtobufRpcEngine implements RpcEngine {
throw new RpcNoSuchMethodException(msg);
}
Message prototype = service.getRequestPrototype(methodDescriptor);
Message param = prototype.newBuilderForType()
.mergeFrom(request.theRequestRead).build();
Message param = request.getValue(prototype);
Message result;
long startTime = Time.now();
@ -683,7 +683,7 @@ public class ProtobufRpcEngine implements RpcEngine {
exception.getClass().getSimpleName();
server.updateMetrics(detailedMetricsName, qTime, processingTime);
}
return new RpcResponseWrapper(result);
return RpcWritable.wrap(result);
}
}
}

View File

@ -0,0 +1,184 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Message;
@InterfaceAudience.Private
public abstract class RpcWritable implements Writable {
static RpcWritable wrap(Object o) {
if (o instanceof RpcWritable) {
return (RpcWritable)o;
} else if (o instanceof Message) {
return new ProtobufWrapper((Message)o);
} else if (o instanceof Writable) {
return new WritableWrapper((Writable)o);
}
throw new IllegalArgumentException("Cannot wrap " + o.getClass());
}
// don't support old inefficient Writable methods.
@Override
public final void readFields(DataInput in) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public final void write(DataOutput out) throws IOException {
throw new UnsupportedOperationException();
}
// methods optimized for reduced intermediate byte[] allocations.
abstract void writeTo(ResponseBuffer out) throws IOException;
abstract <T> T readFrom(ByteBuffer bb) throws IOException;
// adapter for Writables.
static class WritableWrapper extends RpcWritable {
private final Writable writable;
WritableWrapper(Writable writable) {
this.writable = writable;
}
@Override
public void writeTo(ResponseBuffer out) throws IOException {
writable.write(out);
}
@SuppressWarnings("unchecked")
@Override
<T> T readFrom(ByteBuffer bb) throws IOException {
// create a stream that may consume up to the entire ByteBuffer.
DataInputStream in = new DataInputStream(new ByteArrayInputStream(
bb.array(), bb.position() + bb.arrayOffset(), bb.remaining()));
try {
writable.readFields(in);
} finally {
// advance over the bytes read.
bb.position(bb.limit() - in.available());
}
return (T)writable;
}
}
// adapter for Protobufs.
static class ProtobufWrapper extends RpcWritable {
private Message message;
ProtobufWrapper(Message message) {
this.message = message;
}
@Override
void writeTo(ResponseBuffer out) throws IOException {
int length = message.getSerializedSize();
length += CodedOutputStream.computeRawVarint32Size(length);
out.ensureCapacity(length);
message.writeDelimitedTo(out);
}
@SuppressWarnings("unchecked")
@Override
<T> T readFrom(ByteBuffer bb) throws IOException {
// using the parser with a byte[]-backed coded input stream is the
// most efficient way to deserialize a protobuf. it has a direct
// path to the PB ctor that doesn't create multi-layered streams
// that internally buffer.
CodedInputStream cis = CodedInputStream.newInstance(
bb.array(), bb.position() + bb.arrayOffset(), bb.remaining());
try {
cis.pushLimit(cis.readRawVarint32());
message = message.getParserForType().parseFrom(cis);
cis.checkLastTagWas(0);
} finally {
// advance over the bytes read.
bb.position(bb.position() + cis.getTotalBytesRead());
}
return (T)message;
}
}
// adapter to allow decoding of writables and protobufs from a byte buffer.
static class Buffer extends RpcWritable {
private ByteBuffer bb;
static Buffer wrap(ByteBuffer bb) {
return new Buffer(bb);
}
Buffer() {}
Buffer(ByteBuffer bb) {
this.bb = bb;
}
@Override
void writeTo(ResponseBuffer out) throws IOException {
out.ensureCapacity(bb.remaining());
out.write(bb.array(), bb.position() + bb.arrayOffset(), bb.remaining());
}
@SuppressWarnings("unchecked")
@Override
<T> T readFrom(ByteBuffer bb) throws IOException {
// effectively consume the rest of the buffer from the callers
// perspective.
this.bb = bb.slice();
bb.limit(bb.position());
return (T)this;
}
public <T> T newInstance(Class<T> valueClass,
Configuration conf) throws IOException {
T instance;
try {
// this is much faster than ReflectionUtils!
instance = valueClass.newInstance();
if (instance instanceof Configurable) {
((Configurable)instance).setConf(conf);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return getValue(instance);
}
public <T> T getValue(T value) throws IOException {
return RpcWritable.wrap(value).readFrom(bb);
}
int remaining() {
return bb.remaining();
}
}
}

View File

@ -26,7 +26,6 @@ import static org.apache.hadoop.ipc.RpcConstants.PING_CALL_ID;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
@ -83,8 +82,6 @@ import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseWrapper;
import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcWrapper;
import org.apache.hadoop.ipc.RPC.RpcInvoker;
import org.apache.hadoop.ipc.RPC.VersionMismatch;
import org.apache.hadoop.ipc.metrics.RpcDetailedMetrics;
@ -114,7 +111,6 @@ import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.ProtoUtil;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.htrace.core.SpanId;
@ -123,9 +119,7 @@ import org.apache.htrace.core.Tracer;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
/** An abstract IPC service. IPC calls take a single {@link Writable} as a
* parameter, and return a {@link Writable} as their value. A service runs on
@ -1346,6 +1340,7 @@ public abstract class Server {
* A WrappedRpcServerException that is suppressed altogether
* for the purposes of logging.
*/
@SuppressWarnings("serial")
private static class WrappedRpcServerExceptionSuppressed
extends WrappedRpcServerException {
public WrappedRpcServerExceptionSuppressed(
@ -1478,10 +1473,10 @@ public abstract class Server {
}
}
private void saslReadAndProcess(DataInputStream dis) throws
private void saslReadAndProcess(RpcWritable.Buffer buffer) throws
WrappedRpcServerException, IOException, InterruptedException {
final RpcSaslProto saslMessage =
decodeProtobufFromStream(RpcSaslProto.newBuilder(), dis);
getMessage(RpcSaslProto.getDefaultInstance(), buffer);
switch (saslMessage.getState()) {
case WRAP: {
if (!saslContextEstablished || !useWrap) {
@ -1713,7 +1708,7 @@ public abstract class Server {
RpcConstants.INVALID_RETRY_COUNT, null, this);
setupResponse(saslCall,
RpcStatusProto.SUCCESS, null,
new RpcResponseWrapper(message), null, null);
RpcWritable.wrap(message), null, null);
saslCall.sendResponse();
}
@ -1839,7 +1834,7 @@ public abstract class Server {
dataLengthBuffer.clear(); // to read length of future rpc packets
data.flip();
boolean isHeaderRead = connectionContextRead;
processOneRpc(data.array());
processOneRpc(data);
data = null;
// the last rpc-request we processed could have simply been the
// connectionContext; if so continue to read the first RPC.
@ -1966,7 +1961,7 @@ public abstract class Server {
* @throws WrappedRpcServerException - if the header cannot be
* deserialized, or the user is not authorized
*/
private void processConnectionContext(DataInputStream dis)
private void processConnectionContext(RpcWritable.Buffer buffer)
throws WrappedRpcServerException {
// allow only one connection context during a session
if (connectionContextRead) {
@ -1974,8 +1969,7 @@ public abstract class Server {
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Connection context already processed");
}
connectionContext = decodeProtobufFromStream(
IpcConnectionContextProto.newBuilder(), dis);
connectionContext = getMessage(IpcConnectionContextProto.getDefaultInstance(), buffer);
protocolName = connectionContext.hasProtocol() ? connectionContext
.getProtocol() : null;
@ -2053,7 +2047,7 @@ public abstract class Server {
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(unwrappedData.array());
processOneRpc(unwrappedData);
unwrappedData = null;
}
}
@ -2078,15 +2072,14 @@ public abstract class Server {
* Listener thread
* @throws InterruptedException
*/
private void processOneRpc(byte[] buf)
private void processOneRpc(ByteBuffer bb)
throws IOException, WrappedRpcServerException, InterruptedException {
int callId = -1;
int retry = RpcConstants.INVALID_RETRY_COUNT;
try {
final DataInputStream dis =
new DataInputStream(new ByteArrayInputStream(buf));
final RpcWritable.Buffer buffer = RpcWritable.Buffer.wrap(bb);
final RpcRequestHeaderProto header =
decodeProtobufFromStream(RpcRequestHeaderProto.newBuilder(), dis);
getMessage(RpcRequestHeaderProto.getDefaultInstance(), buffer);
callId = header.getCallId();
retry = header.getRetryCount();
if (LOG.isDebugEnabled()) {
@ -2095,13 +2088,13 @@ public abstract class Server {
checkRpcHeaders(header);
if (callId < 0) { // callIds typically used during connection setup
processRpcOutOfBandRequest(header, dis);
processRpcOutOfBandRequest(header, buffer);
} else if (!connectionContextRead) {
throw new WrappedRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Connection context not established");
} else {
processRpcRequest(header, dis);
processRpcRequest(header, buffer);
}
} catch (WrappedRpcServerException wrse) { // inform client of error
Throwable ioe = wrse.getCause();
@ -2157,7 +2150,7 @@ public abstract class Server {
* @throws InterruptedException
*/
private void processRpcRequest(RpcRequestHeaderProto header,
DataInputStream dis) throws WrappedRpcServerException,
RpcWritable.Buffer buffer) throws WrappedRpcServerException,
InterruptedException {
Class<? extends Writable> rpcRequestClass =
getRpcRequestWrapper(header.getRpcKind());
@ -2171,8 +2164,7 @@ public abstract class Server {
}
Writable rpcRequest;
try { //Read the rpc request
rpcRequest = ReflectionUtils.newInstance(rpcRequestClass, conf);
rpcRequest.readFields(dis);
rpcRequest = buffer.newInstance(rpcRequestClass, conf);
} catch (Throwable t) { // includes runtime exception from newInstance
LOG.warn("Unable to read call parameters for client " +
getHostAddress() + "on connection protocol " +
@ -2253,8 +2245,8 @@ public abstract class Server {
* @throws InterruptedException
*/
private void processRpcOutOfBandRequest(RpcRequestHeaderProto header,
DataInputStream dis) throws WrappedRpcServerException, IOException,
InterruptedException {
RpcWritable.Buffer buffer) throws WrappedRpcServerException,
IOException, InterruptedException {
final int callId = header.getCallId();
if (callId == CONNECTION_CONTEXT_CALL_ID) {
// SASL must be established prior to connection context
@ -2264,7 +2256,7 @@ public abstract class Server {
"Connection header sent during SASL negotiation");
}
// read and authorize the user
processConnectionContext(dis);
processConnectionContext(buffer);
} else if (callId == AuthProtocol.SASL.callId) {
// if client was switched to simple, ignore first SASL message
if (authProtocol != AuthProtocol.SASL) {
@ -2272,7 +2264,7 @@ public abstract class Server {
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"SASL protocol not requested by client");
}
saslReadAndProcess(dis);
saslReadAndProcess(buffer);
} else if (callId == PING_CALL_ID) {
LOG.debug("Received ping message");
} else {
@ -2319,13 +2311,12 @@ public abstract class Server {
* @throws WrappedRpcServerException - deserialization failed
*/
@SuppressWarnings("unchecked")
private <T extends Message> T decodeProtobufFromStream(Builder builder,
DataInputStream dis) throws WrappedRpcServerException {
<T extends Message> T getMessage(Message message,
RpcWritable.Buffer buffer) throws WrappedRpcServerException {
try {
builder.mergeDelimitedFrom(dis);
return (T)builder.build();
return (T)buffer.getValue(message);
} catch (Exception ioe) {
Class<?> protoClass = builder.getDefaultInstanceForType().getClass();
Class<?> protoClass = message.getClass();
throw new WrappedRpcServerException(
RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST,
"Error decoding " + protoClass.getSimpleName() + ": "+ ioe);
@ -2716,25 +2707,20 @@ public abstract class Server {
private void setupResponse(Call call,
RpcResponseHeaderProto header, Writable rv) throws IOException {
ResponseBuffer buf = responseBuffer.get().reset();
// adjust capacity on estimated length to reduce resizing copies
int estimatedLen = header.getSerializedSize();
estimatedLen += CodedOutputStream.computeRawVarint32Size(estimatedLen);
// if it's not a wrapped protobuf, just let it grow on its own
if (rv instanceof RpcWrapper) {
estimatedLen += ((RpcWrapper)rv).getLength();
}
buf.ensureCapacity(estimatedLen);
header.writeDelimitedTo(buf);
if (rv != null) { // null for exceptions
rv.write(buf);
}
call.setResponse(ByteBuffer.wrap(buf.toByteArray()));
// Discard a large buf and reset it back to smaller size
// to free up heap.
if (buf.capacity() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf.setCapacity(INITIAL_RESP_BUF_SIZE);
try {
RpcWritable.wrap(header).writeTo(buf);
if (rv != null) {
RpcWritable.wrap(rv).writeTo(buf);
}
call.setResponse(ByteBuffer.wrap(buf.toByteArray()));
} finally {
// Discard a large buf and reset it back to smaller size
// to free up heap.
if (buf.capacity() > maxRespSize) {
LOG.warn("Large response size " + buf.size() + " for call "
+ call.toString());
buf.setCapacity(INITIAL_RESP_BUF_SIZE);
}
}
}
@ -2785,7 +2771,7 @@ public abstract class Server {
.setState(SaslState.WRAP)
.setToken(ByteString.copyFrom(token))
.build();
setupResponse(call, saslHeader, new RpcResponseWrapper(saslMessage));
setupResponse(call, saslHeader, RpcWritable.wrap(saslMessage));
}
}

View File

@ -0,0 +1,129 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.protobuf.TestProtos.EchoRequestProto;
import org.apache.hadoop.util.Time;
import org.junit.Assert;
import org.junit.Test;
import com.google.protobuf.Message;
public class TestRpcWritable {//extends TestRpcBase {
static Writable writable = new LongWritable(Time.now());
static Message message1 =
EchoRequestProto.newBuilder().setMessage("testing1").build();
static Message message2 =
EchoRequestProto.newBuilder().setMessage("testing2").build();
@Test
public void testWritableWrapper() throws IOException {
// serial writable in byte buffer
ByteArrayOutputStream baos = new ByteArrayOutputStream();
writable.write(new DataOutputStream(baos));
ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
// deserial
LongWritable actual = RpcWritable.wrap(new LongWritable())
.readFrom(bb);
Assert.assertEquals(writable, actual);
Assert.assertEquals(0, bb.remaining());
}
@Test
public void testProtobufWrapper() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
message1.writeDelimitedTo(baos);
ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
Message actual = RpcWritable.wrap(EchoRequestProto.getDefaultInstance())
.readFrom(bb);
Assert.assertEquals(message1, actual);
Assert.assertEquals(0, bb.remaining());
}
@Test
public void testBufferWrapper() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
message1.writeDelimitedTo(dos);
message2.writeDelimitedTo(dos);
writable.write(dos);
ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
RpcWritable.Buffer buf = RpcWritable.Buffer.wrap(bb);
Assert.assertEquals(baos.size(), bb.remaining());
Assert.assertEquals(baos.size(), buf.remaining());
Object actual = buf.getValue(EchoRequestProto.getDefaultInstance());
Assert.assertEquals(message1, actual);
Assert.assertTrue(bb.remaining() > 0);
Assert.assertEquals(bb.remaining(), buf.remaining());
actual = buf.getValue(EchoRequestProto.getDefaultInstance());
Assert.assertEquals(message2, actual);
Assert.assertTrue(bb.remaining() > 0);
Assert.assertEquals(bb.remaining(), buf.remaining());
actual = buf.newInstance(LongWritable.class, null);
Assert.assertEquals(writable, actual);
Assert.assertEquals(0, bb.remaining());
Assert.assertEquals(0, buf.remaining());
}
@Test
public void testBufferWrapperNested() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
writable.write(dos);
message1.writeDelimitedTo(dos);
message2.writeDelimitedTo(dos);
ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
RpcWritable.Buffer buf1 = RpcWritable.Buffer.wrap(bb);
Assert.assertEquals(baos.size(), bb.remaining());
Assert.assertEquals(baos.size(), buf1.remaining());
Object actual = buf1.newInstance(LongWritable.class, null);
Assert.assertEquals(writable, actual);
int left = bb.remaining();
Assert.assertTrue(left > 0);
Assert.assertEquals(left, buf1.remaining());
// original bb now appears empty, but rpc writable has a slice of the bb.
RpcWritable.Buffer buf2 = buf1.newInstance(RpcWritable.Buffer.class, null);
Assert.assertEquals(0, bb.remaining());
Assert.assertEquals(0, buf1.remaining());
Assert.assertEquals(left, buf2.remaining());
actual = buf2.getValue(EchoRequestProto.getDefaultInstance());
Assert.assertEquals(message1, actual);
Assert.assertTrue(buf2.remaining() > 0);
actual = buf2.getValue(EchoRequestProto.getDefaultInstance());
Assert.assertEquals(message2, actual);
Assert.assertEquals(0, buf2.remaining());
}
}