Merge pull request #8097 from edx/gprice/discussion-api-comment-list

Add comment list endpoint to Discussion API
This commit is contained in:
Greg Price
2015-05-20 14:20:24 -04:00
10 changed files with 1090 additions and 206 deletions

View File

@@ -1,15 +1,19 @@
"""
Discussion API internal interface
"""
from django.core.exceptions import ValidationError
from django.http import Http404
from collections import defaultdict
from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access
from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import ThreadSerializer, get_context
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
from xmodule.tabs import DiscussionTab
@@ -123,3 +127,84 @@ def get_thread_list(request, course_key, page, page_size):
results = [ThreadSerializer(thread, context=context).data for thread in threads]
return get_paginated_data(request, results, page, num_pages)
def get_comment_list(request, thread_id, endorsed, page, page_size):
"""
Return the list of comments in the given thread.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
thread_id: The id of the thread to get comments for.
endorsed: Boolean indicating whether to get endorsed or non-endorsed
comments (or None for all comments). Must be None for a discussion
thread and non-None for a question thread.
page: The page number (1-indexed) to retrieve
page_size: The number of comments to retrieve per page
Returns:
A paginated result containing a list of comments; see
discussion_api.views.CommentViewSet for more detail.
"""
response_skip = page_size * (page - 1)
try:
cc_thread = Thread(id=thread_id).retrieve(
recursive=True,
user_id=request.user.id,
mark_as_read=True,
response_skip=response_skip,
response_limit=page_size
)
except CommentClientRequestError:
# page and page_size are validated at a higher level, so the only
# possible request error is if the thread doesn't exist
raise Http404
course_key = CourseLocator.from_string(cc_thread["course_id"])
course = _get_course_or_404(course_key, request.user)
context = get_context(course, request.user)
# Ensure user has access to the thread
if not context["is_requester_privileged"] and cc_thread["group_id"]:
requester_cohort = get_cohort_id(request.user, course_key)
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
raise Http404
# Responses to discussion threads cannot be separated by endorsed, but
# responses to question threads must be separated by endorsed due to the
# existing comments service interface
if cc_thread["thread_type"] == "question":
if endorsed is None:
raise ValidationError({"endorsed": ["This field is required for question threads."]})
elif endorsed:
# CS does not apply resp_skip and resp_limit to endorsed responses
# of a question post
responses = cc_thread["endorsed_responses"][response_skip:(response_skip + page_size)]
resp_total = len(cc_thread["endorsed_responses"])
else:
responses = cc_thread["non_endorsed_responses"]
resp_total = cc_thread["non_endorsed_resp_total"]
else:
if endorsed is not None:
raise ValidationError(
{"endorsed": ["This field may not be specified for discussion threads."]}
)
responses = cc_thread["children"]
resp_total = cc_thread["resp_total"]
# The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general
# behavior and return a 404 in that case
if not responses and page != 1:
raise Http404
num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1
results = [CommentSerializer(response, context=context).data for response in responses]
return get_paginated_data(request, results, page, num_pages)

View File

@@ -2,19 +2,31 @@
Discussion API forms
"""
from django.core.exceptions import ValidationError
from django.forms import Form, CharField, IntegerField
from django.forms import CharField, Form, IntegerField, NullBooleanField
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator
class ThreadListGetForm(Form):
class _PaginationForm(Form):
"""A form that includes pagination fields"""
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
def clean_page(self):
"""Return given valid page or default of 1"""
return self.cleaned_data.get("page") or 1
def clean_page_size(self):
"""Return given valid page_size (capped at 100) or default of 10"""
return min(self.cleaned_data.get("page_size") or 10, 100)
class ThreadListGetForm(_PaginationForm):
"""
A form to validate query parameters in the thread list retrieval endpoint
"""
course_id = CharField()
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
def clean_course_id(self):
"""Validate course_id"""
@@ -24,10 +36,12 @@ class ThreadListGetForm(Form):
except InvalidKeyError:
raise ValidationError("'{}' is not a valid course id".format(value))
def clean_page(self):
"""Return given valid page or default of 1"""
return self.cleaned_data.get("page") or 1
def clean_page_size(self):
"""Return given valid page_size (capped at 100) or default of 10"""
return min(self.cleaned_data.get("page_size") or 10, 100)
class CommentListGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment list retrieval endpoint
"""
thread_id = CharField()
# TODO: should we use something better here? This only accepts "True",
# "False", "1", and "0"
endorsed = NullBooleanField(required=False)

