From 565cdb8edf2b1ee8c60019a95940c2987c02e710 Mon Sep 17 00:00:00 2001 From: Greg Price Date: Fri, 8 May 2015 12:40:05 -0400 Subject: [PATCH] Move Discussion API access control checks The checks are now within the Python API instead of the DRF view. This will be necessary for certain operations (like fetching/editing threads) because the relevant course cannot be known until the thread is fetched from the comments service. This commit updates the existing endpoints to fit that pattern. --- lms/djangoapps/discussion_api/api.py | 24 +++- .../discussion_api/tests/test_api.py | 118 ++++++++++++------ .../discussion_api/tests/test_views.py | 33 +---- lms/djangoapps/discussion_api/views.py | 20 +-- 4 files changed, 108 insertions(+), 87 deletions(-) diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 8ce8405f5e..09fad100d9 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -5,6 +5,7 @@ from django.http import Http404 from collections import defaultdict +from courseware.courses import get_course_with_access from discussion_api.pagination import get_paginated_data from django_comment_client.utils import get_accessible_discussion_modules from django_comment_common.models import ( @@ -16,15 +17,28 @@ from django_comment_common.models import ( from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names +from xmodule.tabs import DiscussionTab -def get_course_topics(course, user): +def _get_course_or_404(course_key, user): + """ + Get the course descriptor, raising Http404 if the course is not found, + the user cannot access forums for the course, or the discussion tab is + disabled for the course. + """ + course = get_course_with_access(user, 'load_forum', course_key) + if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): + raise Http404 + return course + + +def get_course_topics(course_key, user): """ Return the course topic listing for the given course and user. Parameters: - course: The course to get topics for + course_key: The key of the course to get topics for user: The requesting user, for access control Returns: @@ -39,6 +53,7 @@ def get_course_topics(course, user): """ return module.sort_key or module.discussion_target + course = _get_course_or_404(course_key, user) discussion_modules = get_accessible_discussion_modules(course, user) modules_by_category = defaultdict(list) for module in discussion_modules: @@ -138,14 +153,14 @@ def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group return ret -def get_thread_list(request, course, page, page_size): +def get_thread_list(request, course_key, page, page_size): """ Return the list of all discussion threads pertaining to the given course Parameters: request: The django request objects used for build_absolute_uri - course: The course to get discussion threads for + course_key: The key of the course to get discussion threads for page: The page number (1-indexed) to retrieve page_size: The number of threads to retrieve per page @@ -154,6 +169,7 @@ def get_thread_list(request, course, page, page_size): A paginated result containing a list of threads; see discussion_api.views.ThreadViewSet for more detail. """ + course = _get_course_or_404(course_key, request.user) user_is_privileged = Role.objects.filter( course_id=course.id, name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA], diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index c92766c6e2..ccd164c266 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -12,6 +12,8 @@ from pytz import UTC from django.http import Http404 from django.test.client import RequestFactory +from opaque_keys.edx.locator import CourseLocator + from courseware.tests.factories import BetaTesterFactory, StaffFactory from discussion_api.api import get_course_topics, get_thread_list from discussion_api.tests.utils import CommentsServiceMockMixin @@ -24,10 +26,22 @@ from django_comment_common.models import ( ) from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory -from student.tests.factories import UserFactory +from student.tests.factories import CourseEnrollmentFactory, UserFactory +from xmodule.modulestore.django import modulestore from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.partitions.partitions import Group, UserPartition +from xmodule.tabs import DiscussionTab + + +def _remove_discussion_tab(course, user_id): + """ + Remove the discussion tab for the course. + + user_id is passed to the modulestore as the editor of the module. + """ + course.tabs = [tab for tab in course.tabs if not isinstance(tab, DiscussionTab)] + modulestore().update_item(course, user_id) @mock.patch.dict("django.conf.settings.FEATURES", {"DISABLE_START_DATES": False}) @@ -49,12 +63,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase): course="y", run="z", start=datetime.now(UTC), - discussion_topics={}, + discussion_topics={"Test Topic": {"id": "non-courseware-topic-id"}}, user_partitions=[self.partition], cohort_config={"cohorted": True}, days_early_for_beta=3 ) self.user = UserFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) def make_discussion_module(self, topic_id, category, subcategory, **kwargs): """Build a discussion module in self.course""" @@ -72,7 +87,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase): Get course topics for self.course, using the given user or self.user if not provided, and generating absolute URIs with a test scheme/host. """ - return get_course_topics(self.course, user or self.user) + return get_course_topics(self.course.id, user or self.user) def make_expected_tree(self, topic_id, name, children=None): """ @@ -87,50 +102,58 @@ class GetCourseTopicsTest(ModuleStoreTestCase): } return node - def test_empty(self): + def test_nonexistent_course(self): + with self.assertRaises(Http404): + get_course_topics(CourseLocator.from_string("non/existent/course"), self.user) + + def test_not_enrolled(self): + unenrolled_user = UserFactory.create() + with self.assertRaises(Http404): + get_course_topics(self.course.id, unenrolled_user) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + with self.assertRaises(Http404): + self.get_course_topics() + + def test_without_courseware(self): actual = self.get_course_topics() expected = { "courseware_topics": [], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic") + ], } self.assertEqual(actual, expected) - def test_non_courseware(self): - self.course.discussion_topics = {"Topic Name": {"id": "topic-id"}} - self.course.save() - actual = self.get_course_topics() - expected = { - "courseware_topics": [], - "non_courseware_topics": [self.make_expected_tree("topic-id", "Topic Name")], - } - self.assertEqual(actual, expected) - - def test_courseware(self): - self.make_discussion_module("topic-id", "Foo", "Bar") + def test_with_courseware(self): + self.make_discussion_module("courseware-topic-id", "Foo", "Bar") actual = self.get_course_topics() expected = { "courseware_topics": [ self.make_expected_tree( None, "Foo", - [self.make_expected_tree("topic-id", "Bar")] + [self.make_expected_tree("courseware-topic-id", "Bar")] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic") + ], } self.assertEqual(actual, expected) def test_many(self): + self.course.discussion_topics = { + "A": {"id": "non-courseware-1"}, + "B": {"id": "non-courseware-2"}, + } + modulestore().update_item(self.course, self.user.id) self.make_discussion_module("courseware-1", "A", "1") self.make_discussion_module("courseware-2", "A", "2") self.make_discussion_module("courseware-3", "B", "1") self.make_discussion_module("courseware-4", "B", "2") self.make_discussion_module("courseware-5", "C", "1") - self.course.discussion_topics = { - "A": {"id": "non-courseware-1"}, - "B": {"id": "non-courseware-2"}, - } - self.course.save() actual = self.get_course_topics() expected = { "courseware_topics": [ @@ -164,6 +187,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase): self.assertEqual(actual, expected) def test_sort_key(self): + self.course.discussion_topics = { + "W": {"id": "non-courseware-1", "sort_key": "Z"}, + "X": {"id": "non-courseware-2"}, + "Y": {"id": "non-courseware-3", "sort_key": "Y"}, + "Z": {"id": "non-courseware-4", "sort_key": "W"}, + } + modulestore().update_item(self.course, self.user.id) self.make_discussion_module("courseware-1", "First", "A", sort_key="D") self.make_discussion_module("courseware-2", "First", "B", sort_key="B") self.make_discussion_module("courseware-3", "First", "C", sort_key="E") @@ -171,13 +201,6 @@ class GetCourseTopicsTest(ModuleStoreTestCase): self.make_discussion_module("courseware-5", "Second", "B", sort_key="G") self.make_discussion_module("courseware-6", "Second", "C") self.make_discussion_module("courseware-7", "Second", "D", sort_key="A") - self.course.discussion_topics = { - "W": {"id": "non-courseware-1", "sort_key": "Z"}, - "X": {"id": "non-courseware-2"}, - "Y": {"id": "non-courseware-3", "sort_key": "Y"}, - "Z": {"id": "non-courseware-4", "sort_key": "W"}, - } - self.course.save() actual = self.get_course_topics() expected = { "courseware_topics": [ @@ -223,6 +246,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase): subcategories does not appear in the result. """ beta_tester = BetaTesterFactory.create(course_key=self.course.id) + CourseEnrollmentFactory.create(user=beta_tester, course_id=self.course.id) staff = StaffFactory.create(course_key=self.course.id) for user, group_idx in [(self.user, 0), (beta_tester, 1)]: cohort = CohortFactory.create( @@ -269,7 +293,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(student_actual, student_expected) @@ -290,7 +316,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): [self.make_expected_tree("courseware-5", "Future Start Date")] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(beta_actual, beta_expected) @@ -315,7 +343,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(staff_actual, staff_expected) @@ -334,6 +364,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): self.request = RequestFactory().get("/test_path") self.request.user = self.user self.course = CourseFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) self.author = UserFactory.create() self.cohort = CohortFactory.create(course_id=self.course.id) @@ -344,7 +375,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): """ course = course or self.course self.register_get_threads_response(threads, page, num_pages) - ret = get_thread_list(self.request, course, page, page_size) + ret = get_thread_list(self.request, course.id, page, page_size) return ret def create_role(self, role_name, users): @@ -382,6 +413,20 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ret.update(thread_data) return ret + def test_nonexistent_course(self): + with self.assertRaises(Http404): + get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1) + + def test_not_enrolled(self): + self.request.user = UserFactory.create() + with self.assertRaises(Http404): + self.get_thread_list([]) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + with self.assertRaises(Http404): + self.get_thread_list([]) + def test_empty(self): self.assertEqual( self.get_thread_list([]), @@ -565,6 +610,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): @ddt.unpack def test_request_group(self, role_name, course_is_cohorted): cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) CohortFactory.create(course_id=cohort_course.id, users=[self.user]) role = Role.objects.create(name=role_name, course_id=cohort_course.id) role.users = [self.user] @@ -603,7 +649,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): # Test page past the last one self.register_get_threads_response([], page=3, num_pages=3) with self.assertRaises(Http404): - get_thread_list(self.request, self.course, page=4, page_size=10) + get_thread_list(self.request, self.course.id, page=4, page_size=10) @ddt.data( (FORUM_ROLE_ADMINISTRATOR, True, False, True), diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index d6eee0d221..a9c4e1a99d 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -42,11 +42,6 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) self.client.login(username=self.user.username, password=self.password) - def login_unenrolled_user(self): - """Create a user not enrolled in the course and log it in""" - unenrolled_user = UserFactory.create(password=self.password) - self.client.login(username=unenrolled_user.username, password=self.password) - def assert_response_correct(self, response, expected_status, expected_content): """ Assert that the response has the given status code and parsed content @@ -71,7 +66,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): super(CourseTopicsViewTest, self).setUp() self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)}) - def test_non_existent_course(self): + def test_404(self): response = self.client.get( reverse("course_topics", kwargs={"course_id": "non/existent/course"}) ) @@ -81,26 +76,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): {"developer_message": "Not found."} ) - def test_not_enrolled(self): - self.login_unenrolled_user() - response = self.client.get(self.url) - self.assert_response_correct( - response, - 404, - {"developer_message": "Not found."} - ) - - def test_discussions_disabled(self): - self.course.tabs = [tab for tab in self.course.tabs if not isinstance(tab, DiscussionTab)] - modulestore().update_item(self.course, self.user.id) - response = self.client.get(self.url) - self.assert_response_correct( - response, - 404, - {"developer_message": "Not found."} - ) - - def test_get(self): + def test_get_success(self): response = self.client.get(self.url) self.assert_response_correct( response, @@ -132,9 +108,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): {"field_errors": {"course_id": "This field is required."}} ) - def test_not_enrolled(self): - self.login_unenrolled_user() - response = self.client.get(self.url, {"course_id": unicode(self.course.id)}) + def test_404(self): + response = self.client.get(self.url, {"course_id": unicode("non/existent/course")}) self.assert_response_correct( response, 404, diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index b17b7e0caa..cc42f5b4ec 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -2,7 +2,6 @@ Discussion API views """ from django.core.exceptions import ValidationError -from django.http import Http404 from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.permissions import IsAuthenticated @@ -12,11 +11,9 @@ from rest_framework.viewsets import ViewSet from opaque_keys.edx.locator import CourseLocator -from courseware.courses import get_course_with_access from discussion_api.api import get_course_topics, get_thread_list from discussion_api.forms import ThreadListGetForm from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin -from xmodule.tabs import DiscussionTab class _ViewMixin(object): @@ -27,17 +24,6 @@ class _ViewMixin(object): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (IsAuthenticated,) - def get_course_or_404(self, user, course_key): - """ - Get the course descriptor, raising Http404 if the course is not found, - the user cannot access forums for the course, or the discussion tab is - disabled for the course. - """ - course = get_course_with_access(user, 'load_forum', course_key) - if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): - raise Http404 - return course - class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): """ @@ -68,8 +54,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): def get(self, request, course_id): """Implements the GET method as described in the class docstring.""" course_key = CourseLocator.from_string(course_id) - course = self.get_course_or_404(request.user, course_key) - return Response(get_course_topics(course, request.user)) + return Response(get_course_topics(course_key, request.user)) class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): @@ -133,11 +118,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): form = ThreadListGetForm(request.GET) if not form.is_valid(): raise ValidationError(form.errors) - course = self.get_course_or_404(request.user, form.cleaned_data["course_id"]) return Response( get_thread_list( request, - course, + form.cleaned_data["course_id"], form.cleaned_data["page"], form.cleaned_data["page_size"] )