diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestClusterTopology.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestClusterTopology.java index fbed6052a5..6b07d4a455 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestClusterTopology.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/net/TestClusterTopology.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Arrays; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.math3.stat.inference.ChiSquareTest; import org.apache.hadoop.conf.Configuration; import org.junit.Assert; @@ -234,4 +235,41 @@ private NodeElement getNewNode(String name, String rackLocation) { node.setNetworkLocation(rackLocation); return node; } + + private NodeElement getNewNode(NetworkTopology cluster, + String name, String rackLocation) { + NodeElement node = getNewNode(name, rackLocation); + cluster.add(node); + return node; + } + + @Test + @SuppressWarnings("unchecked") + public void testWeights() { + // create the topology + NetworkTopology cluster = NetworkTopology.getInstance(new Configuration()); + NodeElement node1 = getNewNode(cluster, "node1", "/r1"); + NodeElement node2 = getNewNode(cluster, "node2", "/r1"); + NodeElement node3 = getNewNode(cluster, "node3", "/r2"); + for (Pair test: new Pair[]{Pair.of(0, node1), + Pair.of(2, node2), Pair.of(4, node3)}) { + int expect = test.getLeft(); + assertEquals(test.toString(), expect, cluster.getWeight(node1, test.getRight())); + assertEquals(test.toString(), expect, + cluster.getWeightUsingNetworkLocation(node1, test.getRight())); + } + // Reset so that we can have 2 levels + cluster = NetworkTopology.getInstance(new Configuration()); + NodeElement node5 = getNewNode(cluster, "node5", "/pod1/r1"); + NodeElement node6 = getNewNode(cluster, "node6", "/pod1/r1"); + NodeElement node7 = getNewNode(cluster, "node7", "/pod1/r2"); + NodeElement node8 = getNewNode(cluster, "node8", "/pod2/r3"); + for (Pair test: new Pair[]{Pair.of(0, node5), + Pair.of(2, node6), Pair.of(4, node7), Pair.of(6, node8)}) { + int expect = test.getLeft(); + assertEquals(test.toString(), expect, cluster.getWeight(node5, test.getRight())); + assertEquals(test.toString(), expect, + cluster.getWeightUsingNetworkLocation(node5, test.getRight())); + } + } }