SUBMARINE-54. Add test coverage for YarnServiceJobSubmitter and make it ready for extension for PyTorch. Contributed by Szilard Nemeth.

This commit is contained in:
Zhankun Tang 2019-04-25 12:52:24 +08:00
parent afe6613ee6
commit 0b3d41bdee
55 changed files with 5614 additions and 1829 deletions

View File

@ -293,10 +293,18 @@ public String getPsDockerImage() {
return psDockerImage;
}
public void setPsDockerImage(String psDockerImage) {
this.psDockerImage = psDockerImage;
}
public String getWorkerDockerImage() {
return workerDockerImage;
}
public void setWorkerDockerImage(String workerDockerImage) {
this.workerDockerImage = workerDockerImage;
}
public boolean isDistributed() {
return distributed;
}
@ -313,6 +321,10 @@ public String getTensorboardDockerImage() {
return tensorboardDockerImage;
}
public void setTensorboardDockerImage(String tensorboardDockerImage) {
this.tensorboardDockerImage = tensorboardDockerImage;
}
public List<Quicklink> getQuicklinks() {
return quicklinks;
}
@ -366,6 +378,10 @@ public RunJobParameters setConfPairs(List<String> confPairs) {
return this;
}
public void setDistributed(boolean distributed) {
this.distributed = distributed;
}
@VisibleForTesting
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
@Override

View File

@ -177,6 +177,25 @@ public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
Assert.assertTrue(success);
}
@Test
public void testJobWithoutName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage =
"--" + CliConstants.NAME + " is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());

View File

@ -26,6 +26,7 @@
import java.io.File;
import java.io.IOException;
import java.util.Objects;
public class MockRemoteDirectoryManager implements RemoteDirectoryManager {
private File jobsParentDir = null;
@ -35,6 +36,7 @@ public class MockRemoteDirectoryManager implements RemoteDirectoryManager {
@Override
public Path getJobStagingArea(String jobName, boolean create)
throws IOException {
Objects.requireNonNull(jobName, "Job name must not be null!");
if (jobsParentDir == null && create) {
jobsParentDir = new File(
"target/_staging_area_" + System.currentTimeMillis());

View File

@ -115,6 +115,12 @@
<artifactId>hadoop-yarn-services-core</artifactId>
<version>3.3.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-common</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import java.io.IOException;
import java.util.Objects;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
/**
* Abstract base class for Component classes.
* The implementations of this class are act like factories for
* {@link Component} instances.
* All dependencies are passed to the constructor so that child classes
* are obliged to provide matching constructors.
*/
public abstract class AbstractComponent {
private final FileSystemOperations fsOperations;
protected final RunJobParameters parameters;
protected final TaskType taskType;
private final RemoteDirectoryManager remoteDirectoryManager;
protected final Configuration yarnConfig;
private final LaunchCommandFactory launchCommandFactory;
/**
* This is only required for testing.
*/
private String localScriptFile;
public AbstractComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, TaskType taskType,
Configuration yarnConfig,
LaunchCommandFactory launchCommandFactory) {
this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters;
this.taskType = taskType;
this.launchCommandFactory = launchCommandFactory;
this.yarnConfig = yarnConfig;
}
protected abstract Component createComponent() throws IOException;
/**
* Generates a command launch script on local disk,
* returns path to the script.
*/
protected void generateLaunchCommand(Component component)
throws IOException {
AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(taskType, component);
this.localScriptFile = launchCommand.generateLaunchScript();
String remoteLaunchCommand = uploadLaunchCommand(component);
component.setLaunchCommand(remoteLaunchCommand);
}
private String uploadLaunchCommand(Component component)
throws IOException {
Objects.requireNonNull(localScriptFile, "localScriptFile should be " +
"set before calling this method!");
Path stagingDir =
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
String destScriptFileName = getScriptFileName(taskType);
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
localScriptFile, destScriptFileName, component);
return "./" + destScriptFileName;
}
String getLocalScriptFile() {
return localScriptFile;
}
}

View File

@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.ConfigFile;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.utils.ZipUtilities;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
/**
* Contains methods to perform file system operations. Almost all of the methods
* are regular non-static methods as the operations are performed with the help
* of a {@link RemoteDirectoryManager} instance passed in as a constructor
* dependency. Please note that some operations require to read config settings
* as well, so that we have Submarine and YARN config objects as dependencies as
* well.
*/
public class FileSystemOperations {
private static final Logger LOG =
LoggerFactory.getLogger(FileSystemOperations.class);
private final Configuration submarineConfig;
private final Configuration yarnConfig;
private Set<Path> uploadedFiles = new HashSet<>();
private RemoteDirectoryManager remoteDirectoryManager;
public FileSystemOperations(ClientContext clientContext) {
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.submarineConfig = clientContext.getSubmarineConfig();
this.yarnConfig = clientContext.getYarnConfig();
}
/**
* May download a remote uri(file/dir) and zip.
* Skip download if local dir
* Remote uri can be a local dir(won't download)
* or remote HDFS dir, s3 dir/file .etc
* */
public String downloadAndZip(String remoteDir, String zipFileName,
boolean doZip)
throws IOException {
//Append original modification time and size to zip file name
String suffix;
String srcDir = remoteDir;
String zipDirPath =
System.getProperty("java.io.tmpdir") + "/" + zipFileName;
boolean needDeleteTempDir = false;
if (remoteDirectoryManager.isRemote(remoteDir)) {
//Append original modification time and size to zip file name
FileStatus status =
remoteDirectoryManager.getRemoteFileStatus(new Path(remoteDir));
suffix = "_" + status.getModificationTime()
+ "-" + remoteDirectoryManager.getRemoteFileSize(remoteDir);
// Download them to temp dir
boolean downloaded =
remoteDirectoryManager.copyRemoteToLocal(remoteDir, zipDirPath);
if (!downloaded) {
throw new IOException("Failed to download files from "
+ remoteDir);
}
LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath);
srcDir = zipDirPath;
needDeleteTempDir = true;
} else {
File localDir = new File(remoteDir);
suffix = "_" + localDir.lastModified()
+ "-" + localDir.length();
}
if (!doZip) {
return srcDir;
}
// zip a local dir
String zipFileUri =
ZipUtilities.zipDir(srcDir, zipDirPath + suffix + ".zip");
// delete downloaded temp dir
if (needDeleteTempDir) {
deleteFiles(srcDir);
}
return zipFileUri;
}
public void deleteFiles(String localUri) {
boolean success = FileUtil.fullyDelete(new File(localUri));
if (!success) {
LOG.warn("Failed to delete {}", localUri);
}
LOG.info("Deleted {}", localUri);
}
@VisibleForTesting
public void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir,
String fileToUpload, String destFilename, Component comp)
throws IOException {
Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload);
locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath);
}
private void locateRemoteFileToContainerWorkDir(String destFilename,
Component comp, Path uploadedFilePath)
throws IOException {
FileSystem fs = FileSystem.get(yarnConfig);
FileStatus fileStatus = fs.getFileStatus(uploadedFilePath);
LOG.info("Uploaded file path = " + fileStatus.getPath());
// Set it to component's files list
comp.getConfiguration().getFiles().add(new ConfigFile().srcFile(
fileStatus.getPath().toUri().toString()).destFile(destFilename)
.type(ConfigFile.TypeEnum.STATIC));
}
public Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws
IOException {
FileSystem fs = remoteDirectoryManager.getDefaultFileSystem();
// Upload to remote FS under staging area
File localFile = new File(fileToUpload);
if (!localFile.exists()) {
throw new FileNotFoundException(
"Trying to upload file=" + localFile.getAbsolutePath()
+ " to remote, but couldn't find local file.");
}
String filename = new File(fileToUpload).getName();
Path uploadedFilePath = new Path(stagingDir, filename);
if (!uploadedFiles.contains(uploadedFilePath)) {
if (SubmarineLogs.isVerbose()) {
LOG.info("Copying local file=" + fileToUpload + " to remote="
+ uploadedFilePath);
}
fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath);
uploadedFiles.add(uploadedFilePath);
}
return uploadedFilePath;
}
public void validFileSize(String uri) throws IOException {
long actualSizeByte;
String locationType = "Local";
if (remoteDirectoryManager.isRemote(uri)) {
actualSizeByte = remoteDirectoryManager.getRemoteFileSize(uri);
locationType = "Remote";
} else {
actualSizeByte = FileUtil.getDU(new File(uri));
}
long maxFileSizeMB = submarineConfig
.getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB);
LOG.info("{} fie/dir: {}, size(Byte):{},"
+ " Allowed max file/dir size: {}",
locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024);
if (actualSizeByte > maxFileSizeMB * 1024 * 1024) {
throw new IOException(uri + " size(Byte): "
+ actualSizeByte + " exceeds configured max size:"
+ maxFileSizeMB * 1024 * 1024);
}
}
public void setPermission(Path destPath, FsPermission permission) throws
IOException {
FileSystem fs = FileSystem.get(yarnConfig);
fs.setPermission(destPath, new FsPermission(permission));
}
public static boolean needHdfs(String content) {
return content != null && content.contains("hdfs://");
}
}

View File

