Hadoop 16890. Change in expiry calculation for MSI token provider.
Contributed by Bilahari T H
This commit is contained in:
parent
cf9cf83a43
commit
0b931f36ec
@ -283,6 +283,12 @@
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.hamcrest</groupId>
|
||||
<artifactId>hamcrest-library</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
@ -72,7 +72,7 @@ public abstract class AccessTokenProvider {
|
||||
*
|
||||
* @return true if the token is expiring in next 5 minutes
|
||||
*/
|
||||
private boolean isTokenAboutToExpire() {
|
||||
protected boolean isTokenAboutToExpire() {
|
||||
if (token == null) {
|
||||
LOG.debug("AADToken: no token. Returning expiring=true");
|
||||
return true; // no token should have same response as expired token
|
||||
|
@ -137,7 +137,7 @@ public final class AzureADAuthenticator {
|
||||
headers.put("Metadata", "true");
|
||||
|
||||
LOG.debug("AADToken: starting to fetch token using MSI");
|
||||
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET");
|
||||
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -258,8 +258,13 @@ public final class AzureADAuthenticator {
|
||||
}
|
||||
|
||||
private static AzureADToken getTokenCall(String authEndpoint, String body,
|
||||
Hashtable<String, String> headers, String httpMethod)
|
||||
throws IOException {
|
||||
Hashtable<String, String> headers, String httpMethod) throws IOException {
|
||||
return getTokenCall(authEndpoint, body, headers, httpMethod, false);
|
||||
}
|
||||
|
||||
private static AzureADToken getTokenCall(String authEndpoint, String body,
|
||||
Hashtable<String, String> headers, String httpMethod, boolean isMsi)
|
||||
throws IOException {
|
||||
AzureADToken token = null;
|
||||
ExponentialRetryPolicy retryPolicy
|
||||
= new ExponentialRetryPolicy(3, 0, 1000, 2);
|
||||
@ -272,7 +277,7 @@ public final class AzureADAuthenticator {
|
||||
httperror = 0;
|
||||
ex = null;
|
||||
try {
|
||||
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod);
|
||||
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi);
|
||||
} catch (HttpException e) {
|
||||
httperror = e.httpErrorCode;
|
||||
ex = e;
|
||||
@ -288,8 +293,9 @@ public final class AzureADAuthenticator {
|
||||
return token;
|
||||
}
|
||||
|
||||
private static AzureADToken getTokenSingleCall(
|
||||
String authEndpoint, String payload, Hashtable<String, String> headers, String httpMethod)
|
||||
private static AzureADToken getTokenSingleCall(String authEndpoint,
|
||||
String payload, Hashtable<String, String> headers, String httpMethod,
|
||||
boolean isMsi)
|
||||
throws IOException {
|
||||
|
||||
AzureADToken token = null;
|
||||
@ -336,7 +342,7 @@ public final class AzureADAuthenticator {
|
||||
if (httpResponseCode == HttpURLConnection.HTTP_OK
|
||||
&& responseContentType.startsWith("application/json") && responseContentLength > 0) {
|
||||
InputStream httpResponseStream = conn.getInputStream();
|
||||
token = parseTokenFromStream(httpResponseStream);
|
||||
token = parseTokenFromStream(httpResponseStream, isMsi);
|
||||
} else {
|
||||
InputStream stream = conn.getErrorStream();
|
||||
if (stream == null) {
|
||||
@ -390,10 +396,12 @@ public final class AzureADAuthenticator {
|
||||
return token;
|
||||
}
|
||||
|
||||
private static AzureADToken parseTokenFromStream(InputStream httpResponseStream) throws IOException {
|
||||
private static AzureADToken parseTokenFromStream(
|
||||
InputStream httpResponseStream, boolean isMsi) throws IOException {
|
||||
AzureADToken token = new AzureADToken();
|
||||
try {
|
||||
int expiryPeriod = 0;
|
||||
int expiryPeriodInSecs = 0;
|
||||
long expiresOnInSecs = -1;
|
||||
|
||||
JsonFactory jf = new JsonFactory();
|
||||
JsonParser jp = jf.createJsonParser(httpResponseStream);
|
||||
@ -408,17 +416,38 @@ public final class AzureADAuthenticator {
|
||||
if (fieldName.equals("access_token")) {
|
||||
token.setAccessToken(fieldValue);
|
||||
}
|
||||
|
||||
if (fieldName.equals("expires_in")) {
|
||||
expiryPeriod = Integer.parseInt(fieldValue);
|
||||
expiryPeriodInSecs = Integer.parseInt(fieldValue);
|
||||
}
|
||||
|
||||
if (fieldName.equals("expires_on")) {
|
||||
expiresOnInSecs = Long.parseLong(fieldValue);
|
||||
}
|
||||
|
||||
}
|
||||
jp.nextToken();
|
||||
}
|
||||
jp.close();
|
||||
long expiry = System.currentTimeMillis();
|
||||
expiry = expiry + expiryPeriod * 1000L; // convert expiryPeriod to milliseconds and add
|
||||
token.setExpiry(new Date(expiry));
|
||||
LOG.debug("AADToken: fetched token with expiry " + token.getExpiry().toString());
|
||||
if (expiresOnInSecs > 0) {
|
||||
LOG.debug("Expiry based on expires_on: {}", expiresOnInSecs);
|
||||
token.setExpiry(new Date(expiresOnInSecs * 1000));
|
||||
} else {
|
||||
if (isMsi) {
|
||||
// Currently there is a known issue that MSI does not update expires_in
|
||||
// for refresh and will have the value from first AAD token fetch request.
|
||||
// Due to this known limitation, expires_in is not supported for MSI token fetch flow.
|
||||
throw new UnsupportedOperationException("MSI Responded with invalid expires_on");
|
||||
}
|
||||
|
||||
LOG.debug("Expiry based on expires_in: {}", expiryPeriodInSecs);
|
||||
long expiry = System.currentTimeMillis();
|
||||
expiry = expiry + expiryPeriodInSecs * 1000L; // convert expiryPeriod to milliseconds and add
|
||||
token.setExpiry(new Date(expiry));
|
||||
}
|
||||
|
||||
LOG.debug("AADToken: fetched token with expiry {}, expiresOn passed: {}",
|
||||
token.getExpiry().toString(), expiresOnInSecs);
|
||||
} catch (Exception ex) {
|
||||
LOG.debug("AADToken: got exception when parsing json token " + ex.toString());
|
||||
throw ex;
|
||||
|
@ -36,6 +36,10 @@ public class MsiTokenProvider extends AccessTokenProvider {
|
||||
|
||||
private final String clientId;
|
||||
|
||||
private long tokenFetchTime = -1;
|
||||
|
||||
private static final long ONE_HOUR = 3600 * 1000;
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);
|
||||
|
||||
public MsiTokenProvider(final String authEndpoint, final String tenantGuid,
|
||||
@ -51,6 +55,36 @@ public class MsiTokenProvider extends AccessTokenProvider {
|
||||
LOG.debug("AADToken: refreshing token from MSI");
|
||||
AzureADToken token = AzureADAuthenticator
|
||||
.getTokenFromMsi(authEndpoint, tenantGuid, clientId, authority, false);
|
||||
tokenFetchTime = System.currentTimeMillis();
|
||||
return token;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the token is about to expire as per base expiry logic.
|
||||
* Otherwise try to expire every 1 hour
|
||||
*
|
||||
* @return true if the token is expiring in next 1 hour or if a token has
|
||||
* never been fetched
|
||||
*/
|
||||
@Override
|
||||
protected boolean isTokenAboutToExpire() {
|
||||
if (tokenFetchTime == -1 || super.isTokenAboutToExpire()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
boolean expiring = false;
|
||||
long elapsedTimeSinceLastTokenRefreshInMillis =
|
||||
System.currentTimeMillis() - tokenFetchTime;
|
||||
expiring = elapsedTimeSinceLastTokenRefreshInMillis >= ONE_HOUR
|
||||
|| elapsedTimeSinceLastTokenRefreshInMillis < 0;
|
||||
// In case of, Token is not refreshed for 1 hr or any clock skew issues,
|
||||
// refresh token.
|
||||
if (expiring) {
|
||||
LOG.debug("MSIToken: token renewing. Time elapsed since last token fetch:"
|
||||
+ " {} milli seconds", elapsedTimeSinceLastTokenRefreshInMillis);
|
||||
}
|
||||
|
||||
return expiring;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,93 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.fs.azurebfs;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Date;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.fs.azurebfs.oauth2.AccessTokenProvider;
|
||||
import org.apache.hadoop.fs.azurebfs.oauth2.AzureADToken;
|
||||
import org.apache.hadoop.fs.azurebfs.oauth2.MsiTokenProvider;
|
||||
|
||||
import static org.junit.Assume.assumeThat;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.Matchers.isEmptyOrNullString;
|
||||
import static org.hamcrest.Matchers.isEmptyString;
|
||||
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY;
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT;
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID;
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY;
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT;
|
||||
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT;
|
||||
|
||||
/**
|
||||
* Test MsiTokenProvider.
|
||||
*/
|
||||
public final class ITestAbfsMsiTokenProvider
|
||||
extends AbstractAbfsIntegrationTest {
|
||||
|
||||
public ITestAbfsMsiTokenProvider() throws Exception {
|
||||
super();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() throws IOException {
|
||||
AbfsConfiguration conf = getConfiguration();
|
||||
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT),
|
||||
not(isEmptyOrNullString()));
|
||||
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT),
|
||||
not(isEmptyOrNullString()));
|
||||
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID),
|
||||
not(isEmptyOrNullString()));
|
||||
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY),
|
||||
not(isEmptyOrNullString()));
|
||||
|
||||
String tenantGuid = conf
|
||||
.getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
|
||||
String clientId = conf.getPasswordString(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID);
|
||||
String authEndpoint = getTrimmedPasswordString(conf,
|
||||
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
|
||||
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
|
||||
String authority = getTrimmedPasswordString(conf,
|
||||
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
|
||||
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
|
||||
AccessTokenProvider tokenProvider = new MsiTokenProvider(authEndpoint,
|
||||
tenantGuid, clientId, authority);
|
||||
|
||||
AzureADToken token = null;
|
||||
token = tokenProvider.getToken();
|
||||
assertThat(token.getAccessToken(), not(isEmptyString()));
|
||||
assertThat(token.getExpiry().after(new Date()), is(true));
|
||||
}
|
||||
|
||||
private String getTrimmedPasswordString(AbfsConfiguration conf, String key,
|
||||
String defaultValue) throws IOException {
|
||||
String value = conf.getPasswordString(key);
|
||||
if (StringUtils.isBlank(value)) {
|
||||
value = defaultValue;
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user