diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/TestFederationRMFailoverProxyProvider.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/TestFederationRMFailoverProxyProvider.java index fa3523c9f7..e3f91557ee 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/TestFederationRMFailoverProxyProvider.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-client/src/test/java/org/apache/hadoop/yarn/client/TestFederationRMFailoverProxyProvider.java @@ -19,17 +19,21 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.security.PrivilegedExceptionAction; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.ha.HAServiceProtocol; +import org.apache.hadoop.io.retry.FailoverProxyProvider.ProxyInfo; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.yarn.api.ApplicationClientProtocol; +import org.apache.hadoop.yarn.api.ApplicationMasterProtocol; import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsRequest; import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsResponse; import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.server.MiniYARNCluster; import org.apache.hadoop.yarn.server.federation.failover.FederationProxyProviderUtil; +import org.apache.hadoop.yarn.server.federation.failover.FederationRMFailoverProxyProvider; import org.apache.hadoop.yarn.server.federation.store.FederationStateStore; import org.apache.hadoop.yarn.server.federation.store.impl.MemoryFederationStateStore; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId; @@ -44,6 +48,10 @@ import org.junit.Before; import org.junit.Test; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + /** * Unit tests for FederationRMFailoverProxyProvider. */ @@ -151,4 +159,65 @@ private void makeRMActive(final SubClusterId subClusterId, } } + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testUGIForProxyCreation() + throws IOException, InterruptedException { + conf.set(YarnConfiguration.RM_CLUSTER_ID, "cluster1"); + + UserGroupInformation currentUser = UserGroupInformation.getCurrentUser(); + UserGroupInformation user1 = + UserGroupInformation.createProxyUser("user1", currentUser); + UserGroupInformation user2 = + UserGroupInformation.createProxyUser("user2", currentUser); + + final TestableFederationRMFailoverProxyProvider provider = + new TestableFederationRMFailoverProxyProvider(); + + InetSocketAddress addr = + conf.getSocketAddr(YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_PORT); + final ClientRMProxy rmProxy = mock(ClientRMProxy.class); + when(rmProxy.getRMAddress(any(YarnConfiguration.class), any(Class.class))) + .thenReturn(addr); + + user1.doAs(new PrivilegedExceptionAction() { + @Override + public Object run() { + provider.init(conf, rmProxy, ApplicationMasterProtocol.class); + return null; + } + }); + + final ProxyInfo currentProxy = provider.getProxy(); + Assert.assertEquals("user1", provider.getLastProxyUGI().getUserName()); + + user2.doAs(new PrivilegedExceptionAction() { + @Override + public Object run() { + provider.performFailover(currentProxy.proxy); + return null; + } + }); + Assert.assertEquals("user1", provider.getLastProxyUGI().getUserName()); + + provider.close(); + } + + protected static class TestableFederationRMFailoverProxyProvider + extends FederationRMFailoverProxyProvider { + + private UserGroupInformation lastProxyUGI = null; + + @Override + protected T createRMProxy(InetSocketAddress rmAddress) throws IOException { + lastProxyUGI = UserGroupInformation.getCurrentUser(); + return super.createRMProxy(rmAddress); + } + + public UserGroupInformation getLastProxyUGI() { + return lastProxyUGI; + } + } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/failover/FederationRMFailoverProxyProvider.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/failover/FederationRMFailoverProxyProvider.java index 1915f67ae8..e00f8d15bf 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/failover/FederationRMFailoverProxyProvider.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/failover/FederationRMFailoverProxyProvider.java @@ -21,7 +21,7 @@ import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; -import java.util.Collection; +import java.security.PrivilegedExceptionAction; import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceStability.Unstable; @@ -29,14 +29,12 @@ import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.security.UserGroupInformation; -import org.apache.hadoop.security.token.Token; -import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.yarn.api.ApplicationClientProtocol; import org.apache.hadoop.yarn.api.ApplicationMasterProtocol; import org.apache.hadoop.yarn.client.RMFailoverProxyProvider; import org.apache.hadoop.yarn.client.RMProxy; import org.apache.hadoop.yarn.conf.YarnConfiguration; -import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; import org.apache.hadoop.yarn.server.api.ResourceManagerAdministrationProtocol; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterInfo; @@ -44,6 +42,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; /** @@ -64,7 +63,7 @@ public class FederationRMFailoverProxyProvider private YarnConfiguration conf; private FederationStateStoreFacade facade; private SubClusterId subClusterId; - private Collection> originalTokens; + private UserGroupInformation originalUser; private boolean federationFailoverEnabled = false; @Override @@ -97,59 +96,67 @@ public void init(Configuration configuration, RMProxy proxy, YarnConfiguration.DEFAULT_CLIENT_FAILOVER_RETRIES_ON_SOCKET_TIMEOUTS)); try { - UserGroupInformation currentUser = UserGroupInformation.getCurrentUser(); - originalTokens = currentUser.getTokens(); + this.originalUser = UserGroupInformation.getCurrentUser(); LOG.info("Initialized Federation proxy for user: {}", - currentUser.getUserName()); + this.originalUser.getUserName()); } catch (IOException e) { LOG.warn("Could not get information of requester, ignoring for now."); + this.originalUser = null; } } - private void addOriginalTokens(UserGroupInformation currentUser) { - if (originalTokens == null || originalTokens.isEmpty()) { - return; - } - for (Token token : originalTokens) { - currentUser.addToken(token); - } + @VisibleForTesting + protected T createRMProxy(InetSocketAddress rmAddress) throws IOException { + return rmProxy.getProxy(conf, protocol, rmAddress); } private T getProxyInternal(boolean isFailover) { SubClusterInfo subClusterInfo; - UserGroupInformation currentUser = null; + // Use the existing proxy as a backup in case getting the new proxy fails. + // Note that if the first time it fails, the backup is also null. In that + // case we will hit NullPointerException and throw it back to AM. + T proxy = this.current; try { LOG.info("Failing over to the ResourceManager for SubClusterId: {}", subClusterId); subClusterInfo = facade.getSubCluster(subClusterId, isFailover); // updating the conf with the refreshed RM addresses as proxy - // creations - // are based out of conf + // creations are based out of conf updateRMAddress(subClusterInfo); - currentUser = UserGroupInformation.getCurrentUser(); - addOriginalTokens(currentUser); - } catch (YarnException e) { + if (this.originalUser == null) { + InetSocketAddress rmAddress = rmProxy.getRMAddress(conf, protocol); + LOG.info( + "Connecting to {} subClusterId {} with protocol {}" + + " without a proxy user", + rmAddress, subClusterId, protocol.getSimpleName()); + proxy = createRMProxy(rmAddress); + } else { + // If the original ugi exists, always use that to create proxy because + // it contains up-to-date AMRMToken + proxy = this.originalUser.doAs(new PrivilegedExceptionAction() { + @Override + public T run() throws IOException { + InetSocketAddress rmAddress = rmProxy.getRMAddress(conf, protocol); + LOG.info( + "Connecting to {} subClusterId {} with protocol {} as user {}", + rmAddress, subClusterId, protocol.getSimpleName(), + originalUser); + return createRMProxy(rmAddress); + } + }); + } + } catch (Exception e) { LOG.error("Exception while trying to create proxy to the ResourceManager" + " for SubClusterId: {}", subClusterId, e); - return null; - } catch (IOException e) { - LOG.warn("Could not get information of requester, ignoring for now."); - } - try { - final InetSocketAddress rmAddress = rmProxy.getRMAddress(conf, protocol); - LOG.info("Connecting to {} with protocol {} as user: {}", rmAddress, - protocol.getSimpleName(), currentUser); - LOG.info("Failed over to the RM at {} for SubClusterId: {}", rmAddress, - subClusterId); - return rmProxy.getProxy(conf, protocol, rmAddress); - } catch (IOException ioe) { - LOG.error( - "IOException while trying to create proxy to the ResourceManager" - + " for SubClusterId: {}", - subClusterId, ioe); - return null; + if (proxy == null) { + throw new YarnRuntimeException( + String.format("Create initial proxy to the ResourceManager for" + + " SubClusterId %s failed", subClusterId), + e); + } } + return proxy; } private void updateRMAddress(SubClusterInfo subClusterInfo) { @@ -177,8 +184,11 @@ public synchronized ProxyInfo getProxy() { @Override public synchronized void performFailover(T currentProxy) { - closeInternal(currentProxy); + // It will not return null proxy here current = getProxyInternal(federationFailoverEnabled); + if (current != currentProxy) { + closeInternal(currentProxy); + } } @Override