@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.getValueOfEnvironment;
/**
* This class contains helper methods to fill HDFS and Java environment
* variables into scripts.
*/
public class HadoopEnvironmentSetup {
private static final Logger LOG =
LoggerFactory.getLogger(HadoopEnvironmentSetup.class);
private static final String CORE_SITE_XML = "core-site.xml";
private static final String HDFS_SITE_XML = "hdfs-site.xml";
public static final String DOCKER_HADOOP_HDFS_HOME =
"DOCKER_HADOOP_HDFS_HOME";
public static final String DOCKER_JAVA_HOME = "DOCKER_JAVA_HOME";
private final RemoteDirectoryManager remoteDirectoryManager;
private final FileSystemOperations fsOperations;
public HadoopEnvironmentSetup(ClientContext clientContext,
FileSystemOperations fsOperations) {
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.fsOperations = fsOperations;
}
public void addHdfsClassPath(RunJobParameters parameters,
PrintWriter fw, Component comp) throws IOException {
// Find envs to use HDFS
String hdfsHome = null;
String javaHome = null;
boolean hadoopEnv = false;
for (String envVar : parameters.getEnvars()) {
if (envVar.startsWith(DOCKER_HADOOP_HDFS_HOME + "=")) {
hdfsHome = getValueOfEnvironment(envVar);
hadoopEnv = true;
} else if (envVar.startsWith(DOCKER_JAVA_HOME + "=")) {
javaHome = getValueOfEnvironment(envVar);
}
}
boolean hasHdfsEnvs = hdfsHome != null && javaHome != null;
boolean needHdfs = doesNeedHdfs(parameters, hadoopEnv);
if (needHdfs) {
// HDFS is asked either in input or output, set LD_LIBRARY_PATH
// and classpath
if (hdfsHome != null) {
appendHdfsHome(fw, hdfsHome);
}
// hadoop confs will be uploaded to HDFS and localized to container's
// local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR.
fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n");
if (javaHome != null) {
appendJavaHome(fw, javaHome);
}
fw.append(
"export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n");
}
if (needHdfs && !hasHdfsEnvs) {
LOG.error("When HDFS is being used to read/write models/data, " +
"the following environment variables are required: " +
"1) {}=<HDFS_HOME inside docker container> " +
"2) {}=<JAVA_HOME inside docker container>. " +
"You can use --env to pass these environment variables.",
DOCKER_HADOOP_HDFS_HOME, DOCKER_JAVA_HOME);
throw new IOException("Failed to detect HDFS-related environments.");
}
// Trying to upload core-site.xml and hdfs-site.xml
Path stagingDir =
remoteDirectoryManager.getJobStagingArea(
parameters.getName(), true);
File coreSite = findFileOnClassPath(CORE_SITE_XML);
File hdfsSite = findFileOnClassPath(HDFS_SITE_XML);
if (coreSite == null || hdfsSite == null) {
LOG.error("HDFS is being used, however we could not locate " +
"{} nor {} on classpath! " +
"Please double check your classpath setting and make sure these " +
"setting files are included!", CORE_SITE_XML, HDFS_SITE_XML);
throw new IOException(
"Failed to locate core-site.xml / hdfs-site.xml on classpath!");
}
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
coreSite.getAbsolutePath(), CORE_SITE_XML, comp);
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
hdfsSite.getAbsolutePath(), HDFS_SITE_XML, comp);
// DEBUG
if (SubmarineLogs.isVerbose()) {
appendEchoOfEnvVars(fw);
}
}
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
return needHdfs(parameters.getInputPath()) ||
needHdfs(parameters.getPSLaunchCmd()) ||
needHdfs(parameters.getWorkerLaunchCmd()) ||
hadoopEnv;
}
private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
// Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs
// won't pollute docker's env.
fw.append("export HADOOP_HOME=\n");
fw.append("export HADOOP_YARN_HOME=\n");
fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n");
fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n");
}
private void appendJavaHome(PrintWriter fw, String javaHome) {
fw.append("export JAVA_HOME=" + javaHome + "\n");
fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"
+ "$JAVA_HOME/lib/amd64/server\n");
}
private void appendEchoOfEnvVars(PrintWriter fw) {
fw.append("echo \"CLASSPATH:$CLASSPATH\"\n");
fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n");
fw.append(
"echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n");
fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n");
fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n");
fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n");
}
}

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import java.io.IOException;
/**
* This interface is to provide means of creating wrappers around
* {@link org.apache.hadoop.yarn.service.api.records.Service} instances.
*/
public interface ServiceSpec {
ServiceWrapper create() throws IOException;
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.yarn.service.api.records.Service;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser;
/**
* This class is merely responsible for creating Json representation of
* {@link Service} instances.
*/
public final class ServiceSpecFileGenerator {
private ServiceSpecFileGenerator() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
static String generateJson(Service service) throws IOException {
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
String buffer = jsonSerDeser.toJson(service);
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
StandardCharsets.UTF_8);
try (PrintWriter pw = new PrintWriter(w)) {
pw.append(buffer);
}
return serviceSpecFile.getAbsolutePath();
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Service;
import java.io.IOException;
import java.util.Map;
/**
* This class is only existing because we need a component name to
* local launch command mapping from the test code.
* Once this is solved in more clean or different way, we can delete this class.
*/
public class ServiceWrapper {
private final Service service;
@VisibleForTesting
private Map<String, String> componentToLocalLaunchCommand = Maps.newHashMap();
public ServiceWrapper(Service service) {
this.service = service;
}
public void addComponent(AbstractComponent abstractComponent)
throws IOException {
Component component = abstractComponent.createComponent();
service.addComponent(component);
storeComponentName(abstractComponent, component.getName());
}
private void storeComponentName(
AbstractComponent component, String name) {
componentToLocalLaunchCommand.put(name,
component.getLocalScriptFile());
}
public Service getService() {
return service;
}
public String getLocalLaunchCommandPathForComponent(String componentName) {
return componentToLocalLaunchCommand.get(componentName);
}
}

View File

@ -15,858 +15,59 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Artifact;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.ConfigFile;
import org.apache.hadoop.yarn.service.api.records.Resource;
import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants
.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser;
/**
* Submit a job to cluster
* Submit a job to cluster.
*/
public class YarnServiceJobSubmitter implements JobSubmitter {
public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard";
private static final Logger LOG =
LoggerFactory.getLogger(YarnServiceJobSubmitter.class);
ClientContext clientContext;
Service serviceSpec;
private Set<Path> uploadedFiles = new HashSet<>();
private ClientContext clientContext;
private ServiceWrapper serviceWrapper;
// Used by testing
private Map<String, String> componentToLocalLaunchScriptPath =
new HashMap<>();
public YarnServiceJobSubmitter(ClientContext clientContext) {
YarnServiceJobSubmitter(ClientContext clientContext) {
this.clientContext = clientContext;
}
private Resource getServiceResourceFromYarnResource(
org.apache.hadoop.yarn.api.records.Resource yarnResource) {
Resource serviceResource = new Resource();
serviceResource.setCpus(yarnResource.getVirtualCores());
serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize()));
Map<String, ResourceInformation> riMap = new HashMap<>();
for (org.apache.hadoop.yarn.api.records.ResourceInformation ri : yarnResource
.getAllResourcesListCopy()) {
ResourceInformation serviceRi =
new ResourceInformation();
serviceRi.setValue(ri.getValue());
serviceRi.setUnit(ri.getUnits());
riMap.put(ri.getName(), serviceRi);
}
serviceResource.setResourceInformations(riMap);
return serviceResource;
}
private String getValueOfEnvironment(String envar) {
// extract value from "key=value" form
if (envar == null || !envar.contains("=")) {
return "";
} else {
return envar.substring(envar.indexOf("=") + 1);
}
}
private boolean needHdfs(String content) {
return content != null && content.contains("hdfs://");
}
private void addHdfsClassPathIfNeeded(RunJobParameters parameters,
PrintWriter fw, Component comp) throws IOException {
// Find envs to use HDFS
String hdfsHome = null;
String javaHome = null;
boolean hadoopEnv = false;
for (String envar : parameters.getEnvars()) {
if (envar.startsWith("DOCKER_HADOOP_HDFS_HOME=")) {
hdfsHome = getValueOfEnvironment(envar);
hadoopEnv = true;
} else if (envar.startsWith("DOCKER_JAVA_HOME=")) {
javaHome = getValueOfEnvironment(envar);
}
}
boolean lackingEnvs = false;
if (needHdfs(parameters.getInputPath()) || needHdfs(
parameters.getPSLaunchCmd()) || needHdfs(
parameters.getWorkerLaunchCmd()) || hadoopEnv) {
// HDFS is asked either in input or output, set LD_LIBRARY_PATH
// and classpath
if (hdfsHome != null) {
// Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs
// won't pollute docker's env.
fw.append("export HADOOP_HOME=\n");
fw.append("export HADOOP_YARN_HOME=\n");
fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n");
fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n");
} else{
lackingEnvs = true;
}
// hadoop confs will be uploaded to HDFS and localized to container's
// local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR.
fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n");
if (javaHome != null) {
fw.append("export JAVA_HOME=" + javaHome + "\n");
fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"
+ "$JAVA_HOME/lib/amd64/server\n");
} else {
lackingEnvs = true;
}
fw.append("export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n");
}
if (lackingEnvs) {
LOG.error("When hdfs is being used to read/write models/data. Following"
+ "envs are required: 1) DOCKER_HADOOP_HDFS_HOME=<HDFS_HOME inside"
+ "docker container> 2) DOCKER_JAVA_HOME=<JAVA_HOME inside docker"
+ "container>. You can use --env to pass these envars.");
throw new IOException("Failed to detect HDFS-related environments.");
}
// Trying to upload core-site.xml and hdfs-site.xml
Path stagingDir =
clientContext.getRemoteDirectoryManager().getJobStagingArea(
parameters.getName(), true);
File coreSite = findFileOnClassPath("core-site.xml");
File hdfsSite = findFileOnClassPath("hdfs-site.xml");
if (coreSite == null || hdfsSite == null) {
LOG.error("hdfs is being used, however we couldn't locate core-site.xml/"
+ "hdfs-site.xml from classpath, please double check you classpath"
+ "setting and make sure they're included.");
throw new IOException(
"Failed to locate core-site.xml / hdfs-site.xml from class path");
}
uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
coreSite.getAbsolutePath(), "core-site.xml", comp);
uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
hdfsSite.getAbsolutePath(), "hdfs-site.xml", comp);
// DEBUG
if (SubmarineLogs.isVerbose()) {
fw.append("echo \"CLASSPATH:$CLASSPATH\"\n");
fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n");
fw.append("echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n");
fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n");
fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n");
fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n");
}
}
private void addCommonEnvironments(Component component, TaskType taskType) {
Map<String, String> envs = component.getConfiguration().getEnv();
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
envs.put(Envs.TASK_TYPE_ENV, taskType.name());
}
@VisibleForTesting
protected String getUserName() {
return System.getProperty("user.name");
}
private String getDNSDomain() {
return clientContext.getYarnConfig().get("hadoop.registry.dns.domain-name");
}
/*
* Generate a command launch script on local disk, returns patch to the script
*/
private String generateCommandLaunchScript(RunJobParameters parameters,
TaskType taskType, Component comp) throws IOException {
File file = File.createTempFile(taskType.name() + "-launch-script", ".sh");
Writer w = new OutputStreamWriter(new FileOutputStream(file),
StandardCharsets.UTF_8);
PrintWriter pw = new PrintWriter(w);
try {
pw.append("#!/bin/bash\n");
addHdfsClassPathIfNeeded(parameters, pw, comp);
if (taskType.equals(TaskType.TENSORBOARD)) {
String tbCommand =
"export LC_ALL=C && tensorboard --logdir=" + parameters
.getCheckpointPath();
pw.append(tbCommand + "\n");
LOG.info("Tensorboard command=" + tbCommand);
} else{
// When distributed training is required
if (parameters.isDistributed()) {
// Generated TF_CONFIG
String tfConfigEnv = YarnServiceUtils.getTFConfigEnv(
taskType.getComponentName(), parameters.getNumWorkers(),
parameters.getNumPS(), parameters.getName(), getUserName(),
getDNSDomain());
pw.append("export TF_CONFIG=\"" + tfConfigEnv + "\"\n");
}
// Print launch command
if (taskType.equals(TaskType.WORKER) || taskType.equals(
TaskType.PRIMARY_WORKER)) {
pw.append(parameters.getWorkerLaunchCmd() + '\n');
if (SubmarineLogs.isVerbose()) {
LOG.info(
"Worker command =[" + parameters.getWorkerLaunchCmd() + "]");
}
} else if (taskType.equals(TaskType.PS)) {
pw.append(parameters.getPSLaunchCmd() + '\n');
if (SubmarineLogs.isVerbose()) {
LOG.info("PS command =[" + parameters.getPSLaunchCmd() + "]");
}
}
}
} finally {
pw.close();
}
return file.getAbsolutePath();
}
private String getScriptFileName(TaskType taskType) {
return "run-" + taskType.name() + ".sh";
}
private File findFileOnClassPath(final String fileName) {
final String classpath = System.getProperty("java.class.path");
final String pathSeparator = System.getProperty("path.separator");
final StringTokenizer tokenizer = new StringTokenizer(classpath,
pathSeparator);
while (tokenizer.hasMoreTokens()) {
final String pathElement = tokenizer.nextToken();
final File directoryOrJar = new File(pathElement);
final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile();
if (absoluteDirectoryOrJar.isFile()) {
final File target = new File(absoluteDirectoryOrJar.getParent(),
fileName);
if (target.exists()) {
return target;
}
} else{
final File target = new File(directoryOrJar, fileName);
if (target.exists()) {
return target;
}
}
}
return null;
}
private void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir,
String fileToUpload, String destFilename, Component comp)
throws IOException {
Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload);
locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath);
}
private void locateRemoteFileToContainerWorkDir(String destFilename,
Component comp, Path uploadedFilePath)
throws IOException {
FileSystem fs = FileSystem.get(clientContext.getYarnConfig());
FileStatus fileStatus = fs.getFileStatus(uploadedFilePath);
LOG.info("Uploaded file path = " + fileStatus.getPath());
// Set it to component's files list
comp.getConfiguration().getFiles().add(new ConfigFile().srcFile(
fileStatus.getPath().toUri().toString()).destFile(destFilename)
.type(ConfigFile.TypeEnum.STATIC));
}
private Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws
IOException {
FileSystem fs = clientContext.getRemoteDirectoryManager()
.getDefaultFileSystem();
// Upload to remote FS under staging area
File localFile = new File(fileToUpload);
if (!localFile.exists()) {
throw new FileNotFoundException(
"Trying to upload file=" + localFile.getAbsolutePath()
+ " to remote, but couldn't find local file.");
}
String filename = new File(fileToUpload).getName();
Path uploadedFilePath = new Path(stagingDir, filename);
if (!uploadedFiles.contains(uploadedFilePath)) {
if (SubmarineLogs.isVerbose()) {
LOG.info("Copying local file=" + fileToUpload + " to remote="
+ uploadedFilePath);
}
fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath);
uploadedFiles.add(uploadedFilePath);
}
return uploadedFilePath;
}
private void setPermission(Path destPath, FsPermission permission) throws
IOException {
FileSystem fs = FileSystem.get(clientContext.getYarnConfig());
fs.setPermission(destPath, new FsPermission(permission));
}
private void handleLaunchCommand(RunJobParameters parameters,
TaskType taskType, Component component) throws IOException {
// Get staging area directory
Path stagingDir =
clientContext.getRemoteDirectoryManager().getJobStagingArea(
parameters.getName(), true);
// Generate script file in the local disk
String localScriptFile = generateCommandLaunchScript(parameters, taskType,
component);
String destScriptFileName = getScriptFileName(taskType);
uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, localScriptFile,
destScriptFileName, component);
component.setLaunchCommand("./" + destScriptFileName);
componentToLocalLaunchScriptPath.put(taskType.getComponentName(),
localScriptFile);
}
private String getLastNameFromPath(String srcFileStr) {
return new Path(srcFileStr).getName();
}
/**
* May download a remote uri(file/dir) and zip.
* Skip download if local dir
* Remote uri can be a local dir(won't download)
* or remote HDFS dir, s3 dir/file .etc
* */
private String mayDownloadAndZipIt(String remoteDir, String zipFileName,
boolean doZip)
throws IOException {
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
//Append original modification time and size to zip file name
String suffix;
String srcDir = remoteDir;
String zipDirPath =
System.getProperty("java.io.tmpdir") + "/" + zipFileName;
boolean needDeleteTempDir = false;
if (rdm.isRemote(remoteDir)) {
//Append original modification time and size to zip file name
FileStatus status = rdm.getRemoteFileStatus(new Path(remoteDir));
suffix = "_" + status.getModificationTime()
+ "-" + rdm.getRemoteFileSize(remoteDir);
// Download them to temp dir
boolean downloaded = rdm.copyRemoteToLocal(remoteDir, zipDirPath);
if (!downloaded) {
throw new IOException("Failed to download files from "
+ remoteDir);
}
LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath);
srcDir = zipDirPath;
needDeleteTempDir = true;
} else {
File localDir = new File(remoteDir);
suffix = "_" + localDir.lastModified()
+ "-" + localDir.length();
}
if (!doZip) {
return srcDir;
}
// zip a local dir
String zipFileUri = zipDir(srcDir, zipDirPath + suffix + ".zip");
// delete downloaded temp dir
if (needDeleteTempDir) {
deleteFiles(srcDir);
}
return zipFileUri;
}
@VisibleForTesting
public String zipDir(String srcDir, String dstFile) throws IOException {
FileOutputStream fos = new FileOutputStream(dstFile);
ZipOutputStream zos = new ZipOutputStream(fos);
File srcFile = new File(srcDir);
LOG.info("Compressing {}", srcDir);
addDirToZip(zos, srcFile, srcFile);
// close the ZipOutputStream
zos.close();
LOG.info("Compressed {} to {}", srcDir, dstFile);
return dstFile;
}
private void deleteFiles(String localUri) {
boolean success = FileUtil.fullyDelete(new File(localUri));
if (!success) {
LOG.warn("Fail to delete {}", localUri);
}
LOG.info("Deleted {}", localUri);
}
private void addDirToZip(ZipOutputStream zos, File srcFile, File base)
throws IOException {
File[] files = srcFile.listFiles();
if (null == files) {
return;
}
FileInputStream fis = null;
for (int i = 0; i < files.length; i++) {
// if it's directory, add recursively
if (files[i].isDirectory()) {
addDirToZip(zos, files[i], base);
continue;
}
byte[] buffer = new byte[1024];
try {
fis = new FileInputStream(files[i]);
String name = base.toURI().relativize(files[i].toURI()).getPath();
LOG.info(" Zip adding: " + name);
zos.putNextEntry(new ZipEntry(name));
int length;
while ((length = fis.read(buffer)) > 0) {
zos.write(buffer, 0, length);
}
zos.flush();
} finally {
if (fis != null) {
fis.close();
}
zos.closeEntry();
}
}
}
private void addWorkerComponent(Service service,
RunJobParameters parameters, TaskType taskType) throws IOException {
Component workerComponent = new Component();
addCommonEnvironments(workerComponent, taskType);
workerComponent.setName(taskType.getComponentName());
if (taskType.equals(TaskType.PRIMARY_WORKER)) {
workerComponent.setNumberOfContainers(1L);
workerComponent.getConfiguration().setProperty(
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
} else{
workerComponent.setNumberOfContainers(
(long) parameters.getNumWorkers() - 1);
}
if (parameters.getWorkerDockerImage() != null) {
workerComponent.setArtifact(
getDockerArtifact(parameters.getWorkerDockerImage()));
}
workerComponent.setResource(
getServiceResourceFromYarnResource(parameters.getWorkerResource()));
handleLaunchCommand(parameters, taskType, workerComponent);
workerComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
service.addComponent(workerComponent);
}
// Handle worker and primary_worker.
private void addWorkerComponents(Service service, RunJobParameters parameters)
throws IOException {
addWorkerComponent(service, parameters, TaskType.PRIMARY_WORKER);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(service, parameters, TaskType.WORKER);
}
}
private void appendToEnv(Service service, String key, String value,
String delim) {
Map<String, String> env = service.getConfiguration().getEnv();
if (!env.containsKey(key)) {
env.put(key, value);
} else {
if (!value.isEmpty()) {
String existingValue = env.get(key);
if (!existingValue.endsWith(delim)) {
env.put(key, existingValue + delim + value);
} else {
env.put(key, existingValue + value);
}
}
}
}
private void handleServiceEnvs(Service service, RunJobParameters parameters) {
if (parameters.getEnvars() != null) {
for (String envarPair : parameters.getEnvars()) {
String key, value;
if (envarPair.contains("=")) {
int idx = envarPair.indexOf('=');
key = envarPair.substring(0, idx);
value = envarPair.substring(idx + 1);
} else{
// No "=" found so use the whole key
key = envarPair;
value = "";
}
appendToEnv(service, key, value, ":");
}
}
// Append other configs like /etc/passwd, /etc/krb5.conf
appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
"/etc/passwd:/etc/passwd:ro", ",");
String authenication = clientContext.getYarnConfig().get(
HADOOP_SECURITY_AUTHENTICATION);
if (authenication != null && authenication.equals("kerberos")) {
appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
"/etc/krb5.conf:/etc/krb5.conf:ro", ",");
}
}
private Artifact getDockerArtifact(String dockerImageName) {
return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName);
}
private void handleQuicklinks(RunJobParameters runJobParameters)
throws IOException {
List<Quicklink> quicklinks = runJobParameters.getQuicklinks();
if (null != quicklinks && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol() + YarnServiceUtils.getDNSName(
serviceSpec.getName(), instanceName, getUserName(), getDNSDomain(),
ql.getPort());
YarnServiceUtils.addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
private Service createServiceByParameters(RunJobParameters parameters)
throws IOException {
componentToLocalLaunchScriptPath.clear();
serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
handleKerberosPrincipal(parameters);
handleServiceEnvs(serviceSpec, parameters);
handleLocalizations(parameters);
if (parameters.getNumWorkers() > 0) {
addWorkerComponents(serviceSpec, parameters);
}
if (parameters.getNumPS() > 0) {
Component psComponent = new Component();
psComponent.setName(TaskType.PS.getComponentName());
addCommonEnvironments(psComponent, TaskType.PS);
psComponent.setNumberOfContainers((long) parameters.getNumPS());
psComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
psComponent.setResource(
getServiceResourceFromYarnResource(parameters.getPsResource()));
// Override global docker image if needed.
if (parameters.getPsDockerImage() != null) {
psComponent.setArtifact(
getDockerArtifact(parameters.getPsDockerImage()));
}
handleLaunchCommand(parameters, TaskType.PS, psComponent);
serviceSpec.addComponent(psComponent);
}
if (parameters.isTensorboardEnabled()) {
Component tbComponent = new Component();
tbComponent.setName(TaskType.TENSORBOARD.getComponentName());
addCommonEnvironments(tbComponent, TaskType.TENSORBOARD);
tbComponent.setNumberOfContainers(1L);
tbComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
tbComponent.setResource(getServiceResourceFromYarnResource(
parameters.getTensorboardResource()));
if (parameters.getTensorboardDockerImage() != null) {
tbComponent.setArtifact(
getDockerArtifact(parameters.getTensorboardDockerImage()));
}
handleLaunchCommand(parameters, TaskType.TENSORBOARD, tbComponent);
// Add tensorboard to quicklink
String tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
parameters.getName(),
TaskType.TENSORBOARD.getComponentName() + "-" + 0, getUserName(),
getDNSDomain(), 6006);
LOG.info("Link to tensorboard:" + tensorboardLink);
serviceSpec.addComponent(tbComponent);
YarnServiceUtils.addQuicklink(serviceSpec, TENSORBOARD_QUICKLINK_LABEL,
tensorboardLink);
}
// After all components added, handle quicklinks
handleQuicklinks(parameters);
return serviceSpec;
}
/**
* Localize dependencies for all containers.
* If remoteUri is a local directory,
* we'll zip it, upload to HDFS staging dir HDFS.
* If remoteUri is directory, we'll download it, zip it and upload
* to HDFS.
* If localFilePath is ".", we'll use remoteUri's file/dir name
* */
private void handleLocalizations(RunJobParameters parameters)
throws IOException {
// Handle localizations
Path stagingDir =
clientContext.getRemoteDirectoryManager().getJobStagingArea(
parameters.getName(), true);
List<Localization> locs = parameters.getLocalizations();
String remoteUri;
String containerLocalPath;
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
// Check to fail fast
for (Localization loc : locs) {
remoteUri = loc.getRemoteUri();
Path resourceToLocalize = new Path(remoteUri);
// Check if remoteUri exists
if (rdm.isRemote(remoteUri)) {
// check if exists
if (!rdm.existsRemoteFile(resourceToLocalize)) {
throw new FileNotFoundException(
"File " + remoteUri + " doesn't exists.");
}
} else {
// Check if exists
File localFile = new File(remoteUri);
if (!localFile.exists()) {
throw new FileNotFoundException(
"File " + remoteUri + " doesn't exists.");
}
}
// check remote file size
validFileSize(remoteUri);
}
// Start download remote if needed and upload to HDFS
for (Localization loc : locs) {
remoteUri = loc.getRemoteUri();
containerLocalPath = loc.getLocalPath();
String srcFileStr = remoteUri;
ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC;
Path resourceToLocalize = new Path(remoteUri);
boolean needUploadToHDFS = true;
/**
* Special handling for remoteUri directory.
* */
boolean needDeleteTempFile = false;
if (rdm.isDir(remoteUri)) {
destFileType = ConfigFile.TypeEnum.ARCHIVE;
srcFileStr = mayDownloadAndZipIt(
remoteUri, getLastNameFromPath(srcFileStr), true);
} else if (rdm.isRemote(remoteUri)) {
if (!needHdfs(remoteUri)) {
// Non HDFS remote uri. Non directory, no need to zip
srcFileStr = mayDownloadAndZipIt(
remoteUri, getLastNameFromPath(srcFileStr), false);
needDeleteTempFile = true;
} else {
// HDFS file, no need to upload
needUploadToHDFS = false;
}
}
// Upload file to HDFS
if (needUploadToHDFS) {
resourceToLocalize = uploadToRemoteFile(stagingDir, srcFileStr);
}
if (needDeleteTempFile) {
deleteFiles(srcFileStr);
}
// Remove .zip from zipped dir name
if (destFileType == ConfigFile.TypeEnum.ARCHIVE
&& srcFileStr.endsWith(".zip")) {
// Delete local zip file
deleteFiles(srcFileStr);
int suffixIndex = srcFileStr.lastIndexOf('_');
srcFileStr = srcFileStr.substring(0, suffixIndex);
}
// If provided, use the name of local uri
if (!containerLocalPath.equals(".")
&& !containerLocalPath.equals("./")) {
// Change the YARN localized file name to what'll used in container
srcFileStr = getLastNameFromPath(containerLocalPath);
}
String localizedName = getLastNameFromPath(srcFileStr);
LOG.info("The file/dir to be localized is {}",
resourceToLocalize.toString());
LOG.info("Its localized file name will be {}", localizedName);
serviceSpec.getConfiguration().getFiles().add(new ConfigFile().srcFile(
resourceToLocalize.toUri().toString()).destFile(localizedName)
.type(destFileType));
// set mounts
// if mount path is absolute, just use it.
// if relative, no need to mount explicitly
if (containerLocalPath.startsWith("/")) {
String mountStr = getLastNameFromPath(srcFileStr) + ":"
+ containerLocalPath + ":" + loc.getMountPermission();
LOG.info("Add bind-mount string {}", mountStr);
appendToEnv(serviceSpec, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS",
mountStr, ",");
}
}
}
private void validFileSize(String uri) throws IOException {
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
long actualSizeByte;
String locationType = "Local";
if (rdm.isRemote(uri)) {
actualSizeByte = clientContext.getRemoteDirectoryManager()
.getRemoteFileSize(uri);
locationType = "Remote";
} else {
actualSizeByte = FileUtil.getDU(new File(uri));
}
long maxFileSizeMB = clientContext.getSubmarineConfig()
.getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB);
LOG.info("{} fie/dir: {}, size(Byte):{},"
+ " Allowed max file/dir size: {}",
locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024);
if (actualSizeByte > maxFileSizeMB * 1024 * 1024) {
throw new IOException(uri + " size(Byte): "
+ actualSizeByte + " exceeds configured max size:"
+ maxFileSizeMB * 1024 * 1024);
}
}
private String generateServiceSpecFile(Service service) throws IOException {
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
String buffer = jsonSerDeser.toJson(service);
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
"UTF-8");
PrintWriter pw = new PrintWriter(w);
try {
pw.append(buffer);
} finally {
pw.close();
}
return serviceSpecFile.getAbsolutePath();
}
private void handleKerberosPrincipal(RunJobParameters parameters) throws
IOException {
if(StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils
.isNotBlank(parameters.getPrincipal())) {
String keytab = parameters.getKeytab();
String principal = parameters.getPrincipal();
if(parameters.isDistributeKeytab()) {
Path stagingDir =
clientContext.getRemoteDirectoryManager().getJobStagingArea(
parameters.getName(), true);
Path remoteKeytabPath = uploadToRemoteFile(stagingDir, keytab);
//only the owner has read access
setPermission(remoteKeytabPath,
FsPermission.createImmutable((short)Integer.parseInt("400", 8)));
serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab(
remoteKeytabPath.toString()).principalName(principal));
} else {
if(!keytab.startsWith("file")) {
keytab = "file://" + keytab;
}
serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab(
keytab).principalName(principal));
}
}
}
/**
* {@inheritDoc}
*/
@Override
public ApplicationId submitJob(RunJobParameters parameters)
throws IOException, YarnException {
createServiceByParameters(parameters);
String serviceSpecFile = generateServiceSpecFile(serviceSpec);
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(clientContext, fsOperations);
AppAdminClient appAdminClient = YarnServiceUtils.createServiceClient(
clientContext.getYarnConfig());
Service serviceSpec = createTensorFlowServiceSpec(parameters,
fsOperations, hadoopEnvSetup);
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
AppAdminClient appAdminClient =
YarnServiceUtils.createServiceClient(clientContext.getYarnConfig());
int code = appAdminClient.actionLaunch(serviceSpecFile,
serviceSpec.getName(), null, null);
if(code != EXIT_SUCCESS) {
throw new YarnException("Fail to launch application with exit code:" +
code);
if (code != EXIT_SUCCESS) {
throw new YarnException(
"Fail to launch application with exit code:" + code);
}
String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
@ -896,13 +97,24 @@ public ApplicationId submitJob(RunJobParameters parameters)
return appid;
}
@VisibleForTesting
public Service getServiceSpec() {
return serviceSpec;
private Service createTensorFlowServiceSpec(RunJobParameters parameters,
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
throws IOException {
LaunchCommandFactory launchCommandFactory =
new LaunchCommandFactory(hadoopEnvSetup, parameters,
clientContext.getYarnConfig());
Localizer localizer = new Localizer(fsOperations,
clientContext.getRemoteDirectoryManager(), parameters);
TensorFlowServiceSpec tensorFlowServiceSpec = new TensorFlowServiceSpec(
parameters, this.clientContext, fsOperations, launchCommandFactory,
localizer);
serviceWrapper = tensorFlowServiceSpec.create();
return serviceWrapper.getService();
}
@VisibleForTesting
public Map<String, String> getComponentToLocalLaunchScriptPath() {
return componentToLocalLaunchScriptPath;
public ServiceWrapper getServiceWrapper() {
return serviceWrapper;
}
}

View File

@ -17,33 +17,27 @@
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import static org.apache.hadoop.yarn.client.api.AppAdminClient.DEFAULT_TYPE;
public class YarnServiceUtils {
private static final Logger LOG =
LoggerFactory.getLogger(YarnServiceUtils.class);
/**
* This class contains some static helper methods to query DNS data
* based on the provided parameters.
*/
public final class YarnServiceUtils {
private YarnServiceUtils() {
}
// This will be true only in UT.
private static AppAdminClient stubServiceClient = null;
public static AppAdminClient createServiceClient(
static AppAdminClient createServiceClient(
Configuration yarnConfiguration) {
if (stubServiceClient != null) {
return stubServiceClient;
}
AppAdminClient serviceClient = AppAdminClient.createAppAdminClient(
DEFAULT_TYPE, yarnConfiguration);
return serviceClient;
return AppAdminClient.createAppAdminClient(DEFAULT_TYPE, yarnConfiguration);
}
@VisibleForTesting
@ -57,77 +51,9 @@ public static String getDNSName(String serviceName,
domain, port);
}
private static String getDNSNameCommonSuffix(String serviceName,
public static String getDNSNameCommonSuffix(String serviceName,
String userName, String domain, int port) {
return "." + serviceName + "." + userName + "." + domain + ":" + port;
}
public static String getTFConfigEnv(String curCommponentName, int nWorkers,
int nPs, String serviceName, String userName, String domain) {
String commonEndpointSuffix = getDNSNameCommonSuffix(serviceName, userName,
domain, 8000);
String json = "{\\\"cluster\\\":{";
String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
+ ",";
String worker = getComponentArrayJson("worker", nWorkers - 1,
commonEndpointSuffix) + ",";
String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
StringBuilder sb = new StringBuilder();
sb.append("\\\"task\\\":{");
sb.append(" \\\"type\\\":\\\"");
sb.append(curCommponentName);
sb.append("\\\",");
sb.append(" \\\"index\\\":");
sb.append('$');
sb.append(Envs.TASK_INDEX_ENV + "},");
String task = sb.toString();
String environment = "\\\"environment\\\":\\\"cloud\\\"}";
sb = new StringBuilder();
sb.append(json);
sb.append(master);
sb.append(worker);
sb.append(ps);
sb.append(task);
sb.append(environment);
return sb.toString();
}
public static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (null == quicklinks) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
private static String getComponentArrayJson(String componentName, int count,
String endpointSuffix) {
String component = "\\\"" + componentName + "\\\":";
StringBuilder array = new StringBuilder();
array.append("[");
for (int i = 0; i < count; i++) {
array.append("\\\"");
array.append(componentName);
array.append("-");
array.append(i);
array.append(endpointSuffix);
array.append("\\\"");
if (i != count - 1) {
array.append(",");
}
}
array.append("]");
return component + array.toString();
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import java.io.IOException;
import java.util.Objects;
/**
* Abstract base class for Launch command implementations for Services.
* Currently we have launch command implementations
* for TensorFlow PS, worker and Tensorboard instances.
*/
public abstract class AbstractLaunchCommand {
private final LaunchScriptBuilder builder;
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters)
throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!");
this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
parameters, component);
}
protected LaunchScriptBuilder getBuilder() {
return builder;
}
/**
* Subclasses need to defined this method and return a valid launch script.
* Implementors can utilize the {@link LaunchScriptBuilder} using
* the getBuilder method of this class.
* @return The contents of a script.
* @throws IOException If any IO issue happens.
*/
public abstract String generateLaunchScript() throws IOException;
/**
* Subclasses need to provide a service-specific launch command
* of the service.
* Please note that this method should only return the launch command
* but not the whole script.
* @return The launch command
*/
public abstract String createLaunchCommand();
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import java.io.IOException;
import java.util.Objects;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link TaskType}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/
public class LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final RunJobParameters parameters;
private final Configuration yarnConfig;
public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
RunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
Component component) throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!");
if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
component, parameters, yarnConfig);
} else if (taskType == TaskType.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
parameters, yarnConfig);
} else if (taskType == TaskType.TENSORBOARD) {
return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
parameters);
}
throw new IllegalStateException("Unknown task type: " + taskType);
}
}

