diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 597f08fd80..c65a917cf6 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -1,6 +1,7 @@ """ Discussion API internal interface """ +from collections import defaultdict from urllib import urlencode from urlparse import urlunparse @@ -8,11 +9,10 @@ from django.core.exceptions import ValidationError from django.core.urlresolvers import reverse from django.http import Http404 - -from collections import defaultdict +from rest_framework.exceptions import PermissionDenied from opaque_keys import InvalidKeyError -from opaque_keys.edx.locator import CourseLocator +from opaque_keys.edx.locator import CourseKey from courseware.courses import get_course_with_access from discussion_api.forms import ThreadCreateExtrasForm @@ -28,7 +28,7 @@ from django_comment_client.base.views import ( from django_comment_client.utils import get_accessible_discussion_modules from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.utils import CommentClientRequestError -from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id +from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted def _get_course_or_404(course_key, user): @@ -43,6 +43,37 @@ def _get_course_or_404(course_key, user): return course +def _get_thread_and_context(request, thread_id, parent_id=None, retrieve_kwargs=None): + """ + Retrieve the given thread and build a serializer context for it, returning + both. This function also enforces access control for the thread (checking + both the user's access to the course and to the thread's cohort if + applicable). Raises Http404 if the thread does not exist or the user cannot + access it. + """ + retrieve_kwargs = retrieve_kwargs or {} + try: + if "mark_as_read" not in retrieve_kwargs: + retrieve_kwargs["mark_as_read"] = False + cc_thread = Thread(id=thread_id).retrieve(**retrieve_kwargs) + course_key = CourseKey.from_string(cc_thread["course_id"]) + course = _get_course_or_404(course_key, request.user) + context = get_context(course, request, cc_thread, parent_id) + if ( + not context["is_requester_privileged"] and + cc_thread["group_id"] and + is_commentable_cohorted(course.id, cc_thread["commentable_id"]) + ): + requester_cohort = get_cohort_id(request.user, course.id) + if requester_cohort is not None and cc_thread["group_id"] != requester_cohort: + raise Http404 + return cc_thread, context + except CommentClientRequestError: + # params are validated at a higher level, so the only possible request + # error is if the thread doesn't exist + raise Http404 + + def get_thread_list_url(request, course_key, topic_id_list): """ Returns the URL for the thread_list_url field, given a list of topic_ids @@ -191,28 +222,17 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): discussion_api.views.CommentViewSet for more detail. """ response_skip = page_size * (page - 1) - try: - cc_thread = Thread(id=thread_id).retrieve( - recursive=True, - user_id=request.user.id, - mark_as_read=True, - response_skip=response_skip, - response_limit=page_size - ) - except CommentClientRequestError: - # page and page_size are validated at a higher level, so the only - # possible request error is if the thread doesn't exist - raise Http404 - - course_key = CourseLocator.from_string(cc_thread["course_id"]) - course = _get_course_or_404(course_key, request.user) - context = get_context(course, request, cc_thread) - - # Ensure user has access to the thread - if not context["is_requester_privileged"] and cc_thread["group_id"]: - requester_cohort = get_cohort_id(request.user, course_key) - if requester_cohort is not None and cc_thread["group_id"] != requester_cohort: - raise Http404 + cc_thread, context = _get_thread_and_context( + request, + thread_id, + retrieve_kwargs={ + "recursive": True, + "user_id": request.user.id, + "mark_as_read": True, + "response_skip": response_skip, + "response_limit": page_size, + } + ) # Responses to discussion threads cannot be separated by endorsed, but # responses to question threads must be separated by endorsed due to the @@ -267,7 +287,7 @@ def create_thread(request, thread_data): if not course_id: raise ValidationError({"course_id": ["This field is required."]}) try: - course_key = CourseLocator.from_string(course_id) + course_key = CourseKey.from_string(course_id) course = _get_course_or_404(course_key, request.user) except (Http404, InvalidKeyError): raise ValidationError({"course_id": ["Invalid value."]}) @@ -314,29 +334,55 @@ def create_comment(request, comment_data): detail. """ thread_id = comment_data.get("thread_id") + parent_id = comment_data.get("parent_id") if not thread_id: raise ValidationError({"thread_id": ["This field is required."]}) try: - thread = Thread(id=thread_id).retrieve(mark_as_read=False) - course_key = CourseLocator.from_string(thread["course_id"]) - course = _get_course_or_404(course_key, request.user) - except (Http404, CommentClientRequestError): + cc_thread, context = _get_thread_and_context(request, thread_id, parent_id) + except Http404: raise ValidationError({"thread_id": ["Invalid value."]}) - parent_id = comment_data.get("parent_id") - context = get_context(course, request, thread, parent_id) serializer = CommentSerializer(data=comment_data, context=context) if not serializer.is_valid(): raise ValidationError(serializer.errors) serializer.save() - comment = serializer.object + cc_comment = serializer.object track_forum_event( request, - get_comment_created_event_name(comment), - course, - comment, - get_comment_created_event_data(comment, thread["commentable_id"], followed=False) + get_comment_created_event_name(cc_comment), + context["course"], + cc_comment, + get_comment_created_event_data(cc_comment, cc_thread["commentable_id"], followed=False) ) return serializer.data + + +def update_thread(request, thread_id, update_data): + """ + Update a thread. + + Parameters: + + request: The django request object used for build_absolute_uri and + determining the requesting user. + + thread_id: The id for the thread to update. + + update_data: The data to update in the thread. + + Returns: + + The updated thread; see discussion_api.views.ThreadViewSet for more + detail. + """ + cc_thread, context = _get_thread_and_context(request, thread_id) + is_author = str(request.user.id) == cc_thread["user_id"] + if not (context["is_requester_privileged"] or is_author): + raise PermissionDenied() + serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context) + if not serializer.is_valid(): + raise ValidationError(serializer.errors) + serializer.save() + return serializer.data diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py index d524d44fd3..1dc137565c 100644 --- a/lms/djangoapps/discussion_api/serializers.py +++ b/lms/djangoapps/discussion_api/serializers.py @@ -21,6 +21,7 @@ from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User as CommentClientUser from lms.lib.comment_client.utils import CommentClientRequestError from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names +from openedx.core.lib.api.fields import NonEmptyCharField def get_context(course, request, thread=None, parent_id=None): @@ -44,15 +45,16 @@ def get_context(course, request, thread=None, parent_id=None): } requester = request.user return { - # For now, the only groups are cohorts + "course": course, "request": request, + "thread": thread, + "parent_id": parent_id, + # For now, the only groups are cohorts "group_ids_to_names": get_cohort_names(course), "is_requester_privileged": requester.id in staff_user_ids or requester.id in ta_user_ids, "staff_user_ids": staff_user_ids, "ta_user_ids": ta_user_ids, "cc_requester": CommentClientUser.from_django_user(requester).retrieve(), - "thread": thread, - "parent_id": parent_id, } @@ -63,7 +65,7 @@ class _ContentSerializer(serializers.Serializer): author_label = serializers.SerializerMethodField("get_author_label") created_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True) - raw_body = serializers.CharField(source="body") + raw_body = NonEmptyCharField(source="body") abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") voted = serializers.SerializerMethodField("get_voted") vote_count = serializers.SerializerMethodField("get_vote_count") @@ -138,14 +140,14 @@ class ThreadSerializer(_ContentSerializer): at introspection and Thread's __getattr__. """ course_id = serializers.CharField() - topic_id = serializers.CharField(source="commentable_id") + topic_id = NonEmptyCharField(source="commentable_id") group_id = serializers.IntegerField(read_only=True) group_name = serializers.SerializerMethodField("get_group_name") type_ = serializers.ChoiceField( source="thread_type", choices=[(val, val) for val in ["discussion", "question"]] ) - title = serializers.CharField() + title = NonEmptyCharField() pinned = serializers.BooleanField(read_only=True) closed = serializers.BooleanField(read_only=True) following = serializers.SerializerMethodField("get_following") @@ -198,10 +200,19 @@ class ThreadSerializer(_ContentSerializer): """Returns the URL to retrieve the thread's non-endorsed comments.""" return self.get_comment_list_url(obj, endorsed=False) + def validate_course_id(self, attrs, _source): + """Ensure that course_id is not edited in an update operation.""" + if self.object: + raise ValidationError("This field is not allowed in an update.") + return attrs + def restore_object(self, attrs, instance=None): if instance: - raise ValueError("ThreadSerializer cannot be used for updates.") - return Thread(user_id=self.context["cc_requester"]["id"], **attrs) + for key, val in attrs.items(): + instance[key] = val + return instance + else: + return Thread(user_id=self.context["cc_requester"]["id"], **attrs) class CommentSerializer(_ContentSerializer): diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 96843a70f5..535abf592f 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -15,6 +15,8 @@ from django.core.exceptions import ValidationError from django.http import Http404 from django.test.client import RequestFactory +from rest_framework.exceptions import PermissionDenied + from opaque_keys.edx.locator import CourseLocator from courseware.tests.factories import BetaTesterFactory, StaffFactory @@ -24,6 +26,7 @@ from discussion_api.api import ( get_comment_list, get_course_topics, get_thread_list, + update_thread, ) from discussion_api.tests.utils import ( CommentsServiceMockMixin, @@ -685,18 +688,32 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): FORUM_ROLE_STUDENT, ], [True, False], + [True, False], ["no_group", "match_group", "different_group"], ) ) @ddt.unpack - def test_group_access(self, role_name, course_is_cohorted, thread_group_state): - cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + def test_group_access( + self, + role_name, + course_is_cohorted, + topic_is_cohorted, + thread_group_state + ): + cohort_course = CourseFactory.create( + discussion_topics={"Test Topic": {"id": "test_topic"}}, + cohort_config={ + "cohorted": course_is_cohorted, + "cohorted_discussions": ["test_topic"] if topic_is_cohorted else [], + } + ) CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) role = Role.objects.create(name=role_name, course_id=cohort_course.id) role.users = [self.user] thread = self.make_minimal_cs_thread({ "course_id": unicode(cohort_course.id), + "commentable_id": "test_topic", "group_id": ( None if thread_group_state == "no_group" else cohort.id if thread_group_state == "match_group" else @@ -706,6 +723,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): expected_error = ( role_name == FORUM_ROLE_STUDENT and course_is_cohorted and + topic_is_cohorted and thread_group_state == "different_group" ) try: @@ -1287,8 +1305,224 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest create_comment(self.request, self.minimal_data) self.assertEqual(assertion.exception.message_dict, {"thread_id": ["Invalid value."]}) + @ddt.data( + *itertools.product( + [ + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ], + [True, False], + ["no_group", "match_group", "different_group"], + ) + ) + @ddt.unpack + def test_group_access(self, role_name, course_is_cohorted, thread_group_state): + cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) + cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) + role = Role.objects.create(name=role_name, course_id=cohort_course.id) + role.users = [self.user] + self.register_get_thread_response(make_minimal_cs_thread({ + "id": "cohort_thread", + "course_id": unicode(cohort_course.id), + "group_id": ( + None if thread_group_state == "no_group" else + cohort.id if thread_group_state == "match_group" else + cohort.id + 1 + ), + })) + self.register_post_comment_response({}, thread_id="cohort_thread") + data = self.minimal_data.copy() + data["thread_id"] = "cohort_thread" + expected_error = ( + role_name == FORUM_ROLE_STUDENT and + course_is_cohorted and + thread_group_state == "different_group" + ) + try: + create_comment(self.request, data) + self.assertFalse(expected_error) + except ValidationError as err: + self.assertTrue(expected_error) + self.assertEqual( + err.message_dict, + {"thread_id": ["Invalid value."]} + ) + def test_invalid_field(self): data = self.minimal_data.copy() del data["raw_body"] with self.assertRaises(ValidationError): create_comment(self.request, data) + + +@ddt.ddt +class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): + """Tests for update_thread""" + @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + def setUp(self): + super(UpdateThreadTest, self).setUp() + httpretty.reset() + httpretty.enable() + self.addCleanup(httpretty.disable) + self.user = UserFactory.create() + self.register_get_user_response(self.user) + self.request = RequestFactory().get("/test_path") + self.request.user = self.user + self.course = CourseFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) + + def register_thread(self, overrides=None): + """ + Make a thread with appropriate data overridden by the overrides + parameter and register mock responses for both GET and PUT on its + endpoint. + """ + cs_data = make_minimal_cs_thread({ + "id": "test_thread", + "course_id": unicode(self.course.id), + "commentable_id": "original_topic", + "username": self.user.username, + "user_id": str(self.user.id), + "created_at": "2015-05-29T00:00:00Z", + "updated_at": "2015-05-29T00:00:00Z", + "type": "discussion", + "title": "Original Title", + "body": "Original body", + }) + cs_data.update(overrides or {}) + self.register_get_thread_response(cs_data) + self.register_put_thread_response(cs_data) + + def test_basic(self): + self.register_thread() + actual = update_thread(self.request, "test_thread", {"raw_body": "Edited body"}) + expected = { + "id": "test_thread", + "course_id": unicode(self.course.id), + "topic_id": "original_topic", + "group_id": None, + "group_name": None, + "author": self.user.username, + "author_label": None, + "created_at": "2015-05-29T00:00:00Z", + "updated_at": "2015-05-29T00:00:00Z", + "type": "discussion", + "title": "Original Title", + "raw_body": "Edited body", + "pinned": False, + "closed": False, + "following": False, + "abuse_flagged": False, + "voted": False, + "vote_count": 0, + "comment_count": 0, + "unread_comment_count": 0, + "comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread", + "endorsed_comment_list_url": None, + "non_endorsed_comment_list_url": None, + } + self.assertEqual(actual, expected) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "commentable_id": ["original_topic"], + "thread_type": ["discussion"], + "title": ["Original Title"], + "body": ["Edited body"], + "user_id": [str(self.user.id)], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + } + ) + + def test_nonexistent_thread(self): + self.register_get_thread_error_response("test_thread", 404) + with self.assertRaises(Http404): + update_thread(self.request, "test_thread", {}) + + def test_nonexistent_course(self): + self.register_thread({"course_id": "non/existent/course"}) + with self.assertRaises(Http404): + update_thread(self.request, "test_thread", {}) + + def test_unenrolled(self): + self.register_thread() + self.request.user = UserFactory.create() + with self.assertRaises(Http404): + update_thread(self.request, "test_thread", {}) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + self.register_thread() + with self.assertRaises(Http404): + update_thread(self.request, "test_thread", {}) + + @ddt.data( + *itertools.product( + [ + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ], + [True, False], + ["no_group", "match_group", "different_group"], + ) + ) + @ddt.unpack + def test_group_access(self, role_name, course_is_cohorted, thread_group_state): + cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) + cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) + role = Role.objects.create(name=role_name, course_id=cohort_course.id) + role.users = [self.user] + self.register_thread({ + "course_id": unicode(cohort_course.id), + "group_id": ( + None if thread_group_state == "no_group" else + cohort.id if thread_group_state == "match_group" else + cohort.id + 1 + ), + }) + expected_error = ( + role_name == FORUM_ROLE_STUDENT and + course_is_cohorted and + thread_group_state == "different_group" + ) + try: + update_thread(self.request, "test_thread", {}) + self.assertFalse(expected_error) + except Http404: + self.assertTrue(expected_error) + + @ddt.data( + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ) + def test_non_author_access(self, role_name): + role = Role.objects.create(name=role_name, course_id=self.course.id) + role.users = [self.user] + self.register_thread({"user_id": str(self.user.id + 1)}) + expected_error = role_name == FORUM_ROLE_STUDENT + try: + update_thread(self.request, "test_thread", {}) + self.assertFalse(expected_error) + except PermissionDenied: + self.assertTrue(expected_error) + + def test_invalid_field(self): + self.register_thread() + with self.assertRaises(ValidationError) as assertion: + update_thread(self.request, "test_thread", {"raw_body": ""}) + self.assertEqual( + assertion.exception.message_dict, + {"raw_body": ["This field is required."]} + ) diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py index 106273b3bd..b160658060 100644 --- a/lms/djangoapps/discussion_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -23,6 +23,7 @@ from django_comment_common.models import ( FORUM_ROLE_STUDENT, Role, ) +from lms.lib.comment_client.thread import Thread from student.tests.factories import UserFactory from util.testing import UrlResetMixin from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase @@ -378,7 +379,6 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi httpretty.reset() httpretty.enable() self.addCleanup(httpretty.disable) - self.register_post_thread_response({"id": "test_id"}) self.course = CourseFactory.create() self.user = UserFactory.create() self.register_get_user_response(self.user) @@ -391,18 +391,34 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi "title": "Test Title", "raw_body": "Test body", } + self.existing_thread = Thread(**make_minimal_cs_thread({ + "id": "existing_thread", + "course_id": unicode(self.course.id), + "commentable_id": "original_topic", + "thread_type": "discussion", + "title": "Original Title", + "body": "Original body", + "user_id": str(self.user.id), + })) - def save_and_reserialize(self, data): + def save_and_reserialize(self, data, instance=None): """ - Create a serializer with the given data, ensure that it is valid, save - the result, and return the full thread data from the serializer. + Create a serializer with the given data and (if updating) instance, + ensure that it is valid, save the result, and return the full thread + data from the serializer. """ - serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request)) + serializer = ThreadSerializer( + instance, + data=data, + partial=(instance is not None), + context=get_context(self.course, self.request) + ) self.assertTrue(serializer.is_valid()) serializer.save() return serializer.data - def test_minimal(self): + def test_create_minimal(self): + self.register_post_thread_response({"id": "test_id"}) saved = self.save_and_reserialize(self.minimal_data) self.assertEqual( urlparse(httpretty.last_request().path).path, @@ -421,7 +437,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ) self.assertEqual(saved["id"], "test_id") - def test_missing_field(self): + def test_create_missing_field(self): for field in self.minimal_data: data = self.minimal_data.copy() data.pop(field) @@ -432,7 +448,8 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi {field: ["This field is required."]} ) - def test_type(self): + def test_create_type(self): + self.register_post_thread_response({"id": "test_id"}) data = self.minimal_data.copy() data["type"] = "question" self.save_and_reserialize(data) @@ -441,6 +458,76 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi serializer = ThreadSerializer(data=data) self.assertFalse(serializer.is_valid()) + def test_update_empty(self): + self.register_put_thread_response(self.existing_thread.attributes) + self.save_and_reserialize({}, self.existing_thread) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "commentable_id": ["original_topic"], + "thread_type": ["discussion"], + "title": ["Original Title"], + "body": ["Original body"], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + "user_id": [str(self.user.id)], + } + ) + + def test_update_all(self): + self.register_put_thread_response(self.existing_thread.attributes) + data = { + "topic_id": "edited_topic", + "type": "question", + "title": "Edited Title", + "raw_body": "Edited body", + } + saved = self.save_and_reserialize(data, self.existing_thread) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "commentable_id": ["edited_topic"], + "thread_type": ["question"], + "title": ["Edited Title"], + "body": ["Edited body"], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + "user_id": [str(self.user.id)], + } + ) + for key in data: + self.assertEqual(saved[key], data[key]) + + def test_update_empty_string(self): + serializer = ThreadSerializer( + self.existing_thread, + data={field: "" for field in ["topic_id", "title", "raw_body"]}, + partial=True, + context=get_context(self.course, self.request) + ) + self.assertEqual( + serializer.errors, + {field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} + ) + + def test_update_course_id(self): + serializer = ThreadSerializer( + self.existing_thread, + data={"course_id": "some/other/course"}, + partial=True, + context=get_context(self.course, self.request) + ) + self.assertEqual( + serializer.errors, + {"course_id": ["This field is not allowed in an update."]} + ) + @ddt.ddt class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStoreTestCase): diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index 4bdc87193e..c5a16ee39d 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -11,6 +11,8 @@ from pytz import UTC from django.core.urlresolvers import reverse +from rest_framework.test import APIClient + from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread from student.tests.factories import CourseEnrollmentFactory, UserFactory from util.testing import UrlResetMixin @@ -25,6 +27,8 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): in the test client, utility functions, and a test case for unauthenticated requests. Subclasses must set self.url in their setUp methods. """ + client_class = APIClient + @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) def setUp(self): super(DiscussionAPIViewTestMixin, self).setUp() @@ -294,6 +298,101 @@ class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): self.assertEqual(response_data, expected_response_data) +@httpretty.activate +class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): + """Tests for ThreadViewSet partial_update""" + def setUp(self): + super(ThreadViewSetPartialUpdateTest, self).setUp() + self.url = reverse("thread-detail", kwargs={"thread_id": "test_thread"}) + + def test_basic(self): + self.register_get_user_response(self.user) + cs_thread = make_minimal_cs_thread({ + "id": "test_thread", + "course_id": unicode(self.course.id), + "commentable_id": "original_topic", + "username": self.user.username, + "user_id": str(self.user.id), + "created_at": "2015-05-29T00:00:00Z", + "updated_at": "2015-05-29T00:00:00Z", + "thread_type": "discussion", + "title": "Original Title", + "body": "Original body", + }) + self.register_get_thread_response(cs_thread) + self.register_put_thread_response(cs_thread) + request_data = {"raw_body": "Edited body"} + expected_response_data = { + "id": "test_thread", + "course_id": unicode(self.course.id), + "topic_id": "original_topic", + "group_id": None, + "group_name": None, + "author": self.user.username, + "author_label": None, + "created_at": "2015-05-29T00:00:00Z", + "updated_at": "2015-05-29T00:00:00Z", + "type": "discussion", + "title": "Original Title", + "raw_body": "Edited body", + "pinned": False, + "closed": False, + "following": False, + "abuse_flagged": False, + "voted": False, + "vote_count": 0, + "comment_count": 0, + "unread_comment_count": 0, + "comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread", + "endorsed_comment_list_url": None, + "non_endorsed_comment_list_url": None, + } + response = self.client.patch( # pylint: disable=no-member + self.url, + json.dumps(request_data), + content_type="application/json" + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.content) + self.assertEqual(response_data, expected_response_data) + self.assertEqual( + httpretty.last_request().parsed_body, + { + "course_id": [unicode(self.course.id)], + "commentable_id": ["original_topic"], + "thread_type": ["discussion"], + "title": ["Original Title"], + "body": ["Edited body"], + "user_id": [str(self.user.id)], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + } + ) + + def test_error(self): + self.register_get_user_response(self.user) + cs_thread = make_minimal_cs_thread({ + "id": "test_thread", + "course_id": unicode(self.course.id), + "user_id": str(self.user.id), + }) + self.register_get_thread_response(cs_thread) + request_data = {"title": ""} + response = self.client.patch( # pylint: disable=no-member + self.url, + json.dumps(request_data), + content_type="application/json" + ) + expected_response_data = { + "field_errors": {"title": {"developer_message": "This field is required."}} + } + self.assertEqual(response.status_code, 400) + response_data = json.loads(response.content) + self.assertEqual(response_data, expected_response_data) + + @httpretty.activate class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): """Tests for CommentViewSet list""" diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py index 726eaf20d1..e534bf1553 100644 --- a/lms/djangoapps/discussion_api/tests/utils.py +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -7,6 +7,29 @@ import re import httpretty +def _get_thread_callback(thread_data): + """ + Get a callback function that will return POST/PUT data overridden by + response_overrides. + """ + def callback(request, _uri, headers): + """ + Simulate the thread creation or update endpoint by returning the provided + data along with the data from response_overrides and dummy values for any + additional required fields. + """ + response_data = make_minimal_cs_thread(thread_data) + for key, val_list in request.parsed_body.items(): + val = val_list[0] + if key in ["anonymous", "anonymous_to_peers", "closed", "pinned"]: + response_data[key] = val == "True" + else: + response_data[key] = val + return (200, headers, json.dumps(response_data)) + + return callback + + class CommentsServiceMockMixin(object): """Mixin with utility methods for mocking the comments service""" def register_get_threads_response(self, threads, page, num_pages): @@ -22,23 +45,23 @@ class CommentsServiceMockMixin(object): status=200 ) - def register_post_thread_response(self, response_overrides): + def register_post_thread_response(self, thread_data): """Register a mock response for POST on the CS commentable endpoint""" - def callback(request, _uri, headers): - """ - Simulate the thread creation endpoint by returning the provided data - along with the data from response_overrides. - """ - response_data = make_minimal_cs_thread( - {key: val[0] for key, val in request.parsed_body.items()} - ) - response_data.update(response_overrides) - return (200, headers, json.dumps(response_data)) - httpretty.register_uri( httpretty.POST, re.compile(r"http://localhost:4567/api/v1/(\w+)/threads"), - body=callback + body=_get_thread_callback(thread_data) + ) + + def register_put_thread_response(self, thread_data): + """ + Register a mock response for PUT on the CS endpoint for the given + thread_id. + """ + httpretty.register_uri( + httpretty.PUT, + "http://localhost:4567/api/v1/threads/{}".format(thread_data["id"]), + body=_get_thread_callback(thread_data) ) def register_get_thread_error_response(self, thread_id, status_code): diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index a1978f0aa5..d2cc19f6b7 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -17,6 +17,7 @@ from discussion_api.api import ( get_comment_list, get_course_topics, get_thread_list, + update_thread, ) from discussion_api.forms import CommentListGetForm, ThreadListGetForm from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin @@ -67,7 +68,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): """ **Use Cases** - Retrieve the list of threads for a course or post a new thread. + Retrieve the list of threads for a course, post a new thread, or modify + an existing thread. **Example Requests**: @@ -82,6 +84,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): "body": "Body text" } + PATCH /api/discussion/v1/threads/thread_id + {"raw_body": "Edited text"} + **GET Parameters**: * course_id (required): The course to retrieve threads for @@ -109,16 +114,21 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): * following (optional): A boolean indicating whether the user should follow the thread upon its creation; defaults to false + **PATCH Parameters**: + + topic_id, type, title, and raw_body are accepted with the same meaning + as in a POST request + **GET Response Values**: * results: The list of threads; each item in the list has the same - fields as the POST response below + fields as the POST/PATCH response below * next: The URL of the next page (or null if first page) * previous: The URL of the previous page (or null if last page) - **POST response values**: + **POST/PATCH response values**: * id: The id of the thread @@ -148,6 +158,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): the thread """ + lookup_field = "thread_id" + def list(self, request): """ Implements the GET method for the list endpoint as described in the @@ -173,6 +185,13 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): """ return Response(create_thread(request, request.DATA)) + def partial_update(self, request, thread_id): + """ + Implements the PATCH method for the instance endpoint as described in + the class docstring. + """ + return Response(update_thread(request, thread_id, request.DATA)) + class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): """ diff --git a/openedx/core/lib/api/fields.py b/openedx/core/lib/api/fields.py index 17d2de0b9a..eaec6243ca 100644 --- a/openedx/core/lib/api/fields.py +++ b/openedx/core/lib/api/fields.py @@ -1,6 +1,7 @@ """Fields useful for edX API implementations.""" +from django.core.exceptions import ValidationError -from rest_framework.serializers import Field +from rest_framework.serializers import CharField, Field class ExpandableField(Field): @@ -20,3 +21,17 @@ class ExpandableField(Field): else: self.collapsed.initialize(self, field_name) return self.collapsed.field_to_native(obj, field_name) + + +class NonEmptyCharField(CharField): + """ + A field that enforces non-emptiness even for partial updates. + + This is necessary because prior to version 3, DRF skips validation for empty + values. Thus, CharField's min_length and RegexField cannot be used to + enforce this constraint. + """ + def validate(self, value): + super(NonEmptyCharField, self).validate(value) + if not value: + raise ValidationError(self.error_messages["required"])