YARN-8821. [YARN-8851] GPU hierarchy/topology scheduling support based on pluggable device framework. Contributed by Zhankun Tang.

This commit is contained in:
Sunil G 2019-02-24 14:36:06 +05:30
parent 106bdc6c04
commit dddcfa4d9f
8 changed files with 4327 additions and 122 deletions

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin; package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin;
import java.util.Map;
import java.util.Set; import java.util.Set;
/** /**
@ -29,10 +30,15 @@ public interface DevicePluginScheduler {
* Called when allocating devices. The framework will do all device book * Called when allocating devices. The framework will do all device book
* keeping and fail recovery. So this hook could be stateless and only do * keeping and fail recovery. So this hook could be stateless and only do
* scheduling based on available devices passed in. It could be * scheduling based on available devices passed in. It could be
* invoked multiple times by the framework. * invoked multiple times by the framework. The hint in environment variables
* passed in could be potentially used in making better scheduling decision.
* For instance, GPU scheduling might support different kind of policy. The
* container can set it through environment variables.
* @param availableDevices Devices allowed to be chosen from. * @param availableDevices Devices allowed to be chosen from.
* @param count Number of device to be allocated. * @param count Number of device to be allocated.
* @param env Environment variables of the container.
* @return A set of {@link Device} allocated * @return A set of {@link Device} allocated
* */ * */
Set<Device> allocateDevices(Set<Device> availableDevices, int count); Set<Device> allocateDevices(Set<Device> availableDevices, int count,
Map<String, String> env);
} }

View File

