HADOOP-17680. Allow ProtobufRpcEngine to be extensible (#2905) Contributed by Hector Chaverri.

(cherry picked from commit f40e3eb059)
This commit is contained in:
hchaverr 2021-05-06 16:40:45 -07:00 committed by Konstantin V Shvachko
parent 217655269a
commit cedebf1c27
2 changed files with 48 additions and 12 deletions

View File

@ -122,7 +122,7 @@ public ProtocolProxy<ProtocolMetaInfoPB> getProtocolMetaInfoProxy(
factory)), false); factory)), false);
} }
private static class Invoker implements RpcInvocationHandler { protected static class Invoker implements RpcInvocationHandler {
private final Map<String, Message> returnTypes = private final Map<String, Message> returnTypes =
new ConcurrentHashMap<String, Message>(); new ConcurrentHashMap<String, Message>();
private boolean isClosed = false; private boolean isClosed = false;
@ -133,7 +133,7 @@ private static class Invoker implements RpcInvocationHandler {
private AtomicBoolean fallbackToSimpleAuth; private AtomicBoolean fallbackToSimpleAuth;
private AlignmentContext alignmentContext; private AlignmentContext alignmentContext;
private Invoker(Class<?> protocol, InetSocketAddress addr, protected Invoker(Class<?> protocol, InetSocketAddress addr,
UserGroupInformation ticket, Configuration conf, SocketFactory factory, UserGroupInformation ticket, Configuration conf, SocketFactory factory,
int rpcTimeout, RetryPolicy connectionRetryPolicy, int rpcTimeout, RetryPolicy connectionRetryPolicy,
AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext)
@ -148,7 +148,7 @@ private Invoker(Class<?> protocol, InetSocketAddress addr,
/** /**
* This constructor takes a connectionId, instead of creating a new one. * This constructor takes a connectionId, instead of creating a new one.
*/ */
private Invoker(Class<?> protocol, Client.ConnectionId connId, protected Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) { Configuration conf, SocketFactory factory) {
this.remoteId = connId; this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
@ -225,8 +225,6 @@ public Message invoke(Object proxy, final Method method, Object[] args)
traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method));
} }
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
LOG.trace(Thread.currentThread().getId() + ": Call -> " + LOG.trace(Thread.currentThread().getId() + ": Call -> " +
remoteId + ": " + method.getName() + remoteId + ": " + method.getName() +
@ -238,7 +236,7 @@ public Message invoke(Object proxy, final Method method, Object[] args)
final RpcWritable.Buffer val; final RpcWritable.Buffer val;
try { try {
val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, constructRpcRequest(method, theRequest), remoteId,
fallbackToSimpleAuth, alignmentContext); fallbackToSimpleAuth, alignmentContext);
} catch (Throwable e) { } catch (Throwable e) {
@ -283,6 +281,11 @@ public boolean isDone() {
} }
} }
protected Writable constructRpcRequest(Method method, Message theRequest) {
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
return new RpcProtobufRequest(rpcRequestHeader, theRequest);
}
private Message getReturnMessage(final Method method, private Message getReturnMessage(final Method method,
final RpcWritable.Buffer buf) throws ServiceException { final RpcWritable.Buffer buf) throws ServiceException {
Message prototype = null; Message prototype = null;
@ -332,6 +335,14 @@ private Message getReturnProtoType(Method method) throws Exception {
public ConnectionId getConnectionId() { public ConnectionId getConnectionId() {
return remoteId; return remoteId;
} }
protected long getClientProtocolVersion() {
return clientProtocolVersion;
}
protected String getProtocolName() {
return protocolName;
}
} }
@VisibleForTesting @VisibleForTesting
@ -518,6 +529,13 @@ public Writable call(RPC.Server server, String connectionProtocolName,
String declaringClassProtoName = String declaringClassProtoName =
rpcRequest.getDeclaringClassProtocolName(); rpcRequest.getDeclaringClassProtocolName();
long clientVersion = rpcRequest.getClientProtocolVersion(); long clientVersion = rpcRequest.getClientProtocolVersion();
return call(server, connectionProtocolName, request, receiveTime,
methodName, declaringClassProtoName, clientVersion);
}
protected Writable call(RPC.Server server, String connectionProtocolName,
RpcWritable.Buffer request, long receiveTime, String methodName,
String declaringClassProtoName, long clientVersion) throws Exception {
if (server.verbose) if (server.verbose)
LOG.info("Call: connectionProtocolName=" + connectionProtocolName + LOG.info("Call: connectionProtocolName=" + connectionProtocolName +
", method=" + methodName); ", method=" + methodName);

View File

@ -116,7 +116,7 @@ public ProtocolProxy<ProtocolMetaInfoPB> getProtocolMetaInfoProxy(
factory)), false); factory)), false);
} }
private static final class Invoker implements RpcInvocationHandler { protected static class Invoker implements RpcInvocationHandler {
private final Map<String, Message> returnTypes = private final Map<String, Message> returnTypes =
new ConcurrentHashMap<String, Message>(); new ConcurrentHashMap<String, Message>();
private boolean isClosed = false; private boolean isClosed = false;
@ -127,7 +127,7 @@ private static final class Invoker implements RpcInvocationHandler {
private AtomicBoolean fallbackToSimpleAuth; private AtomicBoolean fallbackToSimpleAuth;
private AlignmentContext alignmentContext; private AlignmentContext alignmentContext;
private Invoker(Class<?> protocol, InetSocketAddress addr, protected Invoker(Class<?> protocol, InetSocketAddress addr,
UserGroupInformation ticket, Configuration conf, SocketFactory factory, UserGroupInformation ticket, Configuration conf, SocketFactory factory,
int rpcTimeout, RetryPolicy connectionRetryPolicy, int rpcTimeout, RetryPolicy connectionRetryPolicy,
AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext) AtomicBoolean fallbackToSimpleAuth, AlignmentContext alignmentContext)
@ -142,7 +142,7 @@ private Invoker(Class<?> protocol, InetSocketAddress addr,
/** /**
* This constructor takes a connectionId, instead of creating a new one. * This constructor takes a connectionId, instead of creating a new one.
*/ */
private Invoker(Class<?> protocol, Client.ConnectionId connId, protected Invoker(Class<?> protocol, Client.ConnectionId connId,
Configuration conf, SocketFactory factory) { Configuration conf, SocketFactory factory) {
this.remoteId = connId; this.remoteId = connId;
this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class); this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
@ -219,8 +219,6 @@ public Message invoke(Object proxy, final Method method, Object[] args)
traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method)); traceScope = tracer.newScope(RpcClientUtil.methodToTraceString(method));
} }
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
LOG.trace(Thread.currentThread().getId() + ": Call -> " + LOG.trace(Thread.currentThread().getId() + ": Call -> " +
remoteId + ": " + method.getName() + remoteId + ": " + method.getName() +
@ -232,7 +230,7 @@ public Message invoke(Object proxy, final Method method, Object[] args)
final RpcWritable.Buffer val; final RpcWritable.Buffer val;
try { try {
val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER, val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId, constructRpcRequest(method, theRequest), remoteId,
fallbackToSimpleAuth, alignmentContext); fallbackToSimpleAuth, alignmentContext);
} catch (Throwable e) { } catch (Throwable e) {
@ -279,6 +277,11 @@ public boolean isDone() {
} }
} }
protected Writable constructRpcRequest(Method method, Message theRequest) {
RequestHeaderProto rpcRequestHeader = constructRpcRequestHeader(method);
return new RpcProtobufRequest(rpcRequestHeader, theRequest);
}
private Message getReturnMessage(final Method method, private Message getReturnMessage(final Method method,
final RpcWritable.Buffer buf) throws ServiceException { final RpcWritable.Buffer buf) throws ServiceException {
Message prototype = null; Message prototype = null;
@ -328,6 +331,14 @@ private Message getReturnProtoType(Method method) throws Exception {
public ConnectionId getConnectionId() { public ConnectionId getConnectionId() {
return remoteId; return remoteId;
} }
protected long getClientProtocolVersion() {
return clientProtocolVersion;
}
protected String getProtocolName() {
return protocolName;
}
} }
@VisibleForTesting @VisibleForTesting
@ -509,6 +520,13 @@ public Writable call(RPC.Server server, String connectionProtocolName,
String declaringClassProtoName = String declaringClassProtoName =
rpcRequest.getDeclaringClassProtocolName(); rpcRequest.getDeclaringClassProtocolName();
long clientVersion = rpcRequest.getClientProtocolVersion(); long clientVersion = rpcRequest.getClientProtocolVersion();
return call(server, connectionProtocolName, request, receiveTime,
methodName, declaringClassProtoName, clientVersion);
}
protected Writable call(RPC.Server server, String connectionProtocolName,
RpcWritable.Buffer request, long receiveTime, String methodName,
String declaringClassProtoName, long clientVersion) throws Exception {
if (server.verbose) { if (server.verbose) {
LOG.info("Call: connectionProtocolName=" + connectionProtocolName + LOG.info("Call: connectionProtocolName=" + connectionProtocolName +
", method=" + methodName); ", method=" + methodName);