YARN-8821. [YARN-8851] GPU hierarchy/topology scheduling support based on pluggable device framework. Contributed by Zhankun Tang.
This commit is contained in:
parent
106bdc6c04
commit
dddcfa4d9f
@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin;
|
package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -29,10 +30,15 @@ public interface DevicePluginScheduler {
|
|||||||
* Called when allocating devices. The framework will do all device book
|
* Called when allocating devices. The framework will do all device book
|
||||||
* keeping and fail recovery. So this hook could be stateless and only do
|
* keeping and fail recovery. So this hook could be stateless and only do
|
||||||
* scheduling based on available devices passed in. It could be
|
* scheduling based on available devices passed in. It could be
|
||||||
* invoked multiple times by the framework.
|
* invoked multiple times by the framework. The hint in environment variables
|
||||||
|
* passed in could be potentially used in making better scheduling decision.
|
||||||
|
* For instance, GPU scheduling might support different kind of policy. The
|
||||||
|
* container can set it through environment variables.
|
||||||
* @param availableDevices Devices allowed to be chosen from.
|
* @param availableDevices Devices allowed to be chosen from.
|
||||||
* @param count Number of device to be allocated.
|
* @param count Number of device to be allocated.
|
||||||
|
* @param env Environment variables of the container.
|
||||||
* @return A set of {@link Device} allocated
|
* @return A set of {@link Device} allocated
|
||||||
* */
|
* */
|
||||||
Set<Device> allocateDevices(Set<Device> availableDevices, int count);
|
Set<Device> allocateDevices(Set<Device> availableDevices, int count,
|
||||||
|
Map<String, String> env);
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
|
||||||
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
|
||||||
@ -32,7 +33,12 @@
|
|||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
@ -40,8 +46,10 @@
|
|||||||
/**
|
/**
|
||||||
* Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
|
* Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
|
||||||
* non-Docker container.
|
* non-Docker container.
|
||||||
|
* It has topology aware as well as simple scheduling ability.
|
||||||
* */
|
* */
|
||||||
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
|
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin,
|
||||||
|
DevicePluginScheduler {
|
||||||
public static final Logger LOG = LoggerFactory.getLogger(
|
public static final Logger LOG = LoggerFactory.getLogger(
|
||||||
NvidiaGPUPluginForRuntimeV2.class);
|
NvidiaGPUPluginForRuntimeV2.class);
|
||||||
|
|
||||||
@ -69,6 +77,47 @@ public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
|
|||||||
private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
|
private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
|
||||||
"/usr/bin", "/bin", "/usr/local/nvidia/bin");
|
"/usr/bin", "/bin", "/usr/local/nvidia/bin");
|
||||||
|
|
||||||
|
private boolean topoInitialized = false;
|
||||||
|
|
||||||
|
private Set<Device> lastTimeFoundDevices;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* It caches the combination of different devices and the communication cost.
|
||||||
|
* The key is device count
|
||||||
|
* The value is an ordered list of map entry whose key is device combination,
|
||||||
|
* value is cost. The list is sorted by cost in ascending order.
|
||||||
|
* For instance:
|
||||||
|
* { 2=> [[device1,device2]=>0, [device1,device3]=>10]
|
||||||
|
* 3 => [[device1,device2,device3]=>10, [device2,device3,device5]=>20],
|
||||||
|
* }
|
||||||
|
* */
|
||||||
|
private Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable
|
||||||
|
= new HashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The key is a pair of minors. For instance, "0-1" indicates 0 to 1
|
||||||
|
* The value is weight between the two devices.
|
||||||
|
* */
|
||||||
|
private Map<String, Integer> devicePairToWeight = new HashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The container can set this environment variable.
|
||||||
|
* To tell the scheduler what's the policy to use when do scheduling
|
||||||
|
* */
|
||||||
|
public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schedule policy that prefer the faster GPU-GPU communication.
|
||||||
|
* Suitable for heavy GPU computation workload generally.
|
||||||
|
* */
|
||||||
|
public static final String TOPOLOGY_POLICY_PACK = "PACK";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schedule policy that prefer the faster CPU-GPU communication.
|
||||||
|
* Suitable for heavy CPU-GPU IO operations generally.
|
||||||
|
* */
|
||||||
|
public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
|
public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
|
||||||
return DeviceRegisterRequest.Builder.newInstance()
|
return DeviceRegisterRequest.Builder.newInstance()
|
||||||
@ -106,6 +155,8 @@ public Set<Device> getDevices() throws Exception {
|
|||||||
id++;
|
id++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// cache it which help to topology scheduling
|
||||||
|
lastTimeFoundDevices = r;
|
||||||
return r;
|
return r;
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
if (LOG.isDebugEnabled()) {
|
if (LOG.isDebugEnabled()) {
|
||||||
@ -170,6 +221,422 @@ private String getMajorNumber(String devName) {
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<Device> allocateDevices(Set<Device> availableDevices, int count,
|
||||||
|
Map<String, String> envs) {
|
||||||
|
Set<Device> allocation = new TreeSet<>();
|
||||||
|
/**
|
||||||
|
* corner cases.
|
||||||
|
* if allocate 1 device or all devices, no topo scheduling needed.
|
||||||
|
* if total available devices is less than 3, no topo scheduling needed.
|
||||||
|
* */
|
||||||
|
if (availableDevices.size() < 3
|
||||||
|
|| count == 1
|
||||||
|
|| availableDevices.size() == count) {
|
||||||
|
basicSchedule(allocation, count, availableDevices);
|
||||||
|
return allocation;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!topoInitialized) {
|
||||||
|
initCostTable();
|
||||||
|
}
|
||||||
|
// topology aware scheduling
|
||||||
|
topologyAwareSchedule(allocation, count,
|
||||||
|
envs, availableDevices, this.costTable);
|
||||||
|
if (allocation.size() == count) {
|
||||||
|
return allocation;
|
||||||
|
} else {
|
||||||
|
LOG.error("Failed to do topology scheduling. Skip to use basic "
|
||||||
|
+ "scheduling");
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
LOG.error("Error in getting GPU topology info. "
|
||||||
|
+ "Skip topology aware scheduling", e);
|
||||||
|
}
|
||||||
|
// basic scheduling
|
||||||
|
basicSchedule(allocation, count, availableDevices);
|
||||||
|
return allocation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public void initCostTable() throws IOException {
|
||||||
|
// get topology
|
||||||
|
String topo = shellExecutor.getTopologyInfo();
|
||||||
|
// build the graph
|
||||||
|
parseTopo(topo, devicePairToWeight);
|
||||||
|
// build the cost table of different device combinations
|
||||||
|
if (lastTimeFoundDevices == null) {
|
||||||
|
try {
|
||||||
|
getDevices();
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.error("Failed to get devices!", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buildCostTable(costTable, lastTimeFoundDevices);
|
||||||
|
loggingCostTable(costTable);
|
||||||
|
this.topoInitialized = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loggingCostTable(
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
|
||||||
|
if (LOG.isDebugEnabled()) {
|
||||||
|
StringBuilder sb = new StringBuilder("The costTable is:");
|
||||||
|
sb.append("\n{");
|
||||||
|
for (Map.Entry<Integer, List<Map.Entry<Set<Device>, Integer>>> entry
|
||||||
|
: cTable.entrySet()) {
|
||||||
|
sb.append("\n\t")
|
||||||
|
.append(entry.getKey())
|
||||||
|
.append(" => [");
|
||||||
|
for (Map.Entry<Set<Device>, Integer> e : entry.getValue()) {
|
||||||
|
sb.append("\n\t\t").append(e.toString()).append(",\n");
|
||||||
|
}
|
||||||
|
sb.append("\t\t]\n");
|
||||||
|
}
|
||||||
|
sb.append("}\n");
|
||||||
|
LOG.debug(sb.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate combination of devices and its cost.
|
||||||
|
* costTable
|
||||||
|
* */
|
||||||
|
private void buildCostTable(
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
|
||||||
|
Set<Device> ltfDevices) {
|
||||||
|
Device[] deviceList = new Device[ltfDevices.size()];
|
||||||
|
ltfDevices.toArray(deviceList);
|
||||||
|
generateAllDeviceCombination(cTable, deviceList, deviceList.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For every possible combination of i elements.
|
||||||
|
* We generate a map whose key is the combination, value is cost.
|
||||||
|
*/
|
||||||
|
private void generateAllDeviceCombination(
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
|
||||||
|
Device[] allDevices, int n) {
|
||||||
|
// allocated devices count range from 1 to n-1
|
||||||
|
for (int i = 2; i < n; i++) {
|
||||||
|
Map<Set<Device>, Integer> combinationToCost =
|
||||||
|
new HashMap<>();
|
||||||
|
buildCombination(combinationToCost, allDevices, n, i);
|
||||||
|
// sort the map entry by cost ascending order
|
||||||
|
List<Map.Entry<Set<Device>, Integer>> listSortedByCost =
|
||||||
|
new LinkedList<>(combinationToCost.entrySet());
|
||||||
|
Collections.sort(listSortedByCost,
|
||||||
|
(o1, o2) -> (o1.getValue()).compareTo(o2.getValue()));
|
||||||
|
cTable.put(i, listSortedByCost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void buildCombination(Map<Set<Device>, Integer> combinationToCost,
|
||||||
|
Device[] allDevices, int n, int r) {
|
||||||
|
// A temporary list to store all combination one by one
|
||||||
|
Device[] subDeviceList = new Device[r];
|
||||||
|
combinationRecursive(combinationToCost, allDevices, subDeviceList,
|
||||||
|
0, n - 1, 0, r);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populate combination to cost map recursively.
|
||||||
|
*
|
||||||
|
* @param cTc combinationToCost map.
|
||||||
|
* The key is device set, the value is cost
|
||||||
|
* @param allDevices all devices used to assign value to subDevicelist
|
||||||
|
* @param subDeviceList store a subset of devices temporary
|
||||||
|
* @param start start index in the allDevices
|
||||||
|
* @param end last index in the allDevices
|
||||||
|
* @param index dynamic index in subDeviceList need to be assigned
|
||||||
|
* @param r the length of the subDeviceList
|
||||||
|
*/
|
||||||
|
void combinationRecursive(Map<Set<Device>, Integer> cTc,
|
||||||
|
Device[] allDevices, Device[] subDeviceList,
|
||||||
|
int start, int end, int index, int r) {
|
||||||
|
// sub device list's length is ready to compute the cost
|
||||||
|
if (index == r) {
|
||||||
|
Set<Device> oneSet = new TreeSet<>(Arrays.asList(subDeviceList));
|
||||||
|
int cost = computeCostOfDevices(subDeviceList);
|
||||||
|
cTc.put(oneSet, cost);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int i = start; i <= end; i++) {
|
||||||
|
subDeviceList[index] = allDevices[i];
|
||||||
|
combinationRecursive(cTc, allDevices, subDeviceList,
|
||||||
|
i + 1, end, index + 1, r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The cost function used to calculate costs of a sub set of devices.
|
||||||
|
* It calculate link weight of each pair in non-duplicated combination of
|
||||||
|
* devices.
|
||||||
|
*/
|
||||||
|
@VisibleForTesting
|
||||||
|
public int computeCostOfDevices(Device[] devices) {
|
||||||
|
int cost = 0;
|
||||||
|
String gpuIndex0;
|
||||||
|
String gpuIndex1;
|
||||||
|
for (int i = 0; i < devices.length; i++) {
|
||||||
|
gpuIndex0 = String.valueOf(devices[i].getMinorNumber());
|
||||||
|
for (int j = i + 1; j < devices.length; j++) {
|
||||||
|
gpuIndex1 = String.valueOf(devices[j].getMinorNumber());
|
||||||
|
cost += this.devicePairToWeight.get(gpuIndex0 + "-" + gpuIndex1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Topology Aware schedule algorithm.
|
||||||
|
* It doesn't consider CPU affinity or NUMA or bus bandwidths.
|
||||||
|
* It support two plicy: "spread" and "pack" which can be set by container's
|
||||||
|
* environment variable. Use pack by default which means prefer the faster
|
||||||
|
* GPU-GPU. "Spread" means prefer the faster CPU-GPU.
|
||||||
|
* It can potentially be extend to take GPU attribute like GPU chip memory
|
||||||
|
* into consideration.
|
||||||
|
* */
|
||||||
|
@VisibleForTesting
|
||||||
|
public void topologyAwareSchedule(Set<Device> allocation, int count,
|
||||||
|
Map<String, String> envs,
|
||||||
|
Set<Device> availableDevices,
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
|
||||||
|
int num = 0;
|
||||||
|
String policy = envs.get(TOPOLOGY_POLICY_ENV_KEY);
|
||||||
|
if (policy == null) {
|
||||||
|
policy = TOPOLOGY_POLICY_PACK;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get combinations from costTable given the count of device want to
|
||||||
|
* allocate.
|
||||||
|
* */
|
||||||
|
if (cTable == null) {
|
||||||
|
LOG.error("No cost table initialized!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<Map.Entry<Set<Device>, Integer>> combinationsToCost =
|
||||||
|
cTable.get(count);
|
||||||
|
Iterator<Map.Entry<Set<Device>, Integer>> iterator =
|
||||||
|
combinationsToCost.iterator();
|
||||||
|
// the container needs spread policy
|
||||||
|
if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) {
|
||||||
|
// loop from high cost to low cost
|
||||||
|
iterator = ((LinkedList) combinationsToCost).descendingIterator();
|
||||||
|
}
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
Map.Entry<Set<Device>, Integer> element = iterator.next();
|
||||||
|
if (availableDevices.containsAll(element.getKey())) {
|
||||||
|
allocation.addAll(element.getKey());
|
||||||
|
LOG.info("Topology scheduler allocated: " + allocation);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG.error("Unknown error happened in topology scheduler");
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public void basicSchedule(Set<Device> allocation, int count,
|
||||||
|
Set<Device> availableDevices) {
|
||||||
|
// Basic scheduling
|
||||||
|
// allocate all available
|
||||||
|
if (count == availableDevices.size()) {
|
||||||
|
allocation.addAll(availableDevices);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int number = 0;
|
||||||
|
for (Device d : availableDevices) {
|
||||||
|
allocation.add(d);
|
||||||
|
number++;
|
||||||
|
if (number == count) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A typical sample topo output:
|
||||||
|
* GPU0 GPU1 GPU2 GPU3 CPU Affinity
|
||||||
|
* GPU0 X PHB SOC SOC 0-31
|
||||||
|
* GPU1 PHB X SOC SOC 0-31
|
||||||
|
* GPU2 SOC SOC X PHB 0-31
|
||||||
|
* GPU3 SOC SOC PHB X 0-31
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* Legend:
|
||||||
|
*
|
||||||
|
* X = Self
|
||||||
|
* SOC = Connection traversing PCIe as well as the SMP link between
|
||||||
|
* CPU sockets(e.g. QPI)
|
||||||
|
* PHB = Connection traversing PCIe as well as a PCIe Host Bridge
|
||||||
|
* (typically the CPU)
|
||||||
|
* PXB = Connection traversing multiple PCIe switches
|
||||||
|
* (without traversing the PCIe Host Bridge)
|
||||||
|
* PIX = Connection traversing a single PCIe switch
|
||||||
|
* NV# = Connection traversing a bonded set of # NVLinks」
|
||||||
|
* */
|
||||||
|
public void parseTopo(String topo,
|
||||||
|
Map<String, Integer> deviceLinkToWeight) {
|
||||||
|
String[] lines = topo.split("\n");
|
||||||
|
int rowMinor;
|
||||||
|
int colMinor;
|
||||||
|
String legend;
|
||||||
|
String tempType;
|
||||||
|
for (String oneLine : lines) {
|
||||||
|
oneLine = oneLine.trim();
|
||||||
|
if (oneLine.isEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// To the end. No more metrics info
|
||||||
|
if (oneLine.startsWith("Legend")) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Skip header
|
||||||
|
if (oneLine.contains("Affinity")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String[] tokens = oneLine.split(("\\s+"));
|
||||||
|
String name = tokens[0];
|
||||||
|
rowMinor = Integer.parseInt(name.substring(name.lastIndexOf("U") + 1));
|
||||||
|
for (int i = 1; i < tokens.length; i++) {
|
||||||
|
tempType = tokens[i];
|
||||||
|
colMinor = i - 1;
|
||||||
|
// self, skip
|
||||||
|
if (tempType.equals("X")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("SOC") || tempType.equals("SYS")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("PHB") || tempType.equals("NODE")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("PXB")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("PIX")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV1")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink1,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV2")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink2,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV3")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink3,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV4")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink4,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV5")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink5,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV6")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink6,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV7")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink7,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV8")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink8,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tempType.equals("NV9")) {
|
||||||
|
populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink9,
|
||||||
|
rowMinor, colMinor, deviceLinkToWeight);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} // end one line handling
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void populateGraphEdgeWeight(
|
||||||
|
DeviceLinkType linkType,
|
||||||
|
int leftVertex,
|
||||||
|
int rightVertex,
|
||||||
|
Map<String, Integer> deviceLinkToWeight) {
|
||||||
|
deviceLinkToWeight.put(leftVertex + "-" + rightVertex,
|
||||||
|
linkType.getWeight());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Different type of link.
|
||||||
|
* The weight of each link is a relative value.
|
||||||
|
* The higher weight, the higher cost between the GPUs
|
||||||
|
* */
|
||||||
|
public enum DeviceLinkType {
|
||||||
|
/**
|
||||||
|
* For Nvdia GPU NVLink.
|
||||||
|
* */
|
||||||
|
P2PLinkNVLink9(10),
|
||||||
|
P2PLinkNVLink8(20),
|
||||||
|
P2PLinkNVLink7(30),
|
||||||
|
P2PLinkNVLink6(40),
|
||||||
|
P2PLinkNVLink5(50),
|
||||||
|
P2PLinkNVLink4(60),
|
||||||
|
P2PLinkNVLink3(70),
|
||||||
|
P2PLinkNVLink2(80),
|
||||||
|
P2PLinkNVLink1(90),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Connected to same CPU (Same NUMA node).
|
||||||
|
* */
|
||||||
|
P2PLinkSameCPUSocket(200),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cross CPU through socket-level link (e.g. QPI).
|
||||||
|
* Usually cross NUMA node
|
||||||
|
* */
|
||||||
|
P2PLinkCrossCPUSocket(300),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Just need to traverse one PCIe switch to talk.
|
||||||
|
* */
|
||||||
|
P2PLinkSingleSwitch(600),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Need to traverse multiple PCIe switch to talk.
|
||||||
|
* */
|
||||||
|
P2PLinkMultiSwitch(1200);
|
||||||
|
|
||||||
|
// A higher link level means slower communication.
|
||||||
|
private int weight;
|
||||||
|
|
||||||
|
public int getWeight() {
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceLinkType(int w) {
|
||||||
|
this.weight = w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A shell wrapper class easy for test.
|
* A shell wrapper class easy for test.
|
||||||
* */
|
* */
|
||||||
@ -189,6 +656,13 @@ public String getMajorMinorInfo(String devName) throws IOException {
|
|||||||
return shexec.getOutput();
|
return shexec.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the topology metrics info from nvdia-smi
|
||||||
|
public String getTopologyInfo() throws IOException {
|
||||||
|
return Shell.execCommand(environment,
|
||||||
|
new String[]{pathOfGpuBinary, "topo",
|
||||||
|
"-m"}, MAX_EXEC_TIMEOUT_MS);
|
||||||
|
}
|
||||||
|
|
||||||
public void searchBinary() throws Exception {
|
public void searchBinary() throws Exception {
|
||||||
if (pathOfGpuBinary != null) {
|
if (pathOfGpuBinary != null) {
|
||||||
LOG.info("Skip searching, the nvidia gpu binary is already set: "
|
LOG.info("Skip searching, the nvidia gpu binary is already set: "
|
||||||
@ -228,8 +702,8 @@ public void searchBinary() throws Exception {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public void setPathOfGpuBinary(String pathOfGpuBinary) {
|
public void setPathOfGpuBinary(String pOfGpuBinary) {
|
||||||
this.pathOfGpuBinary = pathOfGpuBinary;
|
this.pathOfGpuBinary = pOfGpuBinary;
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
@ -237,4 +711,20 @@ public void setShellExecutor(
|
|||||||
NvidiaCommandExecutor shellExecutor) {
|
NvidiaCommandExecutor shellExecutor) {
|
||||||
this.shellExecutor = shellExecutor;
|
this.shellExecutor = shellExecutor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public boolean isTopoInitialized() {
|
||||||
|
return topoInitialized;
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public Map<Integer, List<Map.Entry<Set<Device>, Integer>>> getCostTable() {
|
||||||
|
return costTable;
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public Map<String, Integer> getDevicePairToWeight() {
|
||||||
|
return devicePairToWeight;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
|
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
import com.google.common.collect.ImmutableSet;
|
import com.google.common.collect.ImmutableSet;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
@ -192,7 +193,7 @@ private synchronized DeviceAllocation internalAssignDevices(
|
|||||||
DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName);
|
DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName);
|
||||||
// Prefer DevicePluginScheduler logic
|
// Prefer DevicePluginScheduler logic
|
||||||
pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices,
|
pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices,
|
||||||
containerId, requestedDeviceCount, resourceName, dps);
|
container, requestedDeviceCount, resourceName, dps);
|
||||||
|
|
||||||
// Record in state store if we allocated anything
|
// Record in state store if we allocated anything
|
||||||
if (!assignedDevices.isEmpty()) {
|
if (!assignedDevices.isEmpty()) {
|
||||||
@ -310,9 +311,11 @@ private long getReleasingDevices(String resourceName) {
|
|||||||
* */
|
* */
|
||||||
private void pickAndDoSchedule(Set<Device> allowed,
|
private void pickAndDoSchedule(Set<Device> allowed,
|
||||||
Map<Device, ContainerId> used, Set<Device> assigned,
|
Map<Device, ContainerId> used, Set<Device> assigned,
|
||||||
ContainerId containerId, int count, String resourceName,
|
Container c, int count, String resourceName,
|
||||||
DevicePluginScheduler dps) throws ResourceHandlerException {
|
DevicePluginScheduler dps)
|
||||||
|
throws ResourceHandlerException {
|
||||||
|
ContainerId containerId = c.getContainerId();
|
||||||
|
Map<String, String> env = c.getLaunchContext().getEnvironment();
|
||||||
if (null == dps) {
|
if (null == dps) {
|
||||||
if (LOG.isDebugEnabled()) {
|
if (LOG.isDebugEnabled()) {
|
||||||
LOG.debug("Customized device plugin scheduler is preferred "
|
LOG.debug("Customized device plugin scheduler is preferred "
|
||||||
@ -331,7 +334,8 @@ private void pickAndDoSchedule(Set<Device> allowed,
|
|||||||
// Pass in unmodifiable set
|
// Pass in unmodifiable set
|
||||||
Set<Device> dpsAllocated = dps.allocateDevices(
|
Set<Device> dpsAllocated = dps.allocateDevices(
|
||||||
Sets.difference(allowed, used.keySet()),
|
Sets.difference(allowed, used.keySet()),
|
||||||
count);
|
count,
|
||||||
|
ImmutableMap.copyOf(env));
|
||||||
if (dpsAllocated.size() != count) {
|
if (dpsAllocated.size() != count) {
|
||||||
throw new ResourceHandlerException(dps.getClass()
|
throw new ResourceHandlerException(dps.getClass()
|
||||||
+ " should allocate " + count
|
+ " should allocate " + count
|
||||||
|
@ -0,0 +1,848 @@
|
|||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* <p>
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* <p>
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
|
||||||
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
|
||||||
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyMap;
|
||||||
|
import static org.mockito.ArgumentMatchers.anySet;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.reset;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test case for NvidiaGPUPluginForRuntimeV2 device plugin.
|
||||||
|
* */
|
||||||
|
public class TestNvidiaGPUPluginForRuntimeV2 {
|
||||||
|
|
||||||
|
private static final Logger LOG =
|
||||||
|
LoggerFactory.getLogger(TestNvidiaGPUPluginForRuntimeV2.class);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetNvidiaDevices() throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
|
||||||
|
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
|
||||||
|
String deviceInfoShellOutput =
|
||||||
|
"0, 00000000:04:00.0\n" +
|
||||||
|
"1, 00000000:82:00.0";
|
||||||
|
String majorMinorNumber0 = "c3:0";
|
||||||
|
String majorMinorNumber1 = "c3:1";
|
||||||
|
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia0"))
|
||||||
|
.thenReturn(majorMinorNumber0);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia1"))
|
||||||
|
.thenReturn(majorMinorNumber1);
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
||||||
|
plugin.setShellExecutor(mockShell);
|
||||||
|
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
|
||||||
|
|
||||||
|
Set<Device> expectedDevices = new TreeSet<>();
|
||||||
|
expectedDevices.add(Device.Builder.newInstance()
|
||||||
|
.setId(0).setHealthy(true)
|
||||||
|
.setBusID("00000000:04:00.0")
|
||||||
|
.setDevPath("/dev/nvidia0")
|
||||||
|
.setMajorNumber(195)
|
||||||
|
.setMinorNumber(0).build());
|
||||||
|
expectedDevices.add(Device.Builder.newInstance()
|
||||||
|
.setId(1).setHealthy(true)
|
||||||
|
.setBusID("00000000:82:00.0")
|
||||||
|
.setDevPath("/dev/nvidia1")
|
||||||
|
.setMajorNumber(195)
|
||||||
|
.setMinorNumber(1).build());
|
||||||
|
Set<Device> devices = plugin.getDevices();
|
||||||
|
Assert.assertEquals(expectedDevices, devices);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOnDeviceAllocated() throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
||||||
|
Set<Device> allocatedDevices = new TreeSet<>();
|
||||||
|
|
||||||
|
DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
|
||||||
|
YarnRuntimeType.RUNTIME_DEFAULT);
|
||||||
|
Assert.assertNull(spec);
|
||||||
|
|
||||||
|
// allocate one device
|
||||||
|
allocatedDevices.add(Device.Builder.newInstance()
|
||||||
|
.setId(0).setHealthy(true)
|
||||||
|
.setBusID("00000000:04:00.0")
|
||||||
|
.setDevPath("/dev/nvidia0")
|
||||||
|
.setMajorNumber(195)
|
||||||
|
.setMinorNumber(0).build());
|
||||||
|
spec = plugin.onDevicesAllocated(allocatedDevices,
|
||||||
|
YarnRuntimeType.RUNTIME_DOCKER);
|
||||||
|
Assert.assertEquals("nvidia", spec.getContainerRuntime());
|
||||||
|
Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
|
||||||
|
|
||||||
|
// two device allowed
|
||||||
|
allocatedDevices.add(Device.Builder.newInstance()
|
||||||
|
.setId(0).setHealthy(true)
|
||||||
|
.setBusID("00000000:82:00.0")
|
||||||
|
.setDevPath("/dev/nvidia1")
|
||||||
|
.setMajorNumber(195)
|
||||||
|
.setMinorNumber(1).build());
|
||||||
|
spec = plugin.onDevicesAllocated(allocatedDevices,
|
||||||
|
YarnRuntimeType.RUNTIME_DOCKER);
|
||||||
|
Assert.assertEquals("nvidia", spec.getContainerRuntime());
|
||||||
|
Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private NvidiaGPUPluginForRuntimeV2 mockEightGPUPlugin() throws IOException {
|
||||||
|
String topoInfo =
|
||||||
|
"\tGPU0\tGPU1\tGPU2\tGPU3\tGPU4\tGPU5\tGPU6\tGPU7\tCPU Affinity\n"
|
||||||
|
+ "GPU0\t X \tNV1\tNV1\tNV2\tNV2\tPHB\tPHB\tPHB\t0-63\n"
|
||||||
|
+ "GPU1\tNV1\t X \tNV2\tNV1\tPHB\tNV2\tPHB\tPHB\t0-63\n"
|
||||||
|
+ "GPU2\tNV1\tNV2\t X \tNV2\tPHB\tPHB\tNV1\tPHB\t0-63\n"
|
||||||
|
+ "GPU3\tNV2\tNV1\tNV2\t X \tPHB\tPHB\tPHB\tNV1\t0-63\n"
|
||||||
|
+ "GPU4\tNV2\tPHB\tPHB\tPHB\t X \tNV1\tNV1\tNV2\t0-63\n"
|
||||||
|
+ "GPU5\tPHB\tNV2\tPHB\tPHB\tNV1\t X \tNV2\tNV1\t0-63\n"
|
||||||
|
+ "GPU6\tPHB\tPHB\tNV1\tPHB\tNV1\tNV2\t X \tNV2\t0-63\n"
|
||||||
|
+ "GPU7\tPHB\tPHB\tPHB\tNV1\tNV2\tNV1\tNV2\t X \t0-63\n"
|
||||||
|
+ "\n"
|
||||||
|
+ "Legend:\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " X = Self\n"
|
||||||
|
+ " SYS = Connection traversing PCIe as well as the SMP interconnect"
|
||||||
|
+ " between NUMA nodes (e.g., QPI/UPI)\n"
|
||||||
|
+ " NODE = Connection traversing PCIe as well as the interconnect"
|
||||||
|
+ " between PCIe Host Bridges within a NUMA node\n"
|
||||||
|
+ " PHB = Connection traversing PCIe as well as a PCIe Host Bridge"
|
||||||
|
+ " (typically the CPU)\n"
|
||||||
|
+ " PXB = Connection traversing multiple PCIe switches"
|
||||||
|
+ " (without traversing the PCIe Host Bridge)\n"
|
||||||
|
+ " PIX = Connection traversing a single PCIe switch\n"
|
||||||
|
+ " NV# = Connection traversing a bonded set of # NVLinks\n";
|
||||||
|
|
||||||
|
String deviceInfoShellOutput = "0, 00000000:04:00.0\n"
|
||||||
|
+ "1, 00000000:82:00.0\n"
|
||||||
|
+ "2, 00000000:83:00.0\n"
|
||||||
|
+ "3, 00000000:84:00.0\n"
|
||||||
|
+ "4, 00000000:85:00.0\n"
|
||||||
|
+ "5, 00000000:86:00.0\n"
|
||||||
|
+ "6, 00000000:87:00.0\n"
|
||||||
|
+ "7, 00000000:88:00.0";
|
||||||
|
String majorMinorNumber0 = "c3:0";
|
||||||
|
String majorMinorNumber1 = "c3:1";
|
||||||
|
String majorMinorNumber2 = "c3:2";
|
||||||
|
String majorMinorNumber3 = "c3:3";
|
||||||
|
String majorMinorNumber4 = "c3:4";
|
||||||
|
String majorMinorNumber5 = "c3:5";
|
||||||
|
String majorMinorNumber6 = "c3:6";
|
||||||
|
String majorMinorNumber7 = "c3:7";
|
||||||
|
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
|
||||||
|
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
|
||||||
|
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia0"))
|
||||||
|
.thenReturn(majorMinorNumber0);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia1"))
|
||||||
|
.thenReturn(majorMinorNumber1);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia2"))
|
||||||
|
.thenReturn(majorMinorNumber2);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia3"))
|
||||||
|
.thenReturn(majorMinorNumber3);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia4"))
|
||||||
|
.thenReturn(majorMinorNumber4);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia5"))
|
||||||
|
.thenReturn(majorMinorNumber5);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia6"))
|
||||||
|
.thenReturn(majorMinorNumber6);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia7"))
|
||||||
|
.thenReturn(majorMinorNumber7);
|
||||||
|
when(mockShell.getTopologyInfo()).thenReturn(topoInfo);
|
||||||
|
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
||||||
|
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
||||||
|
plugin.setShellExecutor(mockShell);
|
||||||
|
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NvidiaGPUPluginForRuntimeV2 mockFourGPUPlugin() throws IOException {
|
||||||
|
String topoInfo = "\tGPU0\tGPU1\tGPU2\tGPU3\tCPU Affinity\n"
|
||||||
|
+ "GPU0\t X \tPHB\tSOC\tSOC\t0-31\n"
|
||||||
|
+ "GPU1\tPHB\t X \tSOC\tSOC\t0-31\n"
|
||||||
|
+ "GPU2\tSOC\tSOC\t X \tPHB\t0-31\n"
|
||||||
|
+ "GPU3\tSOC\tSOC\tPHB\t X \t0-31\n"
|
||||||
|
+ "\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " Legend:\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " X = Self\n"
|
||||||
|
+ " SOC = Connection traversing PCIe as well as the SMP link between\n"
|
||||||
|
+ " CPU sockets(e.g. QPI)\n"
|
||||||
|
+ " PHB = Connection traversing PCIe as well as a PCIe Host Bridge\n"
|
||||||
|
+ " (typically the CPU)\n"
|
||||||
|
+ " PXB = Connection traversing multiple PCIe switches\n"
|
||||||
|
+ " (without traversing the PCIe Host Bridge)\n"
|
||||||
|
+ " PIX = Connection traversing a single PCIe switch\n"
|
||||||
|
+ " NV# = Connection traversing a bonded set of # NVLinks";
|
||||||
|
|
||||||
|
String deviceInfoShellOutput = "0, 00000000:04:00.0\n"
|
||||||
|
+ "1, 00000000:82:00.0\n"
|
||||||
|
+ "2, 00000000:83:00.0\n"
|
||||||
|
+ "3, 00000000:84:00.0";
|
||||||
|
String majorMinorNumber0 = "c3:0";
|
||||||
|
String majorMinorNumber1 = "c3:1";
|
||||||
|
String majorMinorNumber2 = "c3:2";
|
||||||
|
String majorMinorNumber3 = "c3:3";
|
||||||
|
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
|
||||||
|
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
|
||||||
|
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia0"))
|
||||||
|
.thenReturn(majorMinorNumber0);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia1"))
|
||||||
|
.thenReturn(majorMinorNumber1);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia2"))
|
||||||
|
.thenReturn(majorMinorNumber2);
|
||||||
|
when(mockShell.getMajorMinorInfo("nvidia3"))
|
||||||
|
.thenReturn(majorMinorNumber3);
|
||||||
|
when(mockShell.getTopologyInfo()).thenReturn(topoInfo);
|
||||||
|
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
||||||
|
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
||||||
|
plugin.setShellExecutor(mockShell);
|
||||||
|
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTopologySchedulingWithPackPolicy() throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
|
||||||
|
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
|
||||||
|
// cache the total devices
|
||||||
|
Set<Device> allDevices = spyPlugin.getDevices();
|
||||||
|
// environment variable to use PACK policy
|
||||||
|
Map<String, String> env = new HashMap<>();
|
||||||
|
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
|
||||||
|
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
|
||||||
|
// Case 0. if available devices is less than 3, no topo scheduling needed
|
||||||
|
Set<Device> copyAvailableDevices = new TreeSet<>(allDevices);
|
||||||
|
Iterator<Device> iterator0 = copyAvailableDevices.iterator();
|
||||||
|
iterator0.next();
|
||||||
|
iterator0.remove();
|
||||||
|
iterator0.next();
|
||||||
|
iterator0.remove();
|
||||||
|
// Case 0. allocate 1 device
|
||||||
|
reset(spyPlugin);
|
||||||
|
Set<Device> allocation = spyPlugin.allocateDevices(copyAvailableDevices,
|
||||||
|
1, env);
|
||||||
|
Assert.assertEquals(allocation.size(), 1);
|
||||||
|
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
|
||||||
|
Assert.assertFalse(spyPlugin.isTopoInitialized());
|
||||||
|
|
||||||
|
// Case 1. allocate 1 device
|
||||||
|
reset(spyPlugin);
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, 1, env);
|
||||||
|
// ensure no topology scheduling needed
|
||||||
|
Assert.assertEquals(allocation.size(), 1);
|
||||||
|
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
|
||||||
|
reset(spyPlugin);
|
||||||
|
// Case 2. allocate all available
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env);
|
||||||
|
Assert.assertEquals(allocation.size(), allDevices.size());
|
||||||
|
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
|
||||||
|
// Case 3. allocate 2 devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
int count = 2;
|
||||||
|
Map<String, Integer> pairToWeight = spyPlugin.getDevicePairToWeight();
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
Device[] allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// Check weights
|
||||||
|
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight(),
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// Case 4. allocate 3 devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
count = 3;
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin, times(0)).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// check weights
|
||||||
|
int expectedWeight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight()
|
||||||
|
+ 2 * NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkCrossCPUSocket.getWeight();
|
||||||
|
Assert.assertEquals(expectedWeight,
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// Case 5. allocate 2 GPUs from three available devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
Iterator<Device> iterator = allDevices.iterator();
|
||||||
|
iterator.next();
|
||||||
|
// remove GPU0
|
||||||
|
iterator.remove();
|
||||||
|
count = 2;
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin, times(0)).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// check weights
|
||||||
|
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight(),
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// it should allocate GPU 2 and 3
|
||||||
|
for (Device device : allocation) {
|
||||||
|
if (device.getMinorNumber() == 2) {
|
||||||
|
Assert.assertTrue(true);
|
||||||
|
} else if (device.getMinorNumber() == 3) {
|
||||||
|
Assert.assertTrue(true);
|
||||||
|
} else {
|
||||||
|
Assert.assertTrue("Should allocate GPU 2 and 3", false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTopologySchedulingWithSpreadPolicy() throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
|
||||||
|
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
|
||||||
|
// cache the total devices
|
||||||
|
Set<Device> allDevices = spyPlugin.getDevices();
|
||||||
|
// environment variable to use PACK policy
|
||||||
|
Map<String, String> env = new HashMap<>();
|
||||||
|
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
|
||||||
|
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_SPREAD);
|
||||||
|
// Case 1. allocate 1 device
|
||||||
|
Set<Device> allocation = spyPlugin.allocateDevices(allDevices, 1, env);
|
||||||
|
// ensure no topology scheduling needed
|
||||||
|
Assert.assertEquals(allocation.size(), 1);
|
||||||
|
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
|
||||||
|
reset(spyPlugin);
|
||||||
|
// Case 2. allocate all available
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, allDevices.size(), env);
|
||||||
|
Assert.assertEquals(allocation.size(), allDevices.size());
|
||||||
|
verify(spyPlugin).basicSchedule(anySet(), anyInt(), anySet());
|
||||||
|
// Case 3. allocate 2 devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
int count = 2;
|
||||||
|
Map<String, Integer> pairToWeight = spyPlugin.getDevicePairToWeight();
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
Device[] allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// Check weights
|
||||||
|
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkCrossCPUSocket.getWeight(),
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// Case 4. allocate 3 devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
count = 3;
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin, times(0)).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// check weights
|
||||||
|
int expectedWeight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight()
|
||||||
|
+ 2 * NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkCrossCPUSocket.getWeight();
|
||||||
|
Assert.assertEquals(expectedWeight,
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// Case 5. allocate 2 GPUs from three available devices
|
||||||
|
reset(spyPlugin);
|
||||||
|
Iterator<Device> iterator = allDevices.iterator();
|
||||||
|
iterator.next();
|
||||||
|
// remove GPU0
|
||||||
|
iterator.remove();
|
||||||
|
count = 2;
|
||||||
|
allocation = spyPlugin.allocateDevices(allDevices, count, env);
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
// the costTable should be init and used topology scheduling
|
||||||
|
verify(spyPlugin, times(0)).initCostTable();
|
||||||
|
Assert.assertTrue(spyPlugin.isTopoInitialized());
|
||||||
|
verify(spyPlugin).topologyAwareSchedule(anySet(), anyInt(), anyMap(),
|
||||||
|
anySet(), anyMap());
|
||||||
|
Assert.assertEquals(allocation.size(), count);
|
||||||
|
allocatedDevices =
|
||||||
|
allocation.toArray(new Device[count]);
|
||||||
|
// check weights
|
||||||
|
Assert.assertEquals(NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkCrossCPUSocket.getWeight(),
|
||||||
|
spyPlugin.computeCostOfDevices(allocatedDevices));
|
||||||
|
// it should allocate GPU 1 and 2
|
||||||
|
for (Device device : allocation) {
|
||||||
|
if (device.getMinorNumber() == 0) {
|
||||||
|
Assert.assertTrue("Shouldn't allocate GPU 0", false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCostTableWithNVlink() throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = mockEightGPUPlugin();
|
||||||
|
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
|
||||||
|
// verify the device pair to weight map
|
||||||
|
spyPlugin.initCostTable();
|
||||||
|
Map<String, Integer> devicePairToWeight = spyPlugin.getDevicePairToWeight();
|
||||||
|
// 12 combinations when choose 2 GPUs from 8 respect the order. 8!/6!
|
||||||
|
Assert.assertEquals(56, devicePairToWeight.size());
|
||||||
|
int sameCPUWeight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight();
|
||||||
|
int nv1Weight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkNVLink1.getWeight();
|
||||||
|
int nv2Weight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkNVLink2.getWeight();
|
||||||
|
|
||||||
|
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("0-1"));
|
||||||
|
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("1-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("0-4"));
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("4-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("0-3"));
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("3-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("6-3"));
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("3-6"));
|
||||||
|
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("6-7"));
|
||||||
|
Assert.assertEquals(nv2Weight, (int)devicePairToWeight.get("7-6"));
|
||||||
|
|
||||||
|
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("1-3"));
|
||||||
|
Assert.assertEquals(nv1Weight, (int)devicePairToWeight.get("3-1"));
|
||||||
|
|
||||||
|
// verify cost Table
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable =
|
||||||
|
spyPlugin.getCostTable();
|
||||||
|
Assert.assertNull(costTable.get(1));
|
||||||
|
// C8:2 = 8!/2!/6! = 28
|
||||||
|
Assert.assertEquals(28, costTable.get(2).size());
|
||||||
|
// C8:4 = 8!/4!/4! = 70
|
||||||
|
Assert.assertEquals(70, costTable.get(4).size());
|
||||||
|
Assert.assertNull(costTable.get(8));
|
||||||
|
|
||||||
|
Set<Device> allDevices = spyPlugin.getDevices();
|
||||||
|
Map<String, String> env = new HashMap<>();
|
||||||
|
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
|
||||||
|
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
|
||||||
|
spyPlugin.allocateDevices(allDevices, 3, env);
|
||||||
|
spyPlugin.allocateDevices(allDevices, 2, env);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test the key cost table used for topology scheduling.
|
||||||
|
* */
|
||||||
|
@Test
|
||||||
|
public void testCostTable() throws IOException {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = mockFourGPUPlugin();
|
||||||
|
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
|
||||||
|
// verify the device pair to weight map
|
||||||
|
spyPlugin.initCostTable();
|
||||||
|
Map<String, Integer> devicePairToWeight = spyPlugin.getDevicePairToWeight();
|
||||||
|
// 12 combinations when choose 2 GPUs from 4 respect the order. 4!/2!
|
||||||
|
Assert.assertEquals(12, devicePairToWeight.size());
|
||||||
|
int sameCPUWeight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkSameCPUSocket.getWeight();
|
||||||
|
int crossCPUWeight =
|
||||||
|
NvidiaGPUPluginForRuntimeV2.DeviceLinkType
|
||||||
|
.P2PLinkCrossCPUSocket.getWeight();
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("0-1"));
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("1-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-2"));
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("0-3"));
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-0"));
|
||||||
|
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-2"));
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("2-1"));
|
||||||
|
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("1-3"));
|
||||||
|
Assert.assertEquals(crossCPUWeight, (int)devicePairToWeight.get("3-1"));
|
||||||
|
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("2-3"));
|
||||||
|
Assert.assertEquals(sameCPUWeight, (int)devicePairToWeight.get("3-2"));
|
||||||
|
|
||||||
|
// verify cost Table
|
||||||
|
Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable =
|
||||||
|
spyPlugin.getCostTable();
|
||||||
|
Assert.assertNull(costTable.get(1));
|
||||||
|
Assert.assertEquals(6, costTable.get(2).size());
|
||||||
|
Assert.assertEquals(4, costTable.get(3).size());
|
||||||
|
Assert.assertNull(costTable.get(4));
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Test GPU topology allocation.
|
||||||
|
* And analysis the GPU allocation's performance against the actual
|
||||||
|
* performance data using tensorflow benchmarks.
|
||||||
|
* https://github.com/tensorflow/benchmarks
|
||||||
|
* */
|
||||||
|
@Test
|
||||||
|
public void testTopologySchedulingPerformanceWithPackPolicyWithNVLink()
|
||||||
|
throws Exception {
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin = mockEightGPUPlugin();
|
||||||
|
NvidiaGPUPluginForRuntimeV2 spyPlugin = spy(plugin);
|
||||||
|
Set<Device> allDevices = spyPlugin.getDevices();
|
||||||
|
Map<String, String> env = new HashMap<>();
|
||||||
|
env.put(NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_ENV_KEY,
|
||||||
|
NvidiaGPUPluginForRuntimeV2.TOPOLOGY_POLICY_PACK);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Analyze performance against the real data.
|
||||||
|
* Get the topology scheduling algorithm's allocation's
|
||||||
|
* average performance boost against median imagePerSecond and minimum
|
||||||
|
* imagePerSecond in certain model and batch size combinations.
|
||||||
|
* And then calculate the average performance boost.
|
||||||
|
* The average performance boost against
|
||||||
|
* median value means topology scheduler's allocation can stably
|
||||||
|
* outperforms 50% of possible allocations.
|
||||||
|
* The average performance boost against min value means the average boost
|
||||||
|
* comparing to the worst allocations in various scenarios. Which is more
|
||||||
|
* beautiful number for public promotion.
|
||||||
|
* And also the analysis shows the best performance boost against median
|
||||||
|
* and min value.
|
||||||
|
* */
|
||||||
|
ActualPerformanceReport report = new ActualPerformanceReport();
|
||||||
|
report.readFromFile();
|
||||||
|
ArrayList<ActualPerformanceReport.DataRecord> dataSet =
|
||||||
|
report.getDataSet();
|
||||||
|
Assert.assertEquals(dataSet.size(), 2952);
|
||||||
|
String[] allModels = {"alexnet", "resnet50", "vgg16", "inception3"};
|
||||||
|
int[] batchSizes = {32, 64, 128};
|
||||||
|
int[] gpuCounts = {2, 3, 4, 5, 6, 7};
|
||||||
|
float totalBoostAgainstMedian = 0;
|
||||||
|
int count = 0;
|
||||||
|
float maxBoostAgainstMedian = 0;
|
||||||
|
float totalBoostAgainstMin = 0;
|
||||||
|
float maxBoostAgainstMin = 0;
|
||||||
|
for (String model : allModels) {
|
||||||
|
float totalBoostAgainstMinCertainModel = 0;
|
||||||
|
float totalBoostAgainstMedianCertainModel = 0;
|
||||||
|
float maxBoostAgainstMinCertainModel = 0;
|
||||||
|
float maxBoostAgainstMedianCertainModel = 0;
|
||||||
|
int countOfEachModel = 0;
|
||||||
|
for (int bs : batchSizes) {
|
||||||
|
for (int gpuCount: gpuCounts) {
|
||||||
|
float bstAgainstMedian = calculatePerformanceBoostAgainstMedian(
|
||||||
|
report, model, bs, gpuCount, plugin, allDevices, env);
|
||||||
|
float bstAgainstMinimum = calculatePerformanceBoostAgainstMinimum(
|
||||||
|
report, model, bs, gpuCount, plugin, allDevices, env);
|
||||||
|
totalBoostAgainstMedian += bstAgainstMedian;
|
||||||
|
totalBoostAgainstMin += bstAgainstMinimum;
|
||||||
|
count++;
|
||||||
|
if (maxBoostAgainstMedian < bstAgainstMedian) {
|
||||||
|
maxBoostAgainstMedian = bstAgainstMedian;
|
||||||
|
}
|
||||||
|
if (maxBoostAgainstMin < bstAgainstMinimum) {
|
||||||
|
maxBoostAgainstMin = bstAgainstMinimum;
|
||||||
|
}
|
||||||
|
totalBoostAgainstMinCertainModel += bstAgainstMinimum;
|
||||||
|
totalBoostAgainstMedianCertainModel += bstAgainstMedian;
|
||||||
|
if (maxBoostAgainstMinCertainModel < bstAgainstMinimum) {
|
||||||
|
maxBoostAgainstMinCertainModel = bstAgainstMinimum;
|
||||||
|
}
|
||||||
|
if (maxBoostAgainstMedianCertainModel < bstAgainstMedian) {
|
||||||
|
maxBoostAgainstMedianCertainModel = bstAgainstMedian;
|
||||||
|
}
|
||||||
|
countOfEachModel++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG.info("Model:{}, The best performance boost against median value is "
|
||||||
|
+ "{}", model, maxBoostAgainstMedianCertainModel);
|
||||||
|
LOG.info("Model:{}, The aggregated average performance boost against "
|
||||||
|
+ "median value is {}",
|
||||||
|
model, totalBoostAgainstMedianCertainModel/countOfEachModel);
|
||||||
|
LOG.info("Model:{}, The best performance boost against min value is {}",
|
||||||
|
model, maxBoostAgainstMinCertainModel);
|
||||||
|
LOG.info("Model:{}, The aggregated average performance boost against "
|
||||||
|
+ "min value is {}",
|
||||||
|
model, totalBoostAgainstMinCertainModel/countOfEachModel);
|
||||||
|
}
|
||||||
|
LOG.info("For all, the best performance boost against median value is "
|
||||||
|
+ maxBoostAgainstMedian);
|
||||||
|
LOG.info("For all, the aggregated average performance boost against median "
|
||||||
|
+ "value is " + totalBoostAgainstMedian/count);
|
||||||
|
LOG.info("For all, the best performance boost against min value is "
|
||||||
|
+ maxBoostAgainstMin);
|
||||||
|
LOG.info("For all, the aggregated average performance boost against min "
|
||||||
|
+ "value is " + totalBoostAgainstMin/count);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For <code>gpuCount</code> GPUs allocated by the topology algorithm, return
|
||||||
|
* its performance boost against the median value.
|
||||||
|
*
|
||||||
|
* */
|
||||||
|
private float calculatePerformanceBoostAgainstMedian(
|
||||||
|
ActualPerformanceReport report,
|
||||||
|
String model, int bs, int gpuCount,
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin, Set<Device> allDevice,
|
||||||
|
Map<String, String> env) {
|
||||||
|
Set<Device> allocation = plugin.allocateDevices(allDevice, gpuCount, env);
|
||||||
|
String gpuAllocationString = convertAllocationToGpuString(allocation);
|
||||||
|
float[] metrics = report.getVariousImagePerSecond(model, bs,
|
||||||
|
gpuCount, gpuAllocationString);
|
||||||
|
return metrics[7];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For <code>gpuCount</code> GPUs allocated by the topology algorithm, return
|
||||||
|
* its performance boost against the minimum value.
|
||||||
|
*
|
||||||
|
* */
|
||||||
|
private float calculatePerformanceBoostAgainstMinimum(
|
||||||
|
ActualPerformanceReport report,
|
||||||
|
String model, int bs, int gpuCount,
|
||||||
|
NvidiaGPUPluginForRuntimeV2 plugin, Set<Device> allDevice,
|
||||||
|
Map<String, String> env) {
|
||||||
|
Set<Device> allocation = plugin.allocateDevices(allDevice, gpuCount, env);
|
||||||
|
String gpuAllocationString = convertAllocationToGpuString(allocation);
|
||||||
|
float[] metrics = report.getVariousImagePerSecond(model, bs,
|
||||||
|
gpuCount, gpuAllocationString);
|
||||||
|
return metrics[5];
|
||||||
|
}
|
||||||
|
|
||||||
|
private String convertAllocationToGpuString(Set<Device> allocation) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (Device device : allocation) {
|
||||||
|
sb.append(device.getMinorNumber() + "_");
|
||||||
|
}
|
||||||
|
return sb.toString().substring(0, sb.lastIndexOf("_"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Representation of the performance data report.
|
||||||
|
* */
|
||||||
|
private class ActualPerformanceReport {
|
||||||
|
|
||||||
|
private ArrayList<DataRecord> dataSet = new ArrayList<>();
|
||||||
|
|
||||||
|
public ArrayList<DataRecord> getDataSet() {
|
||||||
|
return dataSet;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* One line in the report.
|
||||||
|
* */
|
||||||
|
private class DataRecord {
|
||||||
|
DataRecord(String model, int bs, String combination, float fps,
|
||||||
|
int count) {
|
||||||
|
this.batchSize = bs;
|
||||||
|
this.gpuCombination = combination;
|
||||||
|
this.gpuCount = count;
|
||||||
|
this.model = model;
|
||||||
|
this.imagePerSecond = fps;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getModel() {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getBatchSize() {
|
||||||
|
return batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getGpuCombination() {
|
||||||
|
return gpuCombination;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getImagePerSecond() {
|
||||||
|
return imagePerSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getGpuCount() {
|
||||||
|
return gpuCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String model;
|
||||||
|
private int batchSize;
|
||||||
|
private String gpuCombination;
|
||||||
|
private float imagePerSecond;
|
||||||
|
private int gpuCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The file is a real performance report got from a 8 GPUs AWS instance.
|
||||||
|
* It contains every combination GPUs' training performance of Tensorflow
|
||||||
|
* benchmark.
|
||||||
|
* The columns are the model name, batch size, gpu ids and imagesPerSecond
|
||||||
|
* */
|
||||||
|
public void readFromFile() {
|
||||||
|
String csvReportFilePath = getClass().getClassLoader()
|
||||||
|
.getResource("tensorflow-bench-result-for-GPU.csv").getFile();
|
||||||
|
BufferedReader br = null;
|
||||||
|
String line = "";
|
||||||
|
try {
|
||||||
|
br = new BufferedReader(new FileReader(csvReportFilePath));
|
||||||
|
String model;
|
||||||
|
int batchSize;
|
||||||
|
String gpuCombination;
|
||||||
|
float imagePerSecond;
|
||||||
|
int gpuCount;
|
||||||
|
while ((line = br.readLine()) != null) {
|
||||||
|
// skip the licence content
|
||||||
|
if (line.startsWith("#")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String[] tokens = line.replaceAll("\"", "").split(",");
|
||||||
|
if (tokens.length != 4) {
|
||||||
|
LOG.error("unexpected performance data format!");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
model = tokens[0];
|
||||||
|
batchSize = Integer.parseInt(tokens[1].trim());
|
||||||
|
gpuCombination = tokens[2];
|
||||||
|
imagePerSecond = Float.parseFloat(tokens[3]);
|
||||||
|
gpuCount = getGpuCount(gpuCombination);
|
||||||
|
this.dataSet.add(new DataRecord(model, batchSize, gpuCombination,
|
||||||
|
imagePerSecond, gpuCount));
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
} finally {
|
||||||
|
if (br != null) {
|
||||||
|
try {
|
||||||
|
br.close();
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // end finally
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the maximum, minimum, mean and median performance for model &
|
||||||
|
* bs & gpuCount. And the imagePerSecond for model & bs & gpuCount &
|
||||||
|
* gpuCombinations. And imagePerSecond performance boost comparing to
|
||||||
|
* minimum, mean and media value.
|
||||||
|
* */
|
||||||
|
private float[] getVariousImagePerSecond(String model, int bs,
|
||||||
|
int gpuCount, String gpuCombinations) {
|
||||||
|
float[] result = new float[8];
|
||||||
|
float max = 0;
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float sum = 0;
|
||||||
|
int count = 0;
|
||||||
|
float wantedImagePerSecond = 0;
|
||||||
|
float currentImagePerSecond;
|
||||||
|
ArrayList<Float> allFps = new ArrayList<>();
|
||||||
|
for (DataRecord dr : getDataSet()) {
|
||||||
|
currentImagePerSecond = dr.getImagePerSecond();
|
||||||
|
if (dr.batchSize == bs
|
||||||
|
&& model.equals(dr.getModel())
|
||||||
|
&& gpuCount == dr.getGpuCount()) {
|
||||||
|
sum += currentImagePerSecond;
|
||||||
|
count++;
|
||||||
|
if (max < currentImagePerSecond) {
|
||||||
|
max = currentImagePerSecond;
|
||||||
|
}
|
||||||
|
if (min > currentImagePerSecond) {
|
||||||
|
min = currentImagePerSecond;
|
||||||
|
}
|
||||||
|
if (gpuCombinations.equals(dr.getGpuCombination())) {
|
||||||
|
wantedImagePerSecond = dr.getImagePerSecond();
|
||||||
|
}
|
||||||
|
allFps.add(dr.getImagePerSecond());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float median = getMedian(allFps);
|
||||||
|
float mean = sum/count;
|
||||||
|
result[0] = max;
|
||||||
|
result[1] = min;
|
||||||
|
result[2] = mean;
|
||||||
|
result[3] = median;
|
||||||
|
result[4] = wantedImagePerSecond;
|
||||||
|
result[5] = wantedImagePerSecond/min - 1;
|
||||||
|
result[6] = wantedImagePerSecond/mean - 1;
|
||||||
|
result[7] = wantedImagePerSecond/median - 1;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float getMedian(ArrayList<Float> allFps) {
|
||||||
|
float[] all = ArrayUtils.toPrimitive(allFps.toArray(new Float[0]), 0);
|
||||||
|
Arrays.sort(all);
|
||||||
|
float median;
|
||||||
|
int size = all.length;
|
||||||
|
if (allFps.size() % 2 == 0) {
|
||||||
|
median = (all[size/2] + all[size/2 - 1])/2;
|
||||||
|
} else {
|
||||||
|
median = all[size/2];
|
||||||
|
}
|
||||||
|
return median;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getGpuCount(String gpuCombination) {
|
||||||
|
String[] tokens = gpuCombination.split("_");
|
||||||
|
return tokens.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*;
|
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
|
|
||||||
@ -62,7 +63,7 @@ public void onDevicesReleased(Set<Device> allocatedDevices) {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Set<Device> allocateDevices(Set<Device> availableDevices,
|
public Set<Device> allocateDevices(Set<Device> availableDevices,
|
||||||
int count) {
|
int count, Map<String, String> env) {
|
||||||
Set<Device> allocated = new TreeSet<Device>();
|
Set<Device> allocated = new TreeSet<Device>();
|
||||||
int number = 0;
|
int number = 0;
|
||||||
for (Device d : availableDevices) {
|
for (Device d : availableDevices) {
|
||||||
|
@ -74,10 +74,11 @@
|
|||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||||
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyMap;
|
||||||
|
import static org.mockito.ArgumentMatchers.anySet;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.ArgumentMatchers.isA;
|
import static org.mockito.ArgumentMatchers.isA;
|
||||||
@ -571,7 +572,7 @@ public void testPreferPluginScheduler() throws IOException, YarnException {
|
|||||||
adapter.getDeviceResourceHandler().preStart(c1);
|
adapter.getDeviceResourceHandler().preStart(c1);
|
||||||
// Use customized scheduler
|
// Use customized scheduler
|
||||||
verify(spyPlugin, times(1)).allocateDevices(
|
verify(spyPlugin, times(1)).allocateDevices(
|
||||||
any(Set.class), anyInt());
|
anySet(), anyInt(), anyMap());
|
||||||
Assert.assertEquals(2,
|
Assert.assertEquals(2,
|
||||||
dmm.getAvailableDevices(resourceName));
|
dmm.getAvailableDevices(resourceName));
|
||||||
Assert.assertEquals(1,
|
Assert.assertEquals(1,
|
||||||
@ -995,7 +996,7 @@ public void onDevicesReleased(Set<Device> releasedDevices) {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Set<Device> allocateDevices(Set<Device> availableDevices,
|
public Set<Device> allocateDevices(Set<Device> availableDevices,
|
||||||
int count) {
|
int count, Map<String, String> env) {
|
||||||
Set<Device> allocated = new TreeSet<>();
|
Set<Device> allocated = new TreeSet<>();
|
||||||
int number = 0;
|
int number = 0;
|
||||||
for (Device d : availableDevices) {
|
for (Device d : availableDevices) {
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
/**
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
* or more contributor license agreements. See the NOTICE file
|
|
||||||
* distributed with this work for additional information
|
|
||||||
* regarding copyright ownership. The ASF licenses this file
|
|
||||||
* to you under the Apache License, Version 2.0 (the
|
|
||||||
* "License"); you may not use this file except in compliance
|
|
||||||
* with the License. You may obtain a copy of the License at
|
|
||||||
* <p>
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
* <p>
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.nvidia.com;
|
|
||||||
|
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
|
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
|
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
|
|
||||||
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPluginForRuntimeV2;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.TreeSet;
|
|
||||||
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Test case for Nvidia GPU device plugin.
|
|
||||||
* */
|
|
||||||
public class TestNvidiaGpuPlugin {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGetNvidiaDevices() throws Exception {
|
|
||||||
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
|
|
||||||
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
|
|
||||||
String deviceInfoShellOutput =
|
|
||||||
"0, 00000000:04:00.0\n" +
|
|
||||||
"1, 00000000:82:00.0";
|
|
||||||
String majorMinorNumber0 = "c3:0";
|
|
||||||
String majorMinorNumber1 = "c3:1";
|
|
||||||
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
|
|
||||||
when(mockShell.getMajorMinorInfo("nvidia0"))
|
|
||||||
.thenReturn(majorMinorNumber0);
|
|
||||||
when(mockShell.getMajorMinorInfo("nvidia1"))
|
|
||||||
.thenReturn(majorMinorNumber1);
|
|
||||||
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
|
||||||
plugin.setShellExecutor(mockShell);
|
|
||||||
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
|
|
||||||
|
|
||||||
Set<Device> expectedDevices = new TreeSet<>();
|
|
||||||
expectedDevices.add(Device.Builder.newInstance()
|
|
||||||
.setId(0).setHealthy(true)
|
|
||||||
.setBusID("00000000:04:00.0")
|
|
||||||
.setDevPath("/dev/nvidia0")
|
|
||||||
.setMajorNumber(195)
|
|
||||||
.setMinorNumber(0).build());
|
|
||||||
expectedDevices.add(Device.Builder.newInstance()
|
|
||||||
.setId(1).setHealthy(true)
|
|
||||||
.setBusID("00000000:82:00.0")
|
|
||||||
.setDevPath("/dev/nvidia1")
|
|
||||||
.setMajorNumber(195)
|
|
||||||
.setMinorNumber(1).build());
|
|
||||||
Set<Device> devices = plugin.getDevices();
|
|
||||||
Assert.assertEquals(expectedDevices, devices);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testOnDeviceAllocated() throws Exception {
|
|
||||||
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
|
|
||||||
Set<Device> allocatedDevices = new TreeSet<>();
|
|
||||||
|
|
||||||
DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
|
|
||||||
YarnRuntimeType.RUNTIME_DEFAULT);
|
|
||||||
Assert.assertNull(spec);
|
|
||||||
|
|
||||||
// allocate one device
|
|
||||||
allocatedDevices.add(Device.Builder.newInstance()
|
|
||||||
.setId(0).setHealthy(true)
|
|
||||||
.setBusID("00000000:04:00.0")
|
|
||||||
.setDevPath("/dev/nvidia0")
|
|
||||||
.setMajorNumber(195)
|
|
||||||
.setMinorNumber(0).build());
|
|
||||||
spec = plugin.onDevicesAllocated(allocatedDevices,
|
|
||||||
YarnRuntimeType.RUNTIME_DOCKER);
|
|
||||||
Assert.assertEquals("nvidia", spec.getContainerRuntime());
|
|
||||||
Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
|
|
||||||
|
|
||||||
// two device allowed
|
|
||||||
allocatedDevices.add(Device.Builder.newInstance()
|
|
||||||
.setId(0).setHealthy(true)
|
|
||||||
.setBusID("00000000:82:00.0")
|
|
||||||
.setDevPath("/dev/nvidia1")
|
|
||||||
.setMajorNumber(195)
|
|
||||||
.setMinorNumber(1).build());
|
|
||||||
spec = plugin.onDevicesAllocated(allocatedDevices,
|
|
||||||
YarnRuntimeType.RUNTIME_DOCKER);
|
|
||||||
Assert.assertEquals("nvidia", spec.getContainerRuntime());
|
|
||||||
Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user