diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/main/java/org/apache/hadoop/mapreduce/v2/hs/webapp/HsWebServices.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/main/java/org/apache/hadoop/mapreduce/v2/hs/webapp/HsWebServices.java index b5ebdb3b9b..c95ce3e557 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/main/java/org/apache/hadoop/mapreduce/v2/hs/webapp/HsWebServices.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/main/java/org/apache/hadoop/mapreduce/v2/hs/webapp/HsWebServices.java @@ -20,17 +20,21 @@ import java.io.IOException; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.UriInfo; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.JobACL; import org.apache.hadoop.mapreduce.v2.api.records.AMInfo; import org.apache.hadoop.mapreduce.v2.api.records.JobState; import org.apache.hadoop.mapreduce.v2.api.records.TaskId; @@ -55,11 +59,13 @@ import org.apache.hadoop.mapreduce.v2.hs.webapp.dao.JobInfo; import org.apache.hadoop.mapreduce.v2.hs.webapp.dao.JobsInfo; import org.apache.hadoop.mapreduce.v2.util.MRApps; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; import org.apache.hadoop.yarn.webapp.BadRequestException; import org.apache.hadoop.yarn.webapp.NotFoundException; import org.apache.hadoop.yarn.webapp.WebApp; +import com.google.common.annotations.VisibleForTesting; import com.google.inject.Inject; @Path("/ws/v1/history") @@ -78,11 +84,31 @@ public HsWebServices(final HistoryContext ctx, final Configuration conf, this.webapp = webapp; } + private boolean hasAccess(Job job, HttpServletRequest request) { + String remoteUser = request.getRemoteUser(); + if (remoteUser != null) { + return job.checkAccess(UserGroupInformation.createRemoteUser(remoteUser), + JobACL.VIEW_JOB); + } + return true; + } + + private void checkAccess(Job job, HttpServletRequest request) { + if (!hasAccess(job, request)) { + throw new WebApplicationException(Status.UNAUTHORIZED); + } + } + private void init() { //clear content type response.setContentType(null); } + @VisibleForTesting + void setResponse(HttpServletResponse response) { + this.response = response; + } + @GET @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) public HistoryInfo get() { @@ -190,10 +216,12 @@ public JobsInfo getJobs(@QueryParam("user") String userQuery, @GET @Path("/mapreduce/jobs/{jobid}") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public JobInfo getJob(@PathParam("jobid") String jid) { + public JobInfo getJob(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); return new JobInfo(job); } @@ -217,20 +245,24 @@ public AMAttemptsInfo getJobAttempts(@PathParam("jobid") String jid) { @GET @Path("/mapreduce/jobs/{jobid}/counters") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public JobCounterInfo getJobCounters(@PathParam("jobid") String jid) { + public JobCounterInfo getJobCounters(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); return new JobCounterInfo(this.ctx, job); } @GET @Path("/mapreduce/jobs/{jobid}/conf") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public ConfInfo getJobConf(@PathParam("jobid") String jid) { + public ConfInfo getJobConf(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); ConfInfo info; try { info = new ConfInfo(job); @@ -244,11 +276,12 @@ public ConfInfo getJobConf(@PathParam("jobid") String jid) { @GET @Path("/mapreduce/jobs/{jobid}/tasks") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public TasksInfo getJobTasks(@PathParam("jobid") String jid, - @QueryParam("type") String type) { + public TasksInfo getJobTasks(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid, @QueryParam("type") String type) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); TasksInfo allTasks = new TasksInfo(); for (Task task : job.getTasks().values()) { TaskType ttype = null; @@ -270,11 +303,12 @@ public TasksInfo getJobTasks(@PathParam("jobid") String jid, @GET @Path("/mapreduce/jobs/{jobid}/tasks/{taskid}") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public TaskInfo getJobTask(@PathParam("jobid") String jid, - @PathParam("taskid") String tid) { + public TaskInfo getJobTask(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid, @PathParam("taskid") String tid) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); Task task = AMWebServices.getTaskFromTaskIdString(tid, job); return new TaskInfo(task); @@ -284,10 +318,12 @@ public TaskInfo getJobTask(@PathParam("jobid") String jid, @Path("/mapreduce/jobs/{jobid}/tasks/{taskid}/counters") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) public JobTaskCounterInfo getSingleTaskCounters( - @PathParam("jobid") String jid, @PathParam("taskid") String tid) { + @Context HttpServletRequest hsr, @PathParam("jobid") String jid, + @PathParam("taskid") String tid) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); TaskId taskID = MRApps.toTaskID(tid); if (taskID == null) { throw new NotFoundException("taskid " + tid + " not found or invalid"); @@ -302,12 +338,13 @@ public JobTaskCounterInfo getSingleTaskCounters( @GET @Path("/mapreduce/jobs/{jobid}/tasks/{taskid}/attempts") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public TaskAttemptsInfo getJobTaskAttempts(@PathParam("jobid") String jid, - @PathParam("taskid") String tid) { + public TaskAttemptsInfo getJobTaskAttempts(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid, @PathParam("taskid") String tid) { init(); TaskAttemptsInfo attempts = new TaskAttemptsInfo(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); Task task = AMWebServices.getTaskFromTaskIdString(tid, job); for (TaskAttempt ta : task.getAttempts().values()) { if (ta != null) { @@ -324,11 +361,13 @@ public TaskAttemptsInfo getJobTaskAttempts(@PathParam("jobid") String jid, @GET @Path("/mapreduce/jobs/{jobid}/tasks/{taskid}/attempts/{attemptid}") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) - public TaskAttemptInfo getJobTaskAttemptId(@PathParam("jobid") String jid, - @PathParam("taskid") String tid, @PathParam("attemptid") String attId) { + public TaskAttemptInfo getJobTaskAttemptId(@Context HttpServletRequest hsr, + @PathParam("jobid") String jid, @PathParam("taskid") String tid, + @PathParam("attemptid") String attId) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); Task task = AMWebServices.getTaskFromTaskIdString(tid, job); TaskAttempt ta = AMWebServices.getTaskAttemptFromTaskAttemptString(attId, task); @@ -343,11 +382,12 @@ public TaskAttemptInfo getJobTaskAttemptId(@PathParam("jobid") String jid, @Path("/mapreduce/jobs/{jobid}/tasks/{taskid}/attempts/{attemptid}/counters") @Produces({ MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML }) public JobTaskAttemptCounterInfo getJobTaskAttemptIdCounters( - @PathParam("jobid") String jid, @PathParam("taskid") String tid, - @PathParam("attemptid") String attId) { + @Context HttpServletRequest hsr, @PathParam("jobid") String jid, + @PathParam("taskid") String tid, @PathParam("attemptid") String attId) { init(); Job job = AMWebServices.getJobFromJobIdString(jid, ctx); + checkAccess(job, hsr); Task task = AMWebServices.getTaskFromTaskIdString(tid, job); TaskAttempt ta = AMWebServices.getTaskAttemptFromTaskAttemptString(attId, task); diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/test/java/org/apache/hadoop/mapreduce/v2/hs/webapp/TestHsWebServicesAcls.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/test/java/org/apache/hadoop/mapreduce/v2/hs/webapp/TestHsWebServicesAcls.java new file mode 100644 index 0000000000..63f7fb0789 --- /dev/null +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-hs/src/test/java/org/apache/hadoop/mapreduce/v2/hs/webapp/TestHsWebServicesAcls.java @@ -0,0 +1,419 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapreduce.v2.hs.webapp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response.Status; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CommonConfigurationKeys; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapred.JobACLsManager; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.TaskCompletionEvent; +import org.apache.hadoop.mapreduce.Counters; +import org.apache.hadoop.mapreduce.JobACL; +import org.apache.hadoop.mapreduce.MRConfig; +import org.apache.hadoop.mapreduce.v2.api.records.AMInfo; +import org.apache.hadoop.mapreduce.v2.api.records.JobId; +import org.apache.hadoop.mapreduce.v2.api.records.JobReport; +import org.apache.hadoop.mapreduce.v2.api.records.JobState; +import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptCompletionEvent; +import org.apache.hadoop.mapreduce.v2.api.records.TaskId; +import org.apache.hadoop.mapreduce.v2.api.records.TaskType; +import org.apache.hadoop.mapreduce.v2.app.job.Job; +import org.apache.hadoop.mapreduce.v2.app.job.Task; +import org.apache.hadoop.mapreduce.v2.hs.HistoryContext; +import org.apache.hadoop.mapreduce.v2.hs.MockHistoryContext; +import org.apache.hadoop.security.GroupMappingServiceProvider; +import org.apache.hadoop.security.Groups; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.authorize.AccessControlList; +import org.apache.hadoop.yarn.webapp.WebApp; +import org.junit.Before; +import org.junit.Test; + +public class TestHsWebServicesAcls { + private static String FRIENDLY_USER = "friendly"; + private static String ENEMY_USER = "enemy"; + + private JobConf conf; + private HistoryContext ctx; + private String jobIdStr; + private String taskIdStr; + private String taskAttemptIdStr; + private HsWebServices hsWebServices; + + @Before + public void setup() throws IOException { + this.conf = new JobConf(); + this.conf.set(CommonConfigurationKeys.HADOOP_SECURITY_GROUP_MAPPING, + NullGroupsProvider.class.getName()); + this.conf.setBoolean(MRConfig.MR_ACLS_ENABLED, true); + Groups.getUserToGroupsMappingService(conf); + this.ctx = buildHistoryContext(this.conf); + WebApp webApp = mock(HsWebApp.class); + when(webApp.name()).thenReturn("hsmockwebapp"); + this.hsWebServices= new HsWebServices(ctx, conf, webApp); + this.hsWebServices.setResponse(mock(HttpServletResponse.class)); + + Job job = ctx.getAllJobs().values().iterator().next(); + this.jobIdStr = job.getID().toString(); + Task task = job.getTasks().values().iterator().next(); + this.taskIdStr = task.getID().toString(); + this.taskAttemptIdStr = + task.getAttempts().keySet().iterator().next().toString(); + } + + @Test + public void testGetJobAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJob(hsr, jobIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJob(hsr, jobIdStr); + } + + @Test + public void testGetJobCountersAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobCounters(hsr, jobIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobCounters(hsr, jobIdStr); + } + + @Test + public void testGetJobConfAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobConf(hsr, jobIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobConf(hsr, jobIdStr); + } + + @Test + public void testGetJobTasksAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobTasks(hsr, jobIdStr, "m"); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobTasks(hsr, jobIdStr, "m"); + } + + @Test + public void testGetJobTaskAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobTask(hsr, jobIdStr, this.taskIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobTask(hsr, this.jobIdStr, this.taskIdStr); + } + + @Test + public void testGetSingleTaskCountersAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getSingleTaskCounters(hsr, this.jobIdStr, this.taskIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getSingleTaskCounters(hsr, this.jobIdStr, this.taskIdStr); + } + + @Test + public void testGetJobTaskAttemptsAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobTaskAttempts(hsr, this.jobIdStr, this.taskIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobTaskAttempts(hsr, this.jobIdStr, this.taskIdStr); + } + + @Test + public void testGetJobTaskAttemptIdAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobTaskAttemptId(hsr, this.jobIdStr, this.taskIdStr, + this.taskAttemptIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobTaskAttemptId(hsr, this.jobIdStr, this.taskIdStr, + this.taskAttemptIdStr); + } + + @Test + public void testGetJobTaskAttemptIdCountersAcls() { + HttpServletRequest hsr = mock(HttpServletRequest.class); + when(hsr.getRemoteUser()).thenReturn(ENEMY_USER); + + try { + hsWebServices.getJobTaskAttemptIdCounters(hsr, this.jobIdStr, + this.taskIdStr, this.taskAttemptIdStr); + fail("enemy can access job"); + } catch (WebApplicationException e) { + assertEquals(Status.UNAUTHORIZED, + Status.fromStatusCode(e.getResponse().getStatus())); + } + + when(hsr.getRemoteUser()).thenReturn(FRIENDLY_USER); + hsWebServices.getJobTaskAttemptIdCounters(hsr, this.jobIdStr, + this.taskIdStr, this.taskAttemptIdStr); + } + + private static HistoryContext buildHistoryContext(final Configuration conf) + throws IOException { + HistoryContext ctx = new MockHistoryContext(1, 1, 1); + Map jobs = ctx.getAllJobs(); + JobId jobId = jobs.keySet().iterator().next(); + Job mockJob = new MockJobForAcls(jobs.get(jobId), conf); + jobs.put(jobId, mockJob); + return ctx; + } + + private static class NullGroupsProvider + implements GroupMappingServiceProvider { + @Override + public List getGroups(String user) throws IOException { + return Collections.emptyList(); + } + + @Override + public void cacheGroupsRefresh() throws IOException { + } + + @Override + public void cacheGroupsAdd(List groups) throws IOException { + } + } + + private static class MockJobForAcls implements Job { + private Job mockJob; + private Configuration conf; + private Map jobAcls; + private JobACLsManager aclsMgr; + + public MockJobForAcls(Job mockJob, Configuration conf) { + this.mockJob = mockJob; + this.conf = conf; + AccessControlList viewAcl = new AccessControlList(FRIENDLY_USER); + this.jobAcls = new HashMap(); + this.jobAcls.put(JobACL.VIEW_JOB, viewAcl); + this.aclsMgr = new JobACLsManager(conf); + } + + @Override + public JobId getID() { + return mockJob.getID(); + } + + @Override + public String getName() { + return mockJob.getName(); + } + + @Override + public JobState getState() { + return mockJob.getState(); + } + + @Override + public JobReport getReport() { + return mockJob.getReport(); + } + + @Override + public Counters getAllCounters() { + return mockJob.getAllCounters(); + } + + @Override + public Map getTasks() { + return mockJob.getTasks(); + } + + @Override + public Map getTasks(TaskType taskType) { + return mockJob.getTasks(taskType); + } + + @Override + public Task getTask(TaskId taskID) { + return mockJob.getTask(taskID); + } + + @Override + public List getDiagnostics() { + return mockJob.getDiagnostics(); + } + + @Override + public int getTotalMaps() { + return mockJob.getTotalMaps(); + } + + @Override + public int getTotalReduces() { + return mockJob.getTotalReduces(); + } + + @Override + public int getCompletedMaps() { + return mockJob.getCompletedMaps(); + } + + @Override + public int getCompletedReduces() { + return mockJob.getCompletedReduces(); + } + + @Override + public float getProgress() { + return mockJob.getProgress(); + } + + @Override + public boolean isUber() { + return mockJob.isUber(); + } + + @Override + public String getUserName() { + return mockJob.getUserName(); + } + + @Override + public String getQueueName() { + return mockJob.getQueueName(); + } + + @Override + public Path getConfFile() { + return new Path("/some/path/to/conf"); + } + + @Override + public Configuration loadConfFile() throws IOException { + return conf; + } + + @Override + public Map getJobACLs() { + return jobAcls; + } + + @Override + public TaskAttemptCompletionEvent[] getTaskAttemptCompletionEvents( + int fromEventId, int maxEvents) { + return mockJob.getTaskAttemptCompletionEvents(fromEventId, maxEvents); + } + + @Override + public TaskCompletionEvent[] getMapAttemptCompletionEvents( + int startIndex, int maxEvents) { + return mockJob.getMapAttemptCompletionEvents(startIndex, maxEvents); + } + + @Override + public List getAMInfos() { + return mockJob.getAMInfos(); + } + + @Override + public boolean checkAccess(UserGroupInformation callerUGI, + JobACL jobOperation) { + return aclsMgr.checkAccess(callerUGI, jobOperation, + this.getUserName(), jobAcls.get(jobOperation)); + } + } +}