diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index b36ab441b2..7a03dba870 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -1,15 +1,19 @@ """ Discussion API internal interface """ +from django.core.exceptions import ValidationError from django.http import Http404 from collections import defaultdict +from opaque_keys.edx.locator import CourseLocator + from courseware.courses import get_course_with_access from discussion_api.pagination import get_paginated_data -from discussion_api.serializers import ThreadSerializer, get_context +from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from django_comment_client.utils import get_accessible_discussion_modules from lms.lib.comment_client.thread import Thread +from lms.lib.comment_client.utils import CommentClientRequestError from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id from xmodule.tabs import DiscussionTab @@ -123,3 +127,84 @@ def get_thread_list(request, course_key, page, page_size): results = [ThreadSerializer(thread, context=context).data for thread in threads] return get_paginated_data(request, results, page, num_pages) + + +def get_comment_list(request, thread_id, endorsed, page, page_size): + """ + Return the list of comments in the given thread. + + Parameters: + + request: The django request object used for build_absolute_uri and + determining the requesting user. + + thread_id: The id of the thread to get comments for. + + endorsed: Boolean indicating whether to get endorsed or non-endorsed + comments (or None for all comments). Must be None for a discussion + thread and non-None for a question thread. + + page: The page number (1-indexed) to retrieve + + page_size: The number of comments to retrieve per page + + Returns: + + A paginated result containing a list of comments; see + discussion_api.views.CommentViewSet for more detail. + """ + response_skip = page_size * (page - 1) + try: + cc_thread = Thread(id=thread_id).retrieve( + recursive=True, + user_id=request.user.id, + mark_as_read=True, + response_skip=response_skip, + response_limit=page_size + ) + except CommentClientRequestError: + # page and page_size are validated at a higher level, so the only + # possible request error is if the thread doesn't exist + raise Http404 + + course_key = CourseLocator.from_string(cc_thread["course_id"]) + course = _get_course_or_404(course_key, request.user) + context = get_context(course, request.user) + + # Ensure user has access to the thread + if not context["is_requester_privileged"] and cc_thread["group_id"]: + requester_cohort = get_cohort_id(request.user, course_key) + if requester_cohort is not None and cc_thread["group_id"] != requester_cohort: + raise Http404 + + # Responses to discussion threads cannot be separated by endorsed, but + # responses to question threads must be separated by endorsed due to the + # existing comments service interface + if cc_thread["thread_type"] == "question": + if endorsed is None: + raise ValidationError({"endorsed": ["This field is required for question threads."]}) + elif endorsed: + # CS does not apply resp_skip and resp_limit to endorsed responses + # of a question post + responses = cc_thread["endorsed_responses"][response_skip:(response_skip + page_size)] + resp_total = len(cc_thread["endorsed_responses"]) + else: + responses = cc_thread["non_endorsed_responses"] + resp_total = cc_thread["non_endorsed_resp_total"] + else: + if endorsed is not None: + raise ValidationError( + {"endorsed": ["This field may not be specified for discussion threads."]} + ) + responses = cc_thread["children"] + resp_total = cc_thread["resp_total"] + + # The comments service returns the last page of results if the requested + # page is beyond the last page, but we want be consistent with DRF's general + # behavior and return a 404 in that case + if not responses and page != 1: + raise Http404 + num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1 + + results = [CommentSerializer(response, context=context).data for response in responses] + return get_paginated_data(request, results, page, num_pages) diff --git a/lms/djangoapps/discussion_api/forms.py b/lms/djangoapps/discussion_api/forms.py index e25b9a53e9..3f57ea71d0 100644 --- a/lms/djangoapps/discussion_api/forms.py +++ b/lms/djangoapps/discussion_api/forms.py @@ -2,19 +2,31 @@ Discussion API forms """ from django.core.exceptions import ValidationError -from django.forms import Form, CharField, IntegerField +from django.forms import CharField, Form, IntegerField, NullBooleanField from opaque_keys import InvalidKeyError from opaque_keys.edx.locator import CourseLocator -class ThreadListGetForm(Form): +class _PaginationForm(Form): + """A form that includes pagination fields""" + page = IntegerField(required=False, min_value=1) + page_size = IntegerField(required=False, min_value=1) + + def clean_page(self): + """Return given valid page or default of 1""" + return self.cleaned_data.get("page") or 1 + + def clean_page_size(self): + """Return given valid page_size (capped at 100) or default of 10""" + return min(self.cleaned_data.get("page_size") or 10, 100) + + +class ThreadListGetForm(_PaginationForm): """ A form to validate query parameters in the thread list retrieval endpoint """ course_id = CharField() - page = IntegerField(required=False, min_value=1) - page_size = IntegerField(required=False, min_value=1) def clean_course_id(self): """Validate course_id""" @@ -24,10 +36,12 @@ class ThreadListGetForm(Form): except InvalidKeyError: raise ValidationError("'{}' is not a valid course id".format(value)) - def clean_page(self): - """Return given valid page or default of 1""" - return self.cleaned_data.get("page") or 1 - def clean_page_size(self): - """Return given valid page_size (capped at 100) or default of 10""" - return min(self.cleaned_data.get("page_size") or 10, 100) +class CommentListGetForm(_PaginationForm): + """ + A form to validate query parameters in the comment list retrieval endpoint + """ + thread_id = CharField() + # TODO: should we use something better here? This only accepts "True", + # "False", "1", and "0" + endorsed = NullBooleanField(required=False) diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py index 727ea12b2e..db8f94c072 100644 --- a/lms/djangoapps/discussion_api/serializers.py +++ b/lms/djangoapps/discussion_api/serializers.py @@ -14,7 +14,10 @@ from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names def get_context(course, requester): - """Returns a context appropriate for use with ThreadSerializer.""" + """ + Returns a context appropriate for use with ThreadSerializer or + CommentSerializer. + """ # TODO: cache staff_user_ids and ta_user_ids if we need to improve perf staff_user_ids = { user.id @@ -39,49 +42,27 @@ def get_context(course, requester): } -class ThreadSerializer(serializers.Serializer): - """ - A serializer for thread data. - - N.B. This should not be used with a comment_client Thread object that has - not had retrieve() called, because of the interaction between DRF's attempts - at introspection and Thread's __getattr__. - """ +class _ContentSerializer(serializers.Serializer): + """A base class for thread and comment serializers.""" id_ = serializers.CharField(read_only=True) - course_id = serializers.CharField() - topic_id = serializers.CharField(source="commentable_id") - group_id = serializers.IntegerField() - group_name = serializers.SerializerMethodField("get_group_name") author = serializers.SerializerMethodField("get_author") author_label = serializers.SerializerMethodField("get_author_label") created_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True) - type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question")) - title = serializers.CharField() raw_body = serializers.CharField(source="body") - pinned = serializers.BooleanField() - closed = serializers.BooleanField() - following = serializers.SerializerMethodField("get_following") abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") voted = serializers.SerializerMethodField("get_voted") vote_count = serializers.SerializerMethodField("get_vote_count") - comment_count = serializers.IntegerField(source="comments_count") - unread_comment_count = serializers.IntegerField(source="unread_comments_count") def __init__(self, *args, **kwargs): - super(ThreadSerializer, self).__init__(*args, **kwargs) - # type and id are invalid class attribute names, so we must declare - # different names above and modify them here + super(_ContentSerializer, self).__init__(*args, **kwargs) + # id is an invalid class attribute name, so we must declare a different + # name above and modify it here self.fields["id"] = self.fields.pop("id_") - self.fields["type"] = self.fields.pop("type_") - - def get_group_name(self, obj): - """Returns the name of the group identified by the thread's group_id.""" - return self.context["group_ids_to_names"].get(obj["group_id"]) def _is_anonymous(self, obj): """ - Returns a boolean indicating whether the thread should be anonymous to + Returns a boolean indicating whether the content should be anonymous to the requester. """ return ( @@ -90,7 +71,7 @@ class ThreadSerializer(serializers.Serializer): ) def get_author(self, obj): - """Returns the author's username, or None if the thread is anonymous.""" + """Returns the author's username, or None if the content is anonymous.""" return None if self._is_anonymous(obj) else obj["username"] def _get_user_label(self, user_id): @@ -105,9 +86,58 @@ class ThreadSerializer(serializers.Serializer): ) def get_author_label(self, obj): - """Returns the role label for the thread author.""" + """Returns the role label for the content author.""" return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"])) + def get_abuse_flagged(self, obj): + """ + Returns a boolean indicating whether the requester has flagged the + content as abusive. + """ + return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] + + def get_voted(self, obj): + """ + Returns a boolean indicating whether the requester has voted for the + content. + """ + return obj["id"] in self.context["cc_requester"]["upvoted_ids"] + + def get_vote_count(self, obj): + """Returns the number of votes for the content.""" + return obj["votes"]["up_count"] + + +class ThreadSerializer(_ContentSerializer): + """ + A serializer for thread data. + + N.B. This should not be used with a comment_client Thread object that has + not had retrieve() called, because of the interaction between DRF's attempts + at introspection and Thread's __getattr__. + """ + course_id = serializers.CharField() + topic_id = serializers.CharField(source="commentable_id") + group_id = serializers.IntegerField() + group_name = serializers.SerializerMethodField("get_group_name") + type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question")) + title = serializers.CharField() + pinned = serializers.BooleanField() + closed = serializers.BooleanField() + following = serializers.SerializerMethodField("get_following") + comment_count = serializers.IntegerField(source="comments_count") + unread_comment_count = serializers.IntegerField(source="unread_comments_count") + + def __init__(self, *args, **kwargs): + super(ThreadSerializer, self).__init__(*args, **kwargs) + # type is an invalid class attribute name, so we must declare a + # different name above and modify it here + self.fields["type"] = self.fields.pop("type_") + + def get_group_name(self, obj): + """Returns the name of the group identified by the thread's group_id.""" + return self.context["group_ids_to_names"].get(obj["group_id"]) + def get_following(self, obj): """ Returns a boolean indicating whether the requester is following the @@ -115,20 +145,25 @@ class ThreadSerializer(serializers.Serializer): """ return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"] - def get_abuse_flagged(self, obj): - """ - Returns a boolean indicating whether the requester has flagged the - thread as abusive. - """ - return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] - def get_voted(self, obj): - """ - Returns a boolean indicating whether the requester has voted for the - thread. - """ - return obj["id"] in self.context["cc_requester"]["upvoted_ids"] +class CommentSerializer(_ContentSerializer): + """ + A serializer for comment data. - def get_vote_count(self, obj): - """Returns the number of votes for the thread.""" - return obj["votes"]["up_count"] + N.B. This should not be used with a comment_client Comment object that has + not had retrieve() called, because of the interaction between DRF's attempts + at introspection and Comment's __getattr__. + """ + thread_id = serializers.CharField() + parent_id = serializers.SerializerMethodField("get_parent_id") + children = serializers.SerializerMethodField("get_children") + + def get_parent_id(self, _obj): + """Returns the comment's parent's id (taken from the context).""" + return self.context.get("parent_id") + + def get_children(self, obj): + """Returns the list of the comment's children, serialized.""" + child_context = dict(self.context) + child_context["parent_id"] = obj["id"] + return [CommentSerializer(child, context=child_context).data for child in obj["children"]] diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 604f7c2391..e55db83369 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -9,14 +9,19 @@ import httpretty import mock from pytz import UTC +from django.core.exceptions import ValidationError from django.http import Http404 from django.test.client import RequestFactory from opaque_keys.edx.locator import CourseLocator from courseware.tests.factories import BetaTesterFactory, StaffFactory -from discussion_api.api import get_course_topics, get_thread_list -from discussion_api.tests.utils import CommentsServiceMockMixin +from discussion_api.api import get_comment_list, get_course_topics, get_thread_list +from discussion_api.tests.utils import ( + CommentsServiceMockMixin, + make_minimal_cs_comment, + make_minimal_cs_thread, +) from django_comment_common.models import ( FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_COMMUNITY_TA, @@ -378,12 +383,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ret = get_thread_list(self.request, course.id, page, page_size) return ret - def create_role(self, role_name, users): - """Create a Role in self.course with the given name and users""" - role = Role.objects.create(name=role_name, course_id=self.course.id) - role.users = users - role.save() - def test_nonexistent_course(self): with self.assertRaises(Http404): get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1) @@ -573,3 +572,368 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): self.register_get_threads_response([], page=3, num_pages=3) with self.assertRaises(Http404): get_thread_list(self.request, self.course.id, page=4, page_size=10) + + +@ddt.ddt +class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): + """Test for get_comment_list""" + def setUp(self): + super(GetCommentListTest, self).setUp() + httpretty.reset() + httpretty.enable() + self.addCleanup(httpretty.disable) + self.maxDiff = None # pylint: disable=invalid-name + self.user = UserFactory.create() + self.register_get_user_response(self.user) + self.request = RequestFactory().get("/test_path") + self.request.user = self.user + self.course = CourseFactory.create() + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) + self.author = UserFactory.create() + + def make_minimal_cs_thread(self, overrides=None): + """ + Create a thread with the given overrides, plus the course_id if not + already in overrides. + """ + overrides = overrides.copy() if overrides else {} + overrides.setdefault("course_id", unicode(self.course.id)) + return make_minimal_cs_thread(overrides) + + def get_comment_list(self, thread, endorsed=None, page=1, page_size=1): + """ + Register the appropriate comments service response, then call + get_comment_list and return the result. + """ + self.register_get_thread_response(thread) + return get_comment_list(self.request, thread["id"], endorsed, page, page_size) + + def test_nonexistent_thread(self): + thread_id = "nonexistent_thread" + self.register_get_thread_error_response(thread_id, 404) + with self.assertRaises(Http404): + get_comment_list(self.request, thread_id, endorsed=False, page=1, page_size=1) + + def test_nonexistent_course(self): + with self.assertRaises(Http404): + self.get_comment_list(self.make_minimal_cs_thread({"course_id": "non/existent/course"})) + + def test_not_enrolled(self): + self.request.user = UserFactory.create() + with self.assertRaises(Http404): + self.get_comment_list(self.make_minimal_cs_thread()) + + def test_discussions_disabled(self): + _remove_discussion_tab(self.course, self.user.id) + with self.assertRaises(Http404): + self.get_comment_list(self.make_minimal_cs_thread()) + + @ddt.data( + *itertools.product( + [ + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_STUDENT, + ], + [True, False], + ["no_group", "match_group", "different_group"], + ) + ) + @ddt.unpack + def test_group_access(self, role_name, course_is_cohorted, thread_group_state): + cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted}) + CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id) + cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) + role = Role.objects.create(name=role_name, course_id=cohort_course.id) + role.users = [self.user] + thread = self.make_minimal_cs_thread({ + "course_id": unicode(cohort_course.id), + "group_id": ( + None if thread_group_state == "no_group" else + cohort.id if thread_group_state == "match_group" else + cohort.id + 1 + ), + }) + expected_error = ( + role_name == FORUM_ROLE_STUDENT and + course_is_cohorted and + thread_group_state == "different_group" + ) + try: + self.get_comment_list(thread) + self.assertFalse(expected_error) + except Http404: + self.assertTrue(expected_error) + + @ddt.data(True, False) + def test_discussion_endorsed(self, endorsed_value): + with self.assertRaises(ValidationError) as assertion: + self.get_comment_list( + self.make_minimal_cs_thread({"thread_type": "discussion"}), + endorsed=endorsed_value + ) + self.assertEqual( + assertion.exception.message_dict, + {"endorsed": ["This field may not be specified for discussion threads."]} + ) + + def test_question_without_endorsed(self): + with self.assertRaises(ValidationError) as assertion: + self.get_comment_list( + self.make_minimal_cs_thread({"thread_type": "question"}), + endorsed=None + ) + self.assertEqual( + assertion.exception.message_dict, + {"endorsed": ["This field is required for question threads."]} + ) + + def test_empty(self): + discussion_thread = self.make_minimal_cs_thread( + {"thread_type": "discussion", "children": [], "resp_total": 0} + ) + self.assertEqual( + self.get_comment_list(discussion_thread), + {"results": [], "next": None, "previous": None} + ) + + question_thread = self.make_minimal_cs_thread({ + "thread_type": "question", + "endorsed_responses": [], + "non_endorsed_responses": [], + "non_endorsed_resp_total": 0 + }) + self.assertEqual( + self.get_comment_list(question_thread, endorsed=False), + {"results": [], "next": None, "previous": None} + ) + self.assertEqual( + self.get_comment_list(question_thread, endorsed=True), + {"results": [], "next": None, "previous": None} + ) + + def test_basic_query_params(self): + self.get_comment_list( + self.make_minimal_cs_thread({ + "children": [make_minimal_cs_comment()], + "resp_total": 71 + }), + page=6, + page_size=14 + ) + self.assert_query_params_equal( + httpretty.httpretty.latest_requests[-2], + { + "recursive": ["True"], + "user_id": [str(self.user.id)], + "mark_as_read": ["True"], + "resp_skip": ["70"], + "resp_limit": ["14"], + } + ) + + def test_discussion_content(self): + source_comments = [ + { + "id": "test_comment_1", + "thread_id": "test_thread", + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "2015-05-11T00:00:00Z", + "updated_at": "2015-05-11T11:11:11Z", + "body": "Test body", + "abuse_flaggers": [], + "votes": {"up_count": 4}, + "children": [], + }, + { + "id": "test_comment_2", + "thread_id": "test_thread", + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": True, + "anonymous_to_peers": False, + "created_at": "2015-05-11T22:22:22Z", + "updated_at": "2015-05-11T33:33:33Z", + "body": "More content", + "abuse_flaggers": [str(self.user.id)], + "votes": {"up_count": 7}, + "children": [], + } + ] + expected_comments = [ + { + "id": "test_comment_1", + "thread_id": "test_thread", + "parent_id": None, + "author": self.author.username, + "author_label": None, + "created_at": "2015-05-11T00:00:00Z", + "updated_at": "2015-05-11T11:11:11Z", + "raw_body": "Test body", + "abuse_flagged": False, + "voted": False, + "vote_count": 4, + "children": [], + }, + { + "id": "test_comment_2", + "thread_id": "test_thread", + "parent_id": None, + "author": None, + "author_label": None, + "created_at": "2015-05-11T22:22:22Z", + "updated_at": "2015-05-11T33:33:33Z", + "raw_body": "More content", + "abuse_flagged": True, + "voted": False, + "vote_count": 7, + "children": [], + }, + ] + actual_comments = self.get_comment_list( + self.make_minimal_cs_thread({"children": source_comments}) + )["results"] + self.assertEqual(actual_comments, expected_comments) + + def test_question_content(self): + thread = self.make_minimal_cs_thread({ + "thread_type": "question", + "endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})], + "non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})], + "non_endorsed_resp_total": 1, + }) + + endorsed_actual = self.get_comment_list(thread, endorsed=True) + self.assertEqual(endorsed_actual["results"][0]["id"], "endorsed_comment") + + non_endorsed_actual = self.get_comment_list(thread, endorsed=False) + self.assertEqual(non_endorsed_actual["results"][0]["id"], "non_endorsed_comment") + + @ddt.data( + ("discussion", None, "children", "resp_total"), + ("question", False, "non_endorsed_responses", "non_endorsed_resp_total"), + ) + @ddt.unpack + def test_cs_pagination(self, thread_type, endorsed_arg, response_field, response_total_field): + """ + Test cases in which pagination is done by the comments service. + + thread_type is the type of thread (question or discussion). + endorsed_arg is the value of the endorsed argument. + repsonse_field is the field in which responses are returned for the + given thread type. + response_total_field is the field in which the total number of responses + is returned for the given thread type. + """ + # N.B. The mismatch between the number of children and the listed total + # number of responses is unrealistic but convenient for this test + thread = self.make_minimal_cs_thread({ + "thread_type": thread_type, + response_field: [make_minimal_cs_comment()], + response_total_field: 5, + }) + + # Only page + actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=5) + self.assertIsNone(actual["next"]) + self.assertIsNone(actual["previous"]) + + # First page of many + actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=2) + self.assertEqual(actual["next"], "http://testserver/test_path?page=2") + self.assertIsNone(actual["previous"]) + + # Middle page of many + actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=2) + self.assertEqual(actual["next"], "http://testserver/test_path?page=3") + self.assertEqual(actual["previous"], "http://testserver/test_path?page=1") + + # Last page of many + actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=3, page_size=2) + self.assertIsNone(actual["next"]) + self.assertEqual(actual["previous"], "http://testserver/test_path?page=2") + + # Page past the end + thread = self.make_minimal_cs_thread({ + "thread_type": thread_type, + response_field: [], + response_total_field: 5 + }) + with self.assertRaises(Http404): + self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=5) + + def test_question_endorsed_pagination(self): + thread = self.make_minimal_cs_thread({ + "thread_type": "question", + "endorsed_responses": [ + make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10) + ] + }) + + def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev): + """ + Check that requesting the given page/page_size returns the expected + output + """ + actual = self.get_comment_list(thread, endorsed=True, page=page, page_size=page_size) + result_ids = [result["id"] for result in actual["results"]] + self.assertEqual( + result_ids, + ["comment_{}".format(i) for i in range(expected_start, expected_stop)] + ) + self.assertEqual( + actual["next"], + "http://testserver/test_path?page={}".format(expected_next) if expected_next else None + ) + self.assertEqual( + actual["previous"], + "http://testserver/test_path?page={}".format(expected_prev) if expected_prev else None + ) + + # Only page + assert_page_correct( + page=1, + page_size=10, + expected_start=0, + expected_stop=10, + expected_next=None, + expected_prev=None + ) + + # First page of many + assert_page_correct( + page=1, + page_size=4, + expected_start=0, + expected_stop=4, + expected_next=2, + expected_prev=None + ) + + # Middle page of many + assert_page_correct( + page=2, + page_size=4, + expected_start=4, + expected_stop=8, + expected_next=3, + expected_prev=1 + ) + + # Last page of many + assert_page_correct( + page=3, + page_size=4, + expected_start=8, + expected_stop=10, + expected_next=None, + expected_prev=2 + ) + + # Page past the end + with self.assertRaises(Http404): + self.get_comment_list(thread, endorsed=True, page=2, page_size=10) diff --git a/lms/djangoapps/discussion_api/tests/test_forms.py b/lms/djangoapps/discussion_api/tests/test_forms.py index f988b1f484..584314d0cc 100644 --- a/lms/djangoapps/discussion_api/tests/test_forms.py +++ b/lms/djangoapps/discussion_api/tests/test_forms.py @@ -5,25 +5,17 @@ from unittest import TestCase from opaque_keys.edx.locator import CourseLocator -from discussion_api.forms import ThreadListGetForm +from discussion_api.forms import CommentListGetForm, ThreadListGetForm -class ThreadListGetFormTest(TestCase): - """Tests for ThreadListGetForm""" - def setUp(self): - super(ThreadListGetFormTest, self).setUp() - self.form_data = { - "course_id": "Foo/Bar/Baz", - "page": "2", - "page_size": "13", - } - +class FormTestMixin(object): + """A mixin for testing forms""" def get_form(self, expected_valid): """ Return a form bound to self.form_data, asserting its validity (or lack thereof) according to expected_valid """ - form = ThreadListGetForm(self.form_data) + form = self.FORM_CLASS(self.form_data) self.assertEqual(form.is_valid(), expected_valid) return form @@ -44,25 +36,9 @@ class ThreadListGetFormTest(TestCase): form = self.get_form(expected_valid=True) self.assertEqual(form.cleaned_data[field], expected_value) - def test_basic(self): - form = self.get_form(expected_valid=True) - self.assertEqual( - form.cleaned_data, - { - "course_id": CourseLocator.from_string("Foo/Bar/Baz"), - "page": 2, - "page_size": 13, - } - ) - - def test_missing_course_id(self): - self.form_data.pop("course_id") - self.assert_error("course_id", "This field is required.") - - def test_invalid_course_id(self): - self.form_data["course_id"] = "invalid course id" - self.assert_error("course_id", "'invalid course id' is not a valid course id") +class PaginationTestMixin(object): + """A mixin for testing forms with pagination fields""" def test_missing_page(self): self.form_data.pop("page") self.assert_field_value("page", 1) @@ -82,3 +58,69 @@ class ThreadListGetFormTest(TestCase): def test_excessive_page_size(self): self.form_data["page_size"] = "101" self.assert_field_value("page_size", 100) + + +class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): + """Tests for ThreadListGetForm""" + FORM_CLASS = ThreadListGetForm + + def setUp(self): + super(ThreadListGetFormTest, self).setUp() + self.form_data = { + "course_id": "Foo/Bar/Baz", + "page": "2", + "page_size": "13", + } + + def test_basic(self): + form = self.get_form(expected_valid=True) + self.assertEqual( + form.cleaned_data, + { + "course_id": CourseLocator.from_string("Foo/Bar/Baz"), + "page": 2, + "page_size": 13, + } + ) + + def test_missing_course_id(self): + self.form_data.pop("course_id") + self.assert_error("course_id", "This field is required.") + + def test_invalid_course_id(self): + self.form_data["course_id"] = "invalid course id" + self.assert_error("course_id", "'invalid course id' is not a valid course id") + + +class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): + """Tests for CommentListGetForm""" + FORM_CLASS = CommentListGetForm + + def setUp(self): + super(CommentListGetFormTest, self).setUp() + self.form_data = { + "thread_id": "deadbeef", + "endorsed": "False", + "page": "2", + "page_size": "13", + } + + def test_basic(self): + form = self.get_form(expected_valid=True) + self.assertEqual( + form.cleaned_data, + { + "thread_id": "deadbeef", + "endorsed": False, + "page": 2, + "page_size": 13, + } + ) + + def test_missing_thread_id(self): + self.form_data.pop("thread_id") + self.assert_error("thread_id", "This field is required.") + + def test_missing_endorsed(self): + self.form_data.pop("endorsed") + self.assert_field_value("endorsed", None) diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py index 92ee95c6ad..7426d39aba 100644 --- a/lms/djangoapps/discussion_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -4,8 +4,12 @@ Tests for Discussion API serializers import ddt import httpretty -from discussion_api.serializers import ThreadSerializer, get_context -from discussion_api.tests.utils import CommentsServiceMockMixin +from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context +from discussion_api.tests.utils import ( + CommentsServiceMockMixin, + make_minimal_cs_thread, + make_minimal_cs_comment, +) from django_comment_common.models import ( FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_COMMUNITY_TA, @@ -20,10 +24,9 @@ from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory @ddt.ddt -class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): - """Tests for ThreadSerializer.""" +class SerializerTestMixin(CommentsServiceMockMixin): def setUp(self): - super(ThreadSerializerTest, self).setUp() + super(SerializerTestMixin, self).setUp() httpretty.reset() httpretty.enable() self.addCleanup(httpretty.disable) @@ -39,37 +42,93 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): role = Role.objects.create(name=role_name, course_id=course.id) role.users = users - def make_cs_thread(self, thread_data=None): + @ddt.data( + (FORUM_ROLE_ADMINISTRATOR, True, False, True), + (FORUM_ROLE_ADMINISTRATOR, False, True, False), + (FORUM_ROLE_MODERATOR, True, False, True), + (FORUM_ROLE_MODERATOR, False, True, False), + (FORUM_ROLE_COMMUNITY_TA, True, False, True), + (FORUM_ROLE_COMMUNITY_TA, False, True, False), + (FORUM_ROLE_STUDENT, True, False, True), + (FORUM_ROLE_STUDENT, False, True, True), + ) + @ddt.unpack + def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous): """ - Create a dictionary containing all needed thread fields as returned by - the comments service with dummy data overridden by thread_data + Test that content is properly made anonymous. + + Content should be anonymous iff the anonymous field is true or the + anonymous_to_peers field is true and the requester does not have a + privileged role. + + role_name is the name of the requester's role. + anonymous is the value of the anonymous field in the content. + anonymous_to_peers is the value of the anonymous_to_peers field in the + content. + expected_serialized_anonymous is whether the content should actually be + anonymous in the API output when requested by a user with the given + role. """ - ret = { - "id": "dummy", + self.create_role(role_name, [self.user]) + serialized = self.serialize( + self.make_cs_content({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers}) + ) + actual_serialized_anonymous = serialized["author"] is None + self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous) + + @ddt.data( + (FORUM_ROLE_ADMINISTRATOR, False, "staff"), + (FORUM_ROLE_ADMINISTRATOR, True, None), + (FORUM_ROLE_MODERATOR, False, "staff"), + (FORUM_ROLE_MODERATOR, True, None), + (FORUM_ROLE_COMMUNITY_TA, False, "community_ta"), + (FORUM_ROLE_COMMUNITY_TA, True, None), + (FORUM_ROLE_STUDENT, False, None), + (FORUM_ROLE_STUDENT, True, None), + ) + @ddt.unpack + def test_author_labels(self, role_name, anonymous, expected_label): + """ + Test correctness of the author_label field. + + The label should be "staff", "staff", or "community_ta" for the + Administrator, Moderator, and Community TA roles, respectively, but + the label should not be present if the content is anonymous. + + role_name is the name of the author's role. + anonymous is the value of the anonymous field in the content. + expected_label is the expected value of the author_label field in the + API output. + """ + self.create_role(role_name, [self.author]) + serialized = self.serialize(self.make_cs_content({"anonymous": anonymous})) + self.assertEqual(serialized["author_label"], expected_label) + + def test_abuse_flagged(self): + serialized = self.serialize(self.make_cs_content({"abuse_flaggers": [str(self.user.id)]})) + self.assertEqual(serialized["abuse_flagged"], True) + + def test_voted(self): + thread_id = "test_thread" + self.register_get_user_response(self.user, upvoted_ids=[thread_id]) + serialized = self.serialize(self.make_cs_content({"id": thread_id})) + self.assertEqual(serialized["voted"], True) + + +@ddt.ddt +class ThreadSerializerTest(SerializerTestMixin, ModuleStoreTestCase): + """Tests for ThreadSerializer.""" + def make_cs_content(self, overrides): + """ + Create a thread with the given overrides, plus some useful test data. + """ + merged_overrides = { "course_id": unicode(self.course.id), - "commentable_id": "dummy", - "group_id": None, "user_id": str(self.author.id), "username": self.author.username, - "anonymous": False, - "anonymous_to_peers": False, - "created_at": "1970-01-01T00:00:00Z", - "updated_at": "1970-01-01T00:00:00Z", - "thread_type": "discussion", - "title": "dummy", - "body": "dummy", - "pinned": False, - "closed": False, - "abuse_flaggers": [], - "votes": {"up_count": 0}, - "comments_count": 0, - "unread_comments_count": 0, - "children": [], - "resp_total": 0, } - if thread_data: - ret.update(thread_data) - return ret + merged_overrides.update(overrides) + return make_minimal_cs_thread(merged_overrides) def serialize(self, thread): """ @@ -126,84 +185,86 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): def test_group(self): cohort = CohortFactory.create(course_id=self.course.id) - serialized = self.serialize(self.make_cs_thread({"group_id": cohort.id})) + serialized = self.serialize(self.make_cs_content({"group_id": cohort.id})) self.assertEqual(serialized["group_id"], cohort.id) self.assertEqual(serialized["group_name"], cohort.name) - @ddt.data( - (FORUM_ROLE_ADMINISTRATOR, True, False, True), - (FORUM_ROLE_ADMINISTRATOR, False, True, False), - (FORUM_ROLE_MODERATOR, True, False, True), - (FORUM_ROLE_MODERATOR, False, True, False), - (FORUM_ROLE_COMMUNITY_TA, True, False, True), - (FORUM_ROLE_COMMUNITY_TA, False, True, False), - (FORUM_ROLE_STUDENT, True, False, True), - (FORUM_ROLE_STUDENT, False, True, True), - ) - @ddt.unpack - def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous): - """ - Test that content is properly made anonymous. - - Content should be anonymous iff the anonymous field is true or the - anonymous_to_peers field is true and the requester does not have a - privileged role. - - role_name is the name of the requester's role. - anonymous is the value of the anonymous field in the content. - anonymous_to_peers is the value of the anonymous_to_peers field in the - content. - expected_serialized_anonymous is whether the content should actually be - anonymous in the API output when requested by a user with the given - role. - """ - self.create_role(role_name, [self.user]) - serialized = self.serialize( - self.make_cs_thread({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers}) - ) - actual_serialized_anonymous = serialized["author"] is None - self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous) - - @ddt.data( - (FORUM_ROLE_ADMINISTRATOR, False, "staff"), - (FORUM_ROLE_ADMINISTRATOR, True, None), - (FORUM_ROLE_MODERATOR, False, "staff"), - (FORUM_ROLE_MODERATOR, True, None), - (FORUM_ROLE_COMMUNITY_TA, False, "community_ta"), - (FORUM_ROLE_COMMUNITY_TA, True, None), - (FORUM_ROLE_STUDENT, False, None), - (FORUM_ROLE_STUDENT, True, None), - ) - @ddt.unpack - def test_author_labels(self, role_name, anonymous, expected_label): - """ - Test correctness of the author_label field. - - The label should be "staff", "staff", or "community_ta" for the - Administrator, Moderator, and Community TA roles, respectively, but - the label should not be present if the thread is anonymous. - - role_name is the name of the author's role. - anonymous is the value of the anonymous field in the content. - expected_label is the expected value of the author_label field in the - API output. - """ - self.create_role(role_name, [self.author]) - serialized = self.serialize(self.make_cs_thread({"anonymous": anonymous})) - self.assertEqual(serialized["author_label"], expected_label) - def test_following(self): thread_id = "test_thread" self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id]) - serialized = self.serialize(self.make_cs_thread({"id": thread_id})) + serialized = self.serialize(self.make_cs_content({"id": thread_id})) self.assertEqual(serialized["following"], True) - def test_abuse_flagged(self): - serialized = self.serialize(self.make_cs_thread({"abuse_flaggers": [str(self.user.id)]})) - self.assertEqual(serialized["abuse_flagged"], True) - def test_voted(self): - thread_id = "test_thread" - self.register_get_user_response(self.user, upvoted_ids=[thread_id]) - serialized = self.serialize(self.make_cs_thread({"id": thread_id})) - self.assertEqual(serialized["voted"], True) +@ddt.ddt +class CommentSerializerTest(SerializerTestMixin, ModuleStoreTestCase): + """Tests for CommentSerializer.""" + def make_cs_content(self, overrides): + """ + Create a comment with the given overrides, plus some useful test data. + """ + merged_overrides = { + "user_id": str(self.author.id), + "username": self.author.username + } + merged_overrides.update(overrides) + return make_minimal_cs_comment(merged_overrides) + + def serialize(self, comment): + """ + Create a serializer with an appropriate context and use it to serialize + the given comment, returning the result. + """ + return CommentSerializer(comment, context=get_context(self.course, self.user)).data + + def test_basic(self): + comment = { + "id": "test_comment", + "thread_id": "test_thread", + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "body": "Test body", + "abuse_flaggers": [], + "votes": {"up_count": 4}, + "children": [], + } + expected = { + "id": "test_comment", + "thread_id": "test_thread", + "parent_id": None, + "author": self.author.username, + "author_label": None, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "raw_body": "Test body", + "abuse_flagged": False, + "voted": False, + "vote_count": 4, + "children": [], + } + self.assertEqual(self.serialize(comment), expected) + + def test_children(self): + comment = self.make_cs_content({ + "id": "test_root", + "children": [ + self.make_cs_content({ + "id": "test_child_1", + }), + self.make_cs_content({ + "id": "test_child_2", + "children": [self.make_cs_content({"id": "test_grandchild"})], + }), + ], + }) + serialized = self.serialize(comment) + self.assertEqual(serialized["children"][0]["id"], "test_child_1") + self.assertEqual(serialized["children"][0]["parent_id"], "test_root") + self.assertEqual(serialized["children"][1]["id"], "test_child_2") + self.assertEqual(serialized["children"][1]["parent_id"], "test_root") + self.assertEqual(serialized["children"][1]["children"][0]["id"], "test_grandchild") + self.assertEqual(serialized["children"][1]["children"][0]["parent_id"], "test_child_2") diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index a9c4e1a99d..1fea0ac3e3 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -10,13 +10,11 @@ from pytz import UTC from django.core.urlresolvers import reverse -from discussion_api.tests.utils import CommentsServiceMockMixin +from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread from student.tests.factories import CourseEnrollmentFactory, UserFactory from util.testing import UrlResetMixin -from xmodule.modulestore.django import modulestore from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory -from xmodule.tabs import DiscussionTab class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): @@ -201,3 +199,125 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): "per_page": ["4"], "recursive": ["False"], }) + + +@httpretty.activate +class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): + """Tests for CommentViewSet list""" + def setUp(self): + super(CommentViewSetListTest, self).setUp() + self.author = UserFactory.create() + self.url = reverse("comment-list") + self.thread_id = "test_thread" + + def test_thread_id_missing(self): + response = self.client.get(self.url) + self.assert_response_correct( + response, + 400, + {"field_errors": {"thread_id": "This field is required."}} + ) + + def test_404(self): + self.register_get_thread_error_response(self.thread_id, 404) + response = self.client.get(self.url, {"thread_id": self.thread_id}) + self.assert_response_correct( + response, + 404, + {"developer_message": "Not found."} + ) + + def test_basic(self): + self.register_get_user_response(self.user, upvoted_ids=["test_comment"]) + source_comments = [{ + "id": "test_comment", + "thread_id": self.thread_id, + "parent_id": None, + "user_id": str(self.author.id), + "username": self.author.username, + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "2015-05-11T00:00:00Z", + "updated_at": "2015-05-11T11:11:11Z", + "body": "Test body", + "abuse_flaggers": [], + "votes": {"up_count": 4}, + "children": [], + }] + expected_comments = [{ + "id": "test_comment", + "thread_id": self.thread_id, + "parent_id": None, + "author": self.author.username, + "author_label": None, + "created_at": "2015-05-11T00:00:00Z", + "updated_at": "2015-05-11T11:11:11Z", + "raw_body": "Test body", + "abuse_flagged": False, + "voted": True, + "vote_count": 4, + "children": [], + }] + self.register_get_thread_response({ + "id": self.thread_id, + "course_id": unicode(self.course.id), + "thread_type": "discussion", + "children": source_comments, + "resp_total": 100, + }) + response = self.client.get(self.url, {"thread_id": self.thread_id}) + self.assert_response_correct( + response, + 200, + { + "results": expected_comments, + "next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format( + self.thread_id + ), + "previous": None, + } + ) + self.assert_query_params_equal( + httpretty.httpretty.latest_requests[-2], + { + "recursive": ["True"], + "resp_skip": ["0"], + "resp_limit": ["10"], + "user_id": [str(self.user.id)], + "mark_as_read": ["True"], + } + ) + + def test_pagination(self): + """ + Test that pagination parameters are correctly plumbed through to the + comments service and that a 404 is correctly returned if a page past the + end is requested + """ + self.register_get_user_response(self.user) + self.register_get_thread_response(make_minimal_cs_thread({ + "id": self.thread_id, + "course_id": unicode(self.course.id), + "thread_type": "discussion", + "children": [], + "resp_total": 10, + })) + response = self.client.get( + self.url, + {"thread_id": self.thread_id, "page": "18", "page_size": "4"} + ) + self.assert_response_correct( + response, + 404, + {"developer_message": "Not found."} + ) + self.assert_query_params_equal( + httpretty.httpretty.latest_requests[-2], + { + "recursive": ["True"], + "resp_skip": ["68"], + "resp_limit": ["4"], + "user_id": [str(self.user.id)], + "mark_as_read": ["True"], + } + ) diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py index b931d7dbc8..8982abffcb 100644 --- a/lms/djangoapps/discussion_api/tests/utils.py +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -21,6 +21,26 @@ class CommentsServiceMockMixin(object): status=200 ) + def register_get_thread_error_response(self, thread_id, status_code): + """Register a mock error response for GET on the CS thread endpoint.""" + httpretty.register_uri( + httpretty.GET, + "http://localhost:4567/api/v1/threads/{id}".format(id=thread_id), + body="", + status=status_code + ) + + def register_get_thread_response(self, thread): + """ + Register a mock response for GET on the CS thread instance endpoint. + """ + httpretty.register_uri( + httpretty.GET, + "http://localhost:4567/api/v1/threads/{id}".format(id=thread["id"]), + body=json.dumps(thread), + status=200 + ) + def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None): """Register a mock response for GET on the CS user instance endpoint""" httpretty.register_uri( @@ -34,10 +54,73 @@ class CommentsServiceMockMixin(object): status=200 ) + def assert_query_params_equal(self, httpretty_request, expected_params): + """ + Assert that the given mock request had the expected query parameters + """ + actual_params = dict(httpretty_request.querystring) + actual_params.pop("request_id") # request_id is random + self.assertEqual(actual_params, expected_params) + def assert_last_query_params(self, expected_params): """ Assert that the last mock request had the expected query parameters """ - actual_params = dict(httpretty.last_request().querystring) - actual_params.pop("request_id") # request_id is random - self.assertEqual(actual_params, expected_params) + self.assert_query_params_equal(httpretty.last_request(), expected_params) + + +def make_minimal_cs_thread(overrides=None): + """ + Create a dictionary containing all needed thread fields as returned by the + comments service with dummy data and optional overrides + """ + ret = { + "id": "dummy", + "course_id": "dummy/dummy/dummy", + "commentable_id": "dummy", + "group_id": None, + "user_id": "0", + "username": "dummy", + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "1970-01-01T00:00:00Z", + "updated_at": "1970-01-01T00:00:00Z", + "thread_type": "discussion", + "title": "dummy", + "body": "dummy", + "pinned": False, + "closed": False, + "abuse_flaggers": [], + "votes": {"up_count": 0}, + "comments_count": 0, + "unread_comments_count": 0, + "children": [], + "resp_total": 0, + } + ret.update(overrides or {}) + return ret + + +def make_minimal_cs_comment(overrides=None): + """ + Create a dictionary containing all needed comment fields as returned by the + comments service with dummy data and optional overrides + """ + ret = { + "id": "dummy", + "thread_id": "dummy", + "user_id": "0", + "username": "dummy", + "anonymous": False, + "anonymous_to_peers": False, + "created_at": "1970-01-01T00:00:00Z", + "updated_at": "1970-01-01T00:00:00Z", + "body": "dummy", + "abuse_flaggers": [], + "votes": {"up_count": 0}, + "endorsed": False, + "endorsement": None, + "children": [], + } + ret.update(overrides or {}) + return ret diff --git a/lms/djangoapps/discussion_api/urls.py b/lms/djangoapps/discussion_api/urls.py index 040ad0e20a..b2808c5567 100644 --- a/lms/djangoapps/discussion_api/urls.py +++ b/lms/djangoapps/discussion_api/urls.py @@ -6,11 +6,12 @@ from django.conf.urls import include, patterns, url from rest_framework.routers import SimpleRouter -from discussion_api.views import CourseTopicsView, ThreadViewSet +from discussion_api.views import CommentViewSet, CourseTopicsView, ThreadViewSet ROUTER = SimpleRouter() ROUTER.register("threads", ThreadViewSet, base_name="thread") +ROUTER.register("comments", CommentViewSet, base_name="comment") urlpatterns = patterns( "discussion_api", diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index cc42f5b4ec..f201a8f408 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -11,8 +11,8 @@ from rest_framework.viewsets import ViewSet from opaque_keys.edx.locator import CourseLocator -from discussion_api.api import get_course_topics, get_thread_list -from discussion_api.forms import ThreadListGetForm +from discussion_api.api import get_comment_list, get_course_topics, get_thread_list +from discussion_api.forms import CommentListGetForm, ThreadListGetForm from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin @@ -126,3 +126,82 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): form.cleaned_data["page_size"] ) ) + + +class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): + """ + **Use Cases** + + Retrieve the list of comments in a thread. + + **Example Requests**: + + GET /api/discussion/v1/comments/?thread_id=0123456789abcdef01234567 + + **GET Parameters**: + + * thread_id (required): The thread to retrieve comments for + + * endorsed: If specified, only retrieve the endorsed or non-endorsed + comments accordingly. Required for a question thread, must be absent + for a discussion thread. + + * page: The (1-indexed) page to retrieve (default is 1) + + * page_size: The number of items per page (default is 10, max is 100) + + **Response Values**: + + * results: The list of comments. Each item in the list includes: + + * id: The id of the comment + + * thread_id: The id of the comment's thread + + * parent_id: The id of the comment's parent + + * author: The username of the comment's author, or None if the + comment is anonymous + + * author_label: A label indicating whether the author has a special + role in the course, either "staff" for moderators and + administrators or "community_ta" for community TAs + + * created_at: The ISO 8601 timestamp for the creation of the comment + + * updated_at: The ISO 8601 timestamp for the last modification of + the comment, which may not have been an update of the body + + * raw_body: The comment's raw body text without any rendering applied + + * abuse_flagged: Boolean indicating whether the requesting user has + flagged the comment for abuse + + * voted: Boolean indicating whether the requesting user has voted + for the comment + + * vote_count: The number of votes for the comment + + * children: The list of child comments (with the same format) + + * next: The URL of the next page (or null if first page) + + * previous: The URL of the previous page (or null if last page) + """ + def list(self, request): + """ + Implements the GET method for the list endpoint as described in the + class docstring. + """ + form = CommentListGetForm(request.GET) + if not form.is_valid(): + raise ValidationError(form.errors) + return Response( + get_comment_list( + request, + form.cleaned_data["thread_id"], + form.cleaned_data["endorsed"], + form.cleaned_data["page"], + form.cleaned_data["page_size"] + ) + )