View File

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* This class is a builder to conveniently create launch scripts.
* All dependencies are provided with the constructor except
* the launch command.
*/
public class LaunchScriptBuilder {
private static final Logger LOG = LoggerFactory.getLogger(
LaunchScriptBuilder.class);
private final File file;
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final RunJobParameters parameters;
private final Component component;
private final OutputStreamWriter writer;
private final StringBuilder scriptBuffer;
private String launchCommand;
LaunchScriptBuilder(String namePrefix,
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
Component component) throws IOException {
this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.component = component;
this.writer = new OutputStreamWriter(new FileOutputStream(file), UTF_8);
this.scriptBuffer = new StringBuilder();
}
public void append(String s) {
scriptBuffer.append(s);
}
public LaunchScriptBuilder withLaunchCommand(String command) {
this.launchCommand = command;
return this;
}
public String build() throws IOException {
if (launchCommand != null) {
append(launchCommand);
} else {
LOG.warn("LaunchScript object was null!");
if (LOG.isDebugEnabled()) {
LOG.debug("LaunchScript's Builder object: {}", this);
}
}
try (PrintWriter pw = new PrintWriter(writer)) {
writeBashHeader(pw);
hadoopEnvSetup.addHdfsClassPath(parameters, pw, component);
if (LOG.isDebugEnabled()) {
LOG.debug("Appending command to launch script: {}", scriptBuffer);
}
pw.append(scriptBuffer);
}
return file.getAbsolutePath();
}
@Override
public String toString() {
return "LaunchScriptBuilder{" +
"file=" + file +
", hadoopEnvSetup=" + hadoopEnvSetup +
", parameters=" + parameters +
", component=" + component +
", writer=" + writer +
", scriptBuffer=" + scriptBuffer +
", launchCommand='" + launchCommand + '\'' +
'}';
}
private void writeBashHeader(PrintWriter pw) {
pw.append("#!/bin/bash\n");
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to produce launch commands and scripts.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;

View File

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import java.util.Map;
/**
* This class has common helper methods for TensorFlow.
*/
public final class TensorFlowCommons {
private TensorFlowCommons() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
public static void addCommonEnvironments(Component component,
TaskType taskType) {
Map<String, String> envs = component.getConfiguration().getEnv();
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
envs.put(Envs.TASK_TYPE_ENV, taskType.name());
}
public static String getUserName() {
return System.getProperty("user.name");
}
public static String getDNSDomain(Configuration yarnConfig) {
return yarnConfig.get("hadoop.registry.dns.domain-name");
}
public static String getScriptFileName(TaskType taskType) {
return "run-" + taskType.name() + ".sh";
}
public static String getTFConfigEnv(String componentName, int nWorkers,
int nPs, String serviceName, String userName, String domain) {
String commonEndpointSuffix = YarnServiceUtils
.getDNSNameCommonSuffix(serviceName, userName, domain, 8000);
String json = "{\\\"cluster\\\":{";
String master = getComponentArrayJson("master", 1, commonEndpointSuffix)
+ ",";
String worker = getComponentArrayJson("worker", nWorkers - 1,
commonEndpointSuffix) + ",";
String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
StringBuilder sb = new StringBuilder();
sb.append("\\\"task\\\":{");
sb.append(" \\\"type\\\":\\\"");
sb.append(componentName);
sb.append("\\\",");
sb.append(" \\\"index\\\":");
sb.append('$');
sb.append(Envs.TASK_INDEX_ENV + "},");
String task = sb.toString();
String environment = "\\\"environment\\\":\\\"cloud\\\"}";
sb = new StringBuilder();
sb.append(json);
sb.append(master);
sb.append(worker);
sb.append(ps);
sb.append(task);
sb.append(environment);
return sb.toString();
}
private static String getComponentArrayJson(String componentName, int count,
String endpointSuffix) {
String component = "\\\"" + componentName + "\\\":";
StringBuilder array = new StringBuilder();
array.append("[");
for (int i = 0; i < count; i++) {
array.append("\\\"");
array.append(componentName);
array.append("-");
array.append(i);
array.append(endpointSuffix);
array.append("\\\"");
if (i != count - 1) {
array.append(",");
}
}
array.append("]");
return component + array.toString();
}
}

View File

@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
/**
* This class contains all the logic to create an instance
* of a {@link Service} object for TensorFlow.
* Worker,PS and Tensorboard components are added to the Service
* based on the value of the received {@link RunJobParameters}.
*/
public class TensorFlowServiceSpec implements ServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowServiceSpec.class);
private final RemoteDirectoryManager remoteDirectoryManager;
private final RunJobParameters parameters;
private final Configuration yarnConfig;
private final FileSystemOperations fsOperations;
private final LaunchCommandFactory launchCommandFactory;
private final Localizer localizer;
public TensorFlowServiceSpec(RunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
LaunchCommandFactory launchCommandFactory, Localizer localizer) {
this.parameters = parameters;
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.yarnConfig = clientContext.getYarnConfig();
this.fsOperations = fsOperations;
this.launchCommandFactory = launchCommandFactory;
this.localizer = localizer;
}
@Override
public ServiceWrapper create() throws IOException {
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
if (parameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper);
}
if (parameters.getNumPS() > 0) {
addPsComponent(serviceWrapper);
}
if (parameters.isTensorboardEnabled()) {
createTensorBoardComponent(serviceWrapper);
}
// After all components added, handle quicklinks
handleQuicklinks(serviceWrapper.getService());
return serviceWrapper;
}
private ServiceWrapper createServiceSpecWrapper() throws IOException {
Service serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
.create(fsOperations, remoteDirectoryManager, parameters);
if (kerberosPrincipal != null) {
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
}
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
localizer.handleLocalizations(serviceSpec);
return new ServiceWrapper(serviceSpec);
}
private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
throws IOException {
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
serviceWrapper.addComponent(tbComponent);
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
tbComponent.getTensorboardLink());
}
private static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (quicklinks == null) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
private void handleQuicklinks(Service serviceSpec)
throws IOException {
List<Quicklink> quicklinks = parameters.getQuicklinks();
if (quicklinks != null && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol()
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
// Handle worker and primary_worker.
private void addWorkerComponents(ServiceWrapper serviceWrapper)
throws IOException {
addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
}
}
private void addWorkerComponent(ServiceWrapper serviceWrapper,
RunJobParameters parameters, TaskType taskType) throws IOException {
serviceWrapper.addComponent(
new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
parameters, taskType, launchCommandFactory, yarnConfig));
}
private void addPsComponent(ServiceWrapper serviceWrapper)
throws IOException {
serviceWrapper.addComponent(
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
launchCommandFactory, parameters, yarnConfig));
}
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Objects;
/**
* Launch command implementation for Tensorboard.
*/
public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(TensorBoardLaunchCommand.class);
private final String checkpointPath;
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters)
throws IOException {
super(hadoopEnvSetup, taskType, component, parameters);
Objects.requireNonNull(parameters.getCheckpointPath(),
"CheckpointPath must not be null as it is part "
+ "of the tensorboard command!");
if (StringUtils.isEmpty(parameters.getCheckpointPath())) {
throw new IllegalArgumentException("CheckpointPath must not be empty!");
}
this.checkpointPath = parameters.getCheckpointPath();
}
@Override
public String generateLaunchScript() throws IOException {
return getBuilder()
.withLaunchCommand(createLaunchCommand())
.build();
}
@Override
public String createLaunchCommand() {
String tbCommand = String.format("export LC_ALL=C && tensorboard " +
"--logdir=%s%n", checkpointPath);
LOG.info("Tensorboard command=" + tbCommand);
return tbCommand;
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* Launch command implementation for
* TensorFlow PS and Worker Service components.
*/
public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowLaunchCommand.class);
private final Configuration yarnConfig;
private final boolean distributed;
private final int numberOfWorkers;
private final int numberOfPS;
private final String name;
private final TaskType taskType;
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters);
this.taskType = taskType;
this.name = parameters.getName();
this.distributed = parameters.isDistributed();
this.numberOfWorkers = parameters.getNumWorkers();
this.numberOfPS = parameters.getNumPS();
this.yarnConfig = yarnConfig;
logReceivedParameters();
}
private void logReceivedParameters() {
if (this.numberOfWorkers <= 0) {
LOG.warn("Received number of workers: {}", this.numberOfWorkers);
}
if (this.numberOfPS <= 0) {
LOG.warn("Received number of PS: {}", this.numberOfPS);
}
}
@Override
public String generateLaunchScript() throws IOException {
LaunchScriptBuilder builder = getBuilder();
// When distributed training is required
if (distributed) {
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
taskType.getComponentName(), numberOfWorkers,
numberOfPS, name,
TensorFlowCommons.getUserName(),
TensorFlowCommons.getDNSDomain(yarnConfig));
String tfConfig = "export TF_CONFIG=\"" + tfConfigEnvValue + "\"\n";
builder.append(tfConfig);
}
return builder
.withLaunchCommand(createLaunchCommand())
.build();
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* Launch command implementation for Tensorboard's PS component.
*/
public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowPsLaunchCommand.class);
private final String launchCommand;
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
this.launchCommand = parameters.getPSLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {
throw new IllegalArgumentException("LaunchCommand must not be null " +
"or empty!");
}
}
@Override
public String createLaunchCommand() {
if (SubmarineLogs.isVerbose()) {
LOG.info("PS command =[" + launchCommand + "]");
}
return launchCommand + '\n';
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* Launch command implementation for Tensorboard's Worker component.
*/
public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowWorkerLaunchCommand.class);
private final String launchCommand;
public TensorFlowWorkerLaunchCommand(
HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
Component component, RunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
this.launchCommand = parameters.getWorkerLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {
throw new IllegalArgumentException("LaunchCommand must not be null " +
"or empty!");
}
}
@Override
public String createLaunchCommand() {
if (SubmarineLogs.isVerbose()) {
LOG.info("Worker command =[" + launchCommand + "]");
}
return launchCommand + '\n';
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate TensorFlow launch commands.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;

View File

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Objects;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
/**
* Component implementation for Tensorboard's Tensorboard.
*/
public class TensorBoardComponent extends AbstractComponent {
private static final Logger LOG =
LoggerFactory.getLogger(TensorBoardComponent.class);
public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard";
private static final int DEFAULT_PORT = 6006;
//computed fields
private String tensorboardLink;
public TensorBoardComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters,
LaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters,
TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
Objects.requireNonNull(parameters.getTensorboardResource(),
"TensorBoard resource must not be null!");
Component component = new Component();
component.setName(taskType.getComponentName());
component.setNumberOfContainers(1L);
component.setRestartPolicy(RestartPolicyEnum.NEVER);
component.setResource(convertYarnResourceToServiceResource(
parameters.getTensorboardResource()));
if (parameters.getTensorboardDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getTensorboardDockerImage()));
}
addCommonEnvironments(component, taskType);
generateLaunchCommand(component);
tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
parameters.getName(),
taskType.getComponentName() + "-" + 0, getUserName(),
getDNSDomain(yarnConfig), DEFAULT_PORT);
LOG.info("Link to tensorboard:" + tensorboardLink);
return component;
}
public String getTensorboardLink() {
return tensorboardLink;
}
}

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import java.io.IOException;
import java.util.Objects;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
/**
* Component implementation for TensorFlow's PS process.
*/
public class TensorFlowPsComponent extends AbstractComponent {
public TensorFlowPsComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
LaunchCommandFactory launchCommandFactory,
RunJobParameters parameters,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters, TaskType.PS,
yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
Objects.requireNonNull(parameters.getPsResource(),
"PS resource must not be null!");
if (parameters.getNumPS() < 1) {
throw new IllegalArgumentException("Number of PS should be at least 1!");
}
Component component = new Component();
component.setName(taskType.getComponentName());
component.setNumberOfContainers((long) parameters.getNumPS());
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
component.setResource(
convertYarnResourceToServiceResource(parameters.getPsResource()));
// Override global docker image if needed.
if (parameters.getPsDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getPsDockerImage()));
}
addCommonEnvironments(component, taskType);
generateLaunchCommand(component);
return component;
}
}