View File

@@ -14,7 +14,10 @@ from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
def get_context(course, requester):
"""Returns a context appropriate for use with ThreadSerializer."""
"""
Returns a context appropriate for use with ThreadSerializer or
CommentSerializer.
"""
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = {
user.id
@@ -39,49 +42,27 @@ def get_context(course, requester):
}
class ThreadSerializer(serializers.Serializer):
"""
A serializer for thread data.
N.B. This should not be used with a comment_client Thread object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
class _ContentSerializer(serializers.Serializer):
"""A base class for thread and comment serializers."""
id_ = serializers.CharField(read_only=True)
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_name = serializers.SerializerMethodField("get_group_name")
author = serializers.SerializerMethodField("get_author")
author_label = serializers.SerializerMethodField("get_author_label")
created_at = serializers.CharField(read_only=True)
updated_at = serializers.CharField(read_only=True)
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
title = serializers.CharField()
raw_body = serializers.CharField(source="body")
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
following = serializers.SerializerMethodField("get_following")
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
# type and id are invalid class attribute names, so we must declare
# different names above and modify them here
super(_ContentSerializer, self).__init__(*args, **kwargs)
# id is an invalid class attribute name, so we must declare a different
# name above and modify it here
self.fields["id"] = self.fields.pop("id_")
self.fields["type"] = self.fields.pop("type_")
def get_group_name(self, obj):
"""Returns the name of the group identified by the thread's group_id."""
return self.context["group_ids_to_names"].get(obj["group_id"])
def _is_anonymous(self, obj):
"""
Returns a boolean indicating whether the thread should be anonymous to
Returns a boolean indicating whether the content should be anonymous to
the requester.
"""
return (
@@ -90,7 +71,7 @@ class ThreadSerializer(serializers.Serializer):
)
def get_author(self, obj):
"""Returns the author's username, or None if the thread is anonymous."""
"""Returns the author's username, or None if the content is anonymous."""
return None if self._is_anonymous(obj) else obj["username"]
def _get_user_label(self, user_id):
@@ -105,9 +86,58 @@ class ThreadSerializer(serializers.Serializer):
)
def get_author_label(self, obj):
"""Returns the role label for the thread author."""
"""Returns the role label for the content author."""
return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"]))
def get_abuse_flagged(self, obj):
"""
Returns a boolean indicating whether the requester has flagged the
content as abusive.
"""
return self.context["cc_requester"]["id"] in obj["abuse_flaggers"]
def get_voted(self, obj):
"""
Returns a boolean indicating whether the requester has voted for the
content.
"""
return obj["id"] in self.context["cc_requester"]["upvoted_ids"]
def get_vote_count(self, obj):
"""Returns the number of votes for the content."""
return obj["votes"]["up_count"]
class ThreadSerializer(_ContentSerializer):
"""
A serializer for thread data.
N.B. This should not be used with a comment_client Thread object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
title = serializers.CharField()
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
following = serializers.SerializerMethodField("get_following")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
# type is an invalid class attribute name, so we must declare a
# different name above and modify it here
self.fields["type"] = self.fields.pop("type_")
def get_group_name(self, obj):
"""Returns the name of the group identified by the thread's group_id."""
return self.context["group_ids_to_names"].get(obj["group_id"])
def get_following(self, obj):
"""
Returns a boolean indicating whether the requester is following the
@@ -115,20 +145,25 @@ class ThreadSerializer(serializers.Serializer):
"""
return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"]
def get_abuse_flagged(self, obj):
"""
Returns a boolean indicating whether the requester has flagged the
thread as abusive.
"""
return self.context["cc_requester"]["id"] in obj["abuse_flaggers"]
def get_voted(self, obj):
"""
Returns a boolean indicating whether the requester has voted for the
thread.
"""
return obj["id"] in self.context["cc_requester"]["upvoted_ids"]
class CommentSerializer(_ContentSerializer):
"""
A serializer for comment data.
def get_vote_count(self, obj):
"""Returns the number of votes for the thread."""
return obj["votes"]["up_count"]
N.B. This should not be used with a comment_client Comment object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Comment's __getattr__.
"""
thread_id = serializers.CharField()
parent_id = serializers.SerializerMethodField("get_parent_id")
children = serializers.SerializerMethodField("get_children")
def get_parent_id(self, _obj):
"""Returns the comment's parent's id (taken from the context)."""
return self.context.get("parent_id")
def get_children(self, obj):
"""Returns the list of the comment's children, serialized."""
child_context = dict(self.context)
child_context["parent_id"] = obj["id"]
return [CommentSerializer(child, context=child_context).data for child in obj["children"]]

