diff --git a/openedx/core/djangoapps/course_groups/tests/test_views.py b/openedx/core/djangoapps/course_groups/tests/test_views.py index 1032b18d68..ef73700838 100644 --- a/openedx/core/djangoapps/course_groups/tests/test_views.py +++ b/openedx/core/djangoapps/course_groups/tests/test_views.py @@ -11,6 +11,7 @@ from django.http import Http404 from django.test.client import RequestFactory from django_comment_common.models import CourseDiscussionSettings from django_comment_common.utils import get_course_discussion_settings +from lms.djangoapps.courseware.tests.factories import StaffFactory, InstructorFactory from opaque_keys.edx.locator import CourseLocator from openedx.core.lib.tests import attr from student.models import CourseEnrollment @@ -93,13 +94,15 @@ class CohortViewsTestCase(ModuleStoreTestCase): view_args.insert(0, request) self.assertRaises(Http404, view, *view_args) - def get_handler(self, course, cohort=None, expected_response_code=200, handler=cohort_handler): + def get_handler(self, course, cohort=None, expected_response_code=200, handler=cohort_handler, user=None): """ Call a GET on `handler` for a given `course` and return its response as a dict. Raise an exception if response status code is not as expected. """ request = RequestFactory().get("dummy_url") - request.user = self.staff_user + if not user: + user = self.staff_user + request.user = user if cohort: response = handler(request, unicode(course.id), cohort.id) else: @@ -236,13 +239,24 @@ class CohortHandlerTestCase(CohortViewsTestCase): """ Tests the `cohort_handler` view. """ - def verify_lists_expected_cohorts(self, expected_cohorts, response_dict=None): + def setUp(self): + super(CohortHandlerTestCase, self).setUp() + self.course_staff_user = StaffFactory( + username="coursestaff", + course_key=self.course.id + ) + self.course_instructor_user = InstructorFactory( + username='courseinstructor', + course_key=self.course.id + ) + + def verify_lists_expected_cohorts(self, expected_cohorts, response_dict=None, user=None): """ Verify that the server response contains the expected_cohorts. If response_dict is None, the list of cohorts is requested from the server. """ if response_dict is None: - response_dict = self.get_handler(self.course) + response_dict = self.get_handler(self.course, user=user) self.assertEqual( response_dict.get("cohorts"), @@ -274,10 +288,16 @@ class CohortHandlerTestCase(CohortViewsTestCase): """ Verify that we cannot access cohort_handler if we're a non-staff user. """ - self._verify_non_staff_cannot_access(cohort_handler, "GET", [unicode(self.course.id)]) self._verify_non_staff_cannot_access(cohort_handler, "POST", [unicode(self.course.id)]) self._verify_non_staff_cannot_access(cohort_handler, "PUT", [unicode(self.course.id)]) + def test_course_writers(self): + """ + Verify course staff and course instructors can access cohort_handler view + """ + self.verify_lists_expected_cohorts([], user=self.course_staff_user) + self.verify_lists_expected_cohorts([], user=self.course_instructor_user) + def test_no_cohorts(self): """ Verify that no cohorts are in response for a course with no cohorts. diff --git a/openedx/core/djangoapps/course_groups/views.py b/openedx/core/djangoapps/course_groups/views.py index 307635b31c..6930b7143f 100644 --- a/openedx/core/djangoapps/course_groups/views.py +++ b/openedx/core/djangoapps/course_groups/views.py @@ -23,13 +23,13 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from six import text_type -from courseware.courses import get_course_with_access +from courseware.courses import get_course, get_course_with_access from edxmako.shortcuts import render_to_response from util.json_request import JsonResponse, expect_json - from openedx.core.djangoapps.course_groups.models import CohortMembership from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin +from student.auth import has_course_author_access from . import api, cohorts from .models import CourseUserGroup, CourseUserGroupPartitionGroup from .serializers import CohortUsersAPISerializer @@ -165,7 +165,11 @@ def cohort_handler(request, course_key_string, cohort_id=None): cohort. """ course_key = CourseKey.from_string(course_key_string) - course = get_course_with_access(request.user, 'staff', course_key) + if not has_course_author_access(request.user, course_key): + raise Http404('The requesting user does not have course author permissions.') + + course = get_course(course_key) + if request.method == 'GET': if not cohort_id: all_cohorts = [