diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index 183bad41fa..567b932e1b 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -21,7 +21,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import com.google.protobuf.CodedOutputStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience; @@ -31,13 +30,11 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeysPublic; -import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.retry.RetryPolicies; import org.apache.hadoop.io.retry.RetryPolicy; import org.apache.hadoop.io.retry.RetryPolicy.RetryAction; -import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper; import org.apache.hadoop.ipc.RPC.RpcKind; import org.apache.hadoop.ipc.Server.AuthProtocol; import org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto; @@ -54,7 +51,6 @@ import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.util.ProtoUtil; -import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.util.Time; import org.apache.hadoop.util.concurrent.AsyncGet; @@ -65,6 +61,7 @@ import javax.security.sasl.Sasl; import java.io.*; import java.net.*; +import java.nio.ByteBuffer; import java.security.PrivilegedExceptionAction; import java.util.*; import java.util.Map.Entry; @@ -429,7 +426,7 @@ private class Connection extends Thread { private final boolean doPing; //do we need to send ping message private final int pingInterval; // how often sends ping to the server private final int soTimeout; // used by ipc ping and rpc timeout - private ByteArrayOutputStream pingRequest; // ping message + private ResponseBuffer pingRequest; // ping message // currently active calls private Hashtable calls = new Hashtable(); @@ -459,7 +456,7 @@ public Connection(ConnectionId remoteId, int serviceClass) throws IOException { this.doPing = remoteId.getDoPing(); if (doPing) { // construct a RPC header with the callId as the ping callId - pingRequest = new ByteArrayOutputStream(); + pingRequest = new ResponseBuffer(); RpcRequestHeaderProto pingHeader = ProtoUtil .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, OperationProto.RPC_FINAL_PACKET, PING_CALL_ID, @@ -979,12 +976,10 @@ private void writeConnectionContext(ConnectionId remoteId, .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, OperationProto.RPC_FINAL_PACKET, CONNECTION_CONTEXT_CALL_ID, RpcConstants.INVALID_RETRY_COUNT, clientId); - RpcRequestMessageWrapper request = - new RpcRequestMessageWrapper(connectionContextHeader, message); - - // Write out the packet length - out.writeInt(request.getLength()); - request.write(out); + final ResponseBuffer buf = new ResponseBuffer(); + connectionContextHeader.writeDelimitedTo(buf); + message.writeDelimitedTo(buf); + buf.writeTo(out); } /* wait till someone signals us to start reading RPC response or @@ -1030,7 +1025,6 @@ private synchronized void sendPing() throws IOException { if ( curTime - lastActivity.get() >= pingInterval) { lastActivity.set(curTime); synchronized (out) { - out.writeInt(pingRequest.size()); pingRequest.writeTo(out); out.flush(); } @@ -1085,12 +1079,13 @@ public void sendRpcRequest(final Call call) // 2) RpcRequest // // Items '1' and '2' are prepared here. - final DataOutputBuffer d = new DataOutputBuffer(); RpcRequestHeaderProto header = ProtoUtil.makeRpcRequestHeader( call.rpcKind, OperationProto.RPC_FINAL_PACKET, call.id, call.retry, clientId); - header.writeDelimitedTo(d); - call.rpcRequest.write(d); + + final ResponseBuffer buf = new ResponseBuffer(); + header.writeDelimitedTo(buf); + RpcWritable.wrap(call.rpcRequest).writeTo(buf); synchronized (sendRpcRequestLock) { Future senderFuture = sendParamsExecutor.submit(new Runnable() { @@ -1101,14 +1096,10 @@ public void run() { if (shouldCloseConnection.get()) { return; } - - if (LOG.isDebugEnabled()) + if (LOG.isDebugEnabled()) { LOG.debug(getName() + " sending #" + call.id); - - byte[] data = d.getData(); - int totalLength = d.getLength(); - out.writeInt(totalLength); // Total Length - out.write(data, 0, totalLength);// RpcRequestHeader + RpcRequest + } + buf.writeTo(out); // RpcRequestHeader + RpcRequest out.flush(); } } catch (IOException e) { @@ -1119,7 +1110,7 @@ public void run() { } finally { //the buffer is just an in-memory buffer, but it is still polite to // close early - IOUtils.closeStream(d); + IOUtils.closeStream(buf); } } }); @@ -1151,12 +1142,13 @@ private void receiveRpcResponse() { try { int totalLen = in.readInt(); - RpcResponseHeaderProto header = - RpcResponseHeaderProto.parseDelimitedFrom(in); - checkResponse(header); + ByteBuffer bb = ByteBuffer.allocate(totalLen); + in.readFully(bb.array()); - int headerLen = header.getSerializedSize(); - headerLen += CodedOutputStream.computeRawVarint32Size(headerLen); + RpcWritable.Buffer packet = RpcWritable.Buffer.wrap(bb); + RpcResponseHeaderProto header = + packet.getValue(RpcResponseHeaderProto.getDefaultInstance()); + checkResponse(header); int callId = header.getCallId(); if (LOG.isDebugEnabled()) @@ -1164,28 +1156,15 @@ private void receiveRpcResponse() { RpcStatusProto status = header.getStatus(); if (status == RpcStatusProto.SUCCESS) { - Writable value = ReflectionUtils.newInstance(valueClass, conf); - value.readFields(in); // read value + Writable value = packet.newInstance(valueClass, conf); final Call call = calls.remove(callId); call.setRpcResponse(value); - - // verify that length was correct - // only for ProtobufEngine where len can be verified easily - if (call.getRpcResponse() instanceof ProtobufRpcEngine.RpcWrapper) { - ProtobufRpcEngine.RpcWrapper resWrapper = - (ProtobufRpcEngine.RpcWrapper) call.getRpcResponse(); - if (totalLen != headerLen + resWrapper.getLength()) { - throw new RpcClientException( - "RPC response length mismatch on rpc success"); - } - } - } else { // Rpc Request failed - // Verify that length was correct - if (totalLen != headerLen) { - throw new RpcClientException( - "RPC response length mismatch on rpc error"); - } - + } + // verify that packet length was correct + if (packet.remaining() > 0) { + throw new RpcClientException("RPC response length mismatch"); + } + if (status != RpcStatusProto.SUCCESS) { // Rpc Request failed final String exceptionClassName = header.hasExceptionClassName() ? header.getExceptionClassName() : "ServerDidNotSetExceptionClassName"; diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java index eb30aa207e..83e4b9ec8b 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java @@ -27,29 +27,22 @@ import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.classification.InterfaceStability.Unstable; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.DataOutputOutputStream; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.retry.RetryPolicy; import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.RPC.RpcInvoker; import org.apache.hadoop.ipc.RpcWritable; import org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto; -import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto; -import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.TokenIdentifier; -import org.apache.hadoop.util.ProtoUtil; import org.apache.hadoop.util.Time; import org.apache.hadoop.util.concurrent.AsyncGet; import org.apache.htrace.core.TraceScope; import org.apache.htrace.core.Tracer; import javax.net.SocketFactory; -import java.io.DataInput; -import java.io.DataOutput; import java.io.IOException; -import java.io.OutputStream; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.InetSocketAddress; @@ -146,7 +139,7 @@ private Invoker(Class protocol, InetSocketAddress addr, private Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; - this.client = CLIENTS.getClient(conf, factory, RpcResponseWrapper.class); + this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); this.protocolName = RPC.getProtocolName(protocol); this.clientProtocolVersion = RPC .getProtocolVersion(protocol); @@ -193,7 +186,7 @@ private RequestHeaderProto constructRpcRequestHeader(Method method) { * the server. */ @Override - public Object invoke(Object proxy, final Method method, Object[] args) + public Message invoke(Object proxy, final Method method, Object[] args) throws ServiceException { long startTime = 0; if (LOG.isDebugEnabled()) { @@ -228,11 +221,11 @@ public Object invoke(Object proxy, final Method method, Object[] args) } - Message theRequest = (Message) args[1]; - final RpcResponseWrapper val; + final Message theRequest = (Message) args[1]; + final RpcWritable.Buffer val; try { - val = (RpcResponseWrapper) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcRequestWrapper(rpcRequestHeader, theRequest), remoteId, + val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, + new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, fallbackToSimpleAuth); } catch (Throwable e) { @@ -256,7 +249,7 @@ public Object invoke(Object proxy, final Method method, Object[] args) } if (Client.isAsynchronousMode()) { - final AsyncGet arr + final AsyncGet arr = Client.getAsyncRpcResponse(); final AsyncGet asyncGet = new AsyncGet() { @@ -278,7 +271,7 @@ public boolean isDone() { } private Message getReturnMessage(final Method method, - final RpcResponseWrapper rrw) throws ServiceException { + final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; try { prototype = getReturnProtoType(method); @@ -287,8 +280,7 @@ private Message getReturnMessage(final Method method, } Message returnMessage; try { - returnMessage = prototype.newBuilderForType() - .mergeFrom(rrw.theResponseRead).build(); + returnMessage = buf.getValue(prototype.getDefaultInstanceForType()); if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Response <- " + @@ -329,201 +321,12 @@ public ConnectionId getConnectionId() { } } - interface RpcWrapper extends Writable { - int getLength(); - } - /** - * Wrapper for Protocol Buffer Requests - * - * Note while this wrapper is writable, the request on the wire is in - * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} - * use type Writable as a wrapper to work across multiple RpcEngine kinds. - */ - private static abstract class RpcMessageWithHeader - implements RpcWrapper { - T requestHeader; - Message theRequest; // for clientSide, the request is here - byte[] theRequestRead; // for server side, the request is here - - public RpcMessageWithHeader() { - } - - public RpcMessageWithHeader(T requestHeader, Message theRequest) { - this.requestHeader = requestHeader; - this.theRequest = theRequest; - } - - @Override - public void write(DataOutput out) throws IOException { - OutputStream os = DataOutputOutputStream.constructOutputStream(out); - - ((Message)requestHeader).writeDelimitedTo(os); - theRequest.writeDelimitedTo(os); - } - - @Override - public void readFields(DataInput in) throws IOException { - requestHeader = parseHeaderFrom(readVarintBytes(in)); - theRequestRead = readMessageRequest(in); - } - - abstract T parseHeaderFrom(byte[] bytes) throws IOException; - - byte[] readMessageRequest(DataInput in) throws IOException { - return readVarintBytes(in); - } - - private static byte[] readVarintBytes(DataInput in) throws IOException { - final int length = ProtoUtil.readRawVarint32(in); - final byte[] bytes = new byte[length]; - in.readFully(bytes); - return bytes; - } - - public T getMessageHeader() { - return requestHeader; - } - - public byte[] getMessageBytes() { - return theRequestRead; - } - - @Override - public int getLength() { - int headerLen = requestHeader.getSerializedSize(); - int reqLen; - if (theRequest != null) { - reqLen = theRequest.getSerializedSize(); - } else if (theRequestRead != null ) { - reqLen = theRequestRead.length; - } else { - throw new IllegalArgumentException( - "getLength on uninitialized RpcWrapper"); - } - return CodedOutputStream.computeRawVarint32Size(headerLen) + headerLen - + CodedOutputStream.computeRawVarint32Size(reqLen) + reqLen; - } - } - - private static class RpcRequestWrapper - extends RpcMessageWithHeader { - @SuppressWarnings("unused") - public RpcRequestWrapper() {} - - public RpcRequestWrapper( - RequestHeaderProto requestHeader, Message theRequest) { - super(requestHeader, theRequest); - } - - @Override - RequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException { - return RequestHeaderProto.parseFrom(bytes); - } - - @Override - public String toString() { - return requestHeader.getDeclaringClassProtocolName() + "." + - requestHeader.getMethodName(); - } - } - - @InterfaceAudience.LimitedPrivate({"RPC"}) - public static class RpcRequestMessageWrapper - extends RpcMessageWithHeader { - public RpcRequestMessageWrapper() {} - - public RpcRequestMessageWrapper( - RpcRequestHeaderProto requestHeader, Message theRequest) { - super(requestHeader, theRequest); - } - - @Override - RpcRequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException { - return RpcRequestHeaderProto.parseFrom(bytes); - } - } - - @InterfaceAudience.LimitedPrivate({"RPC"}) - public static class RpcResponseMessageWrapper - extends RpcMessageWithHeader { - public RpcResponseMessageWrapper() {} - - public RpcResponseMessageWrapper( - RpcResponseHeaderProto responseHeader, Message theRequest) { - super(responseHeader, theRequest); - } - - @Override - byte[] readMessageRequest(DataInput in) throws IOException { - // error message contain no message body - switch (requestHeader.getStatus()) { - case ERROR: - case FATAL: - return null; - default: - return super.readMessageRequest(in); - } - } - - @Override - RpcResponseHeaderProto parseHeaderFrom(byte[] bytes) throws IOException { - return RpcResponseHeaderProto.parseFrom(bytes); - } - } - - /** - * Wrapper for Protocol Buffer Responses - * - * Note while this wrapper is writable, the request on the wire is in - * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} - * use type Writable as a wrapper to work across multiple RpcEngine kinds. - */ - @InterfaceAudience.LimitedPrivate({"RPC"}) // temporarily exposed - public static class RpcResponseWrapper implements RpcWrapper { - Message theResponse; // for senderSide, the response is here - byte[] theResponseRead; // for receiver side, the response is here - - public RpcResponseWrapper() { - } - - public RpcResponseWrapper(Message message) { - this.theResponse = message; - } - - @Override - public void write(DataOutput out) throws IOException { - OutputStream os = DataOutputOutputStream.constructOutputStream(out); - theResponse.writeDelimitedTo(os); - } - - @Override - public void readFields(DataInput in) throws IOException { - int length = ProtoUtil.readRawVarint32(in); - theResponseRead = new byte[length]; - in.readFully(theResponseRead); - } - - @Override - public int getLength() { - int resLen; - if (theResponse != null) { - resLen = theResponse.getSerializedSize(); - } else if (theResponseRead != null ) { - resLen = theResponseRead.length; - } else { - throw new IllegalArgumentException( - "getLength on uninitialized RpcWrapper"); - } - return CodedOutputStream.computeRawVarint32Size(resLen) + resLen; - } - } - @VisibleForTesting @InterfaceAudience.Private @InterfaceStability.Unstable static Client getClient(Configuration conf) { return CLIENTS.getClient(conf, SocketFactory.getDefault(), - RpcResponseWrapper.class); + RpcWritable.Buffer.class); } @@ -691,16 +494,30 @@ public Writable call(RPC.Server server, String connectionProtocolName, // which uses the rpc header. in the normal case we want to defer decoding // the rpc header until needed by the rpc engine. static class RpcProtobufRequest extends RpcWritable.Buffer { - private RequestHeaderProto lazyHeader; + private volatile RequestHeaderProto requestHeader; + private Message payload; public RpcProtobufRequest() { } - synchronized RequestHeaderProto getRequestHeader() throws IOException { - if (lazyHeader == null) { - lazyHeader = getValue(RequestHeaderProto.getDefaultInstance()); + RpcProtobufRequest(RequestHeaderProto header, Message payload) { + this.requestHeader = header; + this.payload = payload; + } + + RequestHeaderProto getRequestHeader() throws IOException { + if (getByteBuffer() != null && requestHeader == null) { + requestHeader = getValue(RequestHeaderProto.getDefaultInstance()); + } + return requestHeader; + } + + @Override + public void writeTo(ResponseBuffer out) throws IOException { + requestHeader.writeDelimitedTo(out); + if (payload != null) { + payload.writeDelimitedTo(out); } - return lazyHeader; } // this is used by htrace to name the span. diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java index ac96a24178..a789d83dfd 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java @@ -27,8 +27,14 @@ import org.apache.hadoop.classification.InterfaceAudience; @InterfaceAudience.Private -class ResponseBuffer extends DataOutputStream { - ResponseBuffer(int capacity) { +/** generates byte-length framed buffers. */ +public class ResponseBuffer extends DataOutputStream { + + public ResponseBuffer() { + this(1024); + } + + public ResponseBuffer(int capacity) { super(new FramedBuffer(capacity)); } @@ -39,7 +45,7 @@ private FramedBuffer getFramedBuffer() { return buf; } - void writeTo(OutputStream out) throws IOException { + public void writeTo(OutputStream out) throws IOException { getFramedBuffer().writeTo(out); } diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java index 5125939e05..54fb98e80d 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java @@ -24,7 +24,6 @@ import java.io.DataOutput; import java.io.IOException; import java.nio.ByteBuffer; - import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; @@ -34,6 +33,7 @@ import com.google.protobuf.CodedOutputStream; import com.google.protobuf.Message; +// note anything marked public is solely for access by SaslRpcClient @InterfaceAudience.Private public abstract class RpcWritable implements Writable { @@ -99,6 +99,10 @@ static class ProtobufWrapper extends RpcWritable { this.message = message; } + Message getMessage() { + return message; + } + @Override void writeTo(ResponseBuffer out) throws IOException { int length = message.getSerializedSize(); @@ -128,11 +132,13 @@ T readFrom(ByteBuffer bb) throws IOException { } } - // adapter to allow decoding of writables and protobufs from a byte buffer. - static class Buffer extends RpcWritable { + /** + * adapter to allow decoding of writables and protobufs from a byte buffer. + */ + public static class Buffer extends RpcWritable { private ByteBuffer bb; - static Buffer wrap(ByteBuffer bb) { + public static Buffer wrap(ByteBuffer bb) { return new Buffer(bb); } @@ -142,6 +148,10 @@ static Buffer wrap(ByteBuffer bb) { this.bb = bb; } + ByteBuffer getByteBuffer() { + return bb; + } + @Override void writeTo(ResponseBuffer out) throws IOException { out.ensureCapacity(bb.remaining()); @@ -177,7 +187,7 @@ public T getValue(T value) throws IOException { return RpcWritable.wrap(value).readFrom(bb); } - int remaining() { + public int remaining() { return bb.remaining(); } } diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java index c360937baf..60ae3b04da 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java @@ -53,11 +53,11 @@ import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.GlobPattern; -import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper; -import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper; import org.apache.hadoop.ipc.RPC.RpcKind; import org.apache.hadoop.ipc.RemoteException; +import org.apache.hadoop.ipc.ResponseBuffer; import org.apache.hadoop.ipc.RpcConstants; +import org.apache.hadoop.ipc.RpcWritable; import org.apache.hadoop.ipc.Server.AuthProtocol; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto; @@ -368,11 +368,13 @@ public AuthMethod saslConnect(InputStream inS, OutputStream outS) // loop until sasl is complete or a rpc error occurs boolean done = false; do { - int totalLen = inStream.readInt(); - RpcResponseMessageWrapper responseWrapper = - new RpcResponseMessageWrapper(); - responseWrapper.readFields(inStream); - RpcResponseHeaderProto header = responseWrapper.getMessageHeader(); + int rpcLen = inStream.readInt(); + ByteBuffer bb = ByteBuffer.allocate(rpcLen); + inStream.readFully(bb.array()); + + RpcWritable.Buffer saslPacket = RpcWritable.Buffer.wrap(bb); + RpcResponseHeaderProto header = + saslPacket.getValue(RpcResponseHeaderProto.getDefaultInstance()); switch (header.getStatus()) { case ERROR: // might get a RPC error during case FATAL: @@ -380,15 +382,14 @@ public AuthMethod saslConnect(InputStream inS, OutputStream outS) header.getErrorMsg()); default: break; } - if (totalLen != responseWrapper.getLength()) { - throw new SaslException("Received malformed response length"); - } - if (header.getCallId() != AuthProtocol.SASL.callId) { throw new SaslException("Non-SASL response during negotiation"); } RpcSaslProto saslMessage = - RpcSaslProto.parseFrom(responseWrapper.getMessageBytes()); + saslPacket.getValue(RpcSaslProto.getDefaultInstance()); + if (saslPacket.remaining() > 0) { + throw new SaslException("Received malformed response length"); + } // handle sasl negotiation process RpcSaslProto.Builder response = null; switch (saslMessage.getState()) { @@ -452,16 +453,16 @@ public AuthMethod saslConnect(InputStream inS, OutputStream outS) return authMethod; } - private void sendSaslMessage(DataOutputStream out, RpcSaslProto message) + private void sendSaslMessage(OutputStream out, RpcSaslProto message) throws IOException { if (LOG.isDebugEnabled()) { LOG.debug("Sending sasl message "+message); } - RpcRequestMessageWrapper request = - new RpcRequestMessageWrapper(saslHeader, message); - out.writeInt(request.getLength()); - request.write(out); - out.flush(); + ResponseBuffer buf = new ResponseBuffer(); + saslHeader.writeDelimitedTo(buf); + message.writeDelimitedTo(buf); + buf.writeTo(out); + out.flush(); } /** @@ -634,12 +635,8 @@ public void write(byte[] buf, int off, int len) throws IOException { .setState(SaslState.WRAP) .setToken(ByteString.copyFrom(buf, 0, buf.length)) .build(); - RpcRequestMessageWrapper request = - new RpcRequestMessageWrapper(saslHeader, saslMessage); - DataOutputStream dob = new DataOutputStream(out); - dob.writeInt(request.getLength()); - request.write(dob); - } + sendSaslMessage(out, saslMessage); + } } /** Release resources used by wrapped saslClient */