View File

@@ -9,14 +9,19 @@ import httpretty
import mock
from pytz import UTC
from django.core.exceptions import ValidationError
from django.http import Http404
from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.tests.utils import CommentsServiceMockMixin
from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA,
@@ -378,12 +383,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
ret = get_thread_list(self.request, course.id, page, page_size)
return ret
def create_role(self, role_name, users):
"""Create a Role in self.course with the given name and users"""
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = users
role.save()
def test_nonexistent_course(self):
with self.assertRaises(Http404):
get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1)
@@ -573,3 +572,368 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
self.register_get_threads_response([], page=3, num_pages=3)
with self.assertRaises(Http404):
get_thread_list(self.request, self.course.id, page=4, page_size=10)
@ddt.ddt
class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Test for get_comment_list"""
def setUp(self):
super(GetCommentListTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.maxDiff = None # pylint: disable=invalid-name
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.author = UserFactory.create()
def make_minimal_cs_thread(self, overrides=None):
"""
Create a thread with the given overrides, plus the course_id if not
already in overrides.
"""
overrides = overrides.copy() if overrides else {}
overrides.setdefault("course_id", unicode(self.course.id))
return make_minimal_cs_thread(overrides)
def get_comment_list(self, thread, endorsed=None, page=1, page_size=1):
"""
Register the appropriate comments service response, then call
get_comment_list and return the result.
"""
self.register_get_thread_response(thread)
return get_comment_list(self.request, thread["id"], endorsed, page, page_size)
def test_nonexistent_thread(self):
thread_id = "nonexistent_thread"
self.register_get_thread_error_response(thread_id, 404)
with self.assertRaises(Http404):
get_comment_list(self.request, thread_id, endorsed=False, page=1, page_size=1)
def test_nonexistent_course(self):
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread({"course_id": "non/existent/course"}))
def test_not_enrolled(self):
self.request.user = UserFactory.create()
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread())
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread())
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["no_group", "match_group", "different_group"],
)
)
@ddt.unpack
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
thread = self.make_minimal_cs_thread({
"course_id": unicode(cohort_course.id),
"group_id": (
None if thread_group_state == "no_group" else
cohort.id if thread_group_state == "match_group" else
cohort.id + 1
),
})
expected_error = (
role_name == FORUM_ROLE_STUDENT and
course_is_cohorted and
thread_group_state == "different_group"
)
try:
self.get_comment_list(thread)
self.assertFalse(expected_error)
except Http404:
self.assertTrue(expected_error)
@ddt.data(True, False)
def test_discussion_endorsed(self, endorsed_value):
with self.assertRaises(ValidationError) as assertion:
self.get_comment_list(
self.make_minimal_cs_thread({"thread_type": "discussion"}),
endorsed=endorsed_value
)
self.assertEqual(
assertion.exception.message_dict,
{"endorsed": ["This field may not be specified for discussion threads."]}
)
def test_question_without_endorsed(self):
with self.assertRaises(ValidationError) as assertion:
self.get_comment_list(
self.make_minimal_cs_thread({"thread_type": "question"}),
endorsed=None
)
self.assertEqual(
assertion.exception.message_dict,
{"endorsed": ["This field is required for question threads."]}
)
def test_empty(self):
discussion_thread = self.make_minimal_cs_thread(
{"thread_type": "discussion", "children": [], "resp_total": 0}
)
self.assertEqual(
self.get_comment_list(discussion_thread),
{"results": [], "next": None, "previous": None}
)
question_thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [],
"non_endorsed_responses": [],
"non_endorsed_resp_total": 0
})
self.assertEqual(
self.get_comment_list(question_thread, endorsed=False),
{"results": [], "next": None, "previous": None}
)
self.assertEqual(
self.get_comment_list(question_thread, endorsed=True),
{"results": [], "next": None, "previous": None}
)
def test_basic_query_params(self):
self.get_comment_list(
self.make_minimal_cs_thread({
"children": [make_minimal_cs_comment()],
"resp_total": 71
}),
page=6,
page_size=14
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
"resp_skip": ["70"],
"resp_limit": ["14"],
}
)
def test_discussion_content(self):
source_comments = [
{
"id": "test_comment_1",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
},
{
"id": "test_comment_2",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": True,
"anonymous_to_peers": False,
"created_at": "2015-05-11T22:22:22Z",
"updated_at": "2015-05-11T33:33:33Z",
"body": "More content",
"abuse_flaggers": [str(self.user.id)],
"votes": {"up_count": 7},
"children": [],
}
]
expected_comments = [
{
"id": "test_comment_1",
"thread_id": "test_thread",
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"children": [],
},
{
"id": "test_comment_2",
"thread_id": "test_thread",
"parent_id": None,
"author": None,
"author_label": None,
"created_at": "2015-05-11T22:22:22Z",
"updated_at": "2015-05-11T33:33:33Z",
"raw_body": "More content",
"abuse_flagged": True,
"voted": False,
"vote_count": 7,
"children": [],
},
]
actual_comments = self.get_comment_list(
self.make_minimal_cs_thread({"children": source_comments})
)["results"]
self.assertEqual(actual_comments, expected_comments)
def test_question_content(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})],
"non_endorsed_resp_total": 1,
})
endorsed_actual = self.get_comment_list(thread, endorsed=True)
self.assertEqual(endorsed_actual["results"][0]["id"], "endorsed_comment")
non_endorsed_actual = self.get_comment_list(thread, endorsed=False)
self.assertEqual(non_endorsed_actual["results"][0]["id"], "non_endorsed_comment")
@ddt.data(
("discussion", None, "children", "resp_total"),
("question", False, "non_endorsed_responses", "non_endorsed_resp_total"),
)
@ddt.unpack
def test_cs_pagination(self, thread_type, endorsed_arg, response_field, response_total_field):
"""
Test cases in which pagination is done by the comments service.
thread_type is the type of thread (question or discussion).
endorsed_arg is the value of the endorsed argument.
repsonse_field is the field in which responses are returned for the
given thread type.
response_total_field is the field in which the total number of responses
is returned for the given thread type.
"""
# N.B. The mismatch between the number of children and the listed total
# number of responses is unrealistic but convenient for this test
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [make_minimal_cs_comment()],
response_total_field: 5,
})
# Only page
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=5)
self.assertIsNone(actual["next"])
self.assertIsNone(actual["previous"])
# First page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=2)
self.assertEqual(actual["next"], "http://testserver/test_path?page=2")
self.assertIsNone(actual["previous"])
# Middle page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=2)
self.assertEqual(actual["next"], "http://testserver/test_path?page=3")
self.assertEqual(actual["previous"], "http://testserver/test_path?page=1")
# Last page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=3, page_size=2)
self.assertIsNone(actual["next"])
self.assertEqual(actual["previous"], "http://testserver/test_path?page=2")
# Page past the end
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [],
response_total_field: 5
})
with self.assertRaises(Http404):
self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=5)
def test_question_endorsed_pagination(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [
make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10)
]
})
def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev):
"""
Check that requesting the given page/page_size returns the expected
output
"""
actual = self.get_comment_list(thread, endorsed=True, page=page, page_size=page_size)
result_ids = [result["id"] for result in actual["results"]]
self.assertEqual(
result_ids,
["comment_{}".format(i) for i in range(expected_start, expected_stop)]
)
self.assertEqual(
actual["next"],
"http://testserver/test_path?page={}".format(expected_next) if expected_next else None
)
self.assertEqual(
actual["previous"],
"http://testserver/test_path?page={}".format(expected_prev) if expected_prev else None
)
# Only page
assert_page_correct(
page=1,
page_size=10,
expected_start=0,
expected_stop=10,
expected_next=None,
expected_prev=None
)
# First page of many
assert_page_correct(
page=1,
page_size=4,
expected_start=0,
expected_stop=4,
expected_next=2,
expected_prev=None
)
# Middle page of many
assert_page_correct(
page=2,
page_size=4,
expected_start=4,
expected_stop=8,
expected_next=3,
expected_prev=1
)
# Last page of many
assert_page_correct(
page=3,
page_size=4,
expected_start=8,
expected_stop=10,
expected_next=None,
expected_prev=2
)
# Page past the end
with self.assertRaises(Http404):
self.get_comment_list(thread, endorsed=True, page=2, page_size=10)

View File

@@ -5,25 +5,17 @@ from unittest import TestCase
from opaque_keys.edx.locator import CourseLocator
from discussion_api.forms import ThreadListGetForm
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
class ThreadListGetFormTest(TestCase):
"""Tests for ThreadListGetForm"""
def setUp(self):
super(ThreadListGetFormTest, self).setUp()
self.form_data = {
"course_id": "Foo/Bar/Baz",
"page": "2",
"page_size": "13",
}
class FormTestMixin(object):
"""A mixin for testing forms"""
def get_form(self, expected_valid):
"""
Return a form bound to self.form_data, asserting its validity (or lack
thereof) according to expected_valid
"""
form = ThreadListGetForm(self.form_data)
form = self.FORM_CLASS(self.form_data)
self.assertEqual(form.is_valid(), expected_valid)
return form
@@ -44,25 +36,9 @@ class ThreadListGetFormTest(TestCase):
form = self.get_form(expected_valid=True)
self.assertEqual(form.cleaned_data[field], expected_value)
def test_basic(self):
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data,
{
"course_id": CourseLocator.from_string("Foo/Bar/Baz"),
"page": 2,
"page_size": 13,
}
)
def test_missing_course_id(self):
self.form_data.pop("course_id")
self.assert_error("course_id", "This field is required.")
def test_invalid_course_id(self):
self.form_data["course_id"] = "invalid course id"
self.assert_error("course_id", "'invalid course id' is not a valid course id")
class PaginationTestMixin(object):
"""A mixin for testing forms with pagination fields"""
def test_missing_page(self):
self.form_data.pop("page")
self.assert_field_value("page", 1)
@@ -82,3 +58,69 @@ class ThreadListGetFormTest(TestCase):
def test_excessive_page_size(self):
self.form_data["page_size"] = "101"
self.assert_field_value("page_size", 100)
class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for ThreadListGetForm"""
FORM_CLASS = ThreadListGetForm
def setUp(self):
super(ThreadListGetFormTest, self).setUp()
self.form_data = {
"course_id": "Foo/Bar/Baz",
"page": "2",
"page_size": "13",
}
def test_basic(self):
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data,
{
"course_id": CourseLocator.from_string("Foo/Bar/Baz"),
"page": 2,
"page_size": 13,
}
)
def test_missing_course_id(self):
self.form_data.pop("course_id")
self.assert_error("course_id", "This field is required.")
def test_invalid_course_id(self):
self.form_data["course_id"] = "invalid course id"
self.assert_error("course_id", "'invalid course id' is not a valid course id")
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for CommentListGetForm"""
FORM_CLASS = CommentListGetForm
def setUp(self):
super(CommentListGetFormTest, self).setUp()
self.form_data = {
"thread_id": "deadbeef",
"endorsed": "False",
"page": "2",
"page_size": "13",
}
def test_basic(self):
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data,
{
"thread_id": "deadbeef",
"endorsed": False,
"page": 2,
"page_size": 13,
}
)
def test_missing_thread_id(self):
self.form_data.pop("thread_id")
self.assert_error("thread_id", "This field is required.")
def test_missing_endorsed(self):
self.form_data.pop("endorsed")
self.assert_field_value("endorsed", None)

View File

@@ -4,8 +4,12 @@ Tests for Discussion API serializers
import ddt
import httpretty
from discussion_api.serializers import ThreadSerializer, get_context
from discussion_api.tests.utils import CommentsServiceMockMixin
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_thread,
make_minimal_cs_comment,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA,
@@ -20,10 +24,9 @@ from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
@ddt.ddt
class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer."""
class SerializerTestMixin(CommentsServiceMockMixin):
def setUp(self):
super(ThreadSerializerTest, self).setUp()
super(SerializerTestMixin, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
@@ -39,37 +42,93 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
role = Role.objects.create(name=role_name, course_id=course.id)
role.users = users
def make_cs_thread(self, thread_data=None):
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
"""
Create a dictionary containing all needed thread fields as returned by
the comments service with dummy data overridden by thread_data
Test that content is properly made anonymous.
Content should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
anonymous is the value of the anonymous field in the content.
anonymous_to_peers is the value of the anonymous_to_peers field in the
content.
expected_serialized_anonymous is whether the content should actually be
anonymous in the API output when requested by a user with the given
role.
"""
ret = {
"id": "dummy",
self.create_role(role_name, [self.user])
serialized = self.serialize(
self.make_cs_content({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
)
actual_serialized_anonymous = serialized["author"] is None
self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the content is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the content.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
serialized = self.serialize(self.make_cs_content({"anonymous": anonymous}))
self.assertEqual(serialized["author_label"], expected_label)
def test_abuse_flagged(self):
serialized = self.serialize(self.make_cs_content({"abuse_flaggers": [str(self.user.id)]}))
self.assertEqual(serialized["abuse_flagged"], True)
def test_voted(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, upvoted_ids=[thread_id])
serialized = self.serialize(self.make_cs_content({"id": thread_id}))
self.assertEqual(serialized["voted"], True)
@ddt.ddt
class ThreadSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer."""
def make_cs_content(self, overrides):
"""
Create a thread with the given overrides, plus some useful test data.
"""
merged_overrides = {
"course_id": unicode(self.course.id),
"commentable_id": "dummy",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
"children": [],
"resp_total": 0,
}
if thread_data:
ret.update(thread_data)
return ret
merged_overrides.update(overrides)
return make_minimal_cs_thread(merged_overrides)
def serialize(self, thread):
"""
@@ -126,84 +185,86 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
def test_group(self):
cohort = CohortFactory.create(course_id=self.course.id)
serialized = self.serialize(self.make_cs_thread({"group_id": cohort.id}))
serialized = self.serialize(self.make_cs_content({"group_id": cohort.id}))
self.assertEqual(serialized["group_id"], cohort.id)
self.assertEqual(serialized["group_name"], cohort.name)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
"""
Test that content is properly made anonymous.
Content should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
anonymous is the value of the anonymous field in the content.
anonymous_to_peers is the value of the anonymous_to_peers field in the
content.
expected_serialized_anonymous is whether the content should actually be
anonymous in the API output when requested by a user with the given
role.
"""
self.create_role(role_name, [self.user])
serialized = self.serialize(
self.make_cs_thread({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
)
actual_serialized_anonymous = serialized["author"] is None
self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the thread is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the content.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
serialized = self.serialize(self.make_cs_thread({"anonymous": anonymous}))
self.assertEqual(serialized["author_label"], expected_label)
def test_following(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id])
serialized = self.serialize(self.make_cs_thread({"id": thread_id}))
serialized = self.serialize(self.make_cs_content({"id": thread_id}))
self.assertEqual(serialized["following"], True)
def test_abuse_flagged(self):
serialized = self.serialize(self.make_cs_thread({"abuse_flaggers": [str(self.user.id)]}))
self.assertEqual(serialized["abuse_flagged"], True)
def test_voted(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, upvoted_ids=[thread_id])
serialized = self.serialize(self.make_cs_thread({"id": thread_id}))
self.assertEqual(serialized["voted"], True)
@ddt.ddt
class CommentSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
"""Tests for CommentSerializer."""
def make_cs_content(self, overrides):
"""
Create a comment with the given overrides, plus some useful test data.
"""
merged_overrides = {
"user_id": str(self.author.id),
"username": self.author.username
}
merged_overrides.update(overrides)
return make_minimal_cs_comment(merged_overrides)
def serialize(self, comment):
"""
Create a serializer with an appropriate context and use it to serialize
the given comment, returning the result.
"""
return CommentSerializer(comment, context=get_context(self.course, self.user)).data
def test_basic(self):
comment = {
"id": "test_comment",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
}
expected = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"children": [],
}
self.assertEqual(self.serialize(comment), expected)
def test_children(self):
comment = self.make_cs_content({
"id": "test_root",
"children": [
self.make_cs_content({
"id": "test_child_1",
}),
self.make_cs_content({
"id": "test_child_2",
"children": [self.make_cs_content({"id": "test_grandchild"})],
}),
],
})
serialized = self.serialize(comment)
self.assertEqual(serialized["children"][0]["id"], "test_child_1")
self.assertEqual(serialized["children"][0]["parent_id"], "test_root")
self.assertEqual(serialized["children"][1]["id"], "test_child_2")
self.assertEqual(serialized["children"][1]["parent_id"], "test_root")
self.assertEqual(serialized["children"][1]["children"][0]["id"], "test_grandchild")
self.assertEqual(serialized["children"][1]["children"][0]["parent_id"], "test_child_2")

View File

@@ -10,13 +10,11 @@ from pytz import UTC
from django.core.urlresolvers import reverse
from discussion_api.tests.utils import CommentsServiceMockMixin
from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread
from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.tabs import DiscussionTab
class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
@@ -201,3 +199,125 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"per_page": ["4"],
"recursive": ["False"],
})
@httpretty.activate
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet list"""
def setUp(self):
super(CommentViewSetListTest, self).setUp()
self.author = UserFactory.create()
self.url = reverse("comment-list")
self.thread_id = "test_thread"
def test_thread_id_missing(self):
response = self.client.get(self.url)
self.assert_response_correct(
response,
400,
{"field_errors": {"thread_id": "This field is required."}}
)
def test_404(self):
self.register_get_thread_error_response(self.thread_id, 404)
response = self.client.get(self.url, {"thread_id": self.thread_id})
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_basic(self):
self.register_get_user_response(self.user, upvoted_ids=["test_comment"])
source_comments = [{
"id": "test_comment",
"thread_id": self.thread_id,
"parent_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
}]
expected_comments = [{
"id": "test_comment",
"thread_id": self.thread_id,
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": True,
"vote_count": 4,
"children": [],
}]
self.register_get_thread_response({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"thread_type": "discussion",
"children": source_comments,
"resp_total": 100,
})
response = self.client.get(self.url, {"thread_id": self.thread_id})
self.assert_response_correct(
response,
200,
{
"results": expected_comments,
"next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format(
self.thread_id
),
"previous": None,
}
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"resp_skip": ["0"],
"resp_limit": ["10"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
}
)
def test_pagination(self):
"""
Test that pagination parameters are correctly plumbed through to the
comments service and that a 404 is correctly returned if a page past the
end is requested
"""
self.register_get_user_response(self.user)
self.register_get_thread_response(make_minimal_cs_thread({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"thread_type": "discussion",
"children": [],
"resp_total": 10,
}))
response = self.client.get(
self.url,
{"thread_id": self.thread_id, "page": "18", "page_size": "4"}
)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"resp_skip": ["68"],
"resp_limit": ["4"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
}
)

View File

@@ -21,6 +21,26 @@ class CommentsServiceMockMixin(object):
status=200
)
def register_get_thread_error_response(self, thread_id, status_code):
"""Register a mock error response for GET on the CS thread endpoint."""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/threads/{id}".format(id=thread_id),
body="",
status=status_code
)
def register_get_thread_response(self, thread):
"""
Register a mock response for GET on the CS thread instance endpoint.
"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/threads/{id}".format(id=thread["id"]),
body=json.dumps(thread),
status=200
)
def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None):
"""Register a mock response for GET on the CS user instance endpoint"""
httpretty.register_uri(
@@ -34,10 +54,73 @@ class CommentsServiceMockMixin(object):
status=200
)
def assert_query_params_equal(self, httpretty_request, expected_params):
"""
Assert that the given mock request had the expected query parameters
"""
actual_params = dict(httpretty_request.querystring)
actual_params.pop("request_id") # request_id is random
self.assertEqual(actual_params, expected_params)
def assert_last_query_params(self, expected_params):
"""
Assert that the last mock request had the expected query parameters
"""
actual_params = dict(httpretty.last_request().querystring)
actual_params.pop("request_id") # request_id is random
self.assertEqual(actual_params, expected_params)
self.assert_query_params_equal(httpretty.last_request(), expected_params)
def make_minimal_cs_thread(overrides=None):
"""
Create a dictionary containing all needed thread fields as returned by the
comments service with dummy data and optional overrides
"""
ret = {
"id": "dummy",
"course_id": "dummy/dummy/dummy",
"commentable_id": "dummy",
"group_id": None,
"user_id": "0",
"username": "dummy",
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
"children": [],
"resp_total": 0,
}
ret.update(overrides or {})
return ret
def make_minimal_cs_comment(overrides=None):
"""
Create a dictionary containing all needed comment fields as returned by the
comments service with dummy data and optional overrides
"""
ret = {
"id": "dummy",
"thread_id": "dummy",
"user_id": "0",
"username": "dummy",
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"body": "dummy",
"abuse_flaggers": [],
"votes": {"up_count": 0},
"endorsed": False,
"endorsement": None,
"children": [],
}
ret.update(overrides or {})
return ret