View File

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import java.io.IOException;
import java.util.Objects;
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
/**
* Component implementation for TensorFlow's Worker process.
*/
public class TensorFlowWorkerComponent extends AbstractComponent {
public TensorFlowWorkerComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, TaskType taskType,
LaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters, taskType,
yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
Objects.requireNonNull(parameters.getWorkerResource(),
"Worker resource must not be null!");
if (parameters.getNumWorkers() < 1) {
throw new IllegalArgumentException(
"Number of workers should be at least 1!");
}
Component component = new Component();
component.setName(taskType.getComponentName());
if (taskType.equals(TaskType.PRIMARY_WORKER)) {
component.setNumberOfContainers(1L);
component.getConfiguration().setProperty(
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
} else {
component.setNumberOfContainers(
(long) parameters.getNumWorkers() - 1);
}
if (parameters.getWorkerDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getWorkerDockerImage()));
}
component.setResource(convertYarnResourceToServiceResource(
parameters.getWorkerResource()));
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
addCommonEnvironments(component, taskType);
generateLaunchCommand(component);
return component;
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* TensorFlow Native Service components.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* TensorFlow-related Native Service runtime artifacts.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;

View File

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import java.io.File;
import java.util.StringTokenizer;
/**
* Utilities for classpath operations.
*/
public final class ClassPathUtilities {
private ClassPathUtilities() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
public static File findFileOnClassPath(final String fileName) {
final String classpath = System.getProperty("java.class.path");
final String pathSeparator = System.getProperty("path.separator");
final StringTokenizer tokenizer = new StringTokenizer(classpath,
pathSeparator);
while (tokenizer.hasMoreTokens()) {
final String pathElement = tokenizer.nextToken();
final File directoryOrJar = new File(pathElement);
final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile();
if (absoluteDirectoryOrJar.isFile()) {
final File target =
new File(absoluteDirectoryOrJar.getParent(), fileName);
if (target.exists()) {
return target;
}
} else {
final File target = new File(directoryOrJar, fileName);
if (target.exists()) {
return target;
}
}
}
return null;
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.yarn.service.api.records.Artifact;
/**
* Utilities for Docker-related operations.
*/
public final class DockerUtilities {
private DockerUtilities() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
public static Artifact getDockerArtifact(String dockerImageName) {
return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName);
}
}

View File

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
/**
* Utilities for environment variable related operations
* for {@link Service} objects.
*/
public final class EnvironmentUtilities {
private EnvironmentUtilities() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
private static final Logger LOG =
LoggerFactory.getLogger(EnvironmentUtilities.class);
static final String ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME =
"YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS";
private static final String MOUNTS_DELIM = ",";
private static final String ENV_SEPARATOR = "=";
private static final String ETC_PASSWD_MOUNT_STRING =
"/etc/passwd:/etc/passwd:ro";
private static final String KERBEROS_CONF_MOUNT_STRING =
"/etc/krb5.conf:/etc/krb5.conf:ro";
private static final String ENV_VAR_DELIM = ":";
/**
* Extracts value from a string representation of an environment variable.
* @param envVar The environment variable in 'key=value' format.
* @return The value of the environment variable
*/
public static String getValueOfEnvironment(String envVar) {
if (envVar == null || !envVar.contains(ENV_SEPARATOR)) {
return "";
} else {
return envVar.substring(envVar.indexOf(ENV_SEPARATOR) + 1);
}
}
public static void handleServiceEnvs(Service service,
Configuration yarnConfig, List<String> envVars) {
if (envVars != null) {
for (String envVarPair : envVars) {
String key, value;
if (envVarPair.contains(ENV_SEPARATOR)) {
int idx = envVarPair.indexOf(ENV_SEPARATOR);
key = envVarPair.substring(0, idx);
value = envVarPair.substring(idx + 1);
} else {
LOG.warn("Found environment variable with unusual format: '{}'",
envVarPair);
// No "=" found so use the whole key
key = envVarPair;
value = "";
}
appendToEnv(service, key, value, ENV_VAR_DELIM);
}
}
appendOtherConfigs(service, yarnConfig);
}
/**
* Appends other configs like /etc/passwd, /etc/krb5.conf.
* @param service
* @param yarnConfig
*/
private static void appendOtherConfigs(Service service,
Configuration yarnConfig) {
appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
ETC_PASSWD_MOUNT_STRING, MOUNTS_DELIM);
String authentication = yarnConfig.get(HADOOP_SECURITY_AUTHENTICATION);
if (authentication != null && authentication.equals("kerberos")) {
appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
KERBEROS_CONF_MOUNT_STRING, MOUNTS_DELIM);
}
}
static void appendToEnv(Service service, String key, String value,
String delim) {
Map<String, String> env = service.getConfiguration().getEnv();
if (!env.containsKey(key)) {
env.put(key, value);
} else {
if (!value.isEmpty()) {
String existingValue = env.get(key);
if (!existingValue.endsWith(delim)) {
env.put(key, existingValue + delim + value);
} else {
env.put(key, existingValue + value);
}
}
}
}
}

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Objects;
/**
* Simple factory that creates a {@link KerberosPrincipal}.
*/
public final class KerberosPrincipalFactory {
private KerberosPrincipalFactory() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
private static final Logger LOG =
LoggerFactory.getLogger(KerberosPrincipalFactory.class);
public static KerberosPrincipal create(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters) throws IOException {
Objects.requireNonNull(fsOperations,
"FileSystemOperations must not be null!");
Objects.requireNonNull(remoteDirectoryManager,
"RemoteDirectoryManager must not be null!");
Objects.requireNonNull(parameters, "Parameters must not be null!");
if (StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils
.isNotBlank(parameters.getPrincipal())) {
String keytab = parameters.getKeytab();
String principal = parameters.getPrincipal();
if (parameters.isDistributeKeytab()) {
return handleDistributedKeytab(fsOperations, remoteDirectoryManager,
parameters, keytab, principal);
} else {
return handleNormalKeytab(keytab, principal);
}
}
LOG.debug("Principal and keytab was null or empty, " +
"returning null KerberosPrincipal!");
return null;
}
private static KerberosPrincipal handleDistributedKeytab(
FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, String keytab, String principal)
throws IOException {
Path stagingDir = remoteDirectoryManager
.getJobStagingArea(parameters.getName(), true);
Path remoteKeytabPath =
fsOperations.uploadToRemoteFile(stagingDir, keytab);
// Only the owner has read access
fsOperations.setPermission(remoteKeytabPath,
FsPermission.createImmutable((short)Integer.parseInt("400", 8)));
return new KerberosPrincipal()
.keytab(remoteKeytabPath.toString())
.principalName(principal);
}
private static KerberosPrincipal handleNormalKeytab(String keytab,
String principal) {
if(!keytab.startsWith("file")) {
keytab = "file://" + keytab;
}
return new KerberosPrincipal()
.keytab(keytab)
.principalName(principal);
}
}

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.ConfigFile;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.appendToEnv;
/**
* This class holds all dependencies in order to localize dependencies
* for containers.
*/
public class Localizer {
private static final Logger LOG = LoggerFactory.getLogger(Localizer.class);
private final FileSystemOperations fsOperations;
private final RemoteDirectoryManager remoteDirectoryManager;
private final RunJobParameters parameters;
public Localizer(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters) {
this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters;
}
/**
* Localize dependencies for all containers.
* If remoteUri is a local directory,
* we'll zip it, upload to HDFS staging dir HDFS.
* If remoteUri is directory, we'll download it, zip it and upload
* to HDFS.
* If localFilePath is ".", we'll use remoteUri's file/dir name
* */
public void handleLocalizations(Service service)
throws IOException {
// Handle localizations
Path stagingDir =
remoteDirectoryManager.getJobStagingArea(
parameters.getName(), true);
List<Localization> localizations = parameters.getLocalizations();
String remoteUri;
String containerLocalPath;
// Check to fail fast
for (Localization loc : localizations) {
remoteUri = loc.getRemoteUri();
Path resourceToLocalize = new Path(remoteUri);
// Check if remoteUri exists
if (remoteDirectoryManager.isRemote(remoteUri)) {
// check if exists
if (!remoteDirectoryManager.existsRemoteFile(resourceToLocalize)) {
throw new FileNotFoundException(
"File " + remoteUri + " doesn't exists.");
}
} else {
// Check if exists
File localFile = new File(remoteUri);
if (!localFile.exists()) {
throw new FileNotFoundException(
"File " + remoteUri + " doesn't exists.");
}
}
// check remote file size
fsOperations.validFileSize(remoteUri);
}
// Start download remote if needed and upload to HDFS
for (Localization loc : localizations) {
remoteUri = loc.getRemoteUri();
containerLocalPath = loc.getLocalPath();
String srcFileStr = remoteUri;
ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC;
Path resourceToLocalize = new Path(remoteUri);
boolean needUploadToHDFS = true;
// Special handling of remoteUri directory
boolean needDeleteTempFile = false;
if (remoteDirectoryManager.isDir(remoteUri)) {
destFileType = ConfigFile.TypeEnum.ARCHIVE;
srcFileStr = fsOperations.downloadAndZip(
remoteUri, getLastNameFromPath(srcFileStr), true);
} else if (remoteDirectoryManager.isRemote(remoteUri)) {
if (!needHdfs(remoteUri)) {
// Non HDFS remote uri. Non directory, no need to zip
srcFileStr = fsOperations.downloadAndZip(
remoteUri, getLastNameFromPath(srcFileStr), false);
needDeleteTempFile = true;
} else {
// HDFS file, no need to upload
needUploadToHDFS = false;
}
}
// Upload file to HDFS
if (needUploadToHDFS) {
resourceToLocalize =
fsOperations.uploadToRemoteFile(stagingDir, srcFileStr);
}
if (needDeleteTempFile) {
fsOperations.deleteFiles(srcFileStr);
}
// Remove .zip from zipped dir name
if (destFileType == ConfigFile.TypeEnum.ARCHIVE
&& srcFileStr.endsWith(".zip")) {
// Delete local zip file
fsOperations.deleteFiles(srcFileStr);
int suffixIndex = srcFileStr.lastIndexOf('_');
srcFileStr = srcFileStr.substring(0, suffixIndex);
}
// If provided, use the name of local uri
if (!containerLocalPath.equals(".")
&& !containerLocalPath.equals("./")) {
// Change the YARN localized file name to what'll used in container
srcFileStr = getLastNameFromPath(containerLocalPath);
}
String localizedName = getLastNameFromPath(srcFileStr);
LOG.info("The file/dir to be localized is {}",
resourceToLocalize.toString());
LOG.info("Its localized file name will be {}", localizedName);
service.getConfiguration().getFiles().add(new ConfigFile().srcFile(
resourceToLocalize.toUri().toString()).destFile(localizedName)
.type(destFileType));
// set mounts
// if mount path is absolute, just use it.
// if relative, no need to mount explicitly
if (containerLocalPath.startsWith("/")) {
String mountStr = getLastNameFromPath(srcFileStr) + ":"
+ containerLocalPath + ":" + loc.getMountPermission();
LOG.info("Add bind-mount string {}", mountStr);
appendToEnv(service,
EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME,
mountStr, ",");
}
}
}
private String getLastNameFromPath(String srcFileStr) {
return new Path(srcFileStr).getName();
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.yarn.service.api.records.Resource;
import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
import java.util.HashMap;
import java.util.Map;
/**
* Resource utilities for Submarine.
*/
public final class SubmarineResourceUtils {
private SubmarineResourceUtils() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
public static Resource convertYarnResourceToServiceResource(
org.apache.hadoop.yarn.api.records.Resource yarnResource) {
Resource serviceResource = new Resource();
serviceResource.setCpus(yarnResource.getVirtualCores());
serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize()));
Map<String, ResourceInformation> riMap = new HashMap<>();
for (org.apache.hadoop.yarn.api.records.ResourceInformation ri :
yarnResource.getAllResourcesListCopy()) {
ResourceInformation serviceRi = new ResourceInformation();
serviceRi.setValue(ri.getValue());
serviceRi.setUnit(ri.getUnits());
riMap.put(ri.getName(), serviceRi);
}
serviceResource.setResourceInformations(riMap);
return serviceResource;
}
}

