diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index a3a81b33ce..85c3a9fc56 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -17,7 +17,12 @@ from opaque_keys.edx.locator import CourseKey from courseware.courses import get_course_with_access from discussion_api.forms import CommentActionsForm, ThreadActionsForm from discussion_api.pagination import get_paginated_data -from discussion_api.permissions import can_delete, get_editable_fields +from discussion_api.permissions import ( + can_delete, + get_editable_fields, + get_initializable_comment_fields, + get_initializable_thread_fields, +) from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from django_comment_client.base.views import ( THREAD_CREATED_EVENT_NAME, @@ -295,7 +300,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): """ Return the list of comments in the given thread. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -361,19 +366,76 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): return get_paginated_data(request, results, page, num_pages) +def _check_fields(allowed_fields, data, message): + """ + Checks that the keys given in data is in allowed_fields + + Arguments: + allowed_fields (set): A set of allowed fields + data (dict): The data to compare the allowed_fields against + message (str): The message to return if there are any invalid fields + + Raises: + ValidationError if the given data contains a key that is not in + allowed_fields + """ + non_allowed_fields = {field: [message] for field in data.keys() if field not in allowed_fields} + if non_allowed_fields: + raise ValidationError(non_allowed_fields) + + +def _check_initializable_thread_fields(data, context): # pylint: disable=invalid-name + """ + Checks if the given data contains a thread field that is not initializable + by the requesting user + + Arguments: + data (dict): The data to compare the allowed_fields against + context (dict): The context appropriate for use with the thread which + includes the requesting user + + Raises: + ValidationError if the given data contains a thread field that is not + initializable by the requesting user + """ + _check_fields( + get_initializable_thread_fields(context), + data, + "This field is not initializable." + ) + + +def _check_initializable_comment_fields(data, context): # pylint: disable=invalid-name + """ + Checks if the given data contains a comment field that is not initializable + by the requesting user + + Arguments: + data (dict): The data to compare the allowed_fields against + context (dict): The context appropriate for use with the comment which + includes the requesting user + + Raises: + ValidationError if the given data contains a comment field that is not + initializable by the requesting user + """ + _check_fields( + get_initializable_comment_fields(context), + data, + "This field is not initializable." + ) + + def _check_editable_fields(cc_content, data, context): """ Raise ValidationError if the given update data contains a field that is not - in editable_fields. + editable by the requesting user """ - editable_fields = get_editable_fields(cc_content, context) - non_editable_errors = { - field: ["This field is not editable."] - for field in data.keys() - if field not in editable_fields - } - if non_editable_errors: - raise ValidationError(non_editable_errors) + _check_fields( + get_editable_fields(cc_content, context), + data, + "This field is not editable." + ) def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context): @@ -406,7 +468,7 @@ def create_thread(request, thread_data): """ Create a thread. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -428,6 +490,13 @@ def create_thread(request, thread_data): raise ValidationError({"course_id": ["Invalid value."]}) context = get_context(course, request) + _check_initializable_thread_fields(thread_data, context) + if ( + "group_id" not in thread_data and + is_commentable_cohorted(course_key, thread_data.get("topic_id")) + ): + thread_data = thread_data.copy() + thread_data["group_id"] = get_cohort_id(request.user, course_key) serializer = ThreadSerializer(data=thread_data, context=context) actions_form = ThreadActionsForm(thread_data) if not (serializer.is_valid() and actions_form.is_valid()): @@ -453,7 +522,7 @@ def create_comment(request, comment_data): """ Create a comment. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -473,6 +542,7 @@ def create_comment(request, comment_data): except Http404: raise ValidationError({"thread_id": ["Invalid value."]}) + _check_initializable_comment_fields(comment_data, context) serializer = CommentSerializer(data=comment_data, context=context) actions_form = CommentActionsForm(comment_data) if not (serializer.is_valid() and actions_form.is_valid()): @@ -498,7 +568,7 @@ def update_thread(request, thread_id, update_data): """ Update a thread. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -530,7 +600,7 @@ def update_comment(request, comment_id, update_data): """ Update a comment. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -573,7 +643,7 @@ def delete_thread(request, thread_id): """ Delete a thread. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. @@ -596,7 +666,7 @@ def delete_comment(request, comment_id): """ Delete a comment. - Parameters: + Arguments: request: The django request object used for build_absolute_uri and determining the requesting user. diff --git a/lms/djangoapps/discussion_api/permissions.py b/lms/djangoapps/discussion_api/permissions.py index 596e3792f4..0b77be202e 100644 --- a/lms/djangoapps/discussion_api/permissions.py +++ b/lms/djangoapps/discussion_api/permissions.py @@ -1,6 +1,8 @@ """ Discussion API permission logic """ +from lms.lib.comment_client.comment import Comment +from lms.lib.comment_client.thread import Thread def _is_author(cc_content, context): @@ -18,6 +20,38 @@ def _is_author_or_privileged(cc_content, context): return context["is_requester_privileged"] or _is_author(cc_content, context) +NON_UPDATABLE_THREAD_FIELDS = {"course_id"} +NON_UPDATABLE_COMMENT_FIELDS = {"thread_id", "parent_id"} + + +def get_initializable_thread_fields(context): + """ + Return the set of fields that the requester can initialize for a thread + + Any field that is editable by the author should also be initializable. + """ + ret = get_editable_fields( + Thread(user_id=context["cc_requester"]["id"], type="thread"), + context + ) + ret |= NON_UPDATABLE_THREAD_FIELDS + return ret + + +def get_initializable_comment_fields(context): # pylint: disable=invalid-name + """ + Return the set of fields that the requester can initialize for a comment + + Any field that is editable by the author should also be initializable. + """ + ret = get_editable_fields( + Comment(user_id=context["cc_requester"]["id"], type="comment"), + context + ) + ret |= NON_UPDATABLE_COMMENT_FIELDS + return ret + + def get_editable_fields(cc_content, context): """ Return the set of fields that the requester can edit on the given content @@ -32,6 +66,8 @@ def get_editable_fields(cc_content, context): ret |= {"following"} if _is_author_or_privileged(cc_content, context): ret |= {"topic_id", "type", "title"} + if context["is_requester_privileged"] and context["course"].is_cohorted: + ret |= {"group_id"} # Comment fields if ( diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py index 4ee207e16f..8dc68fe1de 100644 --- a/lms/djangoapps/discussion_api/serializers.py +++ b/lms/djangoapps/discussion_api/serializers.py @@ -10,7 +10,11 @@ from django.core.urlresolvers import reverse from rest_framework import serializers -from discussion_api.permissions import get_editable_fields +from discussion_api.permissions import ( + NON_UPDATABLE_COMMENT_FIELDS, + NON_UPDATABLE_THREAD_FIELDS, + get_editable_fields, +) from discussion_api.render import render_body from django_comment_client.utils import is_comment_too_deep from django_comment_common.models import ( @@ -76,7 +80,7 @@ class _ContentSerializer(serializers.Serializer): vote_count = serializers.SerializerMethodField("get_vote_count") editable_fields = serializers.SerializerMethodField("get_editable_fields") - non_updatable_fields = () + non_updatable_fields = set() def __init__(self, *args, **kwargs): super(_ContentSerializer, self).__init__(*args, **kwargs) @@ -166,7 +170,7 @@ class ThreadSerializer(_ContentSerializer): """ course_id = serializers.CharField() topic_id = NonEmptyCharField(source="commentable_id") - group_id = serializers.IntegerField(read_only=True) + group_id = serializers.IntegerField(required=False) group_name = serializers.SerializerMethodField("get_group_name") type_ = serializers.ChoiceField( source="thread_type", @@ -182,7 +186,7 @@ class ThreadSerializer(_ContentSerializer): endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url") non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url") - non_updatable_fields = ("course_id",) + non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS def __init__(self, *args, **kwargs): super(ThreadSerializer, self).__init__(*args, **kwargs) @@ -256,7 +260,7 @@ class CommentSerializer(_ContentSerializer): endorsed_at = serializers.SerializerMethodField("get_endorsed_at") children = serializers.SerializerMethodField("get_children") - non_updatable_fields = ("thread_id", "parent_id") + non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS def get_endorsed_by(self, obj): """ diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 8286d2b0fa..619484e86c 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -1182,6 +1182,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): self.get_comment_list(thread, endorsed=True, page=2, page_size=10) +@ddt.ddt class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): """Tests for create_thread""" @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @@ -1273,6 +1274,69 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC } ) + @ddt.data( + *itertools.product( + [ + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ], + [True, False], + [True, False], + ["no_group_set", "group_is_none", "group_is_set"], + ) + ) + @ddt.unpack + def test_group_id(self, role_name, course_is_cohorted, topic_is_cohorted, data_group_state): + """ + Tests whether the user has permission to create a thread with certain + group_id values. + + If there is no group, user cannot create a thread. + Else if group is None or set, and the course is not cohorted and/or the + role is a student, user can create a thread. + """ + + cohort_course = CourseFactory.create( + discussion_topics={"Test Topic": {"id": "test_topic"}}, + cohort_config={ + "cohorted": course_is_cohorted, + "cohorted_discussions": ["test_topic"] if topic_is_cohorted else [], + } + ) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) + if course_is_cohorted: + cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) + role = Role.objects.create(name=role_name, course_id=cohort_course.id) + role.users = [self.user] + self.register_post_thread_response({}) + data = self.minimal_data.copy() + data["course_id"] = unicode(cohort_course.id) + if data_group_state == "group_is_none": + data["group_id"] = None + elif data_group_state == "group_is_set": + if course_is_cohorted: + data["group_id"] = cohort.id + 1 + else: + data["group_id"] = 1 # Set to any value since there is no cohort + expected_error = ( + data_group_state in ["group_is_none", "group_is_set"] and + (not course_is_cohorted or role_name == FORUM_ROLE_STUDENT) + ) + try: + create_thread(self.request, data) + self.assertFalse(expected_error) + actual_post_data = httpretty.last_request().parsed_body + if data_group_state == "group_is_set": + self.assertEqual(actual_post_data["group_id"], [str(data["group_id"])]) + elif data_group_state == "no_group_set" and course_is_cohorted and topic_is_cohorted: + self.assertEqual(actual_post_data["group_id"], [str(cohort.id)]) + else: + self.assertNotIn("group_id", actual_post_data) + except ValidationError: + self.assertTrue(expected_error) + def test_following(self): self.register_post_thread_response({"id": "test_id"}) self.register_subscription_response(self.user) @@ -1456,6 +1520,44 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest self.assertEqual(actual_event_name, expected_event_name) self.assertEqual(actual_event_data, expected_event_data) + @ddt.data( + *itertools.product( + [ + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ], + [True, False], + ["question", "discussion"], + ) + ) + @ddt.unpack + def test_endorsed(self, role_name, is_thread_author, thread_type): + role = Role.objects.create(name=role_name, course_id=self.course.id) + role.users = [self.user] + self.register_get_thread_response( + make_minimal_cs_thread({ + "id": "test_thread", + "course_id": unicode(self.course.id), + "thread_type": thread_type, + "user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1), + }) + ) + self.register_post_comment_response({}, "test_thread") + data = self.minimal_data.copy() + data["endorsed"] = True + expected_error = ( + role_name == FORUM_ROLE_STUDENT and + (not is_thread_author or thread_type == "discussion") + ) + try: + create_comment(self.request, data) + self.assertEqual(httpretty.last_request().parsed_body["endorsed"], ["True"]) + self.assertFalse(expected_error) + except ValidationError: + self.assertTrue(expected_error) + def test_voted(self): self.register_post_comment_response({"id": "test_comment"}, "test_thread") self.register_comment_votes_response("test_comment") diff --git a/lms/djangoapps/discussion_api/tests/test_permissions.py b/lms/djangoapps/discussion_api/tests/test_permissions.py index 0de1c13d54..78f0df9606 100644 --- a/lms/djangoapps/discussion_api/tests/test_permissions.py +++ b/lms/djangoapps/discussion_api/tests/test_permissions.py @@ -2,37 +2,86 @@ Tests for discussion API permission logic """ import itertools -from unittest import TestCase import ddt -from discussion_api.permissions import can_delete, get_editable_fields +from discussion_api.permissions import ( + can_delete, + get_editable_fields, + get_initializable_comment_fields, + get_initializable_thread_fields, +) from lms.lib.comment_client.comment import Comment from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User +from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory -def _get_context(requester_id, is_requester_privileged, thread=None): +def _get_context(requester_id, is_requester_privileged, is_cohorted=False, thread=None): """Return a context suitable for testing the permissions module""" return { "cc_requester": User(id=requester_id), "is_requester_privileged": is_requester_privileged, + "course": CourseFactory(cohort_config={"cohorted": is_cohorted}), "thread": thread, } @ddt.ddt -class GetEditableFieldsTest(TestCase): - """Tests for get_editable_fields""" +class GetInitializableFieldsTest(ModuleStoreTestCase): + """Tests for get_*_initializable_fields""" @ddt.data(*itertools.product([True, False], [True, False])) @ddt.unpack - def test_thread(self, is_author, is_privileged): + def test_thread(self, is_privileged, is_cohorted): + context = _get_context( + requester_id="5", + is_requester_privileged=is_privileged, + is_cohorted=is_cohorted + ) + actual = get_initializable_thread_fields(context) + expected = { + "abuse_flagged", "course_id", "following", "raw_body", "title", "topic_id", "type", "voted" + } + if is_privileged and is_cohorted: + expected |= {"group_id"} + self.assertEqual(actual, expected) + + @ddt.data(*itertools.product([True, False], ["question", "discussion"], [True, False])) + @ddt.unpack + def test_comment(self, is_thread_author, thread_type, is_privileged): + context = _get_context( + requester_id="5", + is_requester_privileged=is_privileged, + thread=Thread(user_id="5" if is_thread_author else "6", thread_type=thread_type) + ) + actual = get_initializable_comment_fields(context) + expected = { + "abuse_flagged", "parent_id", "raw_body", "thread_id", "voted" + } + if (is_thread_author and thread_type == "question") or is_privileged: + expected |= {"endorsed"} + self.assertEqual(actual, expected) + + +@ddt.ddt +class GetEditableFieldsTest(ModuleStoreTestCase): + """Tests for get_editable_fields""" + @ddt.data(*itertools.product([True, False], [True, False], [True, False])) + @ddt.unpack + def test_thread(self, is_author, is_privileged, is_cohorted): thread = Thread(user_id="5" if is_author else "6", type="thread") - context = _get_context(requester_id="5", is_requester_privileged=is_privileged) + context = _get_context( + requester_id="5", + is_requester_privileged=is_privileged, + is_cohorted=is_cohorted + ) actual = get_editable_fields(thread, context) expected = {"abuse_flagged", "following", "voted"} if is_author or is_privileged: expected |= {"topic_id", "type", "title", "raw_body"} + if is_privileged and is_cohorted: + expected |= {"group_id"} self.assertEqual(actual, expected) @ddt.data(*itertools.product([True, False], [True, False], ["question", "discussion"], [True, False])) @@ -54,7 +103,7 @@ class GetEditableFieldsTest(TestCase): @ddt.ddt -class CanDeleteTest(TestCase): +class CanDeleteTest(ModuleStoreTestCase): """Tests for can_delete""" @ddt.data(*itertools.product([True, False], [True, False])) @ddt.unpack diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py index 7992715a30..cae90175d9 100644 --- a/lms/djangoapps/discussion_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -462,6 +462,24 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ) self.assertEqual(saved["id"], "test_id") + def test_create_all_fields(self): + self.register_post_thread_response({"id": "test_id"}) + data = self.minimal_data.copy() + data["group_id"] = 42 + self.save_and_reserialize(data) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "commentable_id": ["test_topic"], + "thread_type": ["discussion"], + "title": ["Test Title"], + "body": ["Test body"], + "user_id": [str(self.user.id)], + "group_id": ["42"], + } + ) + def test_create_missing_field(self): for field in self.minimal_data: data = self.minimal_data.copy() @@ -638,6 +656,27 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore self.assertEqual(saved["id"], "test_comment") self.assertEqual(saved["parent_id"], parent_id) + def test_create_all_fields(self): + data = self.minimal_data.copy() + data["parent_id"] = "test_parent" + data["endorsed"] = True + self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"}) + self.register_post_comment_response( + {"id": "test_comment"}, + thread_id="test_thread", + parent_id="test_parent" + ) + self.save_and_reserialize(data) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "body": ["Test body"], + "user_id": [str(self.user.id)], + "endorsed": ["True"], + } + ) + def test_create_parent_id_nonexistent(self): self.register_get_comment_error_response("bad_parent", 404) data = self.minimal_data.copy()