diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index 09fe889790..4c73f6a60a 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -354,9 +354,10 @@ public static int getCallRetryCount() { */ public static InetAddress getRemoteIp() { Call call = CurCall.get(); - return (call != null ) ? call.getHostInetAddress() : null; + return (call != null && call.connection != null) ? call.connection + .getHostInetAddress() : null; } - + /** * Returns the clientId from the current RPC request */ @@ -379,9 +380,10 @@ public static String getRemoteAddress() { */ public static UserGroupInformation getRemoteUser() { Call call = CurCall.get(); - return (call != null) ? call.getRemoteUser() : null; + return (call != null && call.connection != null) ? call.connection.user + : null; } - + /** Return true if the invocation was through an RPC. */ public static boolean isRpcInvocation() { @@ -481,7 +483,7 @@ void logSlowRpcCalls(String methodName, int processingTime) { if ((rpcMetrics.getProcessingSampleCount() > minSampleSize) && (processingTime > threeSigma)) { if(LOG.isWarnEnabled()) { - String client = CurCall.get().toString(); + String client = CurCall.get().connection.toString(); LOG.warn( "Slow RPC : " + methodName + " took " + processingTime + " milliseconds to process from client " + client); @@ -655,65 +657,62 @@ static boolean getClientBackoffEnable( CommonConfigurationKeys.IPC_BACKOFF_ENABLE_DEFAULT); } - /** A generic call queued for handling. */ - public static class Call implements Schedulable, - PrivilegedExceptionAction { - final int callId; // the client's call id - final int retryCount; // the retry count of the call - long timestamp; // time received when response is null - // time served when response is not null + /** A call queued for handling. */ + public static class Call implements Schedulable { + private final int callId; // the client's call id + private final int retryCount; // the retry count of the call + private final Writable rpcRequest; // Serialized Rpc request from client + private final Connection connection; // connection to client + private long timestamp; // time received when response is null + // time served when response is not null + private ByteBuffer rpcResponse; // the response for this call private AtomicInteger responseWaitCount = new AtomicInteger(1); - final RPC.RpcKind rpcKind; - final byte[] clientId; + private final RPC.RpcKind rpcKind; + private final byte[] clientId; private final TraceScope traceScope; // the HTrace scope on the server side private final CallerContext callerContext; // the call context private int priorityLevel; // the priority level assigned by scheduler, 0 by default - Call(Call call) { - this(call.callId, call.retryCount, call.rpcKind, call.clientId, - call.traceScope, call.callerContext); + private Call(Call call) { + this(call.callId, call.retryCount, call.rpcRequest, call.connection, + call.rpcKind, call.clientId, call.traceScope, call.callerContext); } - Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId) { - this(id, retryCount, kind, clientId, null, null); + public Call(int id, int retryCount, Writable param, + Connection connection) { + this(id, retryCount, param, connection, RPC.RpcKind.RPC_BUILTIN, + RpcConstants.DUMMY_CLIENT_ID); } - @VisibleForTesting // primarily TestNamenodeRetryCache - public Call(int id, int retryCount, Void ignore1, Void ignore2, + public Call(int id, int retryCount, Writable param, Connection connection, RPC.RpcKind kind, byte[] clientId) { - this(id, retryCount, kind, clientId, null, null); + this(id, retryCount, param, connection, kind, clientId, null, null); } - Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, - TraceScope traceScope, CallerContext callerContext) { + public Call(int id, int retryCount, Writable param, Connection connection, + RPC.RpcKind kind, byte[] clientId, TraceScope traceScope, + CallerContext callerContext) { this.callId = id; this.retryCount = retryCount; + this.rpcRequest = param; + this.connection = connection; this.timestamp = Time.now(); + this.rpcResponse = null; this.rpcKind = kind; this.clientId = clientId; this.traceScope = traceScope; this.callerContext = callerContext; } - + @Override public String toString() { - return "Call#" + callId + " Retry#" + retryCount; + return rpcRequest + " from " + connection + " Call#" + callId + " Retry#" + + retryCount; } - public Void run() throws Exception { - return null; - } - // should eventually be abstract but need to avoid breaking tests - public UserGroupInformation getRemoteUser() { - return null; - } - public InetAddress getHostInetAddress() { - return null; - } - public String getHostAddress() { - InetAddress addr = getHostInetAddress(); - return (addr != null) ? addr.getHostAddress() : null; + public void setResponse(ByteBuffer response) { + this.rpcResponse = response; } /** @@ -725,36 +724,34 @@ public String getHostAddress() { */ @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public final void postponeResponse() { + public void postponeResponse() { int count = responseWaitCount.incrementAndGet(); assert count > 0 : "response has already been sent"; } @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public final void sendResponse() throws IOException { + public void sendResponse() throws IOException { int count = responseWaitCount.decrementAndGet(); assert count >= 0 : "response has already been sent"; if (count == 0) { - doResponse(null); + connection.sendResponse(this); } } @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public final void abortResponse(Throwable t) throws IOException { + public void abortResponse(Throwable t) throws IOException { // don't send response if the call was already sent or aborted. if (responseWaitCount.getAndSet(-1) > 0) { - doResponse(t); + connection.abortResponse(this, t); } } - void doResponse(Throwable t) throws IOException {} - // For Schedulable @Override public UserGroupInformation getUserGroupInformation() { - return getRemoteUser(); + return connection.user; } @Override @@ -767,114 +764,6 @@ public void setPriorityLevel(int priorityLevel) { } } - /** A RPC extended call queued for handling. */ - private class RpcCall extends Call { - final Connection connection; // connection to client - final Writable rpcRequest; // Serialized Rpc request from client - ByteBuffer rpcResponse; // the response for this call - - RpcCall(RpcCall call) { - super(call); - this.connection = call.connection; - this.rpcRequest = call.rpcRequest; - } - - RpcCall(Connection connection, int id) { - this(connection, id, RpcConstants.INVALID_RETRY_COUNT); - } - - RpcCall(Connection connection, int id, int retryCount) { - this(connection, id, retryCount, null, - RPC.RpcKind.RPC_BUILTIN, RpcConstants.DUMMY_CLIENT_ID, - null, null); - } - - RpcCall(Connection connection, int id, int retryCount, - Writable param, RPC.RpcKind kind, byte[] clientId, - TraceScope traceScope, CallerContext context) { - super(id, retryCount, kind, clientId, traceScope, context); - this.connection = connection; - this.rpcRequest = param; - } - - @Override - public UserGroupInformation getRemoteUser() { - return connection.user; - } - - @Override - public InetAddress getHostInetAddress() { - return connection.getHostInetAddress(); - } - - @Override - public Void run() throws Exception { - if (!connection.channel.isOpen()) { - Server.LOG.info(Thread.currentThread().getName() + ": skipped " + this); - return null; - } - String errorClass = null; - String error = null; - RpcStatusProto returnStatus = RpcStatusProto.SUCCESS; - RpcErrorCodeProto detailedErr = null; - Writable value = null; - - try { - value = call( - rpcKind, connection.protocolName, rpcRequest, timestamp); - } catch (Throwable e) { - if (e instanceof UndeclaredThrowableException) { - e = e.getCause(); - } - logException(Server.LOG, e, this); - if (e instanceof RpcServerException) { - RpcServerException rse = ((RpcServerException)e); - returnStatus = rse.getRpcStatusProto(); - detailedErr = rse.getRpcErrorCodeProto(); - } else { - returnStatus = RpcStatusProto.ERROR; - detailedErr = RpcErrorCodeProto.ERROR_APPLICATION; - } - errorClass = e.getClass().getName(); - error = StringUtils.stringifyException(e); - // Remove redundant error class name from the beginning of the - // stack trace - String exceptionHdr = errorClass + ": "; - if (error.startsWith(exceptionHdr)) { - error = error.substring(exceptionHdr.length()); - } - } - setupResponse(this, returnStatus, detailedErr, - value, errorClass, error); - sendResponse(); - return null; - } - - void setResponse(ByteBuffer response) throws IOException { - this.rpcResponse = response; - } - - @Override - void doResponse(Throwable t) throws IOException { - RpcCall call = this; - if (t != null) { - // clone the call to prevent a race with another thread stomping - // on the response while being sent. the original call is - // effectively discarded since the wait count won't hit zero - call = new RpcCall(this); - setupResponse(call, - RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, - null, t.getClass().getName(), StringUtils.stringifyException(t)); - } - connection.sendResponse(this); - } - - @Override - public String toString() { - return super.toString() + " " + rpcRequest + " from " + connection; - } - } - /** Listens on the socket. Creates jobs for the handler threads*/ private class Listener extends Thread { @@ -1205,22 +1094,22 @@ private void doRunLoop() { if(LOG.isDebugEnabled()) { LOG.debug("Checking for old call responses."); } - ArrayList calls; + ArrayList calls; // get the list of channels from list of keys. synchronized (writeSelector.keys()) { - calls = new ArrayList(writeSelector.keys().size()); + calls = new ArrayList(writeSelector.keys().size()); iter = writeSelector.keys().iterator(); while (iter.hasNext()) { SelectionKey key = iter.next(); - RpcCall call = (RpcCall)key.attachment(); + Call call = (Call)key.attachment(); if (call != null && key.channel() == call.connection.channel) { calls.add(call); } } } - - for (RpcCall call : calls) { + + for(Call call : calls) { doPurge(call, now); } } catch (OutOfMemoryError e) { @@ -1238,7 +1127,7 @@ private void doRunLoop() { } private void doAsyncWrite(SelectionKey key) throws IOException { - RpcCall call = (RpcCall)key.attachment(); + Call call = (Call)key.attachment(); if (call == null) { return; } @@ -1266,10 +1155,10 @@ private void doAsyncWrite(SelectionKey key) throws IOException { // Remove calls that have been pending in the responseQueue // for a long time. // - private void doPurge(RpcCall call, long now) { - LinkedList responseQueue = call.connection.responseQueue; + private void doPurge(Call call, long now) { + LinkedList responseQueue = call.connection.responseQueue; synchronized (responseQueue) { - Iterator iter = responseQueue.listIterator(0); + Iterator iter = responseQueue.listIterator(0); while (iter.hasNext()) { call = iter.next(); if (now > call.timestamp + PURGE_INTERVAL) { @@ -1283,12 +1172,12 @@ private void doPurge(RpcCall call, long now) { // Processes one response. Returns true if there are no more pending // data for this channel. // - private boolean processResponse(LinkedList responseQueue, + private boolean processResponse(LinkedList responseQueue, boolean inHandler) throws IOException { boolean error = true; boolean done = false; // there is more data for this channel. int numElements = 0; - RpcCall call = null; + Call call = null; try { synchronized (responseQueue) { // @@ -1371,7 +1260,7 @@ private boolean processResponse(LinkedList responseQueue, // // Enqueue a response from the application. // - void doRespond(RpcCall call) throws IOException { + void doRespond(Call call) throws IOException { synchronized (call.connection.responseQueue) { // must only wrap before adding to the responseQueue to prevent // postponed responses from being encrypted and sent out of order. @@ -1469,7 +1358,7 @@ public class Connection { private SocketChannel channel; private ByteBuffer data; private ByteBuffer dataLengthBuffer; - private LinkedList responseQueue; + private LinkedList responseQueue; // number of outstanding rpcs private AtomicInteger rpcCount = new AtomicInteger(); private long lastContact; @@ -1496,8 +1385,8 @@ public class Connection { public UserGroupInformation attemptingUser = null; // user name before auth // Fake 'call' for failed authorization response - private final RpcCall authFailedCall = - new RpcCall(this, AUTHORIZATION_FAILED_CALL_ID); + private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID, + RpcConstants.INVALID_RETRY_COUNT, null, this); private boolean sentNegotiate = false; private boolean useWrap = false; @@ -1520,7 +1409,7 @@ public Connection(SocketChannel channel, long lastContact) { this.hostAddress = addr.getHostAddress(); } this.remotePort = socket.getPort(); - this.responseQueue = new LinkedList(); + this.responseQueue = new LinkedList(); if (socketSendBufferSize != 0) { try { socket.setSendBufferSize(socketSendBufferSize); @@ -1815,7 +1704,8 @@ private RpcSaslProto buildSaslResponse(SaslState state, byte[] replyToken) { } private void doSaslReply(Message message) throws IOException { - final RpcCall saslCall = new RpcCall(this, AuthProtocol.SASL.callId); + final Call saslCall = new Call(AuthProtocol.SASL.callId, + RpcConstants.INVALID_RETRY_COUNT, null, this); setupResponse(saslCall, RpcStatusProto.SUCCESS, null, RpcWritable.wrap(message), null, null); @@ -2032,20 +1922,23 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { if (clientVersion >= 9) { // Versions >>9 understand the normal response - RpcCall fakeCall = new RpcCall(this, -1); + Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, + this); setupResponse(fakeCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, null, VersionMismatch.class.getName(), errMsg); fakeCall.sendResponse(); } else if (clientVersion >= 3) { - RpcCall fakeCall = new RpcCall(this, -1); + Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, + this); // Versions 3 to 8 use older response setupResponseOldVersionFatal(buffer, fakeCall, null, VersionMismatch.class.getName(), errMsg); fakeCall.sendResponse(); } else if (clientVersion == 2) { // Hadoop 0.18.3 - RpcCall fakeCall = new RpcCall(this, 0); + Call fakeCall = new Call(0, RpcConstants.INVALID_RETRY_COUNT, null, + this); DataOutputStream out = new DataOutputStream(buffer); out.writeInt(0); // call ID out.writeBoolean(true); // error @@ -2057,7 +1950,7 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { } private void setupHttpRequestOnIpcPortResponse() throws IOException { - RpcCall fakeCall = new RpcCall(this, 0); + Call fakeCall = new Call(0, RpcConstants.INVALID_RETRY_COUNT, null, this); fakeCall.setResponse(ByteBuffer.wrap( RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8))); fakeCall.sendResponse(); @@ -2205,7 +2098,7 @@ private void processOneRpc(ByteBuffer bb) } } catch (WrappedRpcServerException wrse) { // inform client of error Throwable ioe = wrse.getCause(); - final RpcCall call = new RpcCall(this, callId, retry); + final Call call = new Call(callId, retry, null, this); setupResponse(call, RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, ioe.getClass().getName(), ioe.getMessage()); @@ -2305,9 +2198,8 @@ private void processRpcRequest(RpcRequestHeaderProto header, .build(); } - RpcCall call = new RpcCall(this, header.getCallId(), - header.getRetryCount(), rpcRequest, - ProtoUtil.convert(header.getRpcKind()), + Call call = new Call(header.getCallId(), header.getRetryCount(), + rpcRequest, this, ProtoUtil.convert(header.getRpcKind()), header.getClientId().toByteArray(), traceScope, callerContext); // Save the priority level assignment by the scheduler @@ -2431,10 +2323,21 @@ T getMessage(Message message, } } - private void sendResponse(RpcCall call) throws IOException { + private void sendResponse(Call call) throws IOException { responder.doRespond(call); } + private void abortResponse(Call call, Throwable t) throws IOException { + // clone the call to prevent a race with the other thread stomping + // on the response while being sent. the original call is + // effectively discarded since the wait count won't hit zero + call = new Call(call); + setupResponse(call, + RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, + null, t.getClass().getName(), StringUtils.stringifyException(t)); + call.sendResponse(); + } + /** * Get service class for connection * @return the serviceClass @@ -2485,6 +2388,16 @@ public void run() { if (LOG.isDebugEnabled()) { LOG.debug(Thread.currentThread().getName() + ": " + call + " for RpcKind " + call.rpcKind); } + if (!call.connection.channel.isOpen()) { + LOG.info(Thread.currentThread().getName() + ": skipped " + call); + continue; + } + String errorClass = null; + String error = null; + RpcStatusProto returnStatus = RpcStatusProto.SUCCESS; + RpcErrorCodeProto detailedErr = null; + Writable value = null; + CurCall.set(call); if (call.traceScope != null) { call.traceScope.reattach(); @@ -2493,11 +2406,53 @@ public void run() { } // always update the current call context CallerContext.setCurrent(call.callerContext); - UserGroupInformation remoteUser = call.getRemoteUser(); - if (remoteUser != null) { - remoteUser.doAs(call); - } else { - call.run(); + + try { + // Make the call as the user via Subject.doAs, thus associating + // the call with the Subject + if (call.connection.user == null) { + value = call(call.rpcKind, call.connection.protocolName, call.rpcRequest, + call.timestamp); + } else { + value = + call.connection.user.doAs + (new PrivilegedExceptionAction() { + @Override + public Writable run() throws Exception { + // make the call + return call(call.rpcKind, call.connection.protocolName, + call.rpcRequest, call.timestamp); + + } + } + ); + } + } catch (Throwable e) { + if (e instanceof UndeclaredThrowableException) { + e = e.getCause(); + } + logException(LOG, e, call); + if (e instanceof RpcServerException) { + RpcServerException rse = ((RpcServerException)e); + returnStatus = rse.getRpcStatusProto(); + detailedErr = rse.getRpcErrorCodeProto(); + } else { + returnStatus = RpcStatusProto.ERROR; + detailedErr = RpcErrorCodeProto.ERROR_APPLICATION; + } + errorClass = e.getClass().getName(); + error = StringUtils.stringifyException(e); + // Remove redundant error class name from the beginning of the stack trace + String exceptionHdr = errorClass + ": "; + if (error.startsWith(exceptionHdr)) { + error = error.substring(exceptionHdr.length()); + } + } + CurCall.set(null); + synchronized (call.connection.responseQueue) { + setupResponse(call, returnStatus, detailedErr, + value, errorClass, error); + call.sendResponse(); } } catch (InterruptedException e) { if (running) { // unexpected -- log it @@ -2514,7 +2469,6 @@ public void run() { StringUtils.stringifyException(e)); } } finally { - CurCall.set(null); IOUtils.cleanup(LOG, traceScope); } } @@ -2716,7 +2670,7 @@ private void closeConnection(Connection connection) { * @throws IOException */ private void setupResponse( - RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode, + Call call, RpcStatusProto status, RpcErrorCodeProto erCode, Writable rv, String errorClass, String error) throws IOException { RpcResponseHeaderProto.Builder headerBuilder = @@ -2750,7 +2704,7 @@ private void setupResponse( } } - private void setupResponse(RpcCall call, + private void setupResponse(Call call, RpcResponseHeaderProto header, Writable rv) throws IOException { ResponseBuffer buf = responseBuffer.get().reset(); try { @@ -2784,7 +2738,7 @@ private void setupResponse(RpcCall call, * @throws IOException */ private void setupResponseOldVersionFatal(ByteArrayOutputStream response, - RpcCall call, + Call call, Writable rv, String errorClass, String error) throws IOException { final int OLD_VERSION_FATAL_STATUS = -1; @@ -2797,7 +2751,7 @@ private void setupResponseOldVersionFatal(ByteArrayOutputStream response, call.setResponse(ByteBuffer.wrap(response.toByteArray())); } - private void wrapWithSasl(RpcCall call) throws IOException { + private void wrapWithSasl(Call call) throws IOException { if (call.connection.saslServer != null) { byte[] token = call.rpcResponse.array(); // synchronization may be needed since there can be multiple Handler