diff --git a/lms/djangoapps/django_comment_client/base/tests.py b/lms/djangoapps/django_comment_client/base/tests.py index 6da8b1bcb7..3091b78217 100644 --- a/lms/djangoapps/django_comment_client/base/tests.py +++ b/lms/djangoapps/django_comment_client/base/tests.py @@ -316,11 +316,20 @@ class ViewsTestCaseMixin(object): Issues a request to update a thread and verifies the result. """ self._setup_mock_request(mock_request) - response = self.client.post( - reverse("update_thread", kwargs={"thread_id": "dummy", "course_id": self.course_id.to_deprecated_string()}), - data={"body": "foo", "title": "foo", "commentable_id": "some_topic"} - ) + # Mock out saving in order to test that content is correctly + # updated. Otherwise, the call to thread.save() receives the + # same mocked request data that the original call to retrieve + # the thread did, overwriting any changes. + with patch.object(Thread, 'save'): + response = self.client.post( + reverse("update_thread", kwargs={"thread_id": "dummy", "course_id": self.course_id.to_deprecated_string()}), + data={"body": "foo", "title": "foo", "commentable_id": "some_topic"} + ) self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertEqual(data['body'], 'foo') + self.assertEqual(data['title'], 'foo') + self.assertEqual(data['commentable_id'], 'some_topic') @ddt.ddt diff --git a/lms/djangoapps/django_comment_client/base/views.py b/lms/djangoapps/django_comment_client/base/views.py index aacc1cc571..9778fb4392 100644 --- a/lms/djangoapps/django_comment_client/base/views.py +++ b/lms/djangoapps/django_comment_client/base/views.py @@ -243,6 +243,8 @@ def update_thread(request, course_id, thread_id): course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) thread = cc.Thread.find(thread_id) + # Get thread context first in order to be safe from reseting the values of thread object later + thread_context = getattr(thread, "context", "course") thread.body = request.POST["body"] thread.title = request.POST["title"] # The following checks should avoid issues we've seen during deploys, where end users are hitting an updated server @@ -252,7 +254,6 @@ def update_thread(request, course_id, thread_id): if "commentable_id" in request.POST: commentable_id = request.POST["commentable_id"] course = get_course_with_access(request.user, 'load', course_key) - thread_context = getattr(thread, "context", "course") if thread_context == "course" and not discussion_category_id_access(course, request.user, commentable_id): return JsonError(_("Topic doesn't exist")) else: