diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 21034f5aa7..46919297fa 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -1,8 +1,12 @@ """ Discussion API internal interface """ +from django.http import Http404 + from collections import defaultdict +from lms.lib.comment_client.thread import Thread +from discussion_api.pagination import get_paginated_data from django_comment_client.utils import get_accessible_discussion_modules @@ -63,3 +67,62 @@ def get_course_topics(course, user): "courseware_topics": courseware_topics, "non_courseware_topics": non_courseware_topics, } + + +def _cc_thread_to_api_thread(thread): + """ + Convert a thread data dict from the comment_client format (which is a direct + representation of the format returned by the comments service) to the format + used in this API + """ + ret = { + key: thread[key] + for key in [ + "id", + "course_id", + "created_at", + "updated_at", + "type", + "title", + "pinned", + "closed", + ] + } + ret.update({ + "topic_id": thread["commentable_id"], + "raw_body": thread["body"], + "comment_count": thread["comments_count"], + "unread_comment_count": thread["unread_comments_count"], + }) + return ret + + +def get_thread_list(request, course_key, page, page_size): + """ + Return the list of all discussion threads pertaining to the given course + + Parameters: + + request: The django request objects used for build_absolute_uri + course_key: The key of the course to get discussion threads for + page: The page number (1-indexed) to retrieve + page_size: The number of threads to retrieve per page + + Returns: + + A paginated result containing a list of threads; see + discussion_api.views.ThreadViewSet for more detail. + """ + threads, result_page, num_pages, _ = Thread.search({ + "course_id": unicode(course_key), + "page": page, + "per_page": page_size + }) + # The comments service returns the last page of results if the requested + # page is beyond the last page, but we want be consistent with DRF's general + # behavior and return a 404 in that case + if result_page != page: + raise Http404 + + results = [_cc_thread_to_api_thread(thread) for thread in threads] + return get_paginated_data(request, results, page, num_pages) diff --git a/lms/djangoapps/discussion_api/forms.py b/lms/djangoapps/discussion_api/forms.py new file mode 100644 index 0000000000..e25b9a53e9 --- /dev/null +++ b/lms/djangoapps/discussion_api/forms.py @@ -0,0 +1,33 @@ +""" +Discussion API forms +""" +from django.core.exceptions import ValidationError +from django.forms import Form, CharField, IntegerField + +from opaque_keys import InvalidKeyError +from opaque_keys.edx.locator import CourseLocator + + +class ThreadListGetForm(Form): + """ + A form to validate query parameters in the thread list retrieval endpoint + """ + course_id = CharField() + page = IntegerField(required=False, min_value=1) + page_size = IntegerField(required=False, min_value=1) + + def clean_course_id(self): + """Validate course_id""" + value = self.cleaned_data["course_id"] + try: + return CourseLocator.from_string(value) + except InvalidKeyError: + raise ValidationError("'{}' is not a valid course id".format(value)) + + def clean_page(self): + """Return given valid page or default of 1""" + return self.cleaned_data.get("page") or 1 + + def clean_page_size(self): + """Return given valid page_size (capped at 100) or default of 10""" + return min(self.cleaned_data.get("page_size") or 10, 100) diff --git a/lms/djangoapps/discussion_api/pagination.py b/lms/djangoapps/discussion_api/pagination.py new file mode 100644 index 0000000000..5f14b430e8 --- /dev/null +++ b/lms/djangoapps/discussion_api/pagination.py @@ -0,0 +1,58 @@ +""" +Discussion API pagination support +""" +from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField + + +class _PaginationSerializer(BasePaginationSerializer): + """ + A pagination serializer without the count field, because the Comments + Service does not return result counts + """ + next = NextPageField(source="*") + previous = PreviousPageField(source="*") + + +class _Page(object): + """ + Implements just enough of the django.core.paginator.Page interface to allow + PaginationSerializer to work. + """ + def __init__(self, object_list, page_num, num_pages): + """ + Create a new page containing the given objects, with the given page + number and number of pages + """ + self.object_list = object_list + self.page_num = page_num + self.num_pages = num_pages + + def has_next(self): + """Returns True if there is a page after this one, otherwise False""" + return self.page_num < self.num_pages + + def has_previous(self): + """Returns True if there is a page before this one, otherwise False""" + return self.page_num > 1 + + def next_page_number(self): + """Returns the number of the next page""" + return self.page_num + 1 + + def previous_page_number(self): + """Returns the number of the previous page""" + return self.page_num - 1 + + +def get_paginated_data(request, results, page_num, per_page): + """ + Return a dict with the following values: + + next: The URL for the next page + previous: The URL for the previous page + results: The results on this page + """ + return _PaginationSerializer( + instance=_Page(results, page_num, per_page), + context={"request": request} + ).data diff --git a/lms/djangoapps/discussion_api/tests/__init__.py b/lms/djangoapps/discussion_api/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 7b034183f6..79fead35d3 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -3,11 +3,18 @@ Tests for Discussion API internal interface """ from datetime import datetime, timedelta +import httpretty import mock from pytz import UTC +from django.http import Http404 +from django.test.client import RequestFactory + +from opaque_keys.edx.locator import CourseLocator + from courseware.tests.factories import BetaTesterFactory, StaffFactory -from discussion_api.api import get_course_topics +from discussion_api.api import get_course_topics, get_thread_list +from discussion_api.tests.utils import CommentsServiceMockMixin from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from student.tests.factories import UserFactory @@ -304,3 +311,171 @@ class GetCourseTopicsTest(ModuleStoreTestCase): "non_courseware_topics": [], } self.assertEqual(staff_actual, staff_expected) + + +@httpretty.activate +class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): + """Test for get_thread_list""" + def setUp(self): + super(GetThreadListTest, self).setUp() + self.maxDiff = None # pylint: disable=invalid-name + self.request = RequestFactory().get("/test_path") + self.course_key = CourseLocator.from_string("a/b/c") + + def get_thread_list(self, threads, page=1, page_size=1, num_pages=1): + """ + Register the appropriate comments service response, then call + get_thread_list and return the result. + """ + self.register_get_threads_response(threads, page, num_pages) + ret = get_thread_list(self.request, self.course_key, page, page_size) + return ret + + def test_empty(self): + self.assertEqual( + self.get_thread_list([]), + { + "results": [], + "next": None, + "previous": None, + } + ) + + def test_basic_query_params(self): + self.get_thread_list([], page=6, page_size=14) + self.assert_last_query_params({ + "course_id": [unicode(self.course_key)], + "page": ["6"], + "per_page": ["14"], + "recursive": ["False"], + }) + + def test_thread_content(self): + source_threads = [ + { + "id": "test_thread_id_0", + "course_id": unicode(self.course_key), + "commentable_id": "topic_x", + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "type": "discussion", + "title": "Test Title", + "body": "Test body", + "pinned": False, + "closed": False, + "comments_count": 5, + "unread_comments_count": 3, + }, + { + "id": "test_thread_id_1", + "course_id": unicode(self.course_key), + "commentable_id": "topic_y", + "created_at": "2015-04-28T22:22:22Z", + "updated_at": "2015-04-28T00:33:33Z", + "type": "question", + "title": "Another Test Title", + "body": "More content", + "pinned": False, + "closed": True, + "comments_count": 18, + "unread_comments_count": 0, + }, + { + "id": "test_thread_id_2", + "course_id": unicode(self.course_key), + "commentable_id": "topic_x", + "created_at": "2015-04-28T00:44:44Z", + "updated_at": "2015-04-28T00:55:55Z", + "type": "discussion", + "title": "Yet Another Test Title", + "body": "Still more content", + "pinned": True, + "closed": False, + "comments_count": 0, + "unread_comments_count": 0, + }, + ] + expected_threads = [ + { + "id": "test_thread_id_0", + "course_id": unicode(self.course_key), + "topic_id": "topic_x", + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "type": "discussion", + "title": "Test Title", + "raw_body": "Test body", + "pinned": False, + "closed": False, + "comment_count": 5, + "unread_comment_count": 3, + }, + { + "id": "test_thread_id_1", + "course_id": unicode(self.course_key), + "topic_id": "topic_y", + "created_at": "2015-04-28T22:22:22Z", + "updated_at": "2015-04-28T00:33:33Z", + "type": "question", + "title": "Another Test Title", + "raw_body": "More content", + "pinned": False, + "closed": True, + "comment_count": 18, + "unread_comment_count": 0, + }, + { + "id": "test_thread_id_2", + "course_id": unicode(self.course_key), + "topic_id": "topic_x", + "created_at": "2015-04-28T00:44:44Z", + "updated_at": "2015-04-28T00:55:55Z", + "type": "discussion", + "title": "Yet Another Test Title", + "raw_body": "Still more content", + "pinned": True, + "closed": False, + "comment_count": 0, + "unread_comment_count": 0, + }, + ] + self.assertEqual( + self.get_thread_list(source_threads), + { + "results": expected_threads, + "next": None, + "previous": None, + } + ) + + def test_pagination(self): + # N.B. Empty thread list is not realistic but convenient for this test + self.assertEqual( + self.get_thread_list([], page=1, num_pages=3), + { + "results": [], + "next": "http://testserver/test_path?page=2", + "previous": None, + } + ) + self.assertEqual( + self.get_thread_list([], page=2, num_pages=3), + { + "results": [], + "next": "http://testserver/test_path?page=3", + "previous": "http://testserver/test_path?page=1", + } + ) + self.assertEqual( + self.get_thread_list([], page=3, num_pages=3), + { + "results": [], + "next": None, + "previous": "http://testserver/test_path?page=2", + } + ) + + # Test page past the last one + self.register_get_threads_response([], page=3, num_pages=3) + with self.assertRaises(Http404): + get_thread_list(self.request, self.course_key, page=4, page_size=10) diff --git a/lms/djangoapps/discussion_api/tests/test_forms.py b/lms/djangoapps/discussion_api/tests/test_forms.py new file mode 100644 index 0000000000..f988b1f484 --- /dev/null +++ b/lms/djangoapps/discussion_api/tests/test_forms.py @@ -0,0 +1,84 @@ +""" +Tests for Discussion API forms +""" +from unittest import TestCase + +from opaque_keys.edx.locator import CourseLocator + +from discussion_api.forms import ThreadListGetForm + + +class ThreadListGetFormTest(TestCase): + """Tests for ThreadListGetForm""" + def setUp(self): + super(ThreadListGetFormTest, self).setUp() + self.form_data = { + "course_id": "Foo/Bar/Baz", + "page": "2", + "page_size": "13", + } + + def get_form(self, expected_valid): + """ + Return a form bound to self.form_data, asserting its validity (or lack + thereof) according to expected_valid + """ + form = ThreadListGetForm(self.form_data) + self.assertEqual(form.is_valid(), expected_valid) + return form + + def assert_error(self, expected_field, expected_message): + """ + Create a form bound to self.form_data, assert its invalidity, and assert + that its error dictionary contains one entry with the expected field and + message + """ + form = self.get_form(expected_valid=False) + self.assertEqual(form.errors, {expected_field: [expected_message]}) + + def assert_field_value(self, field, expected_value): + """ + Create a form bound to self.form_data, assert its validity, and assert + that the given field in the cleaned data has the expected value + """ + form = self.get_form(expected_valid=True) + self.assertEqual(form.cleaned_data[field], expected_value) + + def test_basic(self): + form = self.get_form(expected_valid=True) + self.assertEqual( + form.cleaned_data, + { + "course_id": CourseLocator.from_string("Foo/Bar/Baz"), + "page": 2, + "page_size": 13, + } + ) + + def test_missing_course_id(self): + self.form_data.pop("course_id") + self.assert_error("course_id", "This field is required.") + + def test_invalid_course_id(self): + self.form_data["course_id"] = "invalid course id" + self.assert_error("course_id", "'invalid course id' is not a valid course id") + + def test_missing_page(self): + self.form_data.pop("page") + self.assert_field_value("page", 1) + + def test_invalid_page(self): + self.form_data["page"] = "0" + self.assert_error("page", "Ensure this value is greater than or equal to 1.") + + def test_missing_page_size(self): + self.form_data.pop("page_size") + self.assert_field_value("page_size", 10) + + def test_zero_page_size(self): + self.form_data["page_size"] = "0" + self.assert_error("page_size", "Ensure this value is greater than or equal to 1.") + + def test_excessive_page_size(self): + self.form_data["page_size"] = "101" + self.assert_field_value("page_size", 100) diff --git a/lms/djangoapps/discussion_api/tests/test_pagination.py b/lms/djangoapps/discussion_api/tests/test_pagination.py new file mode 100644 index 0000000000..5f1775ab52 --- /dev/null +++ b/lms/djangoapps/discussion_api/tests/test_pagination.py @@ -0,0 +1,70 @@ +""" +Tests for Discussion API pagination support +""" +from unittest import TestCase + +from django.test import RequestFactory + +from discussion_api.pagination import get_paginated_data + + +class PaginationSerializerTest(TestCase): + """Tests for PaginationSerializer""" + def do_case(self, objects, page_num, num_pages, expected): + """ + Make a dummy request, and assert that get_paginated_data with the given + parameters returns the expected result + """ + request = RequestFactory().get("/test") + actual = get_paginated_data(request, objects, page_num, num_pages) + self.assertEqual(actual, expected) + + def test_empty(self): + self.do_case( + [], 1, 0, + { + "next": None, + "previous": None, + "results": [], + } + ) + + def test_only_page(self): + self.do_case( + ["foo"], 1, 1, + { + "next": None, + "previous": None, + "results": ["foo"], + } + ) + + def test_first_of_many(self): + self.do_case( + ["foo"], 1, 3, + { + "next": "http://testserver/test?page=2", + "previous": None, + "results": ["foo"], + } + ) + + def test_last_of_many(self): + self.do_case( + ["foo"], 3, 3, + { + "next": None, + "previous": "http://testserver/test?page=2", + "results": ["foo"], + } + ) + + def test_middle_of_many(self): + self.do_case( + ["foo"], 2, 3, + { + "next": "http://testserver/test?page=3", + "previous": "http://testserver/test?page=1", + "results": ["foo"], + } + ) diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index c6d213963e..1bf852eceb 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -4,11 +4,13 @@ Tests for Discussion API views from datetime import datetime import json +import httpretty import mock from pytz import UTC from django.core.urlresolvers import reverse +from discussion_api.tests.utils import CommentsServiceMockMixin from student.tests.factories import CourseEnrollmentFactory, UserFactory from util.testing import UrlResetMixin from xmodule.modulestore.django import modulestore @@ -17,12 +19,16 @@ from xmodule.modulestore.tests.factories import CourseFactory from xmodule.tabs import DiscussionTab -class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): - """Tests for CourseTopicsView""" - +class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): + """ + Mixin for common code in tests of Discussion API views. This includes + creation of common structures (e.g. a course, user, and enrollment), logging + in the test client, utility functions, and a test case for unauthenticated + requests. Subclasses must set self.url in their setUp methods. + """ @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) def setUp(self): - super(CourseTopicsViewTest, self).setUp() + super(DiscussionAPIViewTestMixin, self).setUp() self.maxDiff = None # pylint: disable=invalid-name self.course = CourseFactory.create( org="x", @@ -34,9 +40,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): self.password = "password" self.user = UserFactory.create(password=self.password) CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) - self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)}) self.client.login(username=self.user.username, password=self.password) + def login_unenrolled_user(self): + """Create a user not enrolled in the course and log it in""" + unenrolled_user = UserFactory.create(password=self.password) + self.client.login(username=unenrolled_user.username, password=self.password) + def assert_response_correct(self, response, expected_status, expected_content): """ Assert that the response has the given status code and parsed content @@ -54,6 +64,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): {"developer_message": "Authentication credentials were not provided."} ) + +class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): + """Tests for CourseTopicsView""" + def setUp(self): + super(CourseTopicsViewTest, self).setUp() + self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)}) + def test_non_existent_course(self): response = self.client.get( reverse("course_topics", kwargs={"course_id": "non/existent/course"}) @@ -65,8 +82,7 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): ) def test_not_enrolled(self): - unenrolled_user = UserFactory.create(password=self.password) - self.client.login(username=unenrolled_user.username, password=self.password) + self.login_unenrolled_user() response = self.client.get(self.url) self.assert_response_correct( response, @@ -98,3 +114,93 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): }], } ) + + +@httpretty.activate +class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): + """Tests for ThreadViewSet list""" + def setUp(self): + super(ThreadViewSetListTest, self).setUp() + self.url = reverse("thread-list") + + def test_course_id_missing(self): + response = self.client.get(self.url) + self.assert_response_correct( + response, + 400, + {"field_errors": {"course_id": "This field is required."}} + ) + + def test_not_enrolled(self): + self.login_unenrolled_user() + response = self.client.get(self.url, {"course_id": unicode(self.course.id)}) + self.assert_response_correct( + response, + 404, + {"developer_message": "Not found."} + ) + + def test_basic(self): + source_threads = [{ + "id": "test_thread", + "course_id": unicode(self.course.id), + "commentable_id": "test_topic", + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "type": "discussion", + "title": "Test Title", + "body": "Test body", + "pinned": False, + "closed": False, + "comments_count": 5, + "unread_comments_count": 3, + }] + expected_threads = [{ + "id": "test_thread", + "course_id": unicode(self.course.id), + "topic_id": "test_topic", + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "type": "discussion", + "title": "Test Title", + "raw_body": "Test body", + "pinned": False, + "closed": False, + "comment_count": 5, + "unread_comment_count": 3, + }] + self.register_get_threads_response(source_threads, page=1, num_pages=2) + response = self.client.get(self.url, {"course_id": unicode(self.course.id)}) + self.assert_response_correct( + response, + 200, + { + "results": expected_threads, + "next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2", + "previous": None, + } + ) + self.assert_last_query_params({ + "course_id": [unicode(self.course.id)], + "page": ["1"], + "per_page": ["10"], + "recursive": ["False"], + }) + + def test_pagination(self): + self.register_get_threads_response([], page=1, num_pages=1) + response = self.client.get( + self.url, + {"course_id": unicode(self.course.id), "page": "18", "page_size": "4"} + ) + self.assert_response_correct( + response, + 404, + {"developer_message": "Not found."} + ) + self.assert_last_query_params({ + "course_id": [unicode(self.course.id)], + "page": ["18"], + "per_page": ["4"], + "recursive": ["False"], + }) diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py new file mode 100644 index 0000000000..41f7f50e8c --- /dev/null +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -0,0 +1,30 @@ +""" +Discussion API test utilities +""" +import json + +import httpretty + + +class CommentsServiceMockMixin(object): + """Mixin with utility methods for mocking the comments service""" + def register_get_threads_response(self, threads, page, num_pages): + """Register a mock response for GET on the CS thread list endpoint""" + httpretty.register_uri( + httpretty.GET, + "http://localhost:4567/api/v1/threads", + body=json.dumps({ + "collection": threads, + "page": page, + "num_pages": num_pages, + }), + status=200 + ) + + def assert_last_query_params(self, expected_params): + """ + Assert that the last mock request had the expected query parameters + """ + actual_params = dict(httpretty.last_request().querystring) + actual_params.pop("request_id") # request_id is random + self.assertEqual(actual_params, expected_params) diff --git a/lms/djangoapps/discussion_api/urls.py b/lms/djangoapps/discussion_api/urls.py index 2f023eda57..040ad0e20a 100644 --- a/lms/djangoapps/discussion_api/urls.py +++ b/lms/djangoapps/discussion_api/urls.py @@ -2,10 +2,15 @@ Discussion API URLs """ from django.conf import settings -from django.conf.urls import patterns, url +from django.conf.urls import include, patterns, url -from discussion_api.views import CourseTopicsView +from rest_framework.routers import SimpleRouter +from discussion_api.views import CourseTopicsView, ThreadViewSet + + +ROUTER = SimpleRouter() +ROUTER.register("threads", ThreadViewSet, base_name="thread") urlpatterns = patterns( "discussion_api", @@ -14,4 +19,5 @@ urlpatterns = patterns( CourseTopicsView.as_view(), name="course_topics" ), + url("^v1/", include(ROUTER.urls)), ) diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index f9d82f8b58..3362d5ba27 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -1,22 +1,45 @@ """ Discussion API views """ +from django.core.exceptions import ValidationError from django.http import Http404 from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.viewsets import ViewSet from opaque_keys.edx.locator import CourseLocator from courseware.courses import get_course_with_access -from discussion_api.api import get_course_topics +from discussion_api.api import get_course_topics, get_thread_list +from discussion_api.forms import ThreadListGetForm from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin from xmodule.tabs import DiscussionTab -class CourseTopicsView(DeveloperErrorViewMixin, APIView): +class _ViewMixin(object): + """ + Mixin to provide common characteristics and utility functions for Discussion + API views + """ + authentication_classes = (OAuth2Authentication, SessionAuthentication) + permission_classes = (IsAuthenticated,) + + def get_course_or_404(self, user, course_key): + """ + Get the course descriptor, raising Http404 if the course is not found, + the user cannot access forums for the course, or the discussion tab is + disabled for the course. + """ + course = get_course_with_access(user, 'load_forum', course_key) + if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): + raise Http404 + return course + + +class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): """ **Use Cases** @@ -42,13 +65,81 @@ class CourseTopicsView(DeveloperErrorViewMixin, APIView): * non_courseware_topics: The list of topic trees that are not linked to courseware. Items are of the same format as in courseware_topics. """ - authentication_classes = (OAuth2Authentication, SessionAuthentication) - permission_classes = (IsAuthenticated,) - def get(self, request, course_id): """Implements the GET method as described in the class docstring.""" course_key = CourseLocator.from_string(course_id) - course = get_course_with_access(request.user, 'load_forum', course_key) - if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]): - raise Http404 + course = self.get_course_or_404(request.user, course_key) return Response(get_course_topics(course, request.user)) + + +class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): + """ + **Use Cases** + + Retrieve the list of threads for a course. + + **Example Requests**: + + GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015 + + **GET Parameters**: + + * course_id (required): The course to retrieve threads for + + * page: The (1-indexed) page to retrieve (default is 1) + + * page_size: The number of items per page (default is 10, max is 100) + + **Response Values**: + + * results: The list of threads. Each item in the list includes: + + * id: The id of the thread + + * course_id: The id of the thread's course + + * topic_id: The id of the thread's topic + + * created_at: The ISO 8601 timestamp for the creation of the thread + + * updated_at: The ISO 8601 timestamp for the last modification of + the thread, which may not have been an update of the title/body + + * type: The thread's type (either "question" or "discussion") + + * title: The thread's title + + * raw_body: The thread's raw body text without any rendering applied + + * pinned: Boolean indicating whether the thread has been pinned + + * closed: Boolean indicating whether the thread has been closed + + * comment_count: The number of comments within the thread + + * unread_comment_count: The number of comments within the thread + that were created or updated since the last time the user read + the thread + + * next: The URL of the next page (or null if first page) + + * previous: The URL of the previous page (or null if last page) + """ + def list(self, request): + """ + Implements the GET method for the list endpoint as described in the + class docstring. + """ + form = ThreadListGetForm(request.GET) + if not form.is_valid(): + raise ValidationError(form.errors) + course_key = form.cleaned_data["course_id"] + self.get_course_or_404(request.user, course_key) + return Response( + get_thread_list( + request, + course_key, + form.cleaned_data["page"], + form.cleaned_data["page_size"] + ) + ) diff --git a/openedx/core/lib/api/view_utils.py b/openedx/core/lib/api/view_utils.py index 9ab8c57135..a455edfac0 100644 --- a/openedx/core/lib/api/view_utils.py +++ b/openedx/core/lib/api/view_utils.py @@ -1,6 +1,7 @@ """ Utilities related to API views """ +from django.core.exceptions import NON_FIELD_ERRORS, ValidationError from django.http import Http404 from rest_framework.exceptions import APIException @@ -19,10 +20,31 @@ class DeveloperErrorViewMixin(object): """ return Response({"developer_message": developer_message}, status=status_code) + def make_validation_error_response(self, validation_error): + """ + Build a 400 error response from the given ValidationError + """ + if hasattr(validation_error, "message_dict"): + response_obj = {} + message_dict = dict(validation_error.message_dict) + non_field_error_list = message_dict.pop(NON_FIELD_ERRORS, None) + if non_field_error_list: + response_obj["developer_message"] = non_field_error_list[0] + if message_dict: + response_obj["field_errors"] = { + field: message_dict[field][0] + for field in message_dict + } + return Response(response_obj, status=400) + else: + return self.make_error_response(400, validation_error.messages[0]) + def handle_exception(self, exc): if isinstance(exc, APIException): return self.make_error_response(exc.status_code, exc.detail) elif isinstance(exc, Http404): return self.make_error_response(404, "Not found.") + elif isinstance(exc, ValidationError): + return self.make_validation_error_response(exc) else: raise