Merge pull request #8311 from edx/gprice/discussion-api-edit-thread
Add thread editing to discussion API
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Discussion API internal interface
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from urllib import urlencode
|
||||
from urlparse import urlunparse
|
||||
|
||||
@@ -8,11 +9,10 @@ from django.core.exceptions import ValidationError
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.http import Http404
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
from opaque_keys import InvalidKeyError
|
||||
from opaque_keys.edx.locator import CourseLocator
|
||||
from opaque_keys.edx.locator import CourseKey
|
||||
|
||||
from courseware.courses import get_course_with_access
|
||||
from discussion_api.forms import ThreadCreateExtrasForm
|
||||
@@ -28,7 +28,7 @@ from django_comment_client.base.views import (
|
||||
from django_comment_client.utils import get_accessible_discussion_modules
|
||||
from lms.lib.comment_client.thread import Thread
|
||||
from lms.lib.comment_client.utils import CommentClientRequestError
|
||||
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
|
||||
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted
|
||||
|
||||
|
||||
def _get_course_or_404(course_key, user):
|
||||
@@ -43,6 +43,37 @@ def _get_course_or_404(course_key, user):
|
||||
return course
|
||||
|
||||
|
||||
def _get_thread_and_context(request, thread_id, parent_id=None, retrieve_kwargs=None):
|
||||
"""
|
||||
Retrieve the given thread and build a serializer context for it, returning
|
||||
both. This function also enforces access control for the thread (checking
|
||||
both the user's access to the course and to the thread's cohort if
|
||||
applicable). Raises Http404 if the thread does not exist or the user cannot
|
||||
access it.
|
||||
"""
|
||||
retrieve_kwargs = retrieve_kwargs or {}
|
||||
try:
|
||||
if "mark_as_read" not in retrieve_kwargs:
|
||||
retrieve_kwargs["mark_as_read"] = False
|
||||
cc_thread = Thread(id=thread_id).retrieve(**retrieve_kwargs)
|
||||
course_key = CourseKey.from_string(cc_thread["course_id"])
|
||||
course = _get_course_or_404(course_key, request.user)
|
||||
context = get_context(course, request, cc_thread, parent_id)
|
||||
if (
|
||||
not context["is_requester_privileged"] and
|
||||
cc_thread["group_id"] and
|
||||
is_commentable_cohorted(course.id, cc_thread["commentable_id"])
|
||||
):
|
||||
requester_cohort = get_cohort_id(request.user, course.id)
|
||||
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
|
||||
raise Http404
|
||||
return cc_thread, context
|
||||
except CommentClientRequestError:
|
||||
# params are validated at a higher level, so the only possible request
|
||||
# error is if the thread doesn't exist
|
||||
raise Http404
|
||||
|
||||
|
||||
def get_thread_list_url(request, course_key, topic_id_list):
|
||||
"""
|
||||
Returns the URL for the thread_list_url field, given a list of topic_ids
|
||||
@@ -191,28 +222,17 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
|
||||
discussion_api.views.CommentViewSet for more detail.
|
||||
"""
|
||||
response_skip = page_size * (page - 1)
|
||||
try:
|
||||
cc_thread = Thread(id=thread_id).retrieve(
|
||||
recursive=True,
|
||||
user_id=request.user.id,
|
||||
mark_as_read=True,
|
||||
response_skip=response_skip,
|
||||
response_limit=page_size
|
||||
)
|
||||
except CommentClientRequestError:
|
||||
# page and page_size are validated at a higher level, so the only
|
||||
# possible request error is if the thread doesn't exist
|
||||
raise Http404
|
||||
|
||||
course_key = CourseLocator.from_string(cc_thread["course_id"])
|
||||
course = _get_course_or_404(course_key, request.user)
|
||||
context = get_context(course, request, cc_thread)
|
||||
|
||||
# Ensure user has access to the thread
|
||||
if not context["is_requester_privileged"] and cc_thread["group_id"]:
|
||||
requester_cohort = get_cohort_id(request.user, course_key)
|
||||
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
|
||||
raise Http404
|
||||
cc_thread, context = _get_thread_and_context(
|
||||
request,
|
||||
thread_id,
|
||||
retrieve_kwargs={
|
||||
"recursive": True,
|
||||
"user_id": request.user.id,
|
||||
"mark_as_read": True,
|
||||
"response_skip": response_skip,
|
||||
"response_limit": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Responses to discussion threads cannot be separated by endorsed, but
|
||||
# responses to question threads must be separated by endorsed due to the
|
||||
@@ -267,7 +287,7 @@ def create_thread(request, thread_data):
|
||||
if not course_id:
|
||||
raise ValidationError({"course_id": ["This field is required."]})
|
||||
try:
|
||||
course_key = CourseLocator.from_string(course_id)
|
||||
course_key = CourseKey.from_string(course_id)
|
||||
course = _get_course_or_404(course_key, request.user)
|
||||
except (Http404, InvalidKeyError):
|
||||
raise ValidationError({"course_id": ["Invalid value."]})
|
||||
@@ -314,29 +334,55 @@ def create_comment(request, comment_data):
|
||||
detail.
|
||||
"""
|
||||
thread_id = comment_data.get("thread_id")
|
||||
parent_id = comment_data.get("parent_id")
|
||||
if not thread_id:
|
||||
raise ValidationError({"thread_id": ["This field is required."]})
|
||||
try:
|
||||
thread = Thread(id=thread_id).retrieve(mark_as_read=False)
|
||||
course_key = CourseLocator.from_string(thread["course_id"])
|
||||
course = _get_course_or_404(course_key, request.user)
|
||||
except (Http404, CommentClientRequestError):
|
||||
cc_thread, context = _get_thread_and_context(request, thread_id, parent_id)
|
||||
except Http404:
|
||||
raise ValidationError({"thread_id": ["Invalid value."]})
|
||||
|
||||
parent_id = comment_data.get("parent_id")
|
||||
context = get_context(course, request, thread, parent_id)
|
||||
serializer = CommentSerializer(data=comment_data, context=context)
|
||||
if not serializer.is_valid():
|
||||
raise ValidationError(serializer.errors)
|
||||
serializer.save()
|
||||
|
||||
comment = serializer.object
|
||||
cc_comment = serializer.object
|
||||
track_forum_event(
|
||||
request,
|
||||
get_comment_created_event_name(comment),
|
||||
course,
|
||||
comment,
|
||||
get_comment_created_event_data(comment, thread["commentable_id"], followed=False)
|
||||
get_comment_created_event_name(cc_comment),
|
||||
context["course"],
|
||||
cc_comment,
|
||||
get_comment_created_event_data(cc_comment, cc_thread["commentable_id"], followed=False)
|
||||
)
|
||||
|
||||
return serializer.data
|
||||
|
||||
|
||||
def update_thread(request, thread_id, update_data):
|
||||
"""
|
||||
Update a thread.
|
||||
|
||||
Parameters:
|
||||
|
||||
request: The django request object used for build_absolute_uri and
|
||||
determining the requesting user.
|
||||
|
||||
thread_id: The id for the thread to update.
|
||||
|
||||
update_data: The data to update in the thread.
|
||||
|
||||
Returns:
|
||||
|
||||
The updated thread; see discussion_api.views.ThreadViewSet for more
|
||||
detail.
|
||||
"""
|
||||
cc_thread, context = _get_thread_and_context(request, thread_id)
|
||||
is_author = str(request.user.id) == cc_thread["user_id"]
|
||||
if not (context["is_requester_privileged"] or is_author):
|
||||
raise PermissionDenied()
|
||||
serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context)
|
||||
if not serializer.is_valid():
|
||||
raise ValidationError(serializer.errors)
|
||||
serializer.save()
|
||||
return serializer.data
|
||||
|
||||
@@ -21,6 +21,7 @@ from lms.lib.comment_client.thread import Thread
|
||||
from lms.lib.comment_client.user import User as CommentClientUser
|
||||
from lms.lib.comment_client.utils import CommentClientRequestError
|
||||
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
|
||||
from openedx.core.lib.api.fields import NonEmptyCharField
|
||||
|
||||
|
||||
def get_context(course, request, thread=None, parent_id=None):
|
||||
@@ -44,15 +45,16 @@ def get_context(course, request, thread=None, parent_id=None):
|
||||
}
|
||||
requester = request.user
|
||||
return {
|
||||
# For now, the only groups are cohorts
|
||||
"course": course,
|
||||
"request": request,
|
||||
"thread": thread,
|
||||
"parent_id": parent_id,
|
||||
# For now, the only groups are cohorts
|
||||
"group_ids_to_names": get_cohort_names(course),
|
||||
"is_requester_privileged": requester.id in staff_user_ids or requester.id in ta_user_ids,
|
||||
"staff_user_ids": staff_user_ids,
|
||||
"ta_user_ids": ta_user_ids,
|
||||
"cc_requester": CommentClientUser.from_django_user(requester).retrieve(),
|
||||
"thread": thread,
|
||||
"parent_id": parent_id,
|
||||
}
|
||||
|
||||
|
||||
@@ -63,7 +65,7 @@ class _ContentSerializer(serializers.Serializer):
|
||||
author_label = serializers.SerializerMethodField("get_author_label")
|
||||
created_at = serializers.CharField(read_only=True)
|
||||
updated_at = serializers.CharField(read_only=True)
|
||||
raw_body = serializers.CharField(source="body")
|
||||
raw_body = NonEmptyCharField(source="body")
|
||||
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
|
||||
voted = serializers.SerializerMethodField("get_voted")
|
||||
vote_count = serializers.SerializerMethodField("get_vote_count")
|
||||
@@ -138,14 +140,14 @@ class ThreadSerializer(_ContentSerializer):
|
||||
at introspection and Thread's __getattr__.
|
||||
"""
|
||||
course_id = serializers.CharField()
|
||||
topic_id = serializers.CharField(source="commentable_id")
|
||||
topic_id = NonEmptyCharField(source="commentable_id")
|
||||
group_id = serializers.IntegerField(read_only=True)
|
||||
group_name = serializers.SerializerMethodField("get_group_name")
|
||||
type_ = serializers.ChoiceField(
|
||||
source="thread_type",
|
||||
choices=[(val, val) for val in ["discussion", "question"]]
|
||||
)
|
||||
title = serializers.CharField()
|
||||
title = NonEmptyCharField()
|
||||
pinned = serializers.BooleanField(read_only=True)
|
||||
closed = serializers.BooleanField(read_only=True)
|
||||
following = serializers.SerializerMethodField("get_following")
|
||||
@@ -198,10 +200,19 @@ class ThreadSerializer(_ContentSerializer):
|
||||
"""Returns the URL to retrieve the thread's non-endorsed comments."""
|
||||
return self.get_comment_list_url(obj, endorsed=False)
|
||||
|
||||
def validate_course_id(self, attrs, _source):
|
||||
"""Ensure that course_id is not edited in an update operation."""
|
||||
if self.object:
|
||||
raise ValidationError("This field is not allowed in an update.")
|
||||
return attrs
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
if instance:
|
||||
raise ValueError("ThreadSerializer cannot be used for updates.")
|
||||
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
|
||||
for key, val in attrs.items():
|
||||
instance[key] = val
|
||||
return instance
|
||||
else:
|
||||
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
|
||||
|
||||
|
||||
class CommentSerializer(_ContentSerializer):
|
||||
|
||||
@@ -15,6 +15,8 @@ from django.core.exceptions import ValidationError
|
||||
from django.http import Http404
|
||||
from django.test.client import RequestFactory
|
||||
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
from opaque_keys.edx.locator import CourseLocator
|
||||
|
||||
from courseware.tests.factories import BetaTesterFactory, StaffFactory
|
||||
@@ -24,6 +26,7 @@ from discussion_api.api import (
|
||||
get_comment_list,
|
||||
get_course_topics,
|
||||
get_thread_list,
|
||||
update_thread,
|
||||
)
|
||||
from discussion_api.tests.utils import (
|
||||
CommentsServiceMockMixin,
|
||||
@@ -685,18 +688,32 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
|
||||
FORUM_ROLE_STUDENT,
|
||||
],
|
||||
[True, False],
|
||||
[True, False],
|
||||
["no_group", "match_group", "different_group"],
|
||||
)
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
|
||||
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
|
||||
def test_group_access(
|
||||
self,
|
||||
role_name,
|
||||
course_is_cohorted,
|
||||
topic_is_cohorted,
|
||||
thread_group_state
|
||||
):
|
||||
cohort_course = CourseFactory.create(
|
||||
discussion_topics={"Test Topic": {"id": "test_topic"}},
|
||||
cohort_config={
|
||||
"cohorted": course_is_cohorted,
|
||||
"cohorted_discussions": ["test_topic"] if topic_is_cohorted else [],
|
||||
}
|
||||
)
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
|
||||
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
|
||||
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
|
||||
role.users = [self.user]
|
||||
thread = self.make_minimal_cs_thread({
|
||||
"course_id": unicode(cohort_course.id),
|
||||
"commentable_id": "test_topic",
|
||||
"group_id": (
|
||||
None if thread_group_state == "no_group" else
|
||||
cohort.id if thread_group_state == "match_group" else
|
||||
@@ -706,6 +723,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
|
||||
expected_error = (
|
||||
role_name == FORUM_ROLE_STUDENT and
|
||||
course_is_cohorted and
|
||||
topic_is_cohorted and
|
||||
thread_group_state == "different_group"
|
||||
)
|
||||
try:
|
||||
@@ -1287,8 +1305,224 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
|
||||
create_comment(self.request, self.minimal_data)
|
||||
self.assertEqual(assertion.exception.message_dict, {"thread_id": ["Invalid value."]})
|
||||
|
||||
@ddt.data(
|
||||
*itertools.product(
|
||||
[
|
||||
FORUM_ROLE_ADMINISTRATOR,
|
||||
FORUM_ROLE_MODERATOR,
|
||||
FORUM_ROLE_COMMUNITY_TA,
|
||||
FORUM_ROLE_STUDENT,
|
||||
],
|
||||
[True, False],
|
||||
["no_group", "match_group", "different_group"],
|
||||
)
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
|
||||
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
|
||||
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
|
||||
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
|
||||
role.users = [self.user]
|
||||
self.register_get_thread_response(make_minimal_cs_thread({
|
||||
"id": "cohort_thread",
|
||||
"course_id": unicode(cohort_course.id),
|
||||
"group_id": (
|
||||
None if thread_group_state == "no_group" else
|
||||
cohort.id if thread_group_state == "match_group" else
|
||||
cohort.id + 1
|
||||
),
|
||||
}))
|
||||
self.register_post_comment_response({}, thread_id="cohort_thread")
|
||||
data = self.minimal_data.copy()
|
||||
data["thread_id"] = "cohort_thread"
|
||||
expected_error = (
|
||||
role_name == FORUM_ROLE_STUDENT and
|
||||
course_is_cohorted and
|
||||
thread_group_state == "different_group"
|
||||
)
|
||||
try:
|
||||
create_comment(self.request, data)
|
||||
self.assertFalse(expected_error)
|
||||
except ValidationError as err:
|
||||
self.assertTrue(expected_error)
|
||||
self.assertEqual(
|
||||
err.message_dict,
|
||||
{"thread_id": ["Invalid value."]}
|
||||
)
|
||||
|
||||
def test_invalid_field(self):
|
||||
data = self.minimal_data.copy()
|
||||
del data["raw_body"]
|
||||
with self.assertRaises(ValidationError):
|
||||
create_comment(self.request, data)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
|
||||
"""Tests for update_thread"""
|
||||
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
|
||||
def setUp(self):
|
||||
super(UpdateThreadTest, self).setUp()
|
||||
httpretty.reset()
|
||||
httpretty.enable()
|
||||
self.addCleanup(httpretty.disable)
|
||||
self.user = UserFactory.create()
|
||||
self.register_get_user_response(self.user)
|
||||
self.request = RequestFactory().get("/test_path")
|
||||
self.request.user = self.user
|
||||
self.course = CourseFactory.create()
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
|
||||
|
||||
def register_thread(self, overrides=None):
|
||||
"""
|
||||
Make a thread with appropriate data overridden by the overrides
|
||||
parameter and register mock responses for both GET and PUT on its
|
||||
endpoint.
|
||||
"""
|
||||
cs_data = make_minimal_cs_thread({
|
||||
"id": "test_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"commentable_id": "original_topic",
|
||||
"username": self.user.username,
|
||||
"user_id": str(self.user.id),
|
||||
"created_at": "2015-05-29T00:00:00Z",
|
||||
"updated_at": "2015-05-29T00:00:00Z",
|
||||
"type": "discussion",
|
||||
"title": "Original Title",
|
||||
"body": "Original body",
|
||||
})
|
||||
cs_data.update(overrides or {})
|
||||
self.register_get_thread_response(cs_data)
|
||||
self.register_put_thread_response(cs_data)
|
||||
|
||||
def test_basic(self):
|
||||
self.register_thread()
|
||||
actual = update_thread(self.request, "test_thread", {"raw_body": "Edited body"})
|
||||
expected = {
|
||||
"id": "test_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"topic_id": "original_topic",
|
||||
"group_id": None,
|
||||
"group_name": None,
|
||||
"author": self.user.username,
|
||||
"author_label": None,
|
||||
"created_at": "2015-05-29T00:00:00Z",
|
||||
"updated_at": "2015-05-29T00:00:00Z",
|
||||
"type": "discussion",
|
||||
"title": "Original Title",
|
||||
"raw_body": "Edited body",
|
||||
"pinned": False,
|
||||
"closed": False,
|
||||
"following": False,
|
||||
"abuse_flagged": False,
|
||||
"voted": False,
|
||||
"vote_count": 0,
|
||||
"comment_count": 0,
|
||||
"unread_comment_count": 0,
|
||||
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread",
|
||||
"endorsed_comment_list_url": None,
|
||||
"non_endorsed_comment_list_url": None,
|
||||
}
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(
|
||||
httpretty.last_request().parsed_body,
|
||||
{
|
||||
"course_id": [unicode(self.course.id)],
|
||||
"commentable_id": ["original_topic"],
|
||||
"thread_type": ["discussion"],
|
||||
"title": ["Original Title"],
|
||||
"body": ["Edited body"],
|
||||
"user_id": [str(self.user.id)],
|
||||
"anonymous": ["False"],
|
||||
"anonymous_to_peers": ["False"],
|
||||
"closed": ["False"],
|
||||
"pinned": ["False"],
|
||||
}
|
||||
)
|
||||
|
||||
def test_nonexistent_thread(self):
|
||||
self.register_get_thread_error_response("test_thread", 404)
|
||||
with self.assertRaises(Http404):
|
||||
update_thread(self.request, "test_thread", {})
|
||||
|
||||
def test_nonexistent_course(self):
|
||||
self.register_thread({"course_id": "non/existent/course"})
|
||||
with self.assertRaises(Http404):
|
||||
update_thread(self.request, "test_thread", {})
|
||||
|
||||
def test_unenrolled(self):
|
||||
self.register_thread()
|
||||
self.request.user = UserFactory.create()
|
||||
with self.assertRaises(Http404):
|
||||
update_thread(self.request, "test_thread", {})
|
||||
|
||||
def test_discussions_disabled(self):
|
||||
_remove_discussion_tab(self.course, self.user.id)
|
||||
self.register_thread()
|
||||
with self.assertRaises(Http404):
|
||||
update_thread(self.request, "test_thread", {})
|
||||
|
||||
@ddt.data(
|
||||
*itertools.product(
|
||||
[
|
||||
FORUM_ROLE_ADMINISTRATOR,
|
||||
FORUM_ROLE_MODERATOR,
|
||||
FORUM_ROLE_COMMUNITY_TA,
|
||||
FORUM_ROLE_STUDENT,
|
||||
],
|
||||
[True, False],
|
||||
["no_group", "match_group", "different_group"],
|
||||
)
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
|
||||
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
|
||||
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
|
||||
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
|
||||
role.users = [self.user]
|
||||
self.register_thread({
|
||||
"course_id": unicode(cohort_course.id),
|
||||
"group_id": (
|
||||
None if thread_group_state == "no_group" else
|
||||
cohort.id if thread_group_state == "match_group" else
|
||||
cohort.id + 1
|
||||
),
|
||||
})
|
||||
expected_error = (
|
||||
role_name == FORUM_ROLE_STUDENT and
|
||||
course_is_cohorted and
|
||||
thread_group_state == "different_group"
|
||||
)
|
||||
try:
|
||||
update_thread(self.request, "test_thread", {})
|
||||
self.assertFalse(expected_error)
|
||||
except Http404:
|
||||
self.assertTrue(expected_error)
|
||||
|
||||
@ddt.data(
|
||||
FORUM_ROLE_ADMINISTRATOR,
|
||||
FORUM_ROLE_MODERATOR,
|
||||
FORUM_ROLE_COMMUNITY_TA,
|
||||
FORUM_ROLE_STUDENT,
|
||||
)
|
||||
def test_non_author_access(self, role_name):
|
||||
role = Role.objects.create(name=role_name, course_id=self.course.id)
|
||||
role.users = [self.user]
|
||||
self.register_thread({"user_id": str(self.user.id + 1)})
|
||||
expected_error = role_name == FORUM_ROLE_STUDENT
|
||||
try:
|
||||
update_thread(self.request, "test_thread", {})
|
||||
self.assertFalse(expected_error)
|
||||
except PermissionDenied:
|
||||
self.assertTrue(expected_error)
|
||||
|
||||
def test_invalid_field(self):
|
||||
self.register_thread()
|
||||
with self.assertRaises(ValidationError) as assertion:
|
||||
update_thread(self.request, "test_thread", {"raw_body": ""})
|
||||
self.assertEqual(
|
||||
assertion.exception.message_dict,
|
||||
{"raw_body": ["This field is required."]}
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from django_comment_common.models import (
|
||||
FORUM_ROLE_STUDENT,
|
||||
Role,
|
||||
)
|
||||
from lms.lib.comment_client.thread import Thread
|
||||
from student.tests.factories import UserFactory
|
||||
from util.testing import UrlResetMixin
|
||||
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
|
||||
@@ -378,7 +379,6 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
httpretty.reset()
|
||||
httpretty.enable()
|
||||
self.addCleanup(httpretty.disable)
|
||||
self.register_post_thread_response({"id": "test_id"})
|
||||
self.course = CourseFactory.create()
|
||||
self.user = UserFactory.create()
|
||||
self.register_get_user_response(self.user)
|
||||
@@ -391,18 +391,34 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
"title": "Test Title",
|
||||
"raw_body": "Test body",
|
||||
}
|
||||
self.existing_thread = Thread(**make_minimal_cs_thread({
|
||||
"id": "existing_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"commentable_id": "original_topic",
|
||||
"thread_type": "discussion",
|
||||
"title": "Original Title",
|
||||
"body": "Original body",
|
||||
"user_id": str(self.user.id),
|
||||
}))
|
||||
|
||||
def save_and_reserialize(self, data):
|
||||
def save_and_reserialize(self, data, instance=None):
|
||||
"""
|
||||
Create a serializer with the given data, ensure that it is valid, save
|
||||
the result, and return the full thread data from the serializer.
|
||||
Create a serializer with the given data and (if updating) instance,
|
||||
ensure that it is valid, save the result, and return the full thread
|
||||
data from the serializer.
|
||||
"""
|
||||
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
|
||||
serializer = ThreadSerializer(
|
||||
instance,
|
||||
data=data,
|
||||
partial=(instance is not None),
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertTrue(serializer.is_valid())
|
||||
serializer.save()
|
||||
return serializer.data
|
||||
|
||||
def test_minimal(self):
|
||||
def test_create_minimal(self):
|
||||
self.register_post_thread_response({"id": "test_id"})
|
||||
saved = self.save_and_reserialize(self.minimal_data)
|
||||
self.assertEqual(
|
||||
urlparse(httpretty.last_request().path).path,
|
||||
@@ -421,7 +437,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
)
|
||||
self.assertEqual(saved["id"], "test_id")
|
||||
|
||||
def test_missing_field(self):
|
||||
def test_create_missing_field(self):
|
||||
for field in self.minimal_data:
|
||||
data = self.minimal_data.copy()
|
||||
data.pop(field)
|
||||
@@ -432,7 +448,8 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
{field: ["This field is required."]}
|
||||
)
|
||||
|
||||
def test_type(self):
|
||||
def test_create_type(self):
|
||||
self.register_post_thread_response({"id": "test_id"})
|
||||
data = self.minimal_data.copy()
|
||||
data["type"] = "question"
|
||||
self.save_and_reserialize(data)
|
||||
@@ -441,6 +458,76 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
serializer = ThreadSerializer(data=data)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
|
||||
def test_update_empty(self):
|
||||
self.register_put_thread_response(self.existing_thread.attributes)
|
||||
self.save_and_reserialize({}, self.existing_thread)
|
||||
self.assertEqual(
|
||||
httpretty.last_request().parsed_body,
|
||||
{
|
||||
"course_id": [unicode(self.course.id)],
|
||||
"commentable_id": ["original_topic"],
|
||||
"thread_type": ["discussion"],
|
||||
"title": ["Original Title"],
|
||||
"body": ["Original body"],
|
||||
"anonymous": ["False"],
|
||||
"anonymous_to_peers": ["False"],
|
||||
"closed": ["False"],
|
||||
"pinned": ["False"],
|
||||
"user_id": [str(self.user.id)],
|
||||
}
|
||||
)
|
||||
|
||||
def test_update_all(self):
|
||||
self.register_put_thread_response(self.existing_thread.attributes)
|
||||
data = {
|
||||
"topic_id": "edited_topic",
|
||||
"type": "question",
|
||||
"title": "Edited Title",
|
||||
"raw_body": "Edited body",
|
||||
}
|
||||
saved = self.save_and_reserialize(data, self.existing_thread)
|
||||
self.assertEqual(
|
||||
httpretty.last_request().parsed_body,
|
||||
{
|
||||
"course_id": [unicode(self.course.id)],
|
||||
"commentable_id": ["edited_topic"],
|
||||
"thread_type": ["question"],
|
||||
"title": ["Edited Title"],
|
||||
"body": ["Edited body"],
|
||||
"anonymous": ["False"],
|
||||
"anonymous_to_peers": ["False"],
|
||||
"closed": ["False"],
|
||||
"pinned": ["False"],
|
||||
"user_id": [str(self.user.id)],
|
||||
}
|
||||
)
|
||||
for key in data:
|
||||
self.assertEqual(saved[key], data[key])
|
||||
|
||||
def test_update_empty_string(self):
|
||||
serializer = ThreadSerializer(
|
||||
self.existing_thread,
|
||||
data={field: "" for field in ["topic_id", "title", "raw_body"]},
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
|
||||
)
|
||||
|
||||
def test_update_course_id(self):
|
||||
serializer = ThreadSerializer(
|
||||
self.existing_thread,
|
||||
data={"course_id": "some/other/course"},
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{"course_id": ["This field is not allowed in an update."]}
|
||||
)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStoreTestCase):
|
||||
|
||||
@@ -11,6 +11,8 @@ from pytz import UTC
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread
|
||||
from student.tests.factories import CourseEnrollmentFactory, UserFactory
|
||||
from util.testing import UrlResetMixin
|
||||
@@ -25,6 +27,8 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
|
||||
in the test client, utility functions, and a test case for unauthenticated
|
||||
requests. Subclasses must set self.url in their setUp methods.
|
||||
"""
|
||||
client_class = APIClient
|
||||
|
||||
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
|
||||
def setUp(self):
|
||||
super(DiscussionAPIViewTestMixin, self).setUp()
|
||||
@@ -294,6 +298,101 @@ class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
self.assertEqual(response_data, expected_response_data)
|
||||
|
||||
|
||||
@httpretty.activate
|
||||
class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
"""Tests for ThreadViewSet partial_update"""
|
||||
def setUp(self):
|
||||
super(ThreadViewSetPartialUpdateTest, self).setUp()
|
||||
self.url = reverse("thread-detail", kwargs={"thread_id": "test_thread"})
|
||||
|
||||
def test_basic(self):
|
||||
self.register_get_user_response(self.user)
|
||||
cs_thread = make_minimal_cs_thread({
|
||||
"id": "test_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"commentable_id": "original_topic",
|
||||
"username": self.user.username,
|
||||
"user_id": str(self.user.id),
|
||||
"created_at": "2015-05-29T00:00:00Z",
|
||||
"updated_at": "2015-05-29T00:00:00Z",
|
||||
"thread_type": "discussion",
|
||||
"title": "Original Title",
|
||||
"body": "Original body",
|
||||
})
|
||||
self.register_get_thread_response(cs_thread)
|
||||
self.register_put_thread_response(cs_thread)
|
||||
request_data = {"raw_body": "Edited body"}
|
||||
expected_response_data = {
|
||||
"id": "test_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"topic_id": "original_topic",
|
||||
"group_id": None,
|
||||
"group_name": None,
|
||||
"author": self.user.username,
|
||||
"author_label": None,
|
||||
"created_at": "2015-05-29T00:00:00Z",
|
||||
"updated_at": "2015-05-29T00:00:00Z",
|
||||
"type": "discussion",
|
||||
"title": "Original Title",
|
||||
"raw_body": "Edited body",
|
||||
"pinned": False,
|
||||
"closed": False,
|
||||
"following": False,
|
||||
"abuse_flagged": False,
|
||||
"voted": False,
|
||||
"vote_count": 0,
|
||||
"comment_count": 0,
|
||||
"unread_comment_count": 0,
|
||||
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread",
|
||||
"endorsed_comment_list_url": None,
|
||||
"non_endorsed_comment_list_url": None,
|
||||
}
|
||||
response = self.client.patch( # pylint: disable=no-member
|
||||
self.url,
|
||||
json.dumps(request_data),
|
||||
content_type="application/json"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response_data = json.loads(response.content)
|
||||
self.assertEqual(response_data, expected_response_data)
|
||||
self.assertEqual(
|
||||
httpretty.last_request().parsed_body,
|
||||
{
|
||||
"course_id": [unicode(self.course.id)],
|
||||
"commentable_id": ["original_topic"],
|
||||
"thread_type": ["discussion"],
|
||||
"title": ["Original Title"],
|
||||
"body": ["Edited body"],
|
||||
"user_id": [str(self.user.id)],
|
||||
"anonymous": ["False"],
|
||||
"anonymous_to_peers": ["False"],
|
||||
"closed": ["False"],
|
||||
"pinned": ["False"],
|
||||
}
|
||||
)
|
||||
|
||||
def test_error(self):
|
||||
self.register_get_user_response(self.user)
|
||||
cs_thread = make_minimal_cs_thread({
|
||||
"id": "test_thread",
|
||||
"course_id": unicode(self.course.id),
|
||||
"user_id": str(self.user.id),
|
||||
})
|
||||
self.register_get_thread_response(cs_thread)
|
||||
request_data = {"title": ""}
|
||||
response = self.client.patch( # pylint: disable=no-member
|
||||
self.url,
|
||||
json.dumps(request_data),
|
||||
content_type="application/json"
|
||||
)
|
||||
expected_response_data = {
|
||||
"field_errors": {"title": {"developer_message": "This field is required."}}
|
||||
}
|
||||
self.assertEqual(response.status_code, 400)
|
||||
response_data = json.loads(response.content)
|
||||
self.assertEqual(response_data, expected_response_data)
|
||||
|
||||
|
||||
@httpretty.activate
|
||||
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
"""Tests for CommentViewSet list"""
|
||||
|
||||
@@ -7,6 +7,29 @@ import re
|
||||
import httpretty
|
||||
|
||||
|
||||
def _get_thread_callback(thread_data):
|
||||
"""
|
||||
Get a callback function that will return POST/PUT data overridden by
|
||||
response_overrides.
|
||||
"""
|
||||
def callback(request, _uri, headers):
|
||||
"""
|
||||
Simulate the thread creation or update endpoint by returning the provided
|
||||
data along with the data from response_overrides and dummy values for any
|
||||
additional required fields.
|
||||
"""
|
||||
response_data = make_minimal_cs_thread(thread_data)
|
||||
for key, val_list in request.parsed_body.items():
|
||||
val = val_list[0]
|
||||
if key in ["anonymous", "anonymous_to_peers", "closed", "pinned"]:
|
||||
response_data[key] = val == "True"
|
||||
else:
|
||||
response_data[key] = val
|
||||
return (200, headers, json.dumps(response_data))
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
class CommentsServiceMockMixin(object):
|
||||
"""Mixin with utility methods for mocking the comments service"""
|
||||
def register_get_threads_response(self, threads, page, num_pages):
|
||||
@@ -22,23 +45,23 @@ class CommentsServiceMockMixin(object):
|
||||
status=200
|
||||
)
|
||||
|
||||
def register_post_thread_response(self, response_overrides):
|
||||
def register_post_thread_response(self, thread_data):
|
||||
"""Register a mock response for POST on the CS commentable endpoint"""
|
||||
def callback(request, _uri, headers):
|
||||
"""
|
||||
Simulate the thread creation endpoint by returning the provided data
|
||||
along with the data from response_overrides.
|
||||
"""
|
||||
response_data = make_minimal_cs_thread(
|
||||
{key: val[0] for key, val in request.parsed_body.items()}
|
||||
)
|
||||
response_data.update(response_overrides)
|
||||
return (200, headers, json.dumps(response_data))
|
||||
|
||||
httpretty.register_uri(
|
||||
httpretty.POST,
|
||||
re.compile(r"http://localhost:4567/api/v1/(\w+)/threads"),
|
||||
body=callback
|
||||
body=_get_thread_callback(thread_data)
|
||||
)
|
||||
|
||||
def register_put_thread_response(self, thread_data):
|
||||
"""
|
||||
Register a mock response for PUT on the CS endpoint for the given
|
||||
thread_id.
|
||||
"""
|
||||
httpretty.register_uri(
|
||||
httpretty.PUT,
|
||||
"http://localhost:4567/api/v1/threads/{}".format(thread_data["id"]),
|
||||
body=_get_thread_callback(thread_data)
|
||||
)
|
||||
|
||||
def register_get_thread_error_response(self, thread_id, status_code):
|
||||
|
||||
@@ -17,6 +17,7 @@ from discussion_api.api import (
|
||||
get_comment_list,
|
||||
get_course_topics,
|
||||
get_thread_list,
|
||||
update_thread,
|
||||
)
|
||||
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
|
||||
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
|
||||
@@ -67,7 +68,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
"""
|
||||
**Use Cases**
|
||||
|
||||
Retrieve the list of threads for a course or post a new thread.
|
||||
Retrieve the list of threads for a course, post a new thread, or modify
|
||||
an existing thread.
|
||||
|
||||
**Example Requests**:
|
||||
|
||||
@@ -82,6 +84,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
"body": "Body text"
|
||||
}
|
||||
|
||||
PATCH /api/discussion/v1/threads/thread_id
|
||||
{"raw_body": "Edited text"}
|
||||
|
||||
**GET Parameters**:
|
||||
|
||||
* course_id (required): The course to retrieve threads for
|
||||
@@ -109,16 +114,21 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
* following (optional): A boolean indicating whether the user should
|
||||
follow the thread upon its creation; defaults to false
|
||||
|
||||
**PATCH Parameters**:
|
||||
|
||||
topic_id, type, title, and raw_body are accepted with the same meaning
|
||||
as in a POST request
|
||||
|
||||
**GET Response Values**:
|
||||
|
||||
* results: The list of threads; each item in the list has the same
|
||||
fields as the POST response below
|
||||
fields as the POST/PATCH response below
|
||||
|
||||
* next: The URL of the next page (or null if first page)
|
||||
|
||||
* previous: The URL of the previous page (or null if last page)
|
||||
|
||||
**POST response values**:
|
||||
**POST/PATCH response values**:
|
||||
|
||||
* id: The id of the thread
|
||||
|
||||
@@ -148,6 +158,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
the thread
|
||||
|
||||
"""
|
||||
lookup_field = "thread_id"
|
||||
|
||||
def list(self, request):
|
||||
"""
|
||||
Implements the GET method for the list endpoint as described in the
|
||||
@@ -173,6 +185,13 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
"""
|
||||
return Response(create_thread(request, request.DATA))
|
||||
|
||||
def partial_update(self, request, thread_id):
|
||||
"""
|
||||
Implements the PATCH method for the instance endpoint as described in
|
||||
the class docstring.
|
||||
"""
|
||||
return Response(update_thread(request, thread_id, request.DATA))
|
||||
|
||||
|
||||
class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Fields useful for edX API implementations."""
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from rest_framework.serializers import Field
|
||||
from rest_framework.serializers import CharField, Field
|
||||
|
||||
|
||||
class ExpandableField(Field):
|
||||
@@ -20,3 +21,17 @@ class ExpandableField(Field):
|
||||
else:
|
||||
self.collapsed.initialize(self, field_name)
|
||||
return self.collapsed.field_to_native(obj, field_name)
|
||||
|
||||
|
||||
class NonEmptyCharField(CharField):
|
||||
"""
|
||||
A field that enforces non-emptiness even for partial updates.
|
||||
|
||||
This is necessary because prior to version 3, DRF skips validation for empty
|
||||
values. Thus, CharField's min_length and RegexField cannot be used to
|
||||
enforce this constraint.
|
||||
"""
|
||||
def validate(self, value):
|
||||
super(NonEmptyCharField, self).validate(value)
|
||||
if not value:
|
||||
raise ValidationError(self.error_messages["required"])
|
||||
|
||||
Reference in New Issue
Block a user