From 565cdb8edf2b1ee8c60019a95940c2987c02e710 Mon Sep 17 00:00:00 2001 From: Greg Price Date: Fri, 8 May 2015 12:40:05 -0400 Subject: [PATCH 1/2] Move Discussion API access control checks The checks are now within the Python API instead of the DRF view. This will be necessary for certain operations (like fetching/editing threads) because the relevant course cannot be known until the thread is fetched from the comments service. This commit updates the existing endpoints to fit that pattern. --- lms/djangoapps/discussion_api/api.py | 24 +++- .../discussion_api/tests/test_api.py | 118 ++++++++++++------ .../discussion_api/tests/test_views.py | 33 +---- lms/djangoapps/discussion_api/views.py | 20 +-- 4 files changed, 108 insertions(+), 87 deletions(-) diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 8ce8405f5e..09fad100d9 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -5,6 +5,7 @@ from django.http import Http404 from collections import defaultdict +from courseware.courses import get_course_with_access from discussion_api.pagination import get_paginated_data from django_comment_client.utils import get_accessible_discussion_modules from django_comment_common.models import ( @@ -16,15 +17,28 @@ from django_comment_common.models import ( from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names +from xmodule.tabs import DiscussionTab -def get_course_topics(course, user): +def _get_course_or_404(course_key, user): + """ + Get the course descriptor, raising Http404 if the course is not found, + the user cannot access forums for the course, or the discussion tab is + disabled for the course. + """ + course = get_course_with_access(user, 'load_forum', course_key) + if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): + raise Http404 + return course + + +def get_course_topics(course_key, user): """ Return the course topic listing for the given course and user. Parameters: - course: The course to get topics for + course_key: The key of the course to get topics for user: The requesting user, for access control Returns: @@ -39,6 +53,7 @@ def get_course_topics(course, user): """ return module.sort_key or module.discussion_target + course = _get_course_or_404(course_key, user) discussion_modules = get_accessible_discussion_modules(course, user) modules_by_category = defaultdict(list) for module in discussion_modules: @@ -138,14 +153,14 @@ def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group return ret -def get_thread_list(request, course, page, page_size): +def get_thread_list(request, course_key, page, page_size): """ Return the list of all discussion threads pertaining to the given course Parameters: request: The django request objects used for build_absolute_uri - course: The course to get discussion threads for + course_key: The key of the course to get discussion threads for page: The page number (1-indexed) to retrieve page_size: The number of threads to retrieve per page @@ -154,6 +169,7 @@ def get_thread_list(request, course, page, page_size): A paginated result containing a list of threads; see discussion_api.views.ThreadViewSet for more detail. """ + course = _get_course_or_404(course_key, request.user) user_is_privileged = Role.objects.filter( course_id=course.id, name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA], diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index c92766c6e2..ccd164c266 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -12,6 +12,8 @@ from pytz import UTC from django.http import Http404 from django.test.client import RequestFactory +from opaque_keys.edx.locator import CourseLocator + from courseware.tests.factories import BetaTesterFactory, StaffFactory from discussion_api.api import get_course_topics, get_thread_list from discussion_api.tests.utils import CommentsServiceMockMixin @@ -24,10 +26,22 @@ from django_comment_common.models import ( ) from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory -from student.tests.factories import UserFactory +from student.tests.factories import CourseEnrollmentFactory, UserFactory +from xmodule.modulestore.django import modulestore from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.partitions.partitions import Group, UserPartition +from xmodule.tabs import DiscussionTab + + +def _remove_discussion_tab(course, user_id): + """ + Remove the discussion tab for the course. + + user_id is passed to the modulestore as the editor of the module. + """ + course.tabs = [tab for tab in course.tabs if not isinstance(tab, DiscussionTab)] + modulestore().update_item(course, user_id) @mock.patch.dict("django.conf.settings.FEATURES", {"DISABLE_START_DATES": False}) @@ -49,12 +63,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase): course="y", run="z", start=datetime.now(UTC), - discussion_topics={}, + discussion_topics={"Test Topic": {"id": "non-courseware-topic-id"}}, user_partitions=[self.partition], cohort_config={"cohorted": True}, days_early_for_beta=3 ) self.user = UserFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) def make_discussion_module(self, topic_id, category, subcategory, **kwargs): """Build a discussion module in self.course""" @@ -72,7 +87,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase): Get course topics for self.course, using the given user or self.user if not provided, and generating absolute URIs with a test scheme/host. """ - return get_course_topics(self.course, user or self.user) + return get_course_topics(self.course.id, user or self.user) def make_expected_tree(self, topic_id, name, children=None): """ @@ -87,50 +102,58 @@ class GetCourseTopicsTest(ModuleStoreTestCase): } return node - def test_empty(self): + def test_nonexistent_course(self): + with self.assertRaises(Http404): + get_course_topics(CourseLocator.from_string("non/existent/course"), self.user) + + def test_not_enrolled(self): + unenrolled_user = UserFactory.create() + with self.assertRaises(Http404): + get_course_topics(self.course.id, unenrolled_user) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + with self.assertRaises(Http404): + self.get_course_topics() + + def test_without_courseware(self): actual = self.get_course_topics() expected = { "courseware_topics": [], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic") + ], } self.assertEqual(actual, expected) - def test_non_courseware(self): - self.course.discussion_topics = {"Topic Name": {"id": "topic-id"}} - self.course.save() - actual = self.get_course_topics() - expected = { - "courseware_topics": [], - "non_courseware_topics": [self.make_expected_tree("topic-id", "Topic Name")], - } - self.assertEqual(actual, expected) - - def test_courseware(self): - self.make_discussion_module("topic-id", "Foo", "Bar") + def test_with_courseware(self): + self.make_discussion_module("courseware-topic-id", "Foo", "Bar") actual = self.get_course_topics() expected = { "courseware_topics": [ self.make_expected_tree( None, "Foo", - [self.make_expected_tree("topic-id", "Bar")] + [self.make_expected_tree("courseware-topic-id", "Bar")] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic") + ], } self.assertEqual(actual, expected) def test_many(self): + self.course.discussion_topics = { + "A": {"id": "non-courseware-1"}, + "B": {"id": "non-courseware-2"}, + } + modulestore().update_item(self.course, self.user.id) self.make_discussion_module("courseware-1", "A", "1") self.make_discussion_module("courseware-2", "A", "2") self.make_discussion_module("courseware-3", "B", "1") self.make_discussion_module("courseware-4", "B", "2") self.make_discussion_module("courseware-5", "C", "1") - self.course.discussion_topics = { - "A": {"id": "non-courseware-1"}, - "B": {"id": "non-courseware-2"}, - } - self.course.save() actual = self.get_course_topics() expected = { "courseware_topics": [ @@ -164,6 +187,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase): self.assertEqual(actual, expected) def test_sort_key(self): + self.course.discussion_topics = { + "W": {"id": "non-courseware-1", "sort_key": "Z"}, + "X": {"id": "non-courseware-2"}, + "Y": {"id": "non-courseware-3", "sort_key": "Y"}, + "Z": {"id": "non-courseware-4", "sort_key": "W"}, + } + modulestore().update_item(self.course, self.user.id) self.make_discussion_module("courseware-1", "First", "A", sort_key="D") self.make_discussion_module("courseware-2", "First", "B", sort_key="B") self.make_discussion_module("courseware-3", "First", "C", sort_key="E") @@ -171,13 +201,6 @@ class GetCourseTopicsTest(ModuleStoreTestCase): self.make_discussion_module("courseware-5", "Second", "B", sort_key="G") self.make_discussion_module("courseware-6", "Second", "C") self.make_discussion_module("courseware-7", "Second", "D", sort_key="A") - self.course.discussion_topics = { - "W": {"id": "non-courseware-1", "sort_key": "Z"}, - "X": {"id": "non-courseware-2"}, - "Y": {"id": "non-courseware-3", "sort_key": "Y"}, - "Z": {"id": "non-courseware-4", "sort_key": "W"}, - } - self.course.save() actual = self.get_course_topics() expected = { "courseware_topics": [ @@ -223,6 +246,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase): subcategories does not appear in the result. """ beta_tester = BetaTesterFactory.create(course_key=self.course.id) + CourseEnrollmentFactory.create(user=beta_tester, course_id=self.course.id) staff = StaffFactory.create(course_key=self.course.id) for user, group_idx in [(self.user, 0), (beta_tester, 1)]: cohort = CohortFactory.create( @@ -269,7 +293,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(student_actual, student_expected) @@ -290,7 +316,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): [self.make_expected_tree("courseware-5", "Future Start Date")] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(beta_actual, beta_expected) @@ -315,7 +343,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ] ), ], - "non_courseware_topics": [], + "non_courseware_topics": [ + self.make_expected_tree("non-courseware-topic-id", "Test Topic"), + ], } self.assertEqual(staff_actual, staff_expected) @@ -334,6 +364,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): self.request = RequestFactory().get("/test_path") self.request.user = self.user self.course = CourseFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) self.author = UserFactory.create() self.cohort = CohortFactory.create(course_id=self.course.id) @@ -344,7 +375,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): """ course = course or self.course self.register_get_threads_response(threads, page, num_pages) - ret = get_thread_list(self.request, course, page, page_size) + ret = get_thread_list(self.request, course.id, page, page_size) return ret def create_role(self, role_name, users): @@ -382,6 +413,20 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ret.update(thread_data) return ret + def test_nonexistent_course(self): + with self.assertRaises(Http404): + get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1) + + def test_not_enrolled(self): + self.request.user = UserFactory.create() + with self.assertRaises(Http404): + self.get_thread_list([]) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + with self.assertRaises(Http404): + self.get_thread_list([]) + def test_empty(self): self.assertEqual( self.get_thread_list([]), @@ -565,6 +610,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): @ddt.unpack def test_request_group(self, role_name, course_is_cohorted): cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) CohortFactory.create(course_id=cohort_course.id, users=[self.user]) role = Role.objects.create(name=role_name, course_id=cohort_course.id) role.users = [self.user] @@ -603,7 +649,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): # Test page past the last one self.register_get_threads_response([], page=3, num_pages=3) with self.assertRaises(Http404): - get_thread_list(self.request, self.course, page=4, page_size=10) + get_thread_list(self.request, self.course.id, page=4, page_size=10) @ddt.data( (FORUM_ROLE_ADMINISTRATOR, True, False, True), diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index d6eee0d221..a9c4e1a99d 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -42,11 +42,6 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) self.client.login(username=self.user.username, password=self.password) - def login_unenrolled_user(self): - """Create a user not enrolled in the course and log it in""" - unenrolled_user = UserFactory.create(password=self.password) - self.client.login(username=unenrolled_user.username, password=self.password) - def assert_response_correct(self, response, expected_status, expected_content): """ Assert that the response has the given status code and parsed content @@ -71,7 +66,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): super(CourseTopicsViewTest, self).setUp() self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)}) - def test_non_existent_course(self): + def test_404(self): response = self.client.get( reverse("course_topics", kwargs={"course_id": "non/existent/course"}) ) @@ -81,26 +76,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): {"developer_message": "Not found."} ) - def test_not_enrolled(self): - self.login_unenrolled_user() - response = self.client.get(self.url) - self.assert_response_correct( - response, - 404, - {"developer_message": "Not found."} - ) - - def test_discussions_disabled(self): - self.course.tabs = [tab for tab in self.course.tabs if not isinstance(tab, DiscussionTab)] - modulestore().update_item(self.course, self.user.id) - response = self.client.get(self.url) - self.assert_response_correct( - response, - 404, - {"developer_message": "Not found."} - ) - - def test_get(self): + def test_get_success(self): response = self.client.get(self.url) self.assert_response_correct( response, @@ -132,9 +108,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): {"field_errors": {"course_id": "This field is required."}} ) - def test_not_enrolled(self): - self.login_unenrolled_user() - response = self.client.get(self.url, {"course_id": unicode(self.course.id)}) + def test_404(self): + response = self.client.get(self.url, {"course_id": unicode("non/existent/course")}) self.assert_response_correct( response, 404, diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index b17b7e0caa..cc42f5b4ec 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -2,7 +2,6 @@ Discussion API views """ from django.core.exceptions import ValidationError -from django.http import Http404 from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.permissions import IsAuthenticated @@ -12,11 +11,9 @@ from rest_framework.viewsets import ViewSet from opaque_keys.edx.locator import CourseLocator -from courseware.courses import get_course_with_access from discussion_api.api import get_course_topics, get_thread_list from discussion_api.forms import ThreadListGetForm from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin -from xmodule.tabs import DiscussionTab class _ViewMixin(object): @@ -27,17 +24,6 @@ class _ViewMixin(object): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (IsAuthenticated,) - def get_course_or_404(self, user, course_key): - """ - Get the course descriptor, raising Http404 if the course is not found, - the user cannot access forums for the course, or the discussion tab is - disabled for the course. - """ - course = get_course_with_access(user, 'load_forum', course_key) - if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): - raise Http404 - return course - class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): """ @@ -68,8 +54,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): def get(self, request, course_id): """Implements the GET method as described in the class docstring.""" course_key = CourseLocator.from_string(course_id) - course = self.get_course_or_404(request.user, course_key) - return Response(get_course_topics(course, request.user)) + return Response(get_course_topics(course_key, request.user)) class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): @@ -133,11 +118,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): form = ThreadListGetForm(request.GET) if not form.is_valid(): raise ValidationError(form.errors) - course = self.get_course_or_404(request.user, form.cleaned_data["course_id"]) return Response( get_thread_list( request, - course, + form.cleaned_data["course_id"], form.cleaned_data["page"], form.cleaned_data["page_size"] ) From 7309352ef739dc79b1f3a1a99b1e68bca2cdb4e3 Mon Sep 17 00:00:00 2001 From: Greg Price Date: Thu, 14 May 2015 17:50:10 -0400 Subject: [PATCH 2/2] Refactor discussion API to use DRF serializer This will make it easier to add the creation and update interfaces. --- lms/djangoapps/discussion_api/api.py | 104 +-------- lms/djangoapps/discussion_api/serializers.py | 134 +++++++++++ .../discussion_api/tests/test_api.py | 147 +----------- .../discussion_api/tests/test_serializers.py | 209 ++++++++++++++++++ 4 files changed, 353 insertions(+), 241 deletions(-) create mode 100644 lms/djangoapps/discussion_api/serializers.py create mode 100644 lms/djangoapps/discussion_api/tests/test_serializers.py diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 09fad100d9..b36ab441b2 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -7,16 +7,10 @@ from collections import defaultdict from courseware.courses import get_course_with_access from discussion_api.pagination import get_paginated_data +from discussion_api.serializers import ThreadSerializer, get_context from django_comment_client.utils import get_accessible_discussion_modules -from django_comment_common.models import ( - FORUM_ROLE_ADMINISTRATOR, - FORUM_ROLE_COMMUNITY_TA, - FORUM_ROLE_MODERATOR, - Role, -) from lms.lib.comment_client.thread import Thread -from lms.lib.comment_client.user import User -from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names +from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id from xmodule.tabs import DiscussionTab @@ -92,67 +86,6 @@ def get_course_topics(course_key, user): } -def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names): - """ - Convert a thread data dict from the comment_client format (which is a direct - representation of the format returned by the comments service) to the format - used in this API - - Arguments: - thread (comment_client.thread.Thread): The thread to convert - cc_user (comment_client.user.User): The comment_client representation of - the requesting user - staff_user_ids (set): The set of user ids for users with the Moderator or - Administrator role in the course - ta_user_ids (set): The set of user ids for users with the Community TA - role in the course - group_ids_to_names (dict): A mapping of group ids to names - - Returns: - dict: The discussion_api format representation of the thread. - """ - is_anonymous = ( - thread["anonymous"] or - ( - thread["anonymous_to_peers"] and - int(cc_user["id"]) not in (staff_user_ids | ta_user_ids) - ) - ) - ret = { - key: thread[key] - for key in [ - "id", - "course_id", - "group_id", - "created_at", - "updated_at", - "title", - "pinned", - "closed", - ] - } - ret.update({ - "topic_id": thread["commentable_id"], - "group_name": group_ids_to_names.get(thread["group_id"]), - "author": None if is_anonymous else thread["username"], - "author_label": ( - None if is_anonymous else - "staff" if int(thread["user_id"]) in staff_user_ids else - "community_ta" if int(thread["user_id"]) in ta_user_ids else - None - ), - "type": thread["thread_type"], - "raw_body": thread["body"], - "following": thread["id"] in cc_user["subscribed_thread_ids"], - "abuse_flagged": cc_user["id"] in thread["abuse_flaggers"], - "voted": thread["id"] in cc_user["upvoted_ids"], - "vote_count": thread["votes"]["up_count"], - "comment_count": thread["comments_count"], - "unread_comment_count": thread["unread_comments_count"], - }) - return ret - - def get_thread_list(request, course_key, page, page_size): """ Return the list of all discussion threads pertaining to the given course @@ -170,15 +103,13 @@ def get_thread_list(request, course_key, page, page_size): discussion_api.views.ThreadViewSet for more detail. """ course = _get_course_or_404(course_key, request.user) - user_is_privileged = Role.objects.filter( - course_id=course.id, - name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA], - users=request.user - ).exists() - cc_user = User.from_django_user(request.user).retrieve() + context = get_context(course, request.user) threads, result_page, num_pages, _ = Thread.search({ "course_id": unicode(course.id), - "group_id": None if user_is_privileged else get_cohort_id(request.user, course.id), + "group_id": ( + None if context["is_requester_privileged"] else + get_cohort_id(request.user, course.id) + ), "sort_key": "date", "sort_order": "desc", "page": page, @@ -189,25 +120,6 @@ def get_thread_list(request, course_key, page, page_size): # behavior and return a 404 in that case if result_page != page: raise Http404 - # TODO: cache staff_user_ids and ta_user_ids if we need to improve perf - staff_user_ids = { - user.id - for role in Role.objects.filter( - name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR], - course_id=course.id - ) - for user in role.users.all() - } - ta_user_ids = { - user.id - for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id) - for user in role.users.all() - } - # For now, the only groups are cohorts - group_ids_to_names = get_cohort_names(course) - results = [ - _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names) - for thread in threads - ] + results = [ThreadSerializer(thread, context=context).data for thread in threads] return get_paginated_data(request, results, page, num_pages) diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py new file mode 100644 index 0000000000..727ea12b2e --- /dev/null +++ b/lms/djangoapps/discussion_api/serializers.py @@ -0,0 +1,134 @@ +""" +Discussion API serializers +""" +from rest_framework import serializers + +from django_comment_common.models import ( + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_MODERATOR, + Role, +) +from lms.lib.comment_client.user import User +from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names + + +def get_context(course, requester): + """Returns a context appropriate for use with ThreadSerializer.""" + # TODO: cache staff_user_ids and ta_user_ids if we need to improve perf + staff_user_ids = { + user.id + for role in Role.objects.filter( + name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR], + course_id=course.id + ) + for user in role.users.all() + } + ta_user_ids = { + user.id + for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id) + for user in role.users.all() + } + return { + # For now, the only groups are cohorts + "group_ids_to_names": get_cohort_names(course), + "is_requester_privileged": requester.id in staff_user_ids or requester.id in ta_user_ids, + "staff_user_ids": staff_user_ids, + "ta_user_ids": ta_user_ids, + "cc_requester": User.from_django_user(requester).retrieve(), + } + + +class ThreadSerializer(serializers.Serializer): + """ + A serializer for thread data. + + N.B. This should not be used with a comment_client Thread object that has + not had retrieve() called, because of the interaction between DRF's attempts + at introspection and Thread's __getattr__. + """ + id_ = serializers.CharField(read_only=True) + course_id = serializers.CharField() + topic_id = serializers.CharField(source="commentable_id") + group_id = serializers.IntegerField() + group_name = serializers.SerializerMethodField("get_group_name") + author = serializers.SerializerMethodField("get_author") + author_label = serializers.SerializerMethodField("get_author_label") + created_at = serializers.CharField(read_only=True) + updated_at = serializers.CharField(read_only=True) + type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question")) + title = serializers.CharField() + raw_body = serializers.CharField(source="body") + pinned = serializers.BooleanField() + closed = serializers.BooleanField() + following = serializers.SerializerMethodField("get_following") + abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") + voted = serializers.SerializerMethodField("get_voted") + vote_count = serializers.SerializerMethodField("get_vote_count") + comment_count = serializers.IntegerField(source="comments_count") + unread_comment_count = serializers.IntegerField(source="unread_comments_count") + + def __init__(self, *args, **kwargs): + super(ThreadSerializer, self).__init__(*args, **kwargs) + # type and id are invalid class attribute names, so we must declare + # different names above and modify them here + self.fields["id"] = self.fields.pop("id_") + self.fields["type"] = self.fields.pop("type_") + + def get_group_name(self, obj): + """Returns the name of the group identified by the thread's group_id.""" + return self.context["group_ids_to_names"].get(obj["group_id"]) + + def _is_anonymous(self, obj): + """ + Returns a boolean indicating whether the thread should be anonymous to + the requester. + """ + return ( + obj["anonymous"] or + obj["anonymous_to_peers"] and not self.context["is_requester_privileged"] + ) + + def get_author(self, obj): + """Returns the author's username, or None if the thread is anonymous.""" + return None if self._is_anonymous(obj) else obj["username"] + + def _get_user_label(self, user_id): + """ + Returns the role label (i.e. "staff" or "community_ta") for the user + with the given id. + """ + return ( + "staff" if user_id in self.context["staff_user_ids"] else + "community_ta" if user_id in self.context["ta_user_ids"] else + None + ) + + def get_author_label(self, obj): + """Returns the role label for the thread author.""" + return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"])) + + def get_following(self, obj): + """ + Returns a boolean indicating whether the requester is following the + thread. + """ + return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"] + + def get_abuse_flagged(self, obj): + """ + Returns a boolean indicating whether the requester has flagged the + thread as abusive. + """ + return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] + + def get_voted(self, obj): + """ + Returns a boolean indicating whether the requester has voted for the + thread. + """ + return obj["id"] in self.context["cc_requester"]["upvoted_ids"] + + def get_vote_count(self, obj): + """Returns the number of votes for the thread.""" + return obj["votes"]["up_count"] diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index ccd164c266..604f7c2391 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -384,35 +384,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): role.users = users role.save() - def make_cs_thread(self, thread_data): - """ - Create a dictionary containing all needed thread fields as returned by - the comments service with dummy data overridden by thread_data - """ - ret = { - "id": "dummy", - "course_id": unicode(self.course.id), - "commentable_id": "dummy", - "group_id": None, - "user_id": str(self.author.id), - "username": self.author.username, - "anonymous": False, - "anonymous_to_peers": False, - "created_at": "1970-01-01T00:00:00Z", - "updated_at": "1970-01-01T00:00:00Z", - "thread_type": "discussion", - "title": "dummy", - "body": "dummy", - "pinned": False, - "closed": False, - "abuse_flaggers": [], - "votes": {"up_count": 0}, - "comments_count": 0, - "unread_comments_count": 0, - } - ret.update(thread_data) - return ret - def test_nonexistent_course(self): with self.assertRaises(Http404): get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1) @@ -449,11 +420,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): }) def test_thread_content(self): - self.register_get_user_response( - self.user, - subscribed_thread_ids=["test_thread_id_0"], - upvoted_ids=["test_thread_id_1"] - ) source_threads = [ { "id": "test_thread_id_0", @@ -497,27 +463,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): "comments_count": 18, "unread_comments_count": 0, }, - { - "id": "test_thread_id_2", - "course_id": unicode(self.course.id), - "commentable_id": "topic_x", - "group_id": self.cohort.id + 1, # non-existent group - "user_id": str(self.author.id), - "username": self.author.username, - "anonymous": False, - "anonymous_to_peers": False, - "created_at": "2015-04-28T00:44:44Z", - "updated_at": "2015-04-28T00:55:55Z", - "thread_type": "discussion", - "title": "Yet Another Test Title", - "body": "Still more content", - "pinned": True, - "closed": False, - "abuse_flaggers": [str(self.user.id)], - "votes": {"up_count": 0}, - "comments_count": 0, - "unread_comments_count": 0, - }, ] expected_threads = [ { @@ -535,7 +480,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): "raw_body": "Test body", "pinned": False, "closed": False, - "following": True, + "following": False, "abuse_flagged": False, "voted": False, "vote_count": 4, @@ -559,33 +504,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): "closed": True, "following": False, "abuse_flagged": False, - "voted": True, + "voted": False, "vote_count": 9, "comment_count": 18, "unread_comment_count": 0, }, - { - "id": "test_thread_id_2", - "course_id": unicode(self.course.id), - "topic_id": "topic_x", - "group_id": self.cohort.id + 1, - "group_name": None, - "author": self.author.username, - "author_label": None, - "created_at": "2015-04-28T00:44:44Z", - "updated_at": "2015-04-28T00:55:55Z", - "type": "discussion", - "title": "Yet Another Test Title", - "raw_body": "Still more content", - "pinned": True, - "closed": False, - "following": False, - "abuse_flagged": True, - "voted": False, - "vote_count": 0, - "comment_count": 0, - "unread_comment_count": 0, - }, ] self.assertEqual( self.get_thread_list(source_threads), @@ -650,69 +573,3 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): self.register_get_threads_response([], page=3, num_pages=3) with self.assertRaises(Http404): get_thread_list(self.request, self.course.id, page=4, page_size=10) - - @ddt.data( - (FORUM_ROLE_ADMINISTRATOR, True, False, True), - (FORUM_ROLE_ADMINISTRATOR, False, True, False), - (FORUM_ROLE_MODERATOR, True, False, True), - (FORUM_ROLE_MODERATOR, False, True, False), - (FORUM_ROLE_COMMUNITY_TA, True, False, True), - (FORUM_ROLE_COMMUNITY_TA, False, True, False), - (FORUM_ROLE_STUDENT, True, False, True), - (FORUM_ROLE_STUDENT, False, True, True), - ) - @ddt.unpack - def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_api_anonymous): - """ - Test that a thread is properly made anonymous. - - A thread should be anonymous iff the anonymous field is true or the - anonymous_to_peers field is true and the requester does not have a - privileged role. - - role_name is the name of the requester's role. - thread_anon is the value of the anonymous field in the thread data. - thread_anon_to_peers is the value of the anonymous_to_peers field in the - thread data. - expected_api_anonymous is whether the thread should actually be - anonymous in the API output when requested by a user with the given - role. - """ - self.create_role(role_name, [self.user]) - result = self.get_thread_list([ - self.make_cs_thread({ - "anonymous": anonymous, - "anonymous_to_peers": anonymous_to_peers, - }) - ]) - actual_api_anonymous = result["results"][0]["author"] is None - self.assertEqual(actual_api_anonymous, expected_api_anonymous) - - @ddt.data( - (FORUM_ROLE_ADMINISTRATOR, False, "staff"), - (FORUM_ROLE_ADMINISTRATOR, True, None), - (FORUM_ROLE_MODERATOR, False, "staff"), - (FORUM_ROLE_MODERATOR, True, None), - (FORUM_ROLE_COMMUNITY_TA, False, "community_ta"), - (FORUM_ROLE_COMMUNITY_TA, True, None), - (FORUM_ROLE_STUDENT, False, None), - (FORUM_ROLE_STUDENT, True, None), - ) - @ddt.unpack - def test_author_labels(self, role_name, anonymous, expected_label): - """ - Test correctness of the author_label field. - - The label should be "staff", "staff", or "community_ta" for the - Administrator, Moderator, and Community TA roles, respectively, but - the label should not be present if the thread is anonymous. - - role_name is the name of the author's role. - anonymous is the value of the anonymous field in the thread data. - expected_label is the expected value of the author_label field in the - API output. - """ - self.create_role(role_name, [self.author]) - result = self.get_thread_list([self.make_cs_thread({"anonymous": anonymous})]) - actual_label = result["results"][0]["author_label"] - self.assertEqual(actual_label, expected_label) diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py new file mode 100644 index 0000000000..92ee95c6ad --- /dev/null +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -0,0 +1,209 @@ +""" +Tests for Discussion API serializers +""" +import ddt +import httpretty + +from discussion_api.serializers import ThreadSerializer, get_context +from discussion_api.tests.utils import CommentsServiceMockMixin +from django_comment_common.models import ( + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_STUDENT, + Role, +) +from student.tests.factories import UserFactory +from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory +from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory + + +@ddt.ddt +class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): + """Tests for ThreadSerializer.""" + def setUp(self): + super(ThreadSerializerTest, self).setUp() + httpretty.reset() + httpretty.enable() + self.addCleanup(httpretty.disable) + self.maxDiff = None # pylint: disable=invalid-name + self.user = UserFactory.create() + self.register_get_user_response(self.user) + self.course = CourseFactory.create() + self.author = UserFactory.create() + + def create_role(self, role_name, users, course=None): + """Create a Role in self.course with the given name and users""" + course = course or self.course + role = Role.objects.create(name=role_name, course_id=course.id) + role.users = users + + def make_cs_thread(self, thread_data=None): + """ + Create a dictionary containing all needed thread fields as returned by + the comments service with dummy data overridden by thread_data + """ + ret = { + "id": "dummy", + "course_id": unicode(self.course.id), + "commentable_id": "dummy", + "group_id": None, + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "1970-01-01T00:00:00Z", + "updated_at": "1970-01-01T00:00:00Z", + "thread_type": "discussion", + "title": "dummy", + "body": "dummy", + "pinned": False, + "closed": False, + "abuse_flaggers": [], + "votes": {"up_count": 0}, + "comments_count": 0, + "unread_comments_count": 0, + "children": [], + "resp_total": 0, + } + if thread_data: + ret.update(thread_data) + return ret + + def serialize(self, thread): + """ + Create a serializer with an appropriate context and use it to serialize + the given thread, returning the result. + """ + return ThreadSerializer(thread, context=get_context(self.course, self.user)).data + + def test_basic(self): + thread = { + "id": "test_thread", + "course_id": unicode(self.course.id), + "commentable_id": "test_topic", + "group_id": None, + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "thread_type": "discussion", + "title": "Test Title", + "body": "Test body", + "pinned": True, + "closed": False, + "abuse_flaggers": [], + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + } + expected = { + "id": "test_thread", + "course_id": unicode(self.course.id), + "topic_id": "test_topic", + "group_id": None, + "group_name": None, + "author": self.author.username, + "author_label": None, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "type": "discussion", + "title": "Test Title", + "raw_body": "Test body", + "pinned": True, + "closed": False, + "following": False, + "abuse_flagged": False, + "voted": False, + "vote_count": 4, + "comment_count": 5, + "unread_comment_count": 3, + } + self.assertEqual(self.serialize(thread), expected) + + def test_group(self): + cohort = CohortFactory.create(course_id=self.course.id) + serialized = self.serialize(self.make_cs_thread({"group_id": cohort.id})) + self.assertEqual(serialized["group_id"], cohort.id) + self.assertEqual(serialized["group_name"], cohort.name) + + @ddt.data( + (FORUM_ROLE_ADMINISTRATOR, True, False, True), + (FORUM_ROLE_ADMINISTRATOR, False, True, False), + (FORUM_ROLE_MODERATOR, True, False, True), + (FORUM_ROLE_MODERATOR, False, True, False), + (FORUM_ROLE_COMMUNITY_TA, True, False, True), + (FORUM_ROLE_COMMUNITY_TA, False, True, False), + (FORUM_ROLE_STUDENT, True, False, True), + (FORUM_ROLE_STUDENT, False, True, True), + ) + @ddt.unpack + def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous): + """ + Test that content is properly made anonymous. + + Content should be anonymous iff the anonymous field is true or the + anonymous_to_peers field is true and the requester does not have a + privileged role. + + role_name is the name of the requester's role. + anonymous is the value of the anonymous field in the content. + anonymous_to_peers is the value of the anonymous_to_peers field in the + content. + expected_serialized_anonymous is whether the content should actually be + anonymous in the API output when requested by a user with the given + role. + """ + self.create_role(role_name, [self.user]) + serialized = self.serialize( + self.make_cs_thread({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers}) + ) + actual_serialized_anonymous = serialized["author"] is None + self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous) + + @ddt.data( + (FORUM_ROLE_ADMINISTRATOR, False, "staff"), + (FORUM_ROLE_ADMINISTRATOR, True, None), + (FORUM_ROLE_MODERATOR, False, "staff"), + (FORUM_ROLE_MODERATOR, True, None), + (FORUM_ROLE_COMMUNITY_TA, False, "community_ta"), + (FORUM_ROLE_COMMUNITY_TA, True, None), + (FORUM_ROLE_STUDENT, False, None), + (FORUM_ROLE_STUDENT, True, None), + ) + @ddt.unpack + def test_author_labels(self, role_name, anonymous, expected_label): + """ + Test correctness of the author_label field. + + The label should be "staff", "staff", or "community_ta" for the + Administrator, Moderator, and Community TA roles, respectively, but + the label should not be present if the thread is anonymous. + + role_name is the name of the author's role. + anonymous is the value of the anonymous field in the content. + expected_label is the expected value of the author_label field in the + API output. + """ + self.create_role(role_name, [self.author]) + serialized = self.serialize(self.make_cs_thread({"anonymous": anonymous})) + self.assertEqual(serialized["author_label"], expected_label) + + def test_following(self): + thread_id = "test_thread" + self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id]) + serialized = self.serialize(self.make_cs_thread({"id": thread_id})) + self.assertEqual(serialized["following"], True) + + def test_abuse_flagged(self): + serialized = self.serialize(self.make_cs_thread({"abuse_flaggers": [str(self.user.id)]})) + self.assertEqual(serialized["abuse_flagged"], True) + + def test_voted(self): + thread_id = "test_thread" + self.register_get_user_response(self.user, upvoted_ids=[thread_id]) + serialized = self.serialize(self.make_cs_thread({"id": thread_id})) + self.assertEqual(serialized["voted"], True)