View File

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
/**
* Utilities for zipping directories and adding existing directories to zips.
*/
public final class ZipUtilities {
private ZipUtilities() {
throw new UnsupportedOperationException("This class should not be " +
"instantiated!");
}
private static final Logger LOG = LoggerFactory.getLogger(ZipUtilities.class);
@VisibleForTesting
public static String zipDir(String srcDir, String dstFile)
throws IOException {
FileOutputStream fos = new FileOutputStream(dstFile);
ZipOutputStream zos = new ZipOutputStream(fos);
File srcFile = new File(srcDir);
LOG.info("Compressing directory {}", srcDir);
addDirToZip(zos, srcFile, srcFile);
// close the ZipOutputStream
zos.close();
LOG.info("Compressed directory {} to file: {}", srcDir, dstFile);
return dstFile;
}
private static void addDirToZip(ZipOutputStream zos, File srcFile, File base)
throws IOException {
File[] files = srcFile.listFiles();
if (files == null) {
return;
}
for (File file : files) {
// if it's directory, add recursively
if (file.isDirectory()) {
addDirToZip(zos, file, base);
continue;
}
byte[] buffer = new byte[1024];
try(FileInputStream fis = new FileInputStream(file)) {
String name = base.toURI().relativize(file.toURI()).getPath();
LOG.info("Adding file {} to zip", name);
zos.putNextEntry(new ZipEntry(name));
int length;
while ((length = fis.read(buffer)) > 0) {
zos.write(buffer, 0, length);
}
zos.flush();
} finally {
zos.closeEntry();
}
}
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes utility classes.
*/
package org.apache.hadoop.yarn.submarine.utils;

View File

@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine;
import com.google.common.collect.Lists;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.List;
import static org.junit.Assert.assertTrue;
/**
* File utilities for tests.
* Provides methods that can create, delete files or directories
* in a temp directory, or any specified directory.
*/
public class FileUtilitiesForTests {
private static final Logger LOG =
LoggerFactory.getLogger(FileUtilitiesForTests.class);
private String tempDir;
private List<File> cleanupFiles;
public void setup() {
cleanupFiles = Lists.newArrayList();
tempDir = System.getProperty("java.io.tmpdir");
}
public void teardown() throws IOException {
LOG.info("About to clean up files: " + cleanupFiles);
List<File> dirs = Lists.newArrayList();
for (File cleanupFile : cleanupFiles) {
if (cleanupFile.isDirectory()) {
dirs.add(cleanupFile);
} else {
deleteFile(cleanupFile);
}
}
for (File dir : dirs) {
deleteFile(dir);
}
}
public File createFileInTempDir(String filename) throws IOException {
File file = new File(tempDir, new Path(filename).getName());
createFile(file);
return file;
}
public File createDirInTempDir(String dirName) {
File file = new File(tempDir, new Path(dirName).getName());
createDirectory(file);
return file;
}
public File createFileInDir(Path dir, String filename) throws IOException {
File dirTmp = new File(dir.toUri().getPath());
if (!dirTmp.exists()) {
createDirectory(dirTmp);
}
File file =
new File(dir.toUri().getPath() + "/" + new Path(filename).getName());
createFile(file);
return file;
}
public File createFileInDir(File dir, String filename) throws IOException {
if (!dir.exists()) {
createDirectory(dir);
}
File file = new File(dir, filename);
createFile(file);
return file;
}
public File createDirectory(Path parent, String dirname) {
File dir =
new File(parent.toUri().getPath() + "/" + new Path(dirname).getName());
createDirectory(dir);
return dir;
}
public File createDirectory(File parent, String dirname) {
File dir =
new File(parent.getPath() + "/" + new Path(dirname).getName());
createDirectory(dir);
return dir;
}
private void createDirectory(File dir) {
boolean result = dir.mkdir();
assertTrue("Failed to create directory " + dir.getAbsolutePath(), result);
assertTrue("Directory does not exist: " + dir.getAbsolutePath(),
dir.exists());
this.cleanupFiles.add(dir);
}
private void createFile(File file) throws IOException {
boolean result = file.createNewFile();
assertTrue("Failed to create file " + file.getAbsolutePath(), result);
assertTrue("File does not exist: " + file.getAbsolutePath(), file.exists());
this.cleanupFiles.add(file);
}
private static void deleteFile(File file) throws IOException {
if (file.isDirectory()) {
LOG.info("Removing directory: " + file.getAbsolutePath());
FileUtils.deleteDirectory(file);
}
if (file.exists()) {
LOG.info("Removing file: " + file.getAbsolutePath());
boolean result = file.delete();
assertTrue("Deletion of file " + file.getAbsolutePath()
+ " was not successful!", result);
}
}
public File getTempFileWithName(String filename) {
return new File(tempDir + "/" + new Path(filename).getName());
}
public static File getFilename(Path parent, String filename) {
return new File(
parent.toUri().getPath() + "/" + new Path(filename).getName());
}
}

View File

@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
import com.google.common.collect.Lists;
import java.util.List;
class ParamBuilderForTest {
private final List<String> params = Lists.newArrayList();
static ParamBuilderForTest create() {
return new ParamBuilderForTest();
}
ParamBuilderForTest withJobName(String name) {
params.add("--name");
params.add(name);
return this;
}
ParamBuilderForTest withDockerImage(String dockerImage) {
params.add("--docker_image");
params.add(dockerImage);
return this;
}
ParamBuilderForTest withInputPath(String inputPath) {
params.add("--input_path");
params.add(inputPath);
return this;
}
ParamBuilderForTest withCheckpointPath(String checkpointPath) {
params.add("--checkpoint_path");
params.add(checkpointPath);
return this;
}
ParamBuilderForTest withNumberOfWorkers(int numWorkers) {
params.add("--num_workers");
params.add(String.valueOf(numWorkers));
return this;
}
ParamBuilderForTest withNumberOfPs(int numPs) {
params.add("--num_ps");
params.add(String.valueOf(numPs));
return this;
}
ParamBuilderForTest withWorkerLaunchCommand(String launchCommand) {
params.add("--worker_launch_cmd");
params.add(launchCommand);
return this;
}
ParamBuilderForTest withPsLaunchCommand(String launchCommand) {
params.add("--ps_launch_cmd");
params.add(launchCommand);
return this;
}
ParamBuilderForTest withWorkerResources(String workerResources) {
params.add("--worker_resources");
params.add(workerResources);
return this;
}
ParamBuilderForTest withPsResources(String psResources) {
params.add("--ps_resources");
params.add(psResources);
return this;
}
ParamBuilderForTest withWorkerDockerImage(String dockerImage) {
params.add("--worker_docker_image");
params.add(dockerImage);
return this;
}
ParamBuilderForTest withPsDockerImage(String dockerImage) {
params.add("--ps_docker_image");
params.add(dockerImage);
return this;
}
ParamBuilderForTest withVerbose() {
params.add("--verbose");
return this;
}
ParamBuilderForTest withTensorboard() {
params.add("--tensorboard");
return this;
}
ParamBuilderForTest withTensorboardResources(String resources) {
params.add("--tensorboard_resources");
params.add(resources);
return this;
}
ParamBuilderForTest withTensorboardDockerImage(String dockerImage) {
params.add("--tensorboard_docker_image");
params.add(dockerImage);
return this;
}
ParamBuilderForTest withQuickLink(String quickLink) {
params.add("--quicklink");
params.add(quickLink);
return this;
}
ParamBuilderForTest withLocalization(String remoteUrl, String localUrl) {
params.add("--localization");
params.add(remoteUrl + ":" + localUrl);
return this;
}
String[] build() {
return params.toArray(new String[0]);
}
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceJobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import java.io.IOException;
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Common operations shared with test classes using Run job-related actions.
*/
public class TestYarnServiceRunJobCliCommons {
static final String DEFAULT_JOB_NAME = "my-job";
static final String DEFAULT_DOCKER_IMAGE = "tf-docker:1.1.0";
static final String DEFAULT_INPUT_PATH = "s3://input";
static final String DEFAULT_CHECKPOINT_PATH = "s3://output";
static final String DEFAULT_WORKER_DOCKER_IMAGE = "worker.image";
static final String DEFAULT_PS_DOCKER_IMAGE = "ps.image";
static final String DEFAULT_WORKER_LAUNCH_CMD = "python run-job.py";
static final String DEFAULT_PS_LAUNCH_CMD = "python run-ps.py";
static final String DEFAULT_TENSORBOARD_RESOURCES = "memory=2G,vcores=2";
static final String DEFAULT_WORKER_RESOURCES = "memory=2048M,vcores=2";
static final String DEFAULT_PS_RESOURCES = "memory=4096M,vcores=4";
static final String DEFAULT_TENSORBOARD_DOCKER_IMAGE = "tb_docker_image:001";
private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
void setup() throws IOException, YarnException {
SubmarineLogs.verboseOff();
AppAdminClient serviceClient = mock(AppAdminClient.class);
when(serviceClient.actionLaunch(any(String.class), any(String.class),
any(Long.class), any(String.class))).thenReturn(EXIT_SUCCESS);
when(serviceClient.getStatusString(any(String.class))).thenReturn(
"{\"id\": \"application_1234_1\"}");
YarnServiceUtils.setStubServiceClient(serviceClient);
fileUtils.setup();
}
void teardown() throws IOException {
fileUtils.teardown();
}
FileUtilitiesForTests getFileUtils() {
return fileUtils;
}
Service getServiceSpecFromJobSubmitter(JobSubmitter jobSubmitter) {
return ((YarnServiceJobSubmitter) jobSubmitter).getServiceWrapper()
.getService();
}
}

View File

@ -0,0 +1,599 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.ConfigFile;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Class to test YarnService localization feature with the Run job CLI action.
*/
public class TestYarnServiceRunJobCliLocalization {
private static final String ZIP_EXTENSION = ".zip";
private TestYarnServiceRunJobCliCommons testCommons =
new TestYarnServiceRunJobCliCommons();
private MockClientContext mockClientContext;
private RemoteDirectoryManager spyRdm;
@Before
public void before() throws IOException, YarnException {
testCommons.setup();
mockClientContext = YarnServiceCliTestUtils.getMockClientContext();
spyRdm = setupSpyRemoteDirManager();
}
@After
public void cleanup() throws IOException {
testCommons.teardown();
}
private ParamBuilderForTest createCommonParamsBuilder() {
return ParamBuilderForTest.create()
.withJobName(DEFAULT_JOB_NAME)
.withDockerImage(DEFAULT_DOCKER_IMAGE)
.withInputPath(DEFAULT_INPUT_PATH)
.withCheckpointPath(DEFAULT_CHECKPOINT_PATH)
.withNumberOfWorkers(3)
.withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE)
.withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD)
.withWorkerResources(DEFAULT_WORKER_RESOURCES)
.withNumberOfPs(2)
.withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE)
.withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD)
.withPsResources(DEFAULT_PS_RESOURCES)
.withVerbose();
}
private void assertFilesAreDeleted(File... files) {
for (File file : files) {
assertFalse("File should be deleted: " + file.getAbsolutePath(),
file.exists());
}
}
private RemoteDirectoryManager setupSpyRemoteDirManager() {
RemoteDirectoryManager spyRdm =
spy(mockClientContext.getRemoteDirectoryManager());
mockClientContext.setRemoteDirectoryMgr(spyRdm);
return spyRdm;
}
private Path getStagingDir() throws IOException {
return mockClientContext.getRemoteDirectoryManager()
.getJobStagingArea(DEFAULT_JOB_NAME, true);
}
private RunJobCli createRunJobCliWithoutVerboseAssertion() {
return new RunJobCli(mockClientContext);
}
private RunJobCli createRunJobCli() {
RunJobCli runJobCli = new RunJobCli(mockClientContext);
assertFalse(SubmarineLogs.isVerbose());
return runJobCli;
}
private String getFilePath(String localUrl, Path stagingDir) {
return stagingDir.toUri().getPath()
+ "/" + new Path(localUrl).getName();
}
private String getFilePathWithSuffix(Path stagingDir, String localUrl,
String suffix) {
return stagingDir.toUri().getPath() + "/" + new Path(localUrl).getName()
+ suffix;
}
private void assertConfigFile(ConfigFile expected, ConfigFile actual) {
assertEquals("ConfigFile does not equal to expected!", expected, actual);
}
private void assertNumberOfLocalizations(List<ConfigFile> files,
int expected) {
assertEquals("Number of localizations is not the expected!", expected,
files.size());
}
private void verifyRdmCopyToRemoteLocalCalls(int expectedCalls)
throws IOException {
verify(spyRdm, times(expectedCalls)).copyRemoteToLocal(anyString(),
anyString());
}
/**
* Basic test.
* In one hand, create local temp file/dir for hdfs URI in
* local staging dir.
* In the other hand, use MockRemoteDirectoryManager mock
* implementation when check FileStatus or exists of HDFS file/dir
* --localization hdfs:///user/yarn/script1.py:.
* --localization /temp/script2.py:./
* --localization /temp/script2.py:/opt/script.py
*/
@Test
public void testRunJobWithBasicLocalization() throws Exception {
String remoteUrl = "hdfs:///user/yarn/script1.py";
String containerLocal1 = ".";
String localUrl = "/temp/script2.py";
String containerLocal2 = "./";
String containerLocal3 = "/opt/script.py";
// Create local file, we need to put it under local temp dir
File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUrl);
// create remote file in local staging dir to simulate HDFS
Path stagingDir = getStagingDir();
testCommons.getFileUtils().createFileInDir(stagingDir, remoteUrl);
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUrl, containerLocal1)
.withLocalization(localFile1.getAbsolutePath(), containerLocal2)
.withLocalization(localFile1.getAbsolutePath(), containerLocal3)
.build();
RunJobCli runJobCli = createRunJobCli();
runJobCli.run(params);
Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
assertNumberOfServiceComponents(serviceSpec, 3);
// No remote dir and HDFS file exists.
// Ensure download never happened.
verifyRdmCopyToRemoteLocalCalls(0);
// Ensure local original files are not deleted
assertTrue(localFile1.exists());
List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
assertNumberOfLocalizations(files, 3);
ConfigFile expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
expectedConfigFile.setSrcFile(remoteUrl);
expectedConfigFile.setDestFile(new Path(remoteUrl).getName());
assertConfigFile(expectedConfigFile, files.get(0));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir));
expectedConfigFile.setDestFile(new Path(localUrl).getName());
assertConfigFile(expectedConfigFile, files.get(1));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir));
expectedConfigFile.setDestFile(new Path(containerLocal3).getName());
assertConfigFile(expectedConfigFile, files.get(2));
// Ensure env value is correct
String env = serviceSpec.getConfiguration().getEnv()
.get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
String expectedMounts = new Path(containerLocal3).getName()
+ ":" + containerLocal3 + ":rw";
assertTrue(env.contains(expectedMounts));
}
private void assertNumberOfServiceComponents(Service serviceSpec,
int expected) {
assertEquals(expected, serviceSpec.getComponents().size());
}
/**
* Non HDFS remote URI test.
* --localization https://a/b/1.patch:.
* --localization s3a://a/dir:/opt/mys3dir
*/
@Test
public void testRunJobWithNonHDFSRemoteLocalization() throws Exception {
String remoteUri1 = "https://a/b/1.patch";
String containerLocal1 = ".";
String remoteUri2 = "s3a://a/s3dir";
String containerLocal2 = "/opt/mys3dir";
// create remote file in local staging dir to simulate HDFS
Path stagingDir = getStagingDir();
testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1);
File remoteDir1 =
testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2);
testCommons.getFileUtils().createFileInDir(remoteDir1, "afile");
String suffix1 = "_" + remoteDir1.lastModified()
+ "-" + mockClientContext.getRemoteDirectoryManager()
.getRemoteFileSize(remoteUri2);
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUri1, containerLocal1)
.withLocalization(remoteUri2, containerLocal2)
.build();
RunJobCli runJobCli = createRunJobCli();
runJobCli.run(params);
Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
assertNumberOfServiceComponents(serviceSpec, 3);
// Ensure download remote dir 2 times
verifyRdmCopyToRemoteLocalCalls(2);
// Ensure downloaded temp files are deleted
assertFilesAreDeleted(
testCommons.getFileUtils().getTempFileWithName(remoteUri1),
testCommons.getFileUtils().getTempFileWithName(remoteUri2));
// Ensure zip file are deleted
assertFilesAreDeleted(
testCommons.getFileUtils()
.getTempFileWithName(remoteUri2 + "_" + suffix1 + ZIP_EXTENSION));
List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
assertNumberOfLocalizations(files, 2);
ConfigFile expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC);
expectedConfigFile.setSrcFile(getFilePath(remoteUri1, stagingDir));
expectedConfigFile.setDestFile(new Path(remoteUri1).getName());
assertConfigFile(expectedConfigFile, files.get(0));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, remoteUri2, suffix1 + ZIP_EXTENSION));
expectedConfigFile.setDestFile(new Path(containerLocal2).getName());
assertConfigFile(expectedConfigFile, files.get(1));
// Ensure env value is correct
String env = serviceSpec.getConfiguration().getEnv()
.get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
String expectedMounts = new Path(remoteUri2).getName()
+ ":" + containerLocal2 + ":rw";
assertTrue(env.contains(expectedMounts));
}
/**
* Test HDFS dir localization.
* --localization hdfs:///user/yarn/mydir:./mydir1
* --localization hdfs:///user/yarn/mydir2:/opt/dir2:rw
* --localization hdfs:///user/yarn/mydir:.
* --localization hdfs:///user/yarn/mydir2:./
*/
@Test
public void testRunJobWithHdfsDirLocalization() throws Exception {
String remoteUrl = "hdfs:///user/yarn/mydir";
String containerPath = "./mydir1";
String remoteUrl2 = "hdfs:///user/yarn/mydir2";
String containerPath2 = "/opt/dir2";
String containerPath3 = ".";
String containerPath4 = "./";
// create remote file in local staging dir to simulate HDFS
Path stagingDir = getStagingDir();
File remoteDir1 =
testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl);
testCommons.getFileUtils().createFileInDir(remoteDir1, "1.py");
testCommons.getFileUtils().createFileInDir(remoteDir1, "2.py");
File remoteDir2 =
testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl2);
testCommons.getFileUtils().createFileInDir(remoteDir2, "3.py");
testCommons.getFileUtils().createFileInDir(remoteDir2, "4.py");
String suffix1 = "_" + remoteDir1.lastModified()
+ "-" + mockClientContext.getRemoteDirectoryManager()
.getRemoteFileSize(remoteUrl);
String suffix2 = "_" + remoteDir2.lastModified()
+ "-" + mockClientContext.getRemoteDirectoryManager()
.getRemoteFileSize(remoteUrl2);
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUrl, containerPath)
.withLocalization(remoteUrl2, containerPath2)
.withLocalization(remoteUrl, containerPath3)
.withLocalization(remoteUrl2, containerPath4)
.build();
RunJobCli runJobCli = createRunJobCli();
runJobCli.run(params);
Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
assertNumberOfServiceComponents(serviceSpec, 3);
// Ensure download remote dir 4 times
verifyRdmCopyToRemoteLocalCalls(4);
// Ensure downloaded temp files are deleted
assertFilesAreDeleted(
testCommons.getFileUtils().getTempFileWithName(remoteUrl),
testCommons.getFileUtils().getTempFileWithName(remoteUrl2));
// Ensure zip file are deleted
assertFilesAreDeleted(
testCommons.getFileUtils()
.getTempFileWithName(remoteUrl + suffix1 + ZIP_EXTENSION),
testCommons.getFileUtils()
.getTempFileWithName(remoteUrl2 + suffix2 + ZIP_EXTENSION));
// Ensure files will be localized
List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
assertNumberOfLocalizations(files, 4);
ConfigFile expectedConfigFile = new ConfigFile();
// The hdfs dir should be download and compress and let YARN to uncompress
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION));
// Relative path in container, but not "." or "./". Use its own name
expectedConfigFile.setDestFile(new Path(containerPath).getName());
assertConfigFile(expectedConfigFile, files.get(0));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION));
expectedConfigFile.setDestFile(new Path(containerPath2).getName());
assertConfigFile(expectedConfigFile, files.get(1));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION));
// Relative path in container ".", use remote path name
expectedConfigFile.setDestFile(new Path(remoteUrl).getName());
assertConfigFile(expectedConfigFile, files.get(2));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION));
// Relative path in container ".", use remote path name
expectedConfigFile.setDestFile(new Path(remoteUrl2).getName());
assertConfigFile(expectedConfigFile, files.get(3));
// Ensure mounts env value is correct. Add one mount string
String env = serviceSpec.getConfiguration().getEnv()
.get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
String expectedMounts =
new Path(containerPath2).getName() + ":" + containerPath2 + ":rw";
assertTrue(env.contains(expectedMounts));
}
/**
* Test if file/dir to be localized whose size exceeds limit.
* Max 10MB in configuration, mock remote will
* always return file size 100MB.
* This configuration will fail the job which has remoteUri
* But don't impact local dir/file
*
* --localization https://a/b/1.patch:.
* --localization s3a://a/dir:/opt/mys3dir
* --localization /temp/script2.py:./
*/
@Test
public void testRunJobRemoteUriExceedLocalizationSize() throws Exception {
String remoteUri1 = "https://a/b/1.patch";
String containerLocal1 = ".";
String remoteUri2 = "s3a://a/s3dir";
String containerLocal2 = "/opt/mys3dir";
String localUri1 = "/temp/script2";
String containerLocal3 = "./";
SubmarineConfiguration submarineConf = new SubmarineConfiguration();
// Max 10MB, mock remote will always return file size 100MB.
submarineConf.set(
SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB,
"10");
mockClientContext.setSubmarineConfig(submarineConf);
assertFalse(SubmarineLogs.isVerbose());
// create remote file in local staging dir to simulate
Path stagingDir = getStagingDir();
testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1);
File remoteDir1 =
testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2);
testCommons.getFileUtils().createFileInDir(remoteDir1, "afile");
// create local file, we need to put it under local temp dir
File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUri1);
try {
RunJobCli runJobCli = createRunJobCli();
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUri1, containerLocal1)
.build();
runJobCli.run(params);
} catch (IOException e) {
// Shouldn't have exception because it's within file size limit
fail();
}
// we should download because fail fast
verifyRdmCopyToRemoteLocalCalls(1);
try {
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUri1, containerLocal1)
.withLocalization(remoteUri2, containerLocal2)
.withLocalization(localFile1.getAbsolutePath(), containerLocal3)
.build();
reset(spyRdm);
RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
runJobCli.run(params);
} catch (IOException e) {
assertTrue(e.getMessage()
.contains("104857600 exceeds configured max size:10485760"));
// we shouldn't do any download because fail fast
verifyRdmCopyToRemoteLocalCalls(0);
}
try {
String[] params = createCommonParamsBuilder()
.withLocalization(localFile1.getAbsolutePath(), containerLocal3)
.build();
RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
runJobCli.run(params);
} catch (IOException e) {
assertTrue(e.getMessage()
.contains("104857600 exceeds configured max size:10485760"));
// we shouldn't do any download because fail fast
verifyRdmCopyToRemoteLocalCalls(0);
}
}
/**
* Test remote Uri doesn't exist.
* */
@Test
public void testRunJobWithNonExistRemoteUri() throws Exception {
String remoteUri1 = "hdfs:///a/b/1.patch";
String containerLocal1 = ".";
String localUri1 = "/a/b/c";
String containerLocal2 = "./";
try {
String[] params = createCommonParamsBuilder()
.withLocalization(remoteUri1, containerLocal1)
.build();
RunJobCli runJobCli = createRunJobCli();
runJobCli.run(params);
} catch (IOException e) {
assertTrue(e.getMessage().contains("doesn't exists"));
}
try {
String[] params = createCommonParamsBuilder()
.withLocalization(localUri1, containerLocal2)
.build();
RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion();
runJobCli.run(params);
} catch (IOException e) {
assertTrue(e.getMessage().contains("doesn't exists"));
}
}
/**
* Test local dir
* --localization /user/yarn/mydir:./mydir1
* --localization /user/yarn/mydir2:/opt/dir2:rw
* --localization /user/yarn/mydir2:.
*/
@Test
public void testRunJobWithLocalDirLocalization() throws Exception {
String localUrl = "/user/yarn/mydir";
String containerPath = "./mydir1";
String localUrl2 = "/user/yarn/mydir2";
String containerPath2 = "/opt/dir2";
String containerPath3 = ".";
// create local file
File localDir1 = testCommons.getFileUtils().createDirInTempDir(localUrl);
testCommons.getFileUtils().createFileInDir(localDir1, "1.py");
testCommons.getFileUtils().createFileInDir(localDir1, "2.py");
File localDir2 = testCommons.getFileUtils().createDirInTempDir(localUrl2);
testCommons.getFileUtils().createFileInDir(localDir2, "3.py");
testCommons.getFileUtils().createFileInDir(localDir2, "4.py");
String suffix1 = "_" + localDir1.lastModified()
+ "-" + localDir1.length();
String suffix2 = "_" + localDir2.lastModified()
+ "-" + localDir2.length();
String[] params = createCommonParamsBuilder()
.withLocalization(localDir1.getAbsolutePath(), containerPath)
.withLocalization(localDir2.getAbsolutePath(), containerPath2)
.withLocalization(localDir2.getAbsolutePath(), containerPath3)
.build();
RunJobCli runJobCli = createRunJobCli();
runJobCli.run(params);
Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
assertNumberOfServiceComponents(serviceSpec, 3);
// we shouldn't do any download
verifyRdmCopyToRemoteLocalCalls(0);
// Ensure local original files are not deleted
assertTrue(localDir1.exists());
assertTrue(localDir2.exists());
// Ensure zip file are deleted
assertFalse(
testCommons.getFileUtils()
.getTempFileWithName(localUrl + suffix1 + ZIP_EXTENSION)
.exists());
assertFalse(
testCommons.getFileUtils()
.getTempFileWithName(localUrl2 + suffix2 + ZIP_EXTENSION)
.exists());
// Ensure dirs will be zipped and localized
List<ConfigFile> files = serviceSpec.getConfiguration().getFiles();
assertNumberOfLocalizations(files, 3);
Path stagingDir = getStagingDir();
ConfigFile expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, localUrl, suffix1 + ZIP_EXTENSION));
expectedConfigFile.setDestFile(new Path(containerPath).getName());
assertConfigFile(expectedConfigFile, files.get(0));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION));
expectedConfigFile.setDestFile(new Path(containerPath2).getName());
assertConfigFile(expectedConfigFile, files.get(1));
expectedConfigFile = new ConfigFile();
expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE);
expectedConfigFile.setSrcFile(
getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION));
expectedConfigFile.setDestFile(new Path(localUrl2).getName());
assertConfigFile(expectedConfigFile, files.get(2));
// Ensure mounts env value is correct
String env = serviceSpec.getConfiguration().getEnv()
.get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS");
String expectedMounts = new Path(containerPath2).getName()
+ ":" + containerPath2 + ":rw";
assertTrue(env.contains(expectedMounts));
}
}

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Class to test the {@link ServiceWrapper}.
*/
public class TestServiceWrapper {
private AbstractComponent createMockAbstractComponent(Component mockComponent,
String componentName, String localScriptFile) throws IOException {
when(mockComponent.getName()).thenReturn(componentName);
AbstractComponent mockAbstractComponent = mock(AbstractComponent.class);
when(mockAbstractComponent.createComponent()).thenReturn(mockComponent);
when(mockAbstractComponent.getLocalScriptFile())
.thenReturn(localScriptFile);
return mockAbstractComponent;
}
@Test
public void testWithSingleComponent() throws IOException {
Service mockService = mock(Service.class);
ServiceWrapper serviceWrapper = new ServiceWrapper(mockService);
Component mockComponent = mock(Component.class);
AbstractComponent mockAbstractComponent =
createMockAbstractComponent(mockComponent, "testComponent",
"testLocalScriptFile");
serviceWrapper.addComponent(mockAbstractComponent);
verify(mockService).addComponent(eq(mockComponent));
String launchCommand =
serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent");
assertEquals("testLocalScriptFile", launchCommand);
}
@Test
public void testWithMultipleComponent() throws IOException {
Service mockService = mock(Service.class);
ServiceWrapper serviceWrapper = new ServiceWrapper(mockService);
Component mockComponent1 = mock(Component.class);
AbstractComponent mockAbstractComponent1 =
createMockAbstractComponent(mockComponent1, "testComponent1",
"testLocalScriptFile1");
Component mockComponent2 = mock(Component.class);
AbstractComponent mockAbstractComponent2 =
createMockAbstractComponent(mockComponent2, "testComponent2",
"testLocalScriptFile2");
serviceWrapper.addComponent(mockAbstractComponent1);
serviceWrapper.addComponent(mockAbstractComponent2);
verify(mockService).addComponent(eq(mockComponent1));
verify(mockService).addComponent(eq(mockComponent2));
String launchCommand1 =
serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent1");
assertEquals("testLocalScriptFile1", launchCommand1);
String launchCommand2 =
serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent2");
assertEquals("testLocalScriptFile2", launchCommand2);
}
}

