diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 1826b19a8f..c59e88b715 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -270,13 +270,20 @@ def _do_extra_thread_actions(api_thread, cc_thread, request_fields, actions_form Perform any necessary additional actions related to thread creation or update that require a separate comments service request. """ - form_following = actions_form.cleaned_data["following"] - if "following" in request_fields and form_following != api_thread["following"]: - if form_following: - context["cc_requester"].follow(cc_thread) - else: - context["cc_requester"].unfollow(cc_thread) - api_thread["following"] = form_following + for field, form_value in actions_form.cleaned_data.items(): + if field in request_fields and form_value != api_thread[field]: + api_thread[field] = form_value + if field == "following": + if form_value: + context["cc_requester"].follow(cc_thread) + else: + context["cc_requester"].unfollow(cc_thread) + else: + assert field == "voted" + if form_value: + context["cc_requester"].vote(cc_thread, "up") + else: + context["cc_requester"].unvote(cc_thread) def create_thread(request, thread_data): @@ -368,7 +375,7 @@ def create_comment(request, comment_data): return serializer.data -_THREAD_EDITABLE_BY_ANY = {"following"} +_THREAD_EDITABLE_BY_ANY = {"following", "voted"} _THREAD_EDITABLE_BY_AUTHOR = {"topic_id", "type", "title", "raw_body"} | _THREAD_EDITABLE_BY_ANY diff --git a/lms/djangoapps/discussion_api/forms.py b/lms/djangoapps/discussion_api/forms.py index be4daa2d36..2f1df946d6 100644 --- a/lms/djangoapps/discussion_api/forms.py +++ b/lms/djangoapps/discussion_api/forms.py @@ -63,6 +63,7 @@ class ThreadActionsForm(Form): interactions with the comments service. """ following = BooleanField(required=False) + voted = BooleanField(required=False) class CommentListGetForm(_PaginationForm): diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 69880074dd..299b2e5059 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -1136,6 +1136,21 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC {"source_type": ["thread"], "source_id": ["test_id"]} ) + def test_voted(self): + self.register_post_thread_response({"id": "test_id"}) + self.register_thread_votes_response("test_id") + data = self.minimal_data.copy() + data["voted"] = "True" + result = create_thread(self.request, data) + self.assertEqual(result["voted"], True) + cs_request = httpretty.last_request() + self.assertEqual(urlparse(cs_request.path).path, "/api/v1/threads/test_id/votes") + self.assertEqual(cs_request.method, "PUT") + self.assertEqual( + cs_request.parsed_body, + {"user_id": [str(self.user.id)], "value": ["up"]} + ) + def test_course_id_missing(self): with self.assertRaises(ValidationError) as assertion: create_thread(self.request, {}) @@ -1571,6 +1586,45 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC {"source_type": ["thread"], "source_id": ["test_thread"]} ) + @ddt.data(*itertools.product([True, False], [True, False])) + @ddt.unpack + def test_voted(self, old_voted, new_voted): + """ + Test attempts to edit the "voted" field. + + old_voted indicates whether the thread should be upvoted at the start of + the test. new_voted indicates the value for the "voted" field in the + update. If old_voted and new_voted are the same, no update should be + made. Otherwise, a vote should be PUT or DELETEd according to the + new_voted value. + """ + if old_voted: + self.register_get_user_response(self.user, upvoted_ids=["test_thread"]) + self.register_thread_votes_response("test_thread") + self.register_thread() + data = {"voted": new_voted} + result = update_thread(self.request, "test_thread", data) + self.assertEqual(result["voted"], new_voted) + last_request_path = urlparse(httpretty.last_request().path).path + votes_url = "/api/v1/threads/test_thread/votes" + if old_voted == new_voted: + self.assertNotEqual(last_request_path, votes_url) + else: + self.assertEqual(last_request_path, votes_url) + self.assertEqual( + httpretty.last_request().method, + "PUT" if new_voted else "DELETE" + ) + actual_request_data = ( + httpretty.last_request().parsed_body if new_voted else + parse_qs(urlparse(httpretty.last_request().path).query) + ) + actual_request_data.pop("request_id", None) + expected_request_data = {"user_id": [str(self.user.id)]} + if new_voted: + expected_request_data["value"] = ["up"] + self.assertEqual(actual_request_data, expected_request_data) + def test_invalid_field(self): self.register_thread() with self.assertRaises(ValidationError) as assertion: diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py index a6c863f4cf..add247dbd2 100644 --- a/lms/djangoapps/discussion_api/tests/utils.py +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -160,6 +160,19 @@ class CommentsServiceMockMixin(object): status=200 ) + def register_thread_votes_response(self, thread_id): + """ + Register a mock response for PUT and DELETE on the CS thread votes + endpoint + """ + for method in [httpretty.PUT, httpretty.DELETE]: + httpretty.register_uri( + method, + "http://localhost:4567/api/v1/threads/{}/votes".format(thread_id), + body=json.dumps({}), # body is unused + status=200 + ) + def assert_query_params_equal(self, httpretty_request, expected_params): """ Assert that the given mock request had the expected query parameters