@ -24,6 +24,7 @@
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
@ -32,7 +33,12 @@
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
@ -40,8 +46,10 @@
/** /**
* Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and * Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
* non-Docker container. * non-Docker container.
* It has topology aware as well as simple scheduling ability.
* */ * */
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin { public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin,
DevicePluginScheduler {
public static final Logger LOG = LoggerFactory.getLogger( public static final Logger LOG = LoggerFactory.getLogger(
NvidiaGPUPluginForRuntimeV2.class); NvidiaGPUPluginForRuntimeV2.class);
@ -69,6 +77,47 @@ public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of( private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
"/usr/bin", "/bin", "/usr/local/nvidia/bin"); "/usr/bin", "/bin", "/usr/local/nvidia/bin");
private boolean topoInitialized = false;
private Set<Device> lastTimeFoundDevices;
/**
* It caches the combination of different devices and the communication cost.
* The key is device count
* The value is an ordered list of map entry whose key is device combination,
* value is cost. The list is sorted by cost in ascending order.
* For instance:
* { 2=> [[device1,device2]=>0, [device1,device3]=>10]
* 3 => [[device1,device2,device3]=>10, [device2,device3,device5]=>20],
* }
* */
private Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable
= new HashMap<>();
/**
* The key is a pair of minors. For instance, "0-1" indicates 0 to 1
* The value is weight between the two devices.
* */
private Map<String, Integer> devicePairToWeight = new HashMap<>();
/**
* The container can set this environment variable.
* To tell the scheduler what's the policy to use when do scheduling
* */
public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY";
/**
* Schedule policy that prefer the faster GPU-GPU communication.
* Suitable for heavy GPU computation workload generally.
* */
public static final String TOPOLOGY_POLICY_PACK = "PACK";
/**
* Schedule policy that prefer the faster CPU-GPU communication.
* Suitable for heavy CPU-GPU IO operations generally.
* */
public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD";
@Override @Override
public DeviceRegisterRequest getRegisterRequestInfo() throws Exception { public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
return DeviceRegisterRequest.Builder.newInstance() return DeviceRegisterRequest.Builder.newInstance()
@ -106,6 +155,8 @@ public Set<Device> getDevices() throws Exception {
id++; id++;
} }
} }
// cache it which help to topology scheduling
lastTimeFoundDevices = r;
return r; return r;
} catch (IOException e) { } catch (IOException e) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
@ -170,6 +221,422 @@ private String getMajorNumber(String devName) {
return output; return output;
} }
@Override
public Set<Device> allocateDevices(Set<Device> availableDevices, int count,
Map<String, String> envs) {
Set<Device> allocation = new TreeSet<>();
/**
* corner cases.
* if allocate 1 device or all devices, no topo scheduling needed.
* if total available devices is less than 3, no topo scheduling needed.
* */
if (availableDevices.size() < 3
|| count == 1
|| availableDevices.size() == count) {
basicSchedule(allocation, count, availableDevices);
return allocation;
}
try {
if (!topoInitialized) {
initCostTable();
}
// topology aware scheduling
topologyAwareSchedule(allocation, count,
envs, availableDevices, this.costTable);
if (allocation.size() == count) {
return allocation;
} else {
LOG.error("Failed to do topology scheduling. Skip to use basic "
+ "scheduling");
}
} catch (IOException e) {
LOG.error("Error in getting GPU topology info. "
+ "Skip topology aware scheduling", e);
}
// basic scheduling
basicSchedule(allocation, count, availableDevices);
return allocation;
}
@VisibleForTesting
public void initCostTable() throws IOException {
// get topology
String topo = shellExecutor.getTopologyInfo();
// build the graph
parseTopo(topo, devicePairToWeight);
// build the cost table of different device combinations
if (lastTimeFoundDevices == null) {
try {
getDevices();
} catch (Exception e) {
LOG.error("Failed to get devices!", e);
return;
}
}
buildCostTable(costTable, lastTimeFoundDevices);
loggingCostTable(costTable);
this.topoInitialized = true;
}
private void loggingCostTable(
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
if (LOG.isDebugEnabled()) {
StringBuilder sb = new StringBuilder("The costTable is:");
sb.append("\n{");
for (Map.Entry<Integer, List<Map.Entry<Set<Device>, Integer>>> entry
: cTable.entrySet()) {
sb.append("\n\t")
.append(entry.getKey())
.append(" => [");
for (Map.Entry<Set<Device>, Integer> e : entry.getValue()) {
sb.append("\n\t\t").append(e.toString()).append(",\n");
}
sb.append("\t\t]\n");
}
sb.append("}\n");
LOG.debug(sb.toString());
}
}
/**
* Generate combination of devices and its cost.
* costTable
* */
private void buildCostTable(
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
Set<Device> ltfDevices) {
Device[] deviceList = new Device[ltfDevices.size()];
ltfDevices.toArray(deviceList);
generateAllDeviceCombination(cTable, deviceList, deviceList.length);
}
/**
* For every possible combination of i elements.
* We generate a map whose key is the combination, value is cost.
*/
private void generateAllDeviceCombination(
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
Device[] allDevices, int n) {
// allocated devices count range from 1 to n-1
for (int i = 2; i < n; i++) {
Map<Set<Device>, Integer> combinationToCost =
new HashMap<>();
buildCombination(combinationToCost, allDevices, n, i);
// sort the map entry by cost ascending order
List<Map.Entry<Set<Device>, Integer>> listSortedByCost =
new LinkedList<>(combinationToCost.entrySet());
Collections.sort(listSortedByCost,
(o1, o2) -> (o1.getValue()).compareTo(o2.getValue()));
cTable.put(i, listSortedByCost);
}
}
private void buildCombination(Map<Set<Device>, Integer> combinationToCost,
Device[] allDevices, int n, int r) {
// A temporary list to store all combination one by one
Device[] subDeviceList = new Device[r];
combinationRecursive(combinationToCost, allDevices, subDeviceList,
0, n - 1, 0, r);
}
/**
* Populate combination to cost map recursively.
*
* @param cTc combinationToCost map.
* The key is device set, the value is cost
* @param allDevices all devices used to assign value to subDevicelist
* @param subDeviceList store a subset of devices temporary
* @param start start index in the allDevices
* @param end last index in the allDevices
* @param index dynamic index in subDeviceList need to be assigned
* @param r the length of the subDeviceList
*/
void combinationRecursive(Map<Set<Device>, Integer> cTc,
Device[] allDevices, Device[] subDeviceList,
int start, int end, int index, int r) {
// sub device list's length is ready to compute the cost
if (index == r) {
Set<Device> oneSet = new TreeSet<>(Arrays.asList(subDeviceList));
int cost = computeCostOfDevices(subDeviceList);
cTc.put(oneSet, cost);
return;
}
for (int i = start; i <= end; i++) {
subDeviceList[index] = allDevices[i];
combinationRecursive(cTc, allDevices, subDeviceList,
i + 1, end, index + 1, r);
}
}
/**
* The cost function used to calculate costs of a sub set of devices.
* It calculate link weight of each pair in non-duplicated combination of
* devices.
*/
@VisibleForTesting
public int computeCostOfDevices(Device[] devices) {
int cost = 0;
String gpuIndex0;
String gpuIndex1;
for (int i = 0; i < devices.length; i++) {
gpuIndex0 = String.valueOf(devices[i].getMinorNumber());
for (int j = i + 1; j < devices.length; j++) {
gpuIndex1 = String.valueOf(devices[j].getMinorNumber());
cost += this.devicePairToWeight.get(gpuIndex0 + "-" + gpuIndex1);
}
}
return cost;
}
/**
* Topology Aware schedule algorithm.
* It doesn't consider CPU affinity or NUMA or bus bandwidths.
* It support two plicy: "spread" and "pack" which can be set by container's
* environment variable. Use pack by default which means prefer the faster
* GPU-GPU. "Spread" means prefer the faster CPU-GPU.
* It can potentially be extend to take GPU attribute like GPU chip memory
* into consideration.
* */
@VisibleForTesting
public void topologyAwareSchedule(Set<Device> allocation, int count,
Map<String, String> envs,
Set<Device> availableDevices,
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
int num = 0;
String policy = envs.get(TOPOLOGY_POLICY_ENV_KEY);
if (policy == null) {
policy = TOPOLOGY_POLICY_PACK;
}
/**
* Get combinations from costTable given the count of device want to
* allocate.
* */
if (cTable == null) {
LOG.error("No cost table initialized!");
return;
}
List<Map.Entry<Set<Device>, Integer>> combinationsToCost =
cTable.get(count);
Iterator<Map.Entry<Set<Device>, Integer>> iterator =
combinationsToCost.iterator();
// the container needs spread policy
if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) {
// loop from high cost to low cost
iterator = ((LinkedList) combinationsToCost).descendingIterator();
}
while (iterator.hasNext()) {
Map.Entry<Set<Device>, Integer> element = iterator.next();
if (availableDevices.containsAll(element.getKey())) {
allocation.addAll(element.getKey());
LOG.info("Topology scheduler allocated: " + allocation);
return;
}
}
LOG.error("Unknown error happened in topology scheduler");
}
@VisibleForTesting
public void basicSchedule(Set<Device> allocation, int count,
Set<Device> availableDevices) {
// Basic scheduling
// allocate all available
if (count == availableDevices.size()) {
allocation.addAll(availableDevices);
return;
}
int number = 0;
for (Device d : availableDevices) {
allocation.add(d);
number++;
if (number == count) {
break;
}
}
}
/**
* A typical sample topo output:
* GPU0 GPU1 GPU2 GPU3 CPU Affinity
* GPU0 X PHB SOC SOC 0-31
* GPU1 PHB X SOC SOC 0-31
* GPU2 SOC SOC X PHB 0-31
* GPU3 SOC SOC PHB X 0-31
*
*
* Legend:
*
* X = Self
* SOC = Connection traversing PCIe as well as the SMP link between
* CPU sockets(e.g. QPI)
* PHB = Connection traversing PCIe as well as a PCIe Host Bridge
* (typically the CPU)
* PXB = Connection traversing multiple PCIe switches
* (without traversing the PCIe Host Bridge)
* PIX = Connection traversing a single PCIe switch
* NV# = Connection traversing a bonded set of # NVLinks
* */
public void parseTopo(String topo,
Map<String, Integer> deviceLinkToWeight) {
String[] lines = topo.split("\n");
int rowMinor;
int colMinor;
String legend;
String tempType;
for (String oneLine : lines) {
oneLine = oneLine.trim();
if (oneLine.isEmpty()) {
continue;
}
// To the end. No more metrics info
if (oneLine.startsWith("Legend")) {
break;
}
// Skip header
if (oneLine.contains("Affinity")) {
continue;
}
String[] tokens = oneLine.split(("\\s+"));
String name = tokens[0];
rowMinor = Integer.parseInt(name.substring(name.lastIndexOf("U") + 1));
for (int i = 1; i < tokens.length; i++) {
tempType = tokens[i];
colMinor = i - 1;
// self, skip
if (tempType.equals("X")) {
continue;
}
if (tempType.equals("SOC") || tempType.equals("SYS")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("PHB") || tempType.equals("NODE")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("PXB")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("PIX")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV1")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink1,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV2")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink2,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV3")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink3,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV4")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink4,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV5")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink5,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV6")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink6,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV7")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink7,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV8")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink8,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
if (tempType.equals("NV9")) {
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink9,
rowMinor, colMinor, deviceLinkToWeight);
continue;
}
} // end one line handling
}
}
private void populateGraphEdgeWeight(
DeviceLinkType linkType,
int leftVertex,
int rightVertex,
Map<String, Integer> deviceLinkToWeight) {
deviceLinkToWeight.put(leftVertex + "-" + rightVertex,
linkType.getWeight());
}
/**
* Different type of link.
* The weight of each link is a relative value.
* The higher weight, the higher cost between the GPUs
* */
public enum DeviceLinkType {
/**
* For Nvdia GPU NVLink.
* */
P2PLinkNVLink9(10),
P2PLinkNVLink8(20),
P2PLinkNVLink7(30),
P2PLinkNVLink6(40),
P2PLinkNVLink5(50),
P2PLinkNVLink4(60),
P2PLinkNVLink3(70),
P2PLinkNVLink2(80),
P2PLinkNVLink1(90),
/**
* Connected to same CPU (Same NUMA node).
* */
P2PLinkSameCPUSocket(200),
/**
* Cross CPU through socket-level link (e.g. QPI).
* Usually cross NUMA node
* */
P2PLinkCrossCPUSocket(300),
/**
* Just need to traverse one PCIe switch to talk.
* */
P2PLinkSingleSwitch(600),
/**
* Need to traverse multiple PCIe switch to talk.
* */
P2PLinkMultiSwitch(1200);
// A higher link level means slower communication.
private int weight;
public int getWeight() {
return weight;
}
DeviceLinkType(int w) {
this.weight = w;
}
}
/** /**
* A shell wrapper class easy for test. * A shell wrapper class easy for test.
* */ * */
@ -189,6 +656,13 @@ public String getMajorMinorInfo(String devName) throws IOException {
return shexec.getOutput(); return shexec.getOutput();
} }
// Get the topology metrics info from nvdia-smi
public String getTopologyInfo() throws IOException {
return Shell.execCommand(environment,
new String[]{pathOfGpuBinary, "topo",
"-m"}, MAX_EXEC_TIMEOUT_MS);
}
public void searchBinary() throws Exception { public void searchBinary() throws Exception {
if (pathOfGpuBinary != null) { if (pathOfGpuBinary != null) {
LOG.info("Skip searching, the nvidia gpu binary is already set: " LOG.info("Skip searching, the nvidia gpu binary is already set: "
@ -228,8 +702,8 @@ public void searchBinary() throws Exception {
} }
@VisibleForTesting @VisibleForTesting
public void setPathOfGpuBinary(String pathOfGpuBinary) { public void setPathOfGpuBinary(String pOfGpuBinary) {
this.pathOfGpuBinary = pathOfGpuBinary; this.pathOfGpuBinary = pOfGpuBinary;
} }
@VisibleForTesting @VisibleForTesting
@ -237,4 +711,20 @@ public void setShellExecutor(
NvidiaCommandExecutor shellExecutor) { NvidiaCommandExecutor shellExecutor) {
this.shellExecutor = shellExecutor; this.shellExecutor = shellExecutor;
} }
@VisibleForTesting
public boolean isTopoInitialized() {
return topoInitialized;
}
@VisibleForTesting
public Map<Integer, List<Map.Entry<Set<Device>, Integer>>> getCostTable() {
return costTable;
}
@VisibleForTesting
public Map<String, Integer> getDevicePairToWeight() {
return devicePairToWeight;
}
} }

View File

@ -19,6 +19,7 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework; package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -192,7 +193,7 @@ private synchronized DeviceAllocation internalAssignDevices(
DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName); DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName);
// Prefer DevicePluginScheduler logic // Prefer DevicePluginScheduler logic
pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices, pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices,
containerId, requestedDeviceCount, resourceName, dps); container, requestedDeviceCount, resourceName, dps);
// Record in state store if we allocated anything // Record in state store if we allocated anything
if (!assignedDevices.isEmpty()) { if (!assignedDevices.isEmpty()) {
@ -310,9 +311,11 @@ private long getReleasingDevices(String resourceName) {
* */ * */
private void pickAndDoSchedule(Set<Device> allowed, private void pickAndDoSchedule(Set<Device> allowed,
Map<Device, ContainerId> used, Set<Device> assigned, Map<Device, ContainerId> used, Set<Device> assigned,
ContainerId containerId, int count, String resourceName, Container c, int count, String resourceName,
DevicePluginScheduler dps) throws ResourceHandlerException { DevicePluginScheduler dps)
throws ResourceHandlerException {
ContainerId containerId = c.getContainerId();
Map<String, String> env = c.getLaunchContext().getEnvironment();
if (null == dps) { if (null == dps) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Customized device plugin scheduler is preferred " LOG.debug("Customized device plugin scheduler is preferred "
@ -331,7 +334,8 @@ private void pickAndDoSchedule(Set<Device> allowed,
// Pass in unmodifiable set // Pass in unmodifiable set
Set<Device> dpsAllocated = dps.allocateDevices( Set<Device> dpsAllocated = dps.allocateDevices(
Sets.difference(allowed, used.keySet()), Sets.difference(allowed, used.keySet()),
count); count,
ImmutableMap.copyOf(env));
if (dpsAllocated.size() != count) { if (dpsAllocated.size() != count) {
throw new ResourceHandlerException(dps.getClass() throw new ResourceHandlerException(dps.getClass()
+ " should allocate " + count + " should allocate " + count

View File

@ -0,0 +1,848 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Test case for NvidiaGPUPluginForRuntimeV2 device plugin.
* */
public class TestNvidiaGPUPluginForRuntimeV2 {
private static final Logger LOG =
LoggerFactory.getLogger(TestNvidiaGPUPluginForRuntimeV2.class);
@Test
public void testGetNvidiaDevices() throws Exception {
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
String deviceInfoShellOutput =
"0, 00000000:04:00.0\n" +
"1, 00000000:82:00.0";
String majorMinorNumber0 = "c3:0";
String majorMinorNumber1 = "c3:1";
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
when(mockShell.getMajorMinorInfo("nvidia0"))
.thenReturn(majorMinorNumber0);
when(mockShell.getMajorMinorInfo("nvidia1"))
.thenReturn(majorMinorNumber1);
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
plugin.setShellExecutor(mockShell);
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
Set<Device> expectedDevices = new TreeSet<>();
expectedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
expectedDevices.add(Device.Builder.newInstance()
.setId(1).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
Set<Device> devices = plugin.getDevices();
Assert.assertEquals(expectedDevices, devices);
}
@Test
public void testOnDeviceAllocated() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
Set<Device> allocatedDevices = new TreeSet<>();
DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DEFAULT);
Assert.assertNull(spec);
// allocate one device
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
// two device allowed
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
}
private NvidiaGPUPluginForRuntimeV2 mockEightGPUPlugin() throws IOException {
String topoInfo =
"\tGPU0\tGPU1\tGPU2\tGPU3\tGPU4\tGPU5\tGPU6\tGPU7\tCPU Affinity\n"
+ "GPU0\t X \tNV1\tNV1\tNV2\tNV2\tPHB\tPHB\tPHB\t0-63\n"
+ "GPU1\tNV1\t X \tNV2\tNV1\tPHB\tNV2\tPHB\tPHB\t0-63\n"
+ "GPU2\tNV1\tNV2\t X \tNV2\tPHB\tPHB\tNV1\tPHB\t0-63\n"
+ "GPU3\tNV2\tNV1\tNV2\t X \tPHB\tPHB\tPHB\tNV1\t0-63\n"
+ "GPU4\tNV2\tPHB\tPHB\tPHB\t X \tNV1\tNV1\tNV2\t0-63\n"
+ "GPU5\tPHB\tNV2\tPHB\tPHB\tNV1\t X \tNV2\tNV1\t0-63\n"
+ "GPU6\tPHB\tPHB\tNV1\tPHB\tNV1\tNV2\t X \tNV2\t0-63\n"
+ "GPU7\tPHB\tPHB\tPHB\tNV1\tNV2\tNV1\tNV2\t X \t0-63\n"
+ "\n"
+ "Legend:\n"
+ "\n"
+ " X = Self\n"
+ " SYS = Connection traversing PCIe as well as the SMP interconnect"
+ " between NUMA nodes (e.g., QPI/UPI)\n"
+ " NODE = Connection traversing PCIe as well as the interconnect"
+ " between PCIe Host Bridges within a NUMA node\n"
+ " PHB = Connection traversing PCIe as well as a PCIe Host Bridge"
+ " (typically the CPU)\n"
+ " PXB = Connection traversing multiple PCIe switches"
+ " (without traversing the PCIe Host Bridge)\n"
+ " PIX = Connection traversing a single PCIe switch\n"
+ " NV# = Connection traversing a bonded set of # NVLinks\n";
String deviceInfoShellOutput = "0, 00000000:04:00.0\n"
+ "1, 00000000:82:00.0\n"
+ "2, 00000000:83:00.0\n"
+ "3, 00000000:84:00.0\n"
+ "4, 00000000:85:00.0\n"
+ "5, 00000000:86:00.0\n"
+ "6, 00000000:87:00.0\n"
+ "7, 00000000:88:00.0";
String majorMinorNumber0 = "c3:0";
String majorMinorNumber1 = "c3:1";
String majorMinorNumber2 = "c3:2";
String majorMinorNumber3 = "c3:3";
String majorMinorNumber4 = "c3:4";
String majorMinorNumber5 = "c3:5";
String majorMinorNumber6 = "c3:6";
String majorMinorNumber7 = "c3:7";
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
when(mockShell.getMajorMinorInfo("nvidia0"))
.thenReturn(majorMinorNumber0);
when(mockShell.getMajorMinorInfo("nvidia1"))
.thenReturn(majorMinorNumber1);
when(mockShell.getMajorMinorInfo("nvidia2"))
.thenReturn(majorMinorNumber2);
when(mockShell.getMajorMinorInfo("nvidia3"))
.thenReturn(majorMinorNumber3);
when(mockShell.getMajorMinorInfo("nvidia4"))
.thenReturn(majorMinorNumber4);
when(mockShell.getMajorMinorInfo("nvidia5"))
.thenReturn(majorMinorNumber5);
when(mockShell.getMajorMinorInfo("nvidia6"))
.thenReturn(majorMinorNumber6);
when(mockShell.getMajorMinorInfo("nvidia7"))
.thenReturn(majorMinorNumber7);
when(mockShell.getTopologyInfo()).thenReturn(topoInfo);
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
plugin.setShellExecutor(mockShell);
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
return plugin;
}
private NvidiaGPUPluginForRuntimeV2 mockFourGPUPlugin() throws IOException {
String topoInfo = "\tGPU0\tGPU1\tGPU2\tGPU3\tCPU Affinity\n"
+ "GPU0\t X \tPHB\tSOC\tSOC\t0-31\n"
+ "GPU1\tPHB\t X \tSOC\tSOC\t0-31\n"
+ "GPU2\tSOC\tSOC\t X \tPHB\t0-31\n"
+ "GPU3\tSOC\tSOC\tPHB\t X \t0-31\n"
+ "\n"
+ "\n"
+ " Legend:\n"
+ "\n"
+ " X = Self\n"
+ " SOC = Connection traversing PCIe as well as the SMP link between\n"
+ " CPU sockets(e.g. QPI)\n"
+ " PHB = Connection traversing PCIe as well as a PCIe Host Bridge\n"
+ " (typically the CPU)\n"
+ " PXB = Connection traversing multiple PCIe switches\n"
+ " (without traversing the PCIe Host Bridge)\n"
+ " PIX = Connection traversing a single PCIe switch\n"
+ " NV# = Connection traversing a bonded set of # NVLinks";
String deviceInfoShellOutput = "0, 00000000:04:00.0\n"
+ "1, 00000000:82:00.0\n"
+ "2, 00000000:83:00.0\n"
+ "3, 00000000:84:00.0";
String majorMinorNumber0 = "c3:0";
String majorMinorNumber1 = "c3:1";
String majorMinorNumber2 = "c3:2";
String majorMinorNumber3 = "c3:3";
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
when(mockShell.getMajorMinorInfo("nvidia0"))
.thenReturn(majorMinorNumber0);
when(mockShell.getMajorMinorInfo("nvidia1"))
.thenReturn(majorMinorNumber1);
when(mockShell.getMajorMinorInfo("nvidia2"))
.thenReturn(majorMinorNumber2);
when(mockShell.getMajorMinorInfo("nvidia3"))
.thenReturn(majorMinorNumber3);
when(mockShell.getTopologyInfo()).thenReturn(topoInfo);
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
plugin.setShellExecutor(mockShell);
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
return plugin;
}
@Test
public void testTopologySchedulingWithPackPolicy() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
// cache the total devices
Set<Device> allDevices = spyPlugin.getDevices();
// environment variable to use PACK policy
Map<String, String> env = new HashMap<>();
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
// Case 0. if available devices is less than 3, no topo scheduling needed
Set<Device> copyAvailableDevices = new TreeSet<>(allDevices);
Iterator<Device> iterator0 = copyAvailableDevices.iterator();
iterator0.next();
iterator0.remove();
iterator0.next();
iterator0.remove();
// Case 0. allocate 1 device
reset(spyPlugin);
Set<Device> allocation = spyPlugin.allocateDevices(copyAvailableDevices,
1, env);
Assert.assertEquals(allocation.size(), 1);
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
Assert.assertFalse(spyPlugin.isTopoInitialized());
// Case 1. allocate 1 device
reset(spyPlugin);
allocation = spyPlugin.allocateDevices(allDevices, 1, env);
// ensure no topology scheduling needed
Assert.assertEquals(allocation.size(), 1);
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
reset(spyPlugin);
// Case 2. allocate all available
allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env);
Assert.assertEquals(allocation.size(), allDevices.size());
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
// Case 3. allocate 2 devices
reset(spyPlugin);
int count = 2;
Map<String, Integer> pairToWeight = spyPlugin.getDevicePairToWeight();
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
Device[] allocatedDevices =
allocation.toArray(new Device[count]);
// Check weights
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight(),
spyPlugin.computeCostOfDevices(allocatedDevices));
// Case 4. allocate 3 devices
reset(spyPlugin);
count = 3;
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin, times(0)).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
allocatedDevices =
allocation.toArray(new Device[count]);
// check weights
int expectedWeight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight()
+ 2 * NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkCrossCPUSocket.getWeight();
Assert.assertEquals(expectedWeight,
spyPlugin.computeCostOfDevices(allocatedDevices));
// Case 5. allocate 2 GPUs from three available devices
reset(spyPlugin);
Iterator<Device> iterator = allDevices.iterator();
iterator.next();
// remove GPU0
iterator.remove();
count = 2;
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin, times(0)).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
allocatedDevices =
allocation.toArray(new Device[count]);
// check weights
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight(),
spyPlugin.computeCostOfDevices(allocatedDevices));
// it should allocate GPU 2 and 3
for (Device device : allocation) {
if (device.getMinorNumber() == 2) {
Assert.assertTrue(true);
} else if (device.getMinorNumber() == 3) {
Assert.assertTrue(true);
} else {
Assert.assertTrue("Should allocate GPU 2 and 3", false);
}
}
}
@Test
public void testTopologySchedulingWithSpreadPolicy() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
// cache the total devices
Set<Device> allDevices = spyPlugin.getDevices();
// environment variable to use PACK policy
Map<String, String> env = new HashMap<>();
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_SPREAD);
// Case 1. allocate 1 device
Set<Device> allocation = spyPlugin.allocateDevices(allDevices, 1, env);
// ensure no topology scheduling needed
Assert.assertEquals(allocation.size(), 1);
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
reset(spyPlugin);
// Case 2. allocate all available
allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env);
Assert.assertEquals(allocation.size(), allDevices.size());
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
// Case 3. allocate 2 devices
reset(spyPlugin);
int count = 2;
Map<String, Integer> pairToWeight = spyPlugin.getDevicePairToWeight();
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
Device[] allocatedDevices =
allocation.toArray(new Device[count]);
// Check weights
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkCrossCPUSocket.getWeight(),
spyPlugin.computeCostOfDevices(allocatedDevices));
// Case 4. allocate 3 devices
reset(spyPlugin);
count = 3;
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin, times(0)).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
allocatedDevices =
allocation.toArray(new Device[count]);
// check weights
int expectedWeight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight()
+ 2 * NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkCrossCPUSocket.getWeight();
Assert.assertEquals(expectedWeight,
spyPlugin.computeCostOfDevices(allocatedDevices));
// Case 5. allocate 2 GPUs from three available devices
reset(spyPlugin);
Iterator<Device> iterator = allDevices.iterator();
iterator.next();
// remove GPU0
iterator.remove();
count = 2;
allocation = spyPlugin.allocateDevices(allDevices, count, env);
Assert.assertEquals(allocation.size(), count);
// the costTable should be init and used topology scheduling
verify(spyPlugin, times(0)).initCostTable();
Assert.assertTrue(spyPlugin.isTopoInitialized());
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
anySet(), anyMap());
Assert.assertEquals(allocation.size(), count);
allocatedDevices =
allocation.toArray(new Device[count]);
// check weights
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkCrossCPUSocket.getWeight(),
spyPlugin.computeCostOfDevices(allocatedDevices));
// it should allocate GPU 1 and 2
for (Device device : allocation) {
if (device.getMinorNumber() == 0) {
Assert.assertTrue("Shouldn't allocate GPU 0", false);
}
}
}
@Test
public void testCostTableWithNVlink() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = mockEightGPUPlugin();
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
// verify the device pair to weight map
spyPlugin.initCostTable();
Map<String, Integer> devicePairToWeight = spyPlugin.getDevicePairToWeight();
// 12 combinations when choose 2 GPUs from 8 respect the order. 8!/6!
Assert.assertEquals(56, devicePairToWeight.size());
int sameCPUWeight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight();
int nv1Weight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkNVLink1.getWeight();
int nv2Weight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkNVLink2.getWeight();
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("0-1"));
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("1-0"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("0-4"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("4-0"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("0-3"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("3-0"));
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("6-3"));
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("3-6"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("6-7"));
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("7-6"));
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("1-3"));
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("3-1"));
// verify cost Table
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable =
spyPlugin.getCostTable();
Assert.assertNull(costTable.get(1));
// C8:2 = 8!/2!/6! = 28
Assert.assertEquals(28, costTable.get(2).size());
// C8:4 = 8!/4!/4! = 70
Assert.assertEquals(70, costTable.get(4).size());
Assert.assertNull(costTable.get(8));
Set<Device> allDevices = spyPlugin.getDevices();
Map<String, String> env = new HashMap<>();
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
spyPlugin.allocateDevices(allDevices, 3, env);
spyPlugin.allocateDevices(allDevices, 2, env);
}
/**
* Test the key cost table used for topology scheduling.
* */
@Test
public void testCostTable() throws IOException {
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
// verify the device pair to weight map
spyPlugin.initCostTable();
Map<String, Integer> devicePairToWeight = spyPlugin.getDevicePairToWeight();
// 12 combinations when choose 2 GPUs from 4 respect the order. 4!/2!
Assert.assertEquals(12, devicePairToWeight.size());
int sameCPUWeight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkSameCPUSocket.getWeight();
int crossCPUWeight =
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
.P2PLinkCrossCPUSocket.getWeight();
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("0-1"));
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("1-0"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-2"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-0"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-3"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-0"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-2"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-1"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-3"));
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-1"));
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("2-3"));
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("3-2"));
// verify cost Table
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable =
spyPlugin.getCostTable();
Assert.assertNull(costTable.get(1));
Assert.assertEquals(6, costTable.get(2).size());
Assert.assertEquals(4, costTable.get(3).size());
Assert.assertNull(costTable.get(4));
}
/**
* Test GPU topology allocation.
* And analysis the GPU allocation's performance against the actual
* performance data using tensorflow benchmarks.
* https://github.com/tensorflow/benchmarks
* */
@Test
public void testTopologySchedulingPerformanceWithPackPolicyWithNVLink()
throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = mockEightGPUPlugin();
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
Set<Device> allDevices = spyPlugin.getDevices();
Map<String, String> env = new HashMap<>();
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
/**
* Analyze performance against the real data.
* Get the topology scheduling algorithm's allocation's
* average performance boost against median imagePerSecond and minimum
* imagePerSecond in certain model and batch size combinations.
* And then calculate the average performance boost.
* The average performance boost against
* median value means topology scheduler's allocation can stably
* outperforms 50% of possible allocations.
* The average performance boost against min value means the average boost
* comparing to the worst allocations in various scenarios. Which is more
* beautiful number for public promotion.
* And also the analysis shows the best performance boost against median
* and min value.
* */
ActualPerformanceReport report = new ActualPerformanceReport();
report.readFromFile();
ArrayList<ActualPerformanceReport.DataRecord> dataSet =
report.getDataSet();
Assert.assertEquals(dataSet.size(), 2952);
String[] allModels = {"alexnet", "resnet50", "vgg16", "inception3"};
int[] batchSizes = {32, 64, 128};
int[] gpuCounts = {2, 3, 4, 5, 6, 7};
float totalBoostAgainstMedian = 0;
int count = 0;
float maxBoostAgainstMedian = 0;
float totalBoostAgainstMin = 0;
float maxBoostAgainstMin = 0;
for (String model : allModels) {
float totalBoostAgainstMinCertainModel = 0;
float totalBoostAgainstMedianCertainModel = 0;
float maxBoostAgainstMinCertainModel = 0;
float maxBoostAgainstMedianCertainModel = 0;
int countOfEachModel = 0;
for (int bs : batchSizes) {
for (int gpuCount: gpuCounts) {
float bstAgainstMedian = calculatePerformanceBoostAgainstMedian(
report, model, bs, gpuCount, plugin, allDevices, env);
float bstAgainstMinimum = calculatePerformanceBoostAgainstMinimum(
report, model, bs, gpuCount, plugin, allDevices, env);
totalBoostAgainstMedian += bstAgainstMedian;
totalBoostAgainstMin += bstAgainstMinimum;
count++;
if (maxBoostAgainstMedian < bstAgainstMedian) {
maxBoostAgainstMedian = bstAgainstMedian;
}
if (maxBoostAgainstMin < bstAgainstMinimum) {
maxBoostAgainstMin = bstAgainstMinimum;
}
totalBoostAgainstMinCertainModel += bstAgainstMinimum;
totalBoostAgainstMedianCertainModel += bstAgainstMedian;
if (maxBoostAgainstMinCertainModel < bstAgainstMinimum) {
maxBoostAgainstMinCertainModel = bstAgainstMinimum;
}
if (maxBoostAgainstMedianCertainModel < bstAgainstMedian) {
maxBoostAgainstMedianCertainModel = bstAgainstMedian;
}
countOfEachModel++;
}
}
LOG.info("Model:{}, The best performance boost against median value is "
+ "{}", model, maxBoostAgainstMedianCertainModel);
LOG.info("Model:{}, The aggregated average performance boost against "
+ "median value is {}",
model, totalBoostAgainstMedianCertainModel/countOfEachModel);
LOG.info("Model:{}, The best performance boost against min value is {}",
model, maxBoostAgainstMinCertainModel);
LOG.info("Model:{}, The aggregated average performance boost against "
+ "min value is {}",
model, totalBoostAgainstMinCertainModel/countOfEachModel);
}
LOG.info("For all, the best performance boost against median value is "
+ maxBoostAgainstMedian);
LOG.info("For all, the aggregated average performance boost against median "
+ "value is " + totalBoostAgainstMedian/count);
LOG.info("For all, the best performance boost against min value is "
+ maxBoostAgainstMin);
LOG.info("For all, the aggregated average performance boost against min "
+ "value is " + totalBoostAgainstMin/count);
}
/**
* For <code>gpuCount</code> GPUs allocated by the topology algorithm, return
* its performance boost against the median value.
*
* */
private float calculatePerformanceBoostAgainstMedian(
ActualPerformanceReport report,
String model, int bs, int gpuCount,
NvidiaGPUPluginForRuntimeV2 plugin, Set<Device> allDevice,
Map<String, String> env) {
Set<Device> allocation = plugin.allocateDevices(allDevice, gpuCount, env);
String gpuAllocationString = convertAllocationToGpuString(allocation);
float[] metrics = report.getVariousImagePerSecond(model, bs,
gpuCount, gpuAllocationString);
return metrics[7];
}
/**
* For <code>gpuCount</code> GPUs allocated by the topology algorithm, return
* its performance boost against the minimum value.
*
* */
private float calculatePerformanceBoostAgainstMinimum(
ActualPerformanceReport report,
String model, int bs, int gpuCount,
NvidiaGPUPluginForRuntimeV2 plugin, Set<Device> allDevice,
Map<String, String> env) {
Set<Device> allocation = plugin.allocateDevices(allDevice, gpuCount, env);
String gpuAllocationString = convertAllocationToGpuString(allocation);
float[] metrics = report.getVariousImagePerSecond(model, bs,
gpuCount, gpuAllocationString);
return metrics[5];
}
private String convertAllocationToGpuString(Set<Device> allocation) {
StringBuilder sb = new StringBuilder();
for (Device device : allocation) {
sb.append(device.getMinorNumber() + "_");
}
return sb.toString().substring(0, sb.lastIndexOf("_"));
}
/**
* Representation of the performance data report.
* */
private class ActualPerformanceReport {
private ArrayList<DataRecord> dataSet = new ArrayList<>();
public ArrayList<DataRecord> getDataSet() {
return dataSet;
}
/**
* One line in the report.
* */
private class DataRecord {
DataRecord(String model, int bs, String combination, float fps,
int count) {
this.batchSize = bs;
this.gpuCombination = combination;
this.gpuCount = count;
this.model = model;
this.imagePerSecond = fps;
}
public String getModel() {
return model;
}
public int getBatchSize() {
return batchSize;
}
public String getGpuCombination() {
return gpuCombination;
}
public float getImagePerSecond() {
return imagePerSecond;
}
public int getGpuCount() {
return gpuCount;
}
private String model;
private int batchSize;
private String gpuCombination;
private float imagePerSecond;
private int gpuCount;
}
/**
* The file is a real performance report got from a 8 GPUs AWS instance.
* It contains every combination GPUs' training performance of Tensorflow
* benchmark.
* The columns are the model name, batch size, gpu ids and imagesPerSecond
* */
public void readFromFile() {
String csvReportFilePath = getClass().getClassLoader()
.getResource("tensorflow-bench-result-for-GPU.csv").getFile();
BufferedReader br = null;
String line = "";
try {
br = new BufferedReader(new FileReader(csvReportFilePath));
String model;
int batchSize;
String gpuCombination;
float imagePerSecond;
int gpuCount;
while ((line = br.readLine()) != null) {
// skip the licence content
if (line.startsWith("#")) {
continue;
}
String[] tokens = line.replaceAll("\"", "").split(",");
if (tokens.length != 4) {
LOG.error("unexpected performance data format!");
break;
}
model = tokens[0];
batchSize = Integer.parseInt(tokens[1].trim());
gpuCombination = tokens[2];
imagePerSecond = Float.parseFloat(tokens[3]);
gpuCount = getGpuCount(gpuCombination);
this.dataSet.add(new DataRecord(model, batchSize, gpuCombination,
imagePerSecond, gpuCount));
}
} catch (Exception e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
} // end finally
}
/**
* Return the maximum, minimum, mean and median performance for model &
* bs & gpuCount. And the imagePerSecond for model & bs & gpuCount &
* gpuCombinations. And imagePerSecond performance boost comparing to
* minimum, mean and media value.
* */
private float[] getVariousImagePerSecond(String model, int bs,
int gpuCount, String gpuCombinations) {
float[] result = new float[8];
float max = 0;
float min = Float.MAX_VALUE;
float sum = 0;
int count = 0;
float wantedImagePerSecond = 0;
float currentImagePerSecond;
ArrayList<Float> allFps = new ArrayList<>();
for (DataRecord dr : getDataSet()) {
currentImagePerSecond = dr.getImagePerSecond();
if (dr.batchSize == bs
&& model.equals(dr.getModel())
&& gpuCount == dr.getGpuCount()) {
sum += currentImagePerSecond;
count++;
if (max < currentImagePerSecond) {
max = currentImagePerSecond;
}
if (min > currentImagePerSecond) {
min = currentImagePerSecond;
}
if (gpuCombinations.equals(dr.getGpuCombination())) {
wantedImagePerSecond = dr.getImagePerSecond();
}
allFps.add(dr.getImagePerSecond());
}
}
float median = getMedian(allFps);
float mean = sum/count;
result[0] = max;
result[1] = min;
result[2] = mean;
result[3] = median;
result[4] = wantedImagePerSecond;
result[5] = wantedImagePerSecond/min - 1;
result[6] = wantedImagePerSecond/mean - 1;
result[7] = wantedImagePerSecond/median - 1;
return result;
}
private float getMedian(ArrayList<Float> allFps) {
float[] all = ArrayUtils.toPrimitive(allFps.toArray(new Float[0]), 0);
Arrays.sort(all);
float median;
int size = all.length;
if (allFps.size() % 2 == 0) {
median = (all[size/2] + all[size/2 - 1])/2;
} else {
median = all[size/2];
}
return median;
}
private int getGpuCount(String gpuCombination) {
String[] tokens = gpuCombination.split("_");
return tokens.length;
}
}
}