View File

@ -14,26 +14,30 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons;
import org.codehaus.jettison.json.JSONException;
import org.junit.Assert;
import org.junit.Test;
/**
* Class to test some functionality of {@link TensorFlowCommons}.
*/
public class TestTFConfigGenerator {
@Test
public void testSimpleDistributedTFConfigGenerator() throws JSONException {
String json = YarnServiceUtils.getTFConfigEnv("worker", 5, 3, "wtan",
String json = TensorFlowCommons.getTFConfigEnv("worker", 5, 3, "wtan",
"tf-job-001", "example.com");
String expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"worker\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}";
Assert.assertEquals(expected, json);
json = YarnServiceUtils.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
json = TensorFlowCommons.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001",
"example.com");
expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"ps\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}";
Assert.assertEquals(expected, json);
json = YarnServiceUtils.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
json = TensorFlowCommons.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001",
"example.com");
expected =
"{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"master\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}";

View File

@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* This class is an abstract base class for testing Tensorboard and TensorFlow
* launch commands.
*/
public abstract class AbstractLaunchCommandTestHelper {
private TaskType taskType;
private boolean useTaskTypeOverride;
@Rule
public ExpectedException expectedException = ExpectedException.none();
private void assertScriptContainsExportedEnvVar(List<String> fileContents,
String varName) {
String expected = String.format("export %s=", varName);
assertScriptContainsLine(fileContents, expected);
}
public static void assertScriptContainsExportedEnvVarWithValue(
List<String> fileContents, String varName, String value) {
String expected = String.format("export %s=%s", varName, value);
assertScriptContainsLine(fileContents, expected);
}
public static void assertScriptContainsLine(List<String> fileContents,
String expected) {
String message = String.format(
"File does not contain expected line '%s'!" + " File contents: %s",
expected, Arrays.toString(fileContents.toArray()));
assertTrue(message, fileContents.contains(expected));
}
public static void assertScriptContainsLineWithRegex(
List<String> fileContents,
String regex) {
String message = String.format(
"File does not contain expected line '%s'!" + " File contents: %s",
regex, Arrays.toString(fileContents.toArray()));
for (String line : fileContents) {
if (line.matches(regex)) {
return;
}
}
fail(message);
}
public static void assertScriptDoesNotContainLine(List<String> fileContents,
String expected) {
String message = String.format(
"File contains unexpected line '%s'!" + " File contents: %s",
expected, Arrays.toString(fileContents.toArray()));
assertFalse(message, fileContents.contains(expected));
}
private AbstractLaunchCommand createLaunchCommandByTaskType(TaskType taskType,
RunJobParameters params) throws IOException {
MockClientContext mockClientContext = new MockClientContext();
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(mockClientContext, fsOperations);
Component component = new Component();
Configuration yarnConfig = new Configuration();
return createLaunchCommandByTaskTypeInternal(taskType, params,
hadoopEnvSetup, component, yarnConfig);
}
private AbstractLaunchCommand createLaunchCommandByTaskTypeInternal(
TaskType taskType, RunJobParameters params,
HadoopEnvironmentSetup hadoopEnvSetup, Component component,
Configuration yarnConfig)
throws IOException {
if (taskType == TaskType.TENSORBOARD) {
return new TensorBoardLaunchCommand(
hadoopEnvSetup, getTaskType(taskType), component, params);
} else if (taskType == TaskType.WORKER
|| taskType == TaskType.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(
hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig);
} else if (taskType == TaskType.PS) {
return new TensorFlowPsLaunchCommand(
hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig);
}
throw new IllegalStateException("Unknown taskType!");
}
public void overrideTaskType(TaskType taskType) {
this.taskType = taskType;
this.useTaskTypeOverride = true;
}
private TaskType getTaskType(TaskType taskType) {
if (useTaskTypeOverride) {
return this.taskType;
}
return taskType;
}
public void testHdfsRelatedEnvironmentIsUndefined(TaskType taskType,
RunJobParameters params) throws IOException {
AbstractLaunchCommand launchCommand =
createLaunchCommandByTaskType(taskType, params);
expectedException.expect(IOException.class);
expectedException
.expectMessage("Failed to detect HDFS-related environments.");
launchCommand.generateLaunchScript();
}
public List<String> testHdfsRelatedEnvironmentIsDefined(TaskType taskType,
RunJobParameters params) throws IOException {
AbstractLaunchCommand launchCommand =
createLaunchCommandByTaskType(taskType, params);
String result = launchCommand.generateLaunchScript();
assertNotNull(result);
File resultFile = new File(result);
assertTrue(resultFile.exists());
List<String> fileContents = Files.readAllLines(
Paths.get(resultFile.toURI()),
Charset.forName("UTF-8"));
assertEquals("#!/bin/bash", fileContents.get(0));
assertScriptContainsExportedEnvVar(fileContents, "HADOOP_HOME");
assertScriptContainsExportedEnvVar(fileContents, "HADOOP_YARN_HOME");
assertScriptContainsExportedEnvVarWithValue(fileContents,
"HADOOP_HDFS_HOME", "testHdfsHome");
assertScriptContainsExportedEnvVarWithValue(fileContents,
"HADOOP_COMMON_HOME", "testHdfsHome");
assertScriptContainsExportedEnvVarWithValue(fileContents, "HADOOP_CONF_DIR",
"$WORK_DIR");
assertScriptContainsExportedEnvVarWithValue(fileContents, "JAVA_HOME",
"testJavaHome");
assertScriptContainsExportedEnvVarWithValue(fileContents, "LD_LIBRARY_PATH",
"$LD_LIBRARY_PATH:$JAVA_HOME/lib/amd64/server");
assertScriptContainsExportedEnvVarWithValue(fileContents, "CLASSPATH",
"`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`");
return fileContents;
}
}

