From cedebf1c27361077d0a9c6a6ec7cc9bb8ec200e5 Mon Sep 17 00:00:00 2001 From: hchaverr Date: Thu, 6 May 2021 16:40:45 -0700 Subject: [PATCH] HADOOP-17680. Allow ProtobufRpcEngine to be extensible (#2905) Contributed by Hector Chaverri. (cherry picked from commit f40e3eb0590f85bb42d2471992bf5d524628fdd6) --- .../apache/hadoop/ipc/ProtobufRpcEngine.java | 30 +++++++++++++++---- .../apache/hadoop/ipc/ProtobufRpcEngine2.java | 30 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java index b7b7ad4db6..d539bb2e81 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java @@ -122,7 +122,7 @@ public ProtocolProxy getProtocolMetaInfoProxy( factory)), false); } - private static class Invoker implements RpcInvocationHandler { + protected static class Invoker implements RpcInvocationHandler { private final Map returnTypes = new ConcurrentHashMap(); private boolean isClosed = false; @@ -133,7 +133,7 @@ private static class Invoker implements RpcInvocationHandler { private AtomicBoolean fallbackToSimpleAuth; private AlignmentContext alignmentContext; - private Invoker(Class protocol, InetSocketAddress addr, + protected Invoker(Class protocol, InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) @@ -148,7 +148,7 @@ private Invoker(Class protocol, InetSocketAddress addr, /** * This constructor takes a connectionId, instead of creating a new one. */ - private Invoker(Class protocol, Client.ConnectionId connId, + protected Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); @@ -225,8 +225,6 @@ public Message invoke(Object proxy, final Method method, Object[] args) traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); } - RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); - if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Call -> " + remoteId + ": " + method.getName() + @@ -238,7 +236,7 @@ public Message invoke(Object proxy, final Method method, Object[] args) final RpcWritable.Buffer val; try { val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, + constructRpcRequest(method, theRequest), remoteId, fallbackToSimpleAuth, alignmentContext); } catch (Throwable e) { @@ -283,6 +281,11 @@ public boolean isDone() { } } + protected Writable constructRpcRequest(Method method, Message theRequest) { + RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); + return new RpcProtobufRequest(rpcRequestHeader, theRequest); + } + private Message getReturnMessage(final Method method, final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; @@ -332,6 +335,14 @@ private Message getReturnProtoType(Method method) throws Exception { public ConnectionId getConnectionId() { return remoteId; } + + protected long getClientProtocolVersion() { + return clientProtocolVersion; + } + + protected String getProtocolName() { + return protocolName; + } } @VisibleForTesting @@ -518,6 +529,13 @@ public Writable call(RPC.Server server, String connectionProtocolName, String declaringClassProtoName = rpcRequest.getDeclaringClassProtocolName(); long clientVersion = rpcRequest.getClientProtocolVersion(); + return call(server, connectionProtocolName, request, receiveTime, + methodName, declaringClassProtoName, clientVersion); + } + + protected Writable call(RPC.Server server, String connectionProtocolName, + RpcWritable.Buffer request, long receiveTime, String methodName, + String declaringClassProtoName, long clientVersion) throws Exception { if (server.verbose) LOG.info("Call: connectionProtocolName=" + connectionProtocolName + ", method=" + methodName); diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java index 5043051ce0..5fdf8540fd 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine2.java @@ -116,7 +116,7 @@ public ProtocolProxy getProtocolMetaInfoProxy( factory)), false); } - private static final class Invoker implements RpcInvocationHandler { + protected static class Invoker implements RpcInvocationHandler { private final Map returnTypes = new ConcurrentHashMap(); private boolean isClosed = false; @@ -127,7 +127,7 @@ private static final class Invoker implements RpcInvocationHandler { private AtomicBoolean fallbackToSimpleAuth; private AlignmentContext alignmentContext; - private Invoker(Class protocol, InetSocketAddress addr, + protected Invoker(Class protocol, InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) @@ -142,7 +142,7 @@ private Invoker(Class protocol, InetSocketAddress addr, /** * This constructor takes a connectionId, instead of creating a new one. */ - private Invoker(Class protocol, Client.ConnectionId connId, + protected Invoker(Class protocol, Client.ConnectionId connId, Configuration conf, SocketFactory factory) { this.remoteId = connId; this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); @@ -219,8 +219,6 @@ public Message invoke(Object proxy, final Method method, Object[] args) traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); } - RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); - if (LOG.isTraceEnabled()) { LOG.trace(Thread.currentThread().getId() + ": Call -> " + remoteId + ": " + method.getName() + @@ -232,7 +230,7 @@ public Message invoke(Object proxy, final Method method, Object[] args) final RpcWritable.Buffer val; try { val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, - new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, + constructRpcRequest(method, theRequest), remoteId, fallbackToSimpleAuth, alignmentContext); } catch (Throwable e) { @@ -279,6 +277,11 @@ public boolean isDone() { } } + protected Writable constructRpcRequest(Method method, Message theRequest) { + RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method); + return new RpcProtobufRequest(rpcRequestHeader, theRequest); + } + private Message getReturnMessage(final Method method, final RpcWritable.Buffer buf) throws ServiceException { Message prototype = null; @@ -328,6 +331,14 @@ private Message getReturnProtoType(Method method) throws Exception { public ConnectionId getConnectionId() { return remoteId; } + + protected long getClientProtocolVersion() { + return clientProtocolVersion; + } + + protected String getProtocolName() { + return protocolName; + } } @VisibleForTesting @@ -509,6 +520,13 @@ public Writable call(RPC.Server server, String connectionProtocolName, String declaringClassProtoName = rpcRequest.getDeclaringClassProtocolName(); long clientVersion = rpcRequest.getClientProtocolVersion(); + return call(server, connectionProtocolName, request, receiveTime, + methodName, declaringClassProtoName, clientVersion); + } + + protected Writable call(RPC.Server server, String connectionProtocolName, + RpcWritable.Buffer request, long receiveTime, String methodName, + String declaringClassProtoName, long clientVersion) throws Exception { if (server.verbose) { LOG.info("Call: connectionProtocolName=" + connectionProtocolName + ", method=" + methodName);