diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 30579b22b1..43dcbbd4ba 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -10,6 +10,7 @@ from opaque_keys import InvalidKeyError from opaque_keys.edx.locator import CourseLocator from courseware.courses import get_course_with_access +from discussion_api.forms import ThreadCreateExtrasForm from discussion_api.pagination import get_paginated_data from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from django_comment_client.base.views import ( @@ -243,17 +244,24 @@ def create_thread(request, thread_data): context = get_context(course, request) serializer = ThreadSerializer(data=thread_data, context=context) - if not serializer.is_valid(): - raise ValidationError(serializer.errors) + extras_form = ThreadCreateExtrasForm(thread_data) + if not (serializer.is_valid() and extras_form.is_valid()): + raise ValidationError(dict(serializer.errors.items() + extras_form.errors.items())) serializer.save() thread = serializer.object + ret = serializer.data + following = extras_form.cleaned_data["following"] + if following: + context["cc_requester"].follow(thread) + ret["following"] = True + track_forum_event( request, THREAD_CREATED_EVENT_NAME, course, thread, - get_thread_created_event_data(thread, followed=False) + get_thread_created_event_data(thread, followed=following) ) return serializer.data diff --git a/lms/djangoapps/discussion_api/forms.py b/lms/djangoapps/discussion_api/forms.py index 3f57ea71d0..08dd42330a 100644 --- a/lms/djangoapps/discussion_api/forms.py +++ b/lms/djangoapps/discussion_api/forms.py @@ -2,7 +2,7 @@ Discussion API forms """ from django.core.exceptions import ValidationError -from django.forms import CharField, Form, IntegerField, NullBooleanField +from django.forms import BooleanField, CharField, Form, IntegerField, NullBooleanField from opaque_keys import InvalidKeyError from opaque_keys.edx.locator import CourseLocator @@ -37,6 +37,14 @@ class ThreadListGetForm(_PaginationForm): raise ValidationError("'{}' is not a valid course id".format(value)) +class ThreadCreateExtrasForm(Form): + """ + A form to handle fields in thread creation that require separate + interactions with the comments service. + """ + following = BooleanField(required=False) + + class CommentListGetForm(_PaginationForm): """ A form to validate query parameters in the comment list retrieval endpoint diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index f7d768557f..7d1e97cfa0 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -3,6 +3,7 @@ Tests for Discussion API internal interface """ from datetime import datetime, timedelta import itertools +from urlparse import urlparse import ddt import httpretty @@ -1066,6 +1067,23 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC } ) + def test_following(self): + self.register_post_thread_response({"id": "test_id"}) + self.register_subscription_response(self.user) + data = self.minimal_data.copy() + data["following"] = "True" + result = create_thread(self.request, data) + self.assertEqual(result["following"], True) + cs_request = httpretty.last_request() + self.assertEqual( + urlparse(cs_request.path).path, + "/api/v1/users/{}/subscriptions".format(self.user.id) + ) + self.assertEqual( + cs_request.parsed_body, + {"source_type": ["thread"], "source_id": ["test_id"]} + ) + def test_course_id_missing(self): with self.assertRaises(ValidationError) as assertion: create_thread(self.request, {}) diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py index 4f60a16e52..9ecfc71eb5 100644 --- a/lms/djangoapps/discussion_api/tests/utils.py +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -74,6 +74,17 @@ class CommentsServiceMockMixin(object): status=200 ) + def register_subscription_response(self, user): + """ + Register a mock response for POST on the CS user subscription endpoint + """ + httpretty.register_uri( + httpretty.POST, + "http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id), + body=json.dumps({}), # body is unused + status=200 + ) + def assert_query_params_equal(self, httpretty_request, expected_params): """ Assert that the given mock request had the expected query parameters diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index 8c0475f97e..2e82641246 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -96,6 +96,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): * raw_body (required): The thread's raw body text + * following (optional): A boolean indicating whether the user should + follow the thread upon its creation; defaults to false + **GET Response Values**: * results: The list of threads; each item in the list has the same