View File

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
/**
* This class is to test the {@link LaunchCommandFactory}.
*/
public class TestLaunchCommandFactory {
private LaunchCommandFactory createLaunchCommandFactory(
RunJobParameters parameters) {
HadoopEnvironmentSetup hadoopEnvSetup = mock(HadoopEnvironmentSetup.class);
Configuration configuration = mock(Configuration.class);
return new LaunchCommandFactory(hadoopEnvSetup, parameters, configuration);
}
@Test
public void createLaunchCommandWorkerAndPrimaryWorker() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerLaunchCmd("testWorkerLaunchCommand");
LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory(
parameters);
Component mockComponent = mock(Component.class);
AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(TaskType.PRIMARY_WORKER,
mockComponent);
assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand);
launchCommand =
launchCommandFactory.createLaunchCommand(TaskType.WORKER,
mockComponent);
assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand);
}
@Test
public void createLaunchCommandPs() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setPSLaunchCmd("testPSLaunchCommand");
LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory(
parameters);
Component mockComponent = mock(Component.class);
AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(TaskType.PS,
mockComponent);
assertTrue(launchCommand instanceof TensorFlowPsLaunchCommand);
}
@Test
public void createLaunchCommandTensorboard() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setCheckpointPath("testCheckpointPath");
LaunchCommandFactory launchCommandFactory =
createLaunchCommandFactory(parameters);
Component mockComponent = mock(Component.class);
AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(TaskType.TENSORBOARD,
mockComponent);
assertTrue(launchCommand instanceof TensorBoardLaunchCommand);
}
}

View File

@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import com.google.common.collect.ImmutableList;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper;
import org.junit.Test;
import java.io.IOException;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME;
/**
* This class is to test the {@link TensorBoardLaunchCommand}.
*/
public class TestTensorBoardLaunchCommand extends
AbstractLaunchCommandTestHelper {
@Test
public void testHdfsRelatedEnvironmentIsUndefined() throws IOException {
RunJobParameters params = new RunJobParameters();
params.setInputPath("hdfs://bla");
params.setName("testJobname");
params.setCheckpointPath("something");
testHdfsRelatedEnvironmentIsUndefined(TaskType.TENSORBOARD,
params);
}
@Test
public void testHdfsRelatedEnvironmentIsDefined() throws IOException {
RunJobParameters params = new RunJobParameters();
params.setName("testName");
params.setCheckpointPath("testCheckpointPath");
params.setInputPath("hdfs://bla");
params.setEnvars(ImmutableList.of(
DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
DOCKER_JAVA_HOME + "=" + "testJavaHome"));
List<String> fileContents =
testHdfsRelatedEnvironmentIsDefined(TaskType.TENSORBOARD,
params);
assertScriptContainsExportedEnvVarWithValue(fileContents, "LC_ALL",
"C && tensorboard --logdir=testCheckpointPath");
}
@Test
public void testCheckpointPathUndefined() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(mockClientContext, fsOperations);
Component component = new Component();
RunJobParameters params = new RunJobParameters();
params.setCheckpointPath(null);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("CheckpointPath must not be null");
new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD,
component, params);
}
@Test
public void testCheckpointPathEmptyString() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(mockClientContext, fsOperations);
Component component = new Component();
RunJobParameters params = new RunJobParameters();
params.setCheckpointPath("");
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("CheckpointPath must not be empty");
new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD,
component, params);
}
}

View File

@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;
import com.google.common.collect.ImmutableList;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME;
/**
* This class is to test the implementors of {@link TensorFlowLaunchCommand}.
*/
@RunWith(Parameterized.class)
public class TestTensorFlowLaunchCommand
extends AbstractLaunchCommandTestHelper {
private TaskType taskType;
@Parameterized.Parameters
public static Collection<Object[]> data() {
Collection<Object[]> params = new ArrayList<>();
params.add(new Object[]{TaskType.WORKER });
params.add(new Object[]{TaskType.PS });
return params;
}
public TestTensorFlowLaunchCommand(TaskType taskType) {
this.taskType = taskType;
}
private void assertScriptContainsLaunchCommand(List<String> fileContents,
RunJobParameters params) {
String launchCommand = null;
if (taskType == TaskType.WORKER) {
launchCommand = params.getWorkerLaunchCmd();
} else if (taskType == TaskType.PS) {
launchCommand = params.getPSLaunchCmd();
}
assertScriptContainsLine(fileContents, launchCommand);
}
private void setLaunchCommandToParams(RunJobParameters params) {
if (taskType == TaskType.WORKER) {
params.setWorkerLaunchCmd("testWorkerLaunchCommand");
} else if (taskType == TaskType.PS) {
params.setPSLaunchCmd("testPsLaunchCommand");
}
}
private void setLaunchCommandToParams(RunJobParameters params, String value) {
if (taskType == TaskType.WORKER) {
params.setWorkerLaunchCmd(value);
} else if (taskType == TaskType.PS) {
params.setPSLaunchCmd(value);
}
}
private void assertTypeInJson(List<String> fileContents) {
String expectedType = null;
if (taskType == TaskType.WORKER) {
expectedType = "worker";
} else if (taskType == TaskType.PS) {
expectedType = "ps";
}
assertScriptContainsLineWithRegex(fileContents, String.format(".*type.*:" +
".*%s.*", expectedType));
}
private TensorFlowLaunchCommand createTensorFlowLaunchCommandObject(
HadoopEnvironmentSetup hadoopEnvSetup, Configuration yarnConfig,
Component component, RunJobParameters params) throws IOException {
if (taskType == TaskType.WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
component,
params, yarnConfig);
} else if (taskType == TaskType.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
params, yarnConfig);
}
throw new IllegalStateException("Unknown tasktype!");
}
@Test
public void testHdfsRelatedEnvironmentIsUndefined() throws IOException {
RunJobParameters params = new RunJobParameters();
params.setInputPath("hdfs://bla");
params.setName("testJobname");
setLaunchCommandToParams(params);
testHdfsRelatedEnvironmentIsUndefined(taskType, params);
}
@Test
public void testHdfsRelatedEnvironmentIsDefined() throws IOException {
RunJobParameters params = new RunJobParameters();
params.setName("testName");
params.setInputPath("hdfs://bla");
params.setEnvars(ImmutableList.of(
DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
DOCKER_JAVA_HOME + "=" + "testJavaHome"));
setLaunchCommandToParams(params);
List<String> fileContents =
testHdfsRelatedEnvironmentIsDefined(taskType,
params);
assertScriptContainsLaunchCommand(fileContents, params);
assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
}
@Test
public void testLaunchCommandIsNull() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(mockClientContext, fsOperations);
Configuration yarnConfig = new Configuration();
Component component = new Component();
RunJobParameters params = new RunJobParameters();
params.setName("testName");
setLaunchCommandToParams(params, null);
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("LaunchCommand must not be null or empty");
TensorFlowLaunchCommand launchCommand =
createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig,
component,
params);
launchCommand.generateLaunchScript();
}
@Test
public void testLaunchCommandIsEmpty() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(mockClientContext, fsOperations);
Configuration yarnConfig = new Configuration();
Component component = new Component();
RunJobParameters params = new RunJobParameters();
params.setName("testName");
setLaunchCommandToParams(params, "");
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("LaunchCommand must not be null or empty");
TensorFlowLaunchCommand launchCommand =
createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig,
component, params);
launchCommand.generateLaunchScript();
}
@Test
public void testDistributedTrainingMissingTaskType() throws IOException {
overrideTaskType(null);
RunJobParameters params = new RunJobParameters();
params.setDistributed(true);
params.setName("testName");
params.setInputPath("hdfs://bla");
params.setEnvars(ImmutableList.of(
DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
DOCKER_JAVA_HOME + "=" + "testJavaHome"));
setLaunchCommandToParams(params);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("TaskType must not be null");
testHdfsRelatedEnvironmentIsDefined(taskType, params);
}
@Test
public void testDistributedTrainingNumberOfWorkersAndPsIsZero()
throws IOException {
RunJobParameters params = new RunJobParameters();
params.setDistributed(true);
params.setNumWorkers(0);
params.setNumPS(0);
params.setName("testName");
params.setInputPath("hdfs://bla");
params.setEnvars(ImmutableList.of(
DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
DOCKER_JAVA_HOME + "=" + "testJavaHome"));
setLaunchCommandToParams(params);
List<String> fileContents =
testHdfsRelatedEnvironmentIsDefined(taskType, params);
assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[\\].*");
assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[\\].*");
assertTypeInJson(fileContents);
}
@Test
public void testDistributedTrainingNumberOfWorkersAndPsIsNonZero()
throws IOException {
RunJobParameters params = new RunJobParameters();
params.setDistributed(true);
params.setNumWorkers(3);
params.setNumPS(2);
params.setName("testName");
params.setInputPath("hdfs://bla");
params.setEnvars(ImmutableList.of(
DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome",
DOCKER_JAVA_HOME + "=" + "testJavaHome"));
setLaunchCommandToParams(params);
List<String> fileContents =
testHdfsRelatedEnvironmentIsDefined(taskType, params);
//assert we have multiple PS and workers
assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG=");
assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[.*,.*\\].*");
assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[.*,.*\\].*");
assertTypeInJson(fileContents);
}
}

View File

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* This class has some helper methods and fields
* in order to test TensorFlow-related Components easier.
*/
public class ComponentTestCommons {
String userName;
TaskType taskType;
LaunchCommandFactory mockLaunchCommandFactory;
FileSystemOperations fsOperations;
MockClientContext mockClientContext;
Configuration yarnConfig;
Resource resource;
ComponentTestCommons(TaskType taskType) {
this.taskType = taskType;
}
public void setup() throws IOException {
this.userName = System.getProperty("user.name");
this.resource = Resource.newInstance(4000, 10);
setupDependencies();
}
private void setupDependencies() throws IOException {
fsOperations = mock(FileSystemOperations.class);
mockClientContext = new MockClientContext();
mockLaunchCommandFactory = mock(LaunchCommandFactory.class);
AbstractLaunchCommand mockLaunchCommand = mock(AbstractLaunchCommand.class);
when(mockLaunchCommand.generateLaunchScript()).thenReturn("mockScript");
when(mockLaunchCommandFactory.createLaunchCommand(eq(taskType),
any(Component.class))).thenReturn(mockLaunchCommand);
yarnConfig = new Configuration();
}
void verifyCommonConfigEnvs(Component component) {
assertNotNull(component.getConfiguration().getEnv());
assertEquals(2, component.getConfiguration().getEnv().size());
assertEquals(ServiceApiConstants.COMPONENT_ID,
component.getConfiguration().getEnv().get(Envs.TASK_INDEX_ENV));
assertEquals(taskType.name(),
component.getConfiguration().getEnv().get(Envs.TASK_TYPE_ENV));
}
void verifyResources(Component component) {
assertNotNull(component.getResource());
assertEquals(10, (int) component.getResource().getCpus());
assertEquals(4000,
(int) Integer.valueOf(component.getResource().getMemory()));
}
}

View File

@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Artifact;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
/**
* This class is to test {@link TensorBoardComponent}.
*/
public class TestTensorBoardComponent {
@Rule
public ExpectedException expectedException = ExpectedException.none();
private ComponentTestCommons testCommons =
new ComponentTestCommons(TaskType.TENSORBOARD);
@Before
public void setUp() throws IOException {
testCommons.setup();
}
private TensorBoardComponent createTensorBoardComponent(
RunJobParameters parameters) {
return new TensorBoardComponent(
testCommons.fsOperations,
testCommons.mockClientContext.getRemoteDirectoryManager(),
parameters,
testCommons.mockLaunchCommandFactory,
testCommons.yarnConfig);
}
@Test
public void testTensorBoardComponentWithNullResource() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setTensorboardResource(null);
TensorBoardComponent tensorBoardComponent =
createTensorBoardComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("TensorBoard resource must not be null");
tensorBoardComponent.createComponent();
}
@Test
public void testTensorBoardComponentWithNullJobName() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setTensorboardResource(testCommons.resource);
parameters.setName(null);
TensorBoardComponent tensorBoardComponent =
createTensorBoardComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("Job name must not be null");
tensorBoardComponent.createComponent();
}
@Test
public void testTensorBoardComponent() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setTensorboardResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setTensorboardDockerImage("testTBDockerImage");
TensorBoardComponent tensorBoardComponent =
createTensorBoardComponent(parameters);
Component component = tensorBoardComponent.createComponent();
assertEquals(testCommons.taskType.getComponentName(), component.getName());
testCommons.verifyCommonConfigEnvs(component);
assertEquals(1L, (long) component.getNumberOfContainers());
assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
testCommons.verifyResources(component);
assertEquals(
new Artifact().type(Artifact.TypeEnum.DOCKER).id("testTBDockerImage"),
component.getArtifact());
assertEquals(String.format(
"http://tensorboard-0.testJobName.%s" + ".testDomain:6006",
testCommons.userName),
tensorBoardComponent.getTensorboardLink());
assertEquals("./run-TENSORBOARD.sh", component.getLaunchCommand());
verify(testCommons.fsOperations)
.uploadToRemoteFileAndLocalizeToContainerWorkDir(
any(Path.class), eq("mockScript"), eq("run-TENSORBOARD.sh"),
eq(component));
}
}

View File

@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import java.io.IOException;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Artifact;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
/**
* This class is to test {@link TensorFlowPsComponent}.
*/
public class TestTensorFlowPsComponent {
@Rule
public ExpectedException expectedException = ExpectedException.none();
private ComponentTestCommons testCommons =
new ComponentTestCommons(TaskType.PS);
@Before
public void setUp() throws IOException {
testCommons.setup();
}
private TensorFlowPsComponent createPsComponent(RunJobParameters parameters) {
return new TensorFlowPsComponent(
testCommons.fsOperations,
testCommons.mockClientContext.getRemoteDirectoryManager(),
testCommons.mockLaunchCommandFactory,
parameters,
testCommons.yarnConfig);
}
private void verifyCommons(Component component) throws IOException {
assertEquals(testCommons.taskType.getComponentName(), component.getName());
testCommons.verifyCommonConfigEnvs(component);
assertTrue(component.getConfiguration().getProperties().isEmpty());
assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
testCommons.verifyResources(component);
assertEquals(
new Artifact().type(Artifact.TypeEnum.DOCKER).id("testPSDockerImage"),
component.getArtifact());
String taskTypeUppercase = testCommons.taskType.name().toUpperCase();
String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase);
assertEquals(String.format("./%s", expectedScriptName),
component.getLaunchCommand());
verify(testCommons.fsOperations)
.uploadToRemoteFileAndLocalizeToContainerWorkDir(
any(Path.class), eq("mockScript"), eq(expectedScriptName),
eq(component));
}
@Test
public void testPSComponentWithNullResource() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setPsResource(null);
TensorFlowPsComponent psComponent =
createPsComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("PS resource must not be null");
psComponent.createComponent();
}
@Test
public void testPSComponentWithNullJobName() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setPsResource(testCommons.resource);
parameters.setNumPS(1);
parameters.setName(null);
TensorFlowPsComponent psComponent =
createPsComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("Job name must not be null");
psComponent.createComponent();
}
@Test
public void testPSComponentZeroNumberOfPS() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setPsResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setPsDockerImage("testPSDockerImage");
parameters.setNumPS(0);
TensorFlowPsComponent psComponent =
createPsComponent(parameters);
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Number of PS should be at least 1!");
psComponent.createComponent();
}
@Test
public void testPSComponentNumPSIsOne() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setPsResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setNumPS(1);
parameters.setPsDockerImage("testPSDockerImage");
TensorFlowPsComponent psComponent =
createPsComponent(parameters);
Component component = psComponent.createComponent();
assertEquals(1L, (long) component.getNumberOfContainers());
verifyCommons(component);
}
@Test
public void testPSComponentNumPSIsTwo() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setPsResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setNumPS(2);
parameters.setPsDockerImage("testPSDockerImage");
TensorFlowPsComponent psComponent =
createPsComponent(parameters);
Component component = psComponent.createComponent();
assertEquals(2L, (long) component.getNumberOfContainers());
verifyCommons(component);
}
}

