YARN-11684. Fix general contract violation in PriorityQueueComparator. (#6725) Contributed by Tamas Domok.

Signed-off-by: Shilun Fan <slfan1989@apache.org>
This commit is contained in:
Tamas Domok 2024-04-19 02:37:05 +02:00 committed by GitHub
parent e8b2c28dec
commit a386ac1f56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 41 deletions

View File

@ -20,6 +20,7 @@
import org.apache.hadoop.classification.VisibleForTesting; import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource; import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.server.resourcemanager.nodelabels import org.apache.hadoop.yarn.server.resourcemanager.nodelabels
.RMNodeLabelsManager; .RMNodeLabelsManager;
@ -32,7 +33,6 @@
import java.util.Comparator; import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@ -54,17 +54,7 @@
public class PriorityUtilizationQueueOrderingPolicy public class PriorityUtilizationQueueOrderingPolicy
implements QueueOrderingPolicy { implements QueueOrderingPolicy {
private List<CSQueue> queues; private List<CSQueue> queues;
private boolean respectPriority; private final boolean respectPriority;
// This makes multiple threads can sort queues at the same time
// For different partitions.
private static ThreadLocal<String> partitionToLookAt =
ThreadLocal.withInitial(new Supplier<String>() {
@Override
public String get() {
return RMNodeLabelsManager.NO_LABEL;
}
});
/** /**
* Compare two queues with possibly different priority and assigned capacity, * Compare two queues with possibly different priority and assigned capacity,
@ -101,15 +91,21 @@ public static int compare(double relativeAssigned1, double relativeAssigned2,
/** /**
* Comparator that both looks at priority and utilization * Comparator that both looks at priority and utilization
*/ */
private class PriorityQueueComparator final private class PriorityQueueComparator
implements Comparator<PriorityQueueResourcesForSorting> { implements Comparator<PriorityQueueResourcesForSorting> {
final private String partition;
private PriorityQueueComparator(String partition) {
this.partition = partition;
}
@Override @Override
public int compare(PriorityQueueResourcesForSorting q1Sort, public int compare(PriorityQueueResourcesForSorting q1Sort,
PriorityQueueResourcesForSorting q2Sort) { PriorityQueueResourcesForSorting q2Sort) {
String p = partitionToLookAt.get(); int rc = compareQueueAccessToPartition(
q1Sort.nodeLabelAccessible,
int rc = compareQueueAccessToPartition(q1Sort.queue, q2Sort.queue, p); q2Sort.nodeLabelAccessible);
if (0 != rc) { if (0 != rc) {
return rc; return rc;
} }
@ -133,8 +129,8 @@ public int compare(PriorityQueueResourcesForSorting q1Sort,
float used2 = q2Sort.absoluteUsedCapacity; float used2 = q2Sort.absoluteUsedCapacity;
return compare(q1Sort, q2Sort, used1, used2, return compare(q1Sort, q2Sort, used1, used2,
q1Sort.queue.getPriority(). q1Sort.priority.
getPriority(), q2Sort.queue.getPriority().getPriority()); getPriority(), q2Sort.priority.getPriority());
} else{ } else{
// both q1 has positive abs capacity and q2 has positive abs // both q1 has positive abs capacity and q2 has positive abs
// capacity // capacity
@ -142,8 +138,8 @@ public int compare(PriorityQueueResourcesForSorting q1Sort,
float used2 = q2Sort.usedCapacity; float used2 = q2Sort.usedCapacity;
return compare(q1Sort, q2Sort, used1, used2, return compare(q1Sort, q2Sort, used1, used2,
q1Sort.queue.getPriority().getPriority(), q1Sort.priority.getPriority(),
q2Sort.queue.getPriority().getPriority()); q2Sort.priority.getPriority());
} }
} }
@ -181,8 +177,7 @@ private int compare(PriorityQueueResourcesForSorting q1Sort,
return rc; return rc;
} }
private int compareQueueAccessToPartition(CSQueue q1, CSQueue q2, private int compareQueueAccessToPartition(boolean q1Accessible, boolean q2Accessible) {
String partition) {
// Everybody has access to default partition // Everybody has access to default partition
if (StringUtils.equals(partition, RMNodeLabelsManager.NO_LABEL)) { if (StringUtils.equals(partition, RMNodeLabelsManager.NO_LABEL)) {
return 0; return 0;
@ -192,14 +187,6 @@ private int compareQueueAccessToPartition(CSQueue q1, CSQueue q2,
* Check accessible to given partition, if one queue accessible and * Check accessible to given partition, if one queue accessible and
* the other not, accessible queue goes first. * the other not, accessible queue goes first.
*/ */
boolean q1Accessible =
q1.getAccessibleNodeLabels() != null && q1.getAccessibleNodeLabels()
.contains(partition) || q1.getAccessibleNodeLabels().contains(
RMNodeLabelsManager.ANY);
boolean q2Accessible =
q2.getAccessibleNodeLabels() != null && q2.getAccessibleNodeLabels()
.contains(partition) || q2.getAccessibleNodeLabels().contains(
RMNodeLabelsManager.ANY);
if (q1Accessible && !q2Accessible) { if (q1Accessible && !q2Accessible) {
return -1; return -1;
} else if (!q1Accessible && q2Accessible) { } else if (!q1Accessible && q2Accessible) {
@ -218,22 +205,32 @@ public static class PriorityQueueResourcesForSorting {
private final float usedCapacity; private final float usedCapacity;
private final Resource configuredMinResource; private final Resource configuredMinResource;
private final float absoluteCapacity; private final float absoluteCapacity;
private final Priority priority;
private final boolean nodeLabelAccessible;
private final CSQueue queue; private final CSQueue queue;
PriorityQueueResourcesForSorting(CSQueue queue) { PriorityQueueResourcesForSorting(CSQueue queue, String partition) {
this.queue = queue; this.queue = queue;
this.absoluteUsedCapacity = this.absoluteUsedCapacity =
queue.getQueueCapacities(). queue.getQueueCapacities().
getAbsoluteUsedCapacity(partitionToLookAt.get()); getAbsoluteUsedCapacity(partition);
this.usedCapacity = this.usedCapacity =
queue.getQueueCapacities(). queue.getQueueCapacities().
getUsedCapacity(partitionToLookAt.get()); getUsedCapacity(partition);
this.absoluteCapacity = this.absoluteCapacity =
queue.getQueueCapacities(). queue.getQueueCapacities().
getAbsoluteCapacity(partitionToLookAt.get()); getAbsoluteCapacity(partition);
this.configuredMinResource = this.configuredMinResource =
queue.getQueueResourceQuotas(). queue.getQueueResourceQuotas().
getConfiguredMinResource(partitionToLookAt.get()); getConfiguredMinResource(partition);
this.priority = queue.getPriority();
this.nodeLabelAccessible = queue.getAccessibleNodeLabels() != null &&
queue.getAccessibleNodeLabels().contains(partition) ||
queue.getAccessibleNodeLabels().contains(RMNodeLabelsManager.ANY);
}
static PriorityQueueResourcesForSorting create(CSQueue queue, String partition) {
return new PriorityQueueResourcesForSorting(queue, partition);
} }
public CSQueue getQueue() { public CSQueue getQueue() {
@ -252,14 +249,13 @@ public void setQueues(List<CSQueue> queues) {
@Override @Override
public Iterator<CSQueue> getAssignmentIterator(String partition) { public Iterator<CSQueue> getAssignmentIterator(String partition) {
// partitionToLookAt is a thread local variable, therefore it is safe to mutate it.
PriorityUtilizationQueueOrderingPolicy.partitionToLookAt.set(partition);
// Copy (for thread safety) and sort the snapshot of the queues in order to avoid breaking // Copy (for thread safety) and sort the snapshot of the queues in order to avoid breaking
// the prerequisites of TimSort. See YARN-10178 for details. // the prerequisites of TimSort. See YARN-10178 for details.
return new ArrayList<>(queues).stream().map(PriorityQueueResourcesForSorting::new).sorted( return new ArrayList<>(queues).stream()
new PriorityQueueComparator()).map(PriorityQueueResourcesForSorting::getQueue).collect( .map(queue -> PriorityQueueResourcesForSorting.create(queue, partition))
Collectors.toList()).iterator(); .sorted(new PriorityQueueComparator(partition))
.map(PriorityQueueResourcesForSorting::getQueue)
.collect(Collectors.toList()).iterator();
} }
@Override @Override

View File

@ -21,6 +21,7 @@
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet; import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;
import org.apache.hadoop.yarn.api.records.Priority; import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.QueueResourceQuotas; import org.apache.hadoop.yarn.server.resourcemanager.scheduler.QueueResourceQuotas;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CSQueue; import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CSQueue;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.QueueCapacities; import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.QueueCapacities;
@ -28,9 +29,13 @@
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -250,4 +255,90 @@ public void testPriorityUtilizationOrdering() {
verifyOrder(policy, "x", new String[] { "e", "c", "d", "b", "a" }); verifyOrder(policy, "x", new String[] { "e", "c", "d", "b", "a" });
} }
@Test
public void testComparatorDoesNotValidateGeneralContract() {
final String[] nodeLabels = {"x", "y", "z"};
PriorityUtilizationQueueOrderingPolicy policy =
new PriorityUtilizationQueueOrderingPolicy(true);
final String partition = nodeLabels[randInt(0, nodeLabels.length - 1)];
List<CSQueue> list = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
CSQueue q = mock(CSQueue.class);
when(q.getQueuePath()).thenReturn(String.format("%d", i));
// simulating change in queueCapacities
when(q.getQueueCapacities())
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition));
// simulating change in the priority
when(q.getPriority())
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)));
if (randInt(0, nodeLabels.length) == 1) {
// simulating change in nodeLabels
when(q.getAccessibleNodeLabels())
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels));
}
// simulating change in configuredMinResource
when(q.getQueueResourceQuotas())
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition));
list.add(q);
}
policy.setQueues(list);
// java.lang.IllegalArgumentException: Comparison method violates its general contract!
assertDoesNotThrow(() -> policy.getAssignmentIterator(partition));
}
private QueueCapacities randomQueueCapacities(String partition) {
QueueCapacities qc = new QueueCapacities(false);
qc.setAbsoluteCapacity(partition, (float) randFloat(0.0d, 100.0d));
qc.setUsedCapacity(partition, (float) randFloat(0.0d, 100.0d));
qc.setAbsoluteUsedCapacity(partition, (float) randFloat(0.0d, 100.0d));
return qc;
}
private Set<String> randomNodeLabels(String[] availableNodeLabels) {
Set<String> nodeLabels = new HashSet<>();
for (String label : availableNodeLabels) {
if (randInt(0, 1) == 1) {
nodeLabels.add(label);
}
}
return nodeLabels;
}
private QueueResourceQuotas randomResourceQuotas(String partition) {
QueueResourceQuotas qr = new QueueResourceQuotas();
qr.setConfiguredMinResource(partition,
Resource.newInstance(randInt(1, 10) * 1024, randInt(1, 10)));
return qr;
}
private static double randFloat(double min, double max) {
return min + ThreadLocalRandom.current().nextFloat() * (max - min);
}
private static int randInt(int min, int max) {
return ThreadLocalRandom.current().nextInt(min, max + 1);
}
} }