HADOOP-19079. HttpExceptionUtils to verify that loaded class is really an exception before instantiation (#6557)

Security hardening

+ Adds new interceptAndValidateMessageContains() method in LambdaTestUtils to verify a list of strings
  can all be found in the toString() value of a raised exception

Contributed by PJ Fanning
This commit is contained in:
PJ Fanning 2024-04-11 20:38:15 +02:00 committed by GitHub
parent 81b05977f2
commit d194ad0242
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 59 deletions

View File

@ -26,7 +26,9 @@
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.Writer; import java.io.Writer;
import java.lang.reflect.Constructor; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
@ -54,6 +56,10 @@ public class HttpExceptionUtils {
private static final String ENTER = System.getProperty("line.separator"); private static final String ENTER = System.getProperty("line.separator");
private static final MethodHandles.Lookup PUBLIC_LOOKUP = MethodHandles.publicLookup();
private static final MethodType EXCEPTION_CONSTRUCTOR_TYPE =
MethodType.methodType(void.class, String.class);
/** /**
* Creates a HTTP servlet response serializing the exception in it as JSON. * Creates a HTTP servlet response serializing the exception in it as JSON.
* *
@ -150,9 +156,12 @@ public static void validateResponse(HttpURLConnection conn,
try { try {
ClassLoader cl = HttpExceptionUtils.class.getClassLoader(); ClassLoader cl = HttpExceptionUtils.class.getClassLoader();
Class klass = cl.loadClass(exClass); Class klass = cl.loadClass(exClass);
Constructor constr = klass.getConstructor(String.class); Preconditions.checkState(Exception.class.isAssignableFrom(klass),
toThrow = (Exception) constr.newInstance(exMsg); "Class [%s] is not a subclass of Exception", klass);
} catch (Exception ex) { MethodHandle methodHandle = PUBLIC_LOOKUP.findConstructor(
klass, EXCEPTION_CONSTRUCTOR_TYPE);
toThrow = (Exception) methodHandle.invoke(exMsg);
} catch (Throwable t) {
toThrow = new IOException(String.format( toThrow = new IOException(String.format(
"HTTP status [%d], exception [%s], message [%s], URL [%s]", "HTTP status [%d], exception [%s], message [%s], URL [%s]",
conn.getResponseCode(), exClass, exMsg, conn.getURL())); conn.getResponseCode(), exClass, exMsg, conn.getURL()));

View File

@ -54,7 +54,7 @@ public void testArgChecks() throws Exception {
() -> cache.put(42, null, null, null)); () -> cache.put(42, null, null, null));
intercept(NullPointerException.class, null, intercept(NullPointerException.class,
() -> new SingleFilePerBlockCache(null, 2, null)); () -> new SingleFilePerBlockCache(null, 2, null));
} }

View File

@ -18,16 +18,17 @@
package org.apache.hadoop.test; package org.apache.hadoop.test;
import org.apache.hadoop.util.Preconditions;
import org.junit.Assert; import org.junit.Assert;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.util.Preconditions;
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import java.io.IOException; import java.io.IOException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.Collection;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
@ -35,6 +36,7 @@
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
/** /**
* Class containing methods and associated classes to make the most of Lambda * Class containing methods and associated classes to make the most of Lambda
@ -476,7 +478,7 @@ public static <T, E extends Throwable> E intercept(
* <i>or a subclass</i>. * <i>or a subclass</i>.
* @param contained string which must be in the {@code toString()} value * @param contained string which must be in the {@code toString()} value
* of the exception * of the exception
* @param message any message tho include in exception/log messages * @param message any message to include in exception/log messages
* @param eval expression to eval * @param eval expression to eval
* @param <T> return type of expression * @param <T> return type of expression
* @param <E> exception class * @param <E> exception class
@ -543,7 +545,7 @@ public static <E extends Throwable> E intercept(
* <i>or a subclass</i>. * <i>or a subclass</i>.
* @param contained string which must be in the {@code toString()} value * @param contained string which must be in the {@code toString()} value
* of the exception * of the exception
* @param message any message tho include in exception/log messages * @param message any message to include in exception/log messages
* @param eval expression to eval * @param eval expression to eval
* @param <E> exception class * @param <E> exception class
* @return the caught exception if it was of the expected type * @return the caught exception if it was of the expected type
@ -563,6 +565,105 @@ public static <E extends Throwable> E intercept(
}); });
} }
/**
* Intercept an exception; throw an {@code AssertionError} if one not raised.
* The caught exception is rethrown if it is of the wrong class or
* does not contain the text defined in {@code contained}.
* <p>
* Example: expect deleting a nonexistent file to raise a
* {@code FileNotFoundException} with the {@code toString()} value
* containing the text {@code "missing"}.
* <pre>
* FileNotFoundException ioe = interceptAndValidateMessageContains(
* FileNotFoundException.class,
* "missing",
* "path should not be found",
* () -> {
* filesystem.delete(new Path("/missing"), false);
* });
* </pre>
*
* @param clazz class of exception; the raised exception must be this class
* <i>or a subclass</i>.
* @param contains strings which must be in the {@code toString()} value
* of the exception (order does not matter)
* @param eval expression to eval
* @param <T> return type of expression
* @param <E> exception class
* @return the caught exception if it was of the expected type and contents
* @throws Exception any other exception raised
* @throws AssertionError if the evaluation call didn't raise an exception.
* The error includes the {@code toString()} value of the result, if this
* can be determined.
* @see GenericTestUtils#assertExceptionContains(String, Throwable)
*/
public static <T, E extends Throwable> E interceptAndValidateMessageContains(
Class<E> clazz,
Collection<String> contains,
VoidCallable eval)
throws Exception {
String message = "Expecting " + clazz.getName()
+ (contains.isEmpty() ? "" : (" with text values " + toString(contains)))
+ " but got ";
return interceptAndValidateMessageContains(clazz, contains, message, eval);
}
/**
* Intercept an exception; throw an {@code AssertionError} if one not raised.
* The caught exception is rethrown if it is of the wrong class or
* does not contain the text defined in {@code contained}.
* <p>
* Example: expect deleting a nonexistent file to raise a
* {@code FileNotFoundException} with the {@code toString()} value
* containing the text {@code "missing"}.
* <pre>
* FileNotFoundException ioe = interceptAndValidateMessageContains(
* FileNotFoundException.class,
* "missing",
* "path should not be found",
* () -> {
* filesystem.delete(new Path("/missing"), false);
* });
* </pre>
*
* @param clazz class of exception; the raised exception must be this class
* <i>or a subclass</i>.
* @param contains strings which must be in the {@code toString()} value
* of the exception (order does not matter)
* @param message any message to include in exception/log messages
* @param eval expression to eval
* @param <T> return type of expression
* @param <E> exception class
* @return the caught exception if it was of the expected type and contents
* @throws Exception any other exception raised
* @throws AssertionError if the evaluation call didn't raise an exception.
* The error includes the {@code toString()} value of the result, if this
* can be determined.
* @see GenericTestUtils#assertExceptionContains(String, Throwable)
*/
public static <T, E extends Throwable> E interceptAndValidateMessageContains(
Class<E> clazz,
Collection<String> contains,
String message,
VoidCallable eval)
throws Exception {
E ex;
try {
eval.call();
throw new AssertionError(message);
} catch (Throwable e) {
if (!clazz.isAssignableFrom(e.getClass())) {
throw e;
} else {
ex = (E) e;
}
}
for (String contained : contains) {
GenericTestUtils.assertExceptionContains(contained, ex, message);
}
return ex;
}
/** /**
* Robust string converter for exception messages; if the {@code toString()} * Robust string converter for exception messages; if the {@code toString()}
* method throws an exception then that exception is caught and logged, * method throws an exception then that exception is caught and logged,
@ -607,7 +708,6 @@ public static <T> void assertOptionalEquals(String message,
* Assert that an optional value matches an expected one; * Assert that an optional value matches an expected one;
* checks include null and empty on the actual value. * checks include null and empty on the actual value.
* @param message message text * @param message message text
* @param expected expected value
* @param actual actual optional value * @param actual actual optional value
* @param <T> type * @param <T> type
*/ */
@ -641,7 +741,6 @@ public static <T> T eval(Callable<T> closure) {
* Invoke a callable; wrap all checked exceptions with an * Invoke a callable; wrap all checked exceptions with an
* AssertionError. * AssertionError.
* @param closure closure to execute * @param closure closure to execute
* @return the value of the closure
* @throws AssertionError if the operation raised an IOE or * @throws AssertionError if the operation raised an IOE or
* other checked exception. * other checked exception.
*/ */
@ -823,6 +922,11 @@ public static <E extends Throwable> E verifyCause(
} }
} }
private static String toString(Collection<String> strings) {
return strings.stream()
.collect(Collectors.joining(",", "[", "]"));
}
/** /**
* Returns {@code TimeoutException} on a timeout. If * Returns {@code TimeoutException} on a timeout. If
* there was a inner class passed in, includes it as the * there was a inner class passed in, includes it as the
@ -1037,3 +1141,4 @@ public Void run() throws Exception {
} }
} }
} }

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.util; package org.apache.hadoop.util;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.hadoop.test.LambdaTestUtils;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
@ -31,6 +32,7 @@
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.StringWriter; import java.io.StringWriter;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -82,40 +84,34 @@ public void testCreateJerseyException() throws IOException {
@Test @Test
public void testValidateResponseOK() throws IOException { public void testValidateResponseOK() throws IOException {
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class); HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getResponseCode()).thenReturn( Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_CREATED);
HttpURLConnection.HTTP_CREATED);
HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED);
}
@Test(expected = IOException.class)
public void testValidateResponseFailNoErrorMessage() throws IOException {
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getResponseCode()).thenReturn(
HttpURLConnection.HTTP_BAD_REQUEST);
HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED); HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED);
} }
@Test @Test
public void testValidateResponseNonJsonErrorMessage() throws IOException { public void testValidateResponseFailNoErrorMessage() throws Exception {
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
LambdaTestUtils.intercept(IOException.class,
() -> HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED));
}
@Test
public void testValidateResponseNonJsonErrorMessage() throws Exception {
String msg = "stream"; String msg = "stream";
InputStream is = new ByteArrayInputStream(msg.getBytes()); InputStream is = new ByteArrayInputStream(msg.getBytes(StandardCharsets.UTF_8));
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class); HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getErrorStream()).thenReturn(is); Mockito.when(conn.getErrorStream()).thenReturn(is);
Mockito.when(conn.getResponseMessage()).thenReturn("msg"); Mockito.when(conn.getResponseMessage()).thenReturn("msg");
Mockito.when(conn.getResponseCode()).thenReturn( Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
HttpURLConnection.HTTP_BAD_REQUEST); LambdaTestUtils.interceptAndValidateMessageContains(IOException.class,
try { Arrays.asList(Integer.toString(HttpURLConnection.HTTP_BAD_REQUEST), "msg",
HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED); "com.fasterxml.jackson.core.JsonParseException"),
Assert.fail(); () -> HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED));
} catch (IOException ex) {
Assert.assertTrue(ex.getMessage().contains("msg"));
Assert.assertTrue(ex.getMessage().contains("" +
HttpURLConnection.HTTP_BAD_REQUEST));
}
} }
@Test @Test
public void testValidateResponseJsonErrorKnownException() throws IOException { public void testValidateResponseJsonErrorKnownException() throws Exception {
Map<String, Object> json = new HashMap<String, Object>(); Map<String, Object> json = new HashMap<String, Object>();
json.put(HttpExceptionUtils.ERROR_EXCEPTION_JSON, IllegalStateException.class.getSimpleName()); json.put(HttpExceptionUtils.ERROR_EXCEPTION_JSON, IllegalStateException.class.getSimpleName());
json.put(HttpExceptionUtils.ERROR_CLASSNAME_JSON, IllegalStateException.class.getName()); json.put(HttpExceptionUtils.ERROR_CLASSNAME_JSON, IllegalStateException.class.getName());
@ -124,23 +120,19 @@ public void testValidateResponseJsonErrorKnownException() throws IOException {
response.put(HttpExceptionUtils.ERROR_JSON, json); response.put(HttpExceptionUtils.ERROR_JSON, json);
ObjectMapper jsonMapper = new ObjectMapper(); ObjectMapper jsonMapper = new ObjectMapper();
String msg = jsonMapper.writeValueAsString(response); String msg = jsonMapper.writeValueAsString(response);
InputStream is = new ByteArrayInputStream(msg.getBytes()); InputStream is = new ByteArrayInputStream(msg.getBytes(StandardCharsets.UTF_8));
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class); HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getErrorStream()).thenReturn(is); Mockito.when(conn.getErrorStream()).thenReturn(is);
Mockito.when(conn.getResponseMessage()).thenReturn("msg"); Mockito.when(conn.getResponseMessage()).thenReturn("msg");
Mockito.when(conn.getResponseCode()).thenReturn( Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
HttpURLConnection.HTTP_BAD_REQUEST); LambdaTestUtils.intercept(IllegalStateException.class,
try { "EX",
HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED); () -> HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED));
Assert.fail();
} catch (IllegalStateException ex) {
Assert.assertEquals("EX", ex.getMessage());
}
} }
@Test @Test
public void testValidateResponseJsonErrorUnknownException() public void testValidateResponseJsonErrorUnknownException()
throws IOException { throws Exception {
Map<String, Object> json = new HashMap<String, Object>(); Map<String, Object> json = new HashMap<String, Object>();
json.put(HttpExceptionUtils.ERROR_EXCEPTION_JSON, "FooException"); json.put(HttpExceptionUtils.ERROR_EXCEPTION_JSON, "FooException");
json.put(HttpExceptionUtils.ERROR_CLASSNAME_JSON, "foo.FooException"); json.put(HttpExceptionUtils.ERROR_CLASSNAME_JSON, "foo.FooException");
@ -149,19 +141,36 @@ public void testValidateResponseJsonErrorUnknownException()
response.put(HttpExceptionUtils.ERROR_JSON, json); response.put(HttpExceptionUtils.ERROR_JSON, json);
ObjectMapper jsonMapper = new ObjectMapper(); ObjectMapper jsonMapper = new ObjectMapper();
String msg = jsonMapper.writeValueAsString(response); String msg = jsonMapper.writeValueAsString(response);
InputStream is = new ByteArrayInputStream(msg.getBytes()); InputStream is = new ByteArrayInputStream(msg.getBytes(StandardCharsets.UTF_8));
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class); HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getErrorStream()).thenReturn(is); Mockito.when(conn.getErrorStream()).thenReturn(is);
Mockito.when(conn.getResponseMessage()).thenReturn("msg"); Mockito.when(conn.getResponseMessage()).thenReturn("msg");
Mockito.when(conn.getResponseCode()).thenReturn( Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
HttpURLConnection.HTTP_BAD_REQUEST); LambdaTestUtils.interceptAndValidateMessageContains(IOException.class,
try { Arrays.asList(Integer.toString(HttpURLConnection.HTTP_BAD_REQUEST),
HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED); "foo.FooException", "EX"),
Assert.fail(); () -> HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED));
} catch (IOException ex) {
Assert.assertTrue(ex.getMessage().contains("EX"));
Assert.assertTrue(ex.getMessage().contains("foo.FooException"));
}
} }
@Test
public void testValidateResponseJsonErrorNonException() throws Exception {
Map<String, Object> json = new HashMap<String, Object>();
json.put(HttpExceptionUtils.ERROR_EXCEPTION_JSON, "invalid");
// test case where the exception classname is not a valid exception class
json.put(HttpExceptionUtils.ERROR_CLASSNAME_JSON, String.class.getName());
json.put(HttpExceptionUtils.ERROR_MESSAGE_JSON, "EX");
Map<String, Object> response = new HashMap<String, Object>();
response.put(HttpExceptionUtils.ERROR_JSON, json);
ObjectMapper jsonMapper = new ObjectMapper();
String msg = jsonMapper.writeValueAsString(response);
InputStream is = new ByteArrayInputStream(msg.getBytes(StandardCharsets.UTF_8));
HttpURLConnection conn = Mockito.mock(HttpURLConnection.class);
Mockito.when(conn.getErrorStream()).thenReturn(is);
Mockito.when(conn.getResponseMessage()).thenReturn("msg");
Mockito.when(conn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
LambdaTestUtils.interceptAndValidateMessageContains(IOException.class,
Arrays.asList(Integer.toString(HttpURLConnection.HTTP_BAD_REQUEST),
"java.lang.String", "EX"),
() -> HttpExceptionUtils.validateResponse(conn, HttpURLConnection.HTTP_CREATED));
}
} }

View File

@ -73,7 +73,6 @@ public void testCheckNotNullFailure() throws Exception {
// failure with Null message // failure with Null message
LambdaTestUtils.intercept(NullPointerException.class, LambdaTestUtils.intercept(NullPointerException.class,
null,
() -> Preconditions.checkNotNull(null, errorMessage)); () -> Preconditions.checkNotNull(null, errorMessage));
// failure with message format // failure with message format
@ -162,7 +161,6 @@ public void testCheckArgumentWithFailure() throws Exception {
errorMessage = null; errorMessage = null;
// failure with Null message // failure with Null message
LambdaTestUtils.intercept(IllegalArgumentException.class, LambdaTestUtils.intercept(IllegalArgumentException.class,
null,
() -> Preconditions.checkArgument(false, errorMessage)); () -> Preconditions.checkArgument(false, errorMessage));
// failure with message // failure with message
errorMessage = EXPECTED_ERROR_MSG; errorMessage = EXPECTED_ERROR_MSG;
@ -200,7 +198,6 @@ public void testCheckArgumentWithFailure() throws Exception {
// failure with Null supplier // failure with Null supplier
final Supplier<String> nullSupplier = null; final Supplier<String> nullSupplier = null;
LambdaTestUtils.intercept(IllegalArgumentException.class, LambdaTestUtils.intercept(IllegalArgumentException.class,
null,
() -> Preconditions.checkArgument(false, nullSupplier)); () -> Preconditions.checkArgument(false, nullSupplier));
// ignore illegal format in supplier // ignore illegal format in supplier
@ -262,7 +259,6 @@ public void testCheckStateWithFailure() throws Exception {
errorMessage = null; errorMessage = null;
// failure with Null message // failure with Null message
LambdaTestUtils.intercept(IllegalStateException.class, LambdaTestUtils.intercept(IllegalStateException.class,
null,
() -> Preconditions.checkState(false, errorMessage)); () -> Preconditions.checkState(false, errorMessage));
// failure with message // failure with message
errorMessage = EXPECTED_ERROR_MSG; errorMessage = EXPECTED_ERROR_MSG;
@ -300,7 +296,6 @@ public void testCheckStateWithFailure() throws Exception {
// failure with Null supplier // failure with Null supplier
final Supplier<String> nullSupplier = null; final Supplier<String> nullSupplier = null;
LambdaTestUtils.intercept(IllegalStateException.class, LambdaTestUtils.intercept(IllegalStateException.class,
null,
() -> Preconditions.checkState(false, nullSupplier)); () -> Preconditions.checkState(false, nullSupplier));
// ignore illegal format in supplier // ignore illegal format in supplier