View File

@@ -6,11 +6,12 @@ from django.conf.urls import include, patterns, url
from rest_framework.routers import SimpleRouter
from discussion_api.views import CourseTopicsView, ThreadViewSet
from discussion_api.views import CommentViewSet, CourseTopicsView, ThreadViewSet
ROUTER = SimpleRouter()
ROUTER.register("threads", ThreadViewSet, base_name="thread")
ROUTER.register("comments", CommentViewSet, base_name="comment")
urlpatterns = patterns(
"discussion_api",

View File

@@ -11,8 +11,8 @@ from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator
from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.forms import ThreadListGetForm
from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
@@ -126,3 +126,82 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["page_size"]
)
)
class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of comments in a thread.
**Example Requests**:
GET /api/discussion/v1/comments/?thread_id=0123456789abcdef01234567
**GET Parameters**:
* thread_id (required): The thread to retrieve comments for
* endorsed: If specified, only retrieve the endorsed or non-endorsed
comments accordingly. Required for a question thread, must be absent
for a discussion thread.
* page: The (1-indexed) page to retrieve (default is 1)
* page_size: The number of items per page (default is 10, max is 100)
**Response Values**:
* results: The list of comments. Each item in the list includes:
* id: The id of the comment
* thread_id: The id of the comment's thread
* parent_id: The id of the comment's parent
* author: The username of the comment's author, or None if the
comment is anonymous
* author_label: A label indicating whether the author has a special
role in the course, either "staff" for moderators and
administrators or "community_ta" for community TAs
* created_at: The ISO 8601 timestamp for the creation of the comment
* updated_at: The ISO 8601 timestamp for the last modification of
the comment, which may not have been an update of the body
* raw_body: The comment's raw body text without any rendering applied
* abuse_flagged: Boolean indicating whether the requesting user has
flagged the comment for abuse
* voted: Boolean indicating whether the requesting user has voted
for the comment
* vote_count: The number of votes for the comment
* children: The list of child comments (with the same format)
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
"""
def list(self, request):
"""
Implements the GET method for the list endpoint as described in the
class docstring.
"""
form = CommentListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
return Response(
get_comment_list(
request,
form.cleaned_data["thread_id"],
form.cleaned_data["endorsed"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
)