View File

@ -20,6 +20,7 @@
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
@ -62,7 +63,7 @@ public void onDevicesReleased(Set<Device> allocatedDevices) {
@Override @Override
public Set<Device> allocateDevices(Set<Device> availableDevices, public Set<Device> allocateDevices(Set<Device> availableDevices,
int count) { int count, Map<String, String> env) {
Set<Device> allocated = new TreeSet<Device>(); Set<Device> allocated = new TreeSet<Device>();
int number = 0; int number = 0;
for (Device d : availableDevices) { for (Device d : availableDevices) {

View File

@ -74,10 +74,11 @@
import java.util.TreeSet; import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.isA;
@ -571,7 +572,7 @@ public void testPreferPluginScheduler() throws IOException, YarnException {
adapter.getDeviceResourceHandler().preStart(c1); adapter.getDeviceResourceHandler().preStart(c1);
// Use customized scheduler // Use customized scheduler
verify(spyPlugin, times(1)).allocateDevices( verify(spyPlugin, times(1)).allocateDevices(
any(Set.class), anyInt()); anySet(), anyInt(), anyMap());
Assert.assertEquals(2, Assert.assertEquals(2,
dmm.getAvailableDevices(resourceName)); dmm.getAvailableDevices(resourceName));
Assert.assertEquals(1, Assert.assertEquals(1,
@ -995,7 +996,7 @@ public void onDevicesReleased(Set<Device> releasedDevices) {
@Override @Override
public Set<Device> allocateDevices(Set<Device> availableDevices, public Set<Device> allocateDevices(Set<Device> availableDevices,
int count) { int count, Map<String, String> env) {
Set<Device> allocated = new TreeSet<>(); Set<Device> allocated = new TreeSet<>();
int number = 0; int number = 0;
for (Device d : availableDevices) { for (Device d : availableDevices) {

View File

@ -1,108 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.nvidia.com;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPluginForRuntimeV2;
import org.junit.Assert;
import org.junit.Test;
import java.util.Set;
import java.util.TreeSet;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Test case for Nvidia GPU device plugin.
* */
public class TestNvidiaGpuPlugin {
@Test
public void testGetNvidiaDevices() throws Exception {
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
String deviceInfoShellOutput =
"0, 00000000:04:00.0\n" +
"1, 00000000:82:00.0";
String majorMinorNumber0 = "c3:0";
String majorMinorNumber1 = "c3:1";
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
when(mockShell.getMajorMinorInfo("nvidia0"))
.thenReturn(majorMinorNumber0);
when(mockShell.getMajorMinorInfo("nvidia1"))
.thenReturn(majorMinorNumber1);
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
plugin.setShellExecutor(mockShell);
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
Set<Device> expectedDevices = new TreeSet<>();
expectedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
expectedDevices.add(Device.Builder.newInstance()
.setId(1).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
Set<Device> devices = plugin.getDevices();
Assert.assertEquals(expectedDevices, devices);
}
@Test
public void testOnDeviceAllocated() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
Set<Device> allocatedDevices = new TreeSet<>();
DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DEFAULT);
Assert.assertNull(spec);
// allocate one device
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
// two device allowed
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
}
}