View File

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Artifact;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import java.util.Map;
import static junit.framework.TestCase.assertTrue;
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
/**
* This class is to test {@link TensorFlowWorkerComponent}.
*/
public class TestTensorFlowWorkerComponent {
@Rule
public ExpectedException expectedException = ExpectedException.none();
private ComponentTestCommons testCommons =
new ComponentTestCommons(TaskType.TENSORBOARD);
@Before
public void setUp() throws IOException {
testCommons.setup();
}
private TensorFlowWorkerComponent createWorkerComponent(
RunJobParameters parameters) {
return new TensorFlowWorkerComponent(
testCommons.fsOperations,
testCommons.mockClientContext.getRemoteDirectoryManager(),
parameters, testCommons.taskType,
testCommons.mockLaunchCommandFactory,
testCommons.yarnConfig);
}
private void verifyCommons(Component component) throws IOException {
verifyCommonsInternal(component, ImmutableMap.of());
}
private void verifyCommons(Component component,
Map<String, String> expectedProperties) throws IOException {
verifyCommonsInternal(component, expectedProperties);
}
private void verifyCommonsInternal(Component component,
Map<String, String> expectedProperties) throws IOException {
assertEquals(testCommons.taskType.getComponentName(), component.getName());
testCommons.verifyCommonConfigEnvs(component);
Map<String, String> actualProperties =
component.getConfiguration().getProperties();
if (!expectedProperties.isEmpty()) {
assertFalse(actualProperties.isEmpty());
expectedProperties.forEach(
(k, v) -> assertEquals(v, actualProperties.get(k)));
} else {
assertTrue(actualProperties.isEmpty());
}
assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy());
testCommons.verifyResources(component);
assertEquals(
new Artifact().type(Artifact.TypeEnum.DOCKER)
.id("testWorkerDockerImage"),
component.getArtifact());
String taskTypeUppercase = testCommons.taskType.name().toUpperCase();
String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase);
assertEquals(String.format("./%s", expectedScriptName),
component.getLaunchCommand());
verify(testCommons.fsOperations)
.uploadToRemoteFileAndLocalizeToContainerWorkDir(
any(Path.class), eq("mockScript"), eq(expectedScriptName),
eq(component));
}
@Test
public void testWorkerComponentWithNullResource() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(null);
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("Worker resource must not be null");
workerComponent.createComponent();
}
@Test
public void testWorkerComponentWithNullJobName() throws IOException {
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(testCommons.resource);
parameters.setNumWorkers(1);
parameters.setName(null);
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
expectedException.expect(NullPointerException.class);
expectedException.expectMessage("Job name must not be null");
workerComponent.createComponent();
}
@Test
public void testNormalWorkerComponentZeroNumberOfWorkers()
throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setWorkerDockerImage("testWorkerDockerImage");
parameters.setNumWorkers(0);
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Number of workers should be at least 1!");
workerComponent.createComponent();
}
@Test
public void testNormalWorkerComponentNumWorkersIsOne() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setNumWorkers(1);
parameters.setWorkerDockerImage("testWorkerDockerImage");
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
Component component = workerComponent.createComponent();
assertEquals(0L, (long) component.getNumberOfContainers());
verifyCommons(component);
}
@Test
public void testNormalWorkerComponentNumWorkersIsTwo() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setNumWorkers(2);
parameters.setWorkerDockerImage("testWorkerDockerImage");
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
Component component = workerComponent.createComponent();
assertEquals(1L, (long) component.getNumberOfContainers());
verifyCommons(component);
}
@Test
public void testPrimaryWorkerComponentNumWorkersIsTwo() throws IOException {
testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain");
testCommons = new ComponentTestCommons(TaskType.PRIMARY_WORKER);
testCommons.setup();
RunJobParameters parameters = new RunJobParameters();
parameters.setWorkerResource(testCommons.resource);
parameters.setName("testJobName");
parameters.setNumWorkers(2);
parameters.setWorkerDockerImage("testWorkerDockerImage");
TensorFlowWorkerComponent workerComponent =
createWorkerComponent(parameters);
Component component = workerComponent.createComponent();
assertEquals(1L, (long) component.getNumberOfContainers());
verifyCommons(component, ImmutableMap.of(
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true"));
}
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
/**
* This class is to test {@link ClassPathUtilities}.
*/
public class TestClassPathUtilities {
private static final String CLASSPATH_KEY = "java.class.path";
private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
private static String originalClasspath;
@BeforeClass
public static void setUpClass() {
originalClasspath = System.getProperty(CLASSPATH_KEY);
}
@Before
public void setUp() {
fileUtils.setup();
}
@After
public void teardown() throws IOException {
fileUtils.teardown();
System.setProperty(CLASSPATH_KEY, originalClasspath);
}
private static void addFileToClasspath(File file) {
String newClasspath = originalClasspath + ":" + file.getAbsolutePath();
System.setProperty(CLASSPATH_KEY, newClasspath);
}
@Test
public void findFileNotInClasspath() {
File resultFile = ClassPathUtilities.findFileOnClassPath("bla");
assertNull(resultFile);
}
@Test
public void findFileOnClasspath() throws Exception {
File testFile = fileUtils.createFileInTempDir("testFile");
addFileToClasspath(testFile);
File resultFile = ClassPathUtilities.findFileOnClassPath("testFile");
assertNotNull(resultFile);
assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath());
}
@Test
public void findDirectoryOnClasspath() throws Exception {
File testDir = fileUtils.createDirInTempDir("testDir");
File testFile = fileUtils.createFileInDir(testDir, "testFile");
addFileToClasspath(testDir);
File resultFile = ClassPathUtilities.findFileOnClassPath("testFile");
assertNotNull(resultFile);
assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath());
}
}

View File

@ -0,0 +1,231 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import org.apache.hadoop.yarn.service.api.records.Configuration;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.junit.Test;
import java.util.Map;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* This class is to test {@link EnvironmentUtilities}.
*/
public class TestEnvironmentUtilities {
private Service createServiceWithEmptyEnvVars() {
return createServiceWithEnvVars(Maps.newHashMap());
}
private Service createServiceWithEnvVars(Map<String, String> envVars) {
Service service = mock(Service.class);
Configuration config = mock(Configuration.class);
when(config.getEnv()).thenReturn(envVars);
when(service.getConfiguration()).thenReturn(config);
return service;
}
private void validateDefaultEnvVars(Map<String, String> resultEnvs) {
assertEquals("/etc/passwd:/etc/passwd:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
}
private org.apache.hadoop.conf.Configuration
createYarnConfigWithSecurityValue(String value) {
org.apache.hadoop.conf.Configuration mockConfig =
mock(org.apache.hadoop.conf.Configuration.class);
when(mockConfig.get(HADOOP_SECURITY_AUTHENTICATION)).thenReturn(value);
return mockConfig;
}
@Test
public void testGetValueOfNullEnvVar() {
assertEquals("", EnvironmentUtilities.getValueOfEnvironment(null));
}
@Test
public void testGetValueOfEmptyEnvVar() {
assertEquals("", EnvironmentUtilities.getValueOfEnvironment(""));
}
@Test
public void testGetValueOfEnvVarJustAnEqualsSign() {
assertEquals("", EnvironmentUtilities.getValueOfEnvironment("="));
}
@Test
public void testGetValueOfEnvVarWithoutValue() {
assertEquals("", EnvironmentUtilities.getValueOfEnvironment("a="));
}
@Test
public void testGetValueOfEnvVarValidFormat() {
assertEquals("bbb", EnvironmentUtilities.getValueOfEnvironment("a=bbb"));
}
@Test
public void testHandleServiceEnvWithNullMap() {
Service service = createServiceWithEmptyEnvVars();
org.apache.hadoop.conf.Configuration yarnConfig =
mock(org.apache.hadoop.conf.Configuration.class);
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(1, resultEnvs.size());
validateDefaultEnvVars(resultEnvs);
}
@Test
public void testHandleServiceEnvWithEmptyMap() {
Service service = createServiceWithEmptyEnvVars();
org.apache.hadoop.conf.Configuration yarnConfig =
mock(org.apache.hadoop.conf.Configuration.class);
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(1, resultEnvs.size());
validateDefaultEnvVars(resultEnvs);
}
@Test
public void testHandleServiceEnvWithYarnConfigSecurityValueNonKerberos() {
Service service = createServiceWithEmptyEnvVars();
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("nonkerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(1, resultEnvs.size());
validateDefaultEnvVars(resultEnvs);
}
@Test
public void testHandleServiceEnvWithYarnConfigSecurityValueKerberos() {
Service service = createServiceWithEmptyEnvVars();
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("kerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(1, resultEnvs.size());
assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
}
@Test
public void testHandleServiceEnvWithExistingEnvsAndValidNewEnvs() {
Map<String, String> existingEnvs = Maps.newHashMap(
ImmutableMap.<String, String>builder().
put("a", "1").
put("b", "2").
build());
ImmutableList<String> newEnvs = ImmutableList.of("c=3", "d=4");
Service service = createServiceWithEnvVars(existingEnvs);
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("kerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(5, resultEnvs.size());
assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
assertEquals("1", resultEnvs.get("a"));
assertEquals("2", resultEnvs.get("b"));
assertEquals("3", resultEnvs.get("c"));
assertEquals("4", resultEnvs.get("d"));
}
@Test
public void testHandleServiceEnvWithExistingEnvsAndNewEnvsWithoutEquals() {
Map<String, String> existingEnvs = Maps.newHashMap(
ImmutableMap.<String, String>builder().
put("a", "1").
put("b", "2").
build());
ImmutableList<String> newEnvs = ImmutableList.of("c3", "d4");
Service service = createServiceWithEnvVars(existingEnvs);
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("kerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(5, resultEnvs.size());
assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
assertEquals("1", resultEnvs.get("a"));
assertEquals("2", resultEnvs.get("b"));
assertEquals("", resultEnvs.get("c3"));
assertEquals("", resultEnvs.get("d4"));
}
@Test
public void testHandleServiceEnvWithExistingEnvVarKey() {
Map<String, String> existingEnvs = Maps.newHashMap(
ImmutableMap.<String, String>builder().
put("a", "1").
put("b", "2").
build());
ImmutableList<String> newEnvs = ImmutableList.of("a=33", "c=44");
Service service = createServiceWithEnvVars(existingEnvs);
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("kerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(4, resultEnvs.size());
assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
assertEquals("1:33", resultEnvs.get("a"));
assertEquals("2", resultEnvs.get("b"));
assertEquals("44", resultEnvs.get("c"));
}
@Test
public void testHandleServiceEnvWithExistingEnvVarKeyMultipleTimes() {
Map<String, String> existingEnvs = Maps.newHashMap(
ImmutableMap.<String, String>builder().
put("a", "1").
put("b", "2").
build());
ImmutableList<String> newEnvs = ImmutableList.of("a=33", "a=44");
Service service = createServiceWithEnvVars(existingEnvs);
org.apache.hadoop.conf.Configuration yarnConfig =
createYarnConfigWithSecurityValue("kerberos");
EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs);
Map<String, String> resultEnvs = service.getConfiguration().getEnv();
assertEquals(3, resultEnvs.size());
assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro",
resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME));
assertEquals("1:33:44", resultEnvs.get("a"));
assertEquals("2", resultEnvs.get("b"));
}
}

View File

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* This class is to test {@link KerberosPrincipalFactory}.
*/
public class TestKerberosPrincipalFactory {
private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests();
@Before
public void setUp() {
fileUtils.setup();
}
@After
public void teardown() throws IOException {
fileUtils.teardown();
}
private File createKeytabFile(String keytabFileName) throws IOException {
return fileUtils.createFileInTempDir(keytabFileName);
}
@Test
public void testCreatePrincipalEmptyPrincipalAndKeytab() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
RunJobParameters parameters = mock(RunJobParameters.class);
when(parameters.getPrincipal()).thenReturn("");
when(parameters.getKeytab()).thenReturn("");
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
KerberosPrincipal result =
KerberosPrincipalFactory.create(fsOperations,
mockClientContext.getRemoteDirectoryManager(), parameters);
assertNull(result);
}
@Test
public void testCreatePrincipalEmptyPrincipalString() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
RunJobParameters parameters = mock(RunJobParameters.class);
when(parameters.getPrincipal()).thenReturn("");
when(parameters.getKeytab()).thenReturn("keytab");
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
KerberosPrincipal result =
KerberosPrincipalFactory.create(fsOperations,
mockClientContext.getRemoteDirectoryManager(), parameters);
assertNull(result);
}
@Test
public void testCreatePrincipalEmptyKeyTabString() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
RunJobParameters parameters = mock(RunJobParameters.class);
when(parameters.getPrincipal()).thenReturn("principal");
when(parameters.getKeytab()).thenReturn("");
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
KerberosPrincipal result =
KerberosPrincipalFactory.create(fsOperations,
mockClientContext.getRemoteDirectoryManager(), parameters);
assertNull(result);
}
@Test
public void testCreatePrincipalNonEmptyPrincipalAndKeytab()
throws IOException {
MockClientContext mockClientContext = new MockClientContext();
RunJobParameters parameters = mock(RunJobParameters.class);
when(parameters.getPrincipal()).thenReturn("principal");
when(parameters.getKeytab()).thenReturn("keytab");
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
KerberosPrincipal result =
KerberosPrincipalFactory.create(fsOperations,
mockClientContext.getRemoteDirectoryManager(), parameters);
assertNotNull(result);
assertEquals("file://keytab", result.getKeytab());
assertEquals("principal", result.getPrincipalName());
}
@Test
public void testCreatePrincipalDistributedKeytab() throws IOException {
MockClientContext mockClientContext = new MockClientContext();
String jobname = "testJobname";
String keytab = "testKeytab";
File keytabFile = createKeytabFile(keytab);
RunJobParameters parameters = mock(RunJobParameters.class);
when(parameters.getPrincipal()).thenReturn("principal");
when(parameters.getKeytab()).thenReturn(keytabFile.getAbsolutePath());
when(parameters.getName()).thenReturn(jobname);
when(parameters.isDistributeKeytab()).thenReturn(true);
FileSystemOperations fsOperations =
new FileSystemOperations(mockClientContext);
KerberosPrincipal result =
KerberosPrincipalFactory.create(fsOperations,
mockClientContext.getRemoteDirectoryManager(), parameters);
Path stagingDir = mockClientContext.getRemoteDirectoryManager()
.getJobStagingArea(parameters.getName(), true);
String expectedKeytabFilePath =
FileUtilitiesForTests.getFilename(stagingDir, keytab).getAbsolutePath();
assertNotNull(result);
assertEquals(expectedKeytabFilePath, result.getKeytab());
assertEquals("principal", result.getPrincipalName());
}
}

View File

@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.utils;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
import org.apache.hadoop.yarn.util.resource.CustomResourceTypesConfigurationProvider;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
import org.junit.Test;
import java.util.Map;
import static org.junit.Assert.*;
/**
* This class is to test {@link SubmarineResourceUtils}.
*/
public class TestSubmarineResourceUtils {
private static final String CUSTOM_RESOURCE_NAME = "a-custom-resource";
private void initResourceTypes() {
CustomResourceTypesConfigurationProvider.initResourceTypes(
ImmutableMap.<String, String>builder()
.put(CUSTOM_RESOURCE_NAME, "G")
.build());
}
@After
public void cleanup() {
ResourceUtils.resetResourceTypes(new Configuration());
}
@Test
public void testConvertResourceWithCustomResource() {
initResourceTypes();
Resource res = Resource.newInstance(4096, 12,
ImmutableMap.of(CUSTOM_RESOURCE_NAME, 20L));
org.apache.hadoop.yarn.service.api.records.Resource serviceResource =
SubmarineResourceUtils.convertYarnResourceToServiceResource(res);
assertEquals(12, serviceResource.getCpus().intValue());
assertEquals(4096, (int) Integer.valueOf(serviceResource.getMemory()));
Map<String, ResourceInformation> additionalResources =
serviceResource.getAdditional();
// Additional resources also includes vcores and memory
assertEquals(3, additionalResources.size());
ResourceInformation customResourceRI =
additionalResources.get(CUSTOM_RESOURCE_NAME);
assertEquals("G", customResourceRI.getUnit());
assertEquals(20L, (long) customResourceRI.getValue());
}
}