diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 3d69c74a40..51846d7d9e 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -185,7 +185,7 @@ def get_course_topics(request, course_key): } -def get_thread_list(request, course_key, page, page_size, topic_id_list=None): +def get_thread_list(request, course_key, page, page_size, topic_id_list=None, text_search=None): """ Return the list of all discussion threads pertaining to the given course @@ -196,16 +196,30 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): page: The page number (1-indexed) to retrieve page_size: The number of threads to retrieve per page topic_id_list: The list of topic_ids to get the discussion threads for + text_search A text search query string to match + + Note that topic_id_list and text_search are mutually exclusive. Returns: A paginated result containing a list of threads; see discussion_api.views.ThreadViewSet for more detail. + + Raises: + + ValueError: if more than one of the mutually exclusive parameters is + provided + Http404: if the requesting user does not have access to the requested course + or a page beyond the last is requested """ + exclusive_param_count = sum(1 for param in [topic_id_list, text_search] if param) + if exclusive_param_count > 1: # pragma: no cover + raise ValueError("More than one mutually exclusive param passed to get_thread_list") + course = _get_course_or_404(course_key, request.user) context = get_context(course, request) topic_ids_csv = ",".join(topic_id_list) if topic_id_list else None - threads, result_page, num_pages, _ = Thread.search({ + threads, result_page, num_pages, text_search_rewrite = Thread.search({ "course_id": unicode(course.id), "group_id": ( None if context["is_requester_privileged"] else @@ -216,6 +230,7 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): "page": page, "per_page": page_size, "commentable_ids": topic_ids_csv, + "text": text_search, }) # The comments service returns the last page of results if the requested # page is beyond the last page, but we want be consistent with DRF's general @@ -224,7 +239,9 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): raise Http404 results = [ThreadSerializer(thread, context=context).data for thread in threads] - return get_paginated_data(request, results, page, num_pages) + ret = get_paginated_data(request, results, page, num_pages) + ret["text_search_rewrite"] = text_search_rewrite + return ret def get_comment_list(request, thread_id, endorsed, page, page_size): diff --git a/lms/djangoapps/discussion_api/forms.py b/lms/djangoapps/discussion_api/forms.py index 2f1df946d6..02683ecb5a 100644 --- a/lms/djangoapps/discussion_api/forms.py +++ b/lms/djangoapps/discussion_api/forms.py @@ -45,8 +45,11 @@ class ThreadListGetForm(_PaginationForm): """ A form to validate query parameters in the thread list retrieval endpoint """ + EXCLUSIVE_PARAMS = ["topic_id", "text_search"] + course_id = CharField() topic_id = TopicIdField(required=False) + text_search = CharField(required=False) def clean_course_id(self): """Validate course_id""" @@ -56,6 +59,19 @@ class ThreadListGetForm(_PaginationForm): except InvalidKeyError: raise ValidationError("'{}' is not a valid course id".format(value)) + def clean(self): + cleaned_data = super(ThreadListGetForm, self).clean() + exclusive_params_count = sum( + 1 for param in self.EXCLUSIVE_PARAMS if cleaned_data.get(param) + ) + if exclusive_params_count > 1: + raise ValidationError( + "The following query parameters are mutually exclusive: {}".format( + ", ".join(self.EXCLUSIVE_PARAMS) + ) + ) + return cleaned_data + class ThreadActionsForm(Form): """ diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index 15c93a116e..9767f07374 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -404,7 +404,15 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest self.author = UserFactory.create() self.cohort = CohortFactory.create(course_id=self.course.id) - def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None, topic_id_list=None): + def get_thread_list( + self, + threads, + page=1, + page_size=1, + num_pages=1, + course=None, + topic_id_list=None, + ): """ Register the appropriate comments service response, then call get_thread_list and return the result. @@ -435,6 +443,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest "results": [], "next": None, "previous": None, + "text_search_rewrite": None, } ) @@ -569,6 +578,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest "results": expected_threads, "next": None, "previous": None, + "text_search_rewrite": None, } ) @@ -603,6 +613,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest "results": [], "next": "http://testserver/test_path?page=2", "previous": None, + "text_search_rewrite": None, } ) self.assertEqual( @@ -611,6 +622,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest "results": [], "next": "http://testserver/test_path?page=3", "previous": "http://testserver/test_path?page=1", + "text_search_rewrite": None, } ) self.assertEqual( @@ -619,6 +631,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest "results": [], "next": None, "previous": "http://testserver/test_path?page=2", + "text_search_rewrite": None, } ) @@ -627,6 +640,34 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest with self.assertRaises(Http404): get_thread_list(self.request, self.course.id, page=4, page_size=10) + @ddt.data(None, "rewritten search string") + def test_text_search(self, text_search_rewrite): + self.register_get_threads_search_response([], text_search_rewrite) + self.assertEqual( + get_thread_list( + self.request, + self.course.id, + page=1, + page_size=10, + text_search="test search string" + ), + { + "results": [], + "next": None, + "previous": None, + "text_search_rewrite": text_search_rewrite, + } + ) + self.assert_last_query_params({ + "course_id": [unicode(self.course.id)], + "sort_key": ["date"], + "sort_order": ["desc"], + "page": ["1"], + "per_page": ["10"], + "recursive": ["False"], + "text": ["test search string"], + }) + @ddt.ddt class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): diff --git a/lms/djangoapps/discussion_api/tests/test_forms.py b/lms/djangoapps/discussion_api/tests/test_forms.py index 07c7128311..60223c58fd 100644 --- a/lms/djangoapps/discussion_api/tests/test_forms.py +++ b/lms/djangoapps/discussion_api/tests/test_forms.py @@ -1,9 +1,12 @@ """ Tests for Discussion API forms """ +import itertools from unittest import TestCase from urllib import urlencode +import ddt + from django.http import QueryDict from opaque_keys.edx.locator import CourseLocator @@ -63,6 +66,7 @@ class PaginationTestMixin(object): self.assert_field_value("page_size", 100) +@ddt.ddt class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): """Tests for ThreadListGetForm""" FORM_CLASS = ThreadListGetForm @@ -81,7 +85,6 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ) def test_basic(self): - self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"]) form = self.get_form(expected_valid=True) self.assertEqual( form.cleaned_data, @@ -89,10 +92,27 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): "course_id": CourseLocator.from_string("Foo/Bar/Baz"), "page": 2, "page_size": 13, - "topic_id": ["example topic_id", "example 2nd topic_id"], + "topic_id": [], + "text_search": "", } ) + def test_topic_id(self): + self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"]) + form = self.get_form(expected_valid=True) + self.assertEqual( + form.cleaned_data["topic_id"], + ["example topic_id", "example 2nd topic_id"], + ) + + def test_text_search(self): + self.form_data["text_search"] = "test search string" + form = self.get_form(expected_valid=True) + self.assertEqual( + form.cleaned_data["text_search"], + "test search string", + ) + def test_missing_course_id(self): self.form_data.pop("course_id") self.assert_error("course_id", "This field is required.") @@ -105,6 +125,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): self.form_data.setlist("topic_id", ["", "not empty"]) self.assert_error("topic_id", "This field cannot be empty.") + @ddt.data(*itertools.combinations(["topic_id", "text_search"], 2)) + def test_mutually_exclusive(self, params): + self.form_data.update({param: "dummy" for param in params}) + self.assert_error( + "__all__", + "The following query parameters are mutually exclusive: topic_id, text_search" + ) + class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): """Tests for CommentListGetForm""" diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index 1b1b038a3d..8da2001c91 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -182,6 +182,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): "results": expected_threads, "next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2", "previous": None, + "text_search_rewrite": None, } ) self.assert_last_query_params({ @@ -214,6 +215,28 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): "recursive": ["False"], }) + def test_text_search(self): + self.register_get_user_response(self.user) + self.register_get_threads_search_response([], None) + response = self.client.get( + self.url, + {"course_id": unicode(self.course.id), "text_search": "test search string"} + ) + self.assert_response_correct( + response, + 200, + {"results": [], "next": None, "previous": None, "text_search_rewrite": None} + ) + self.assert_last_query_params({ + "course_id": [unicode(self.course.id)], + "sort_key": ["date"], + "sort_order": ["desc"], + "page": ["1"], + "per_page": ["10"], + "recursive": ["False"], + "text": ["test search string"], + }) + @httpretty.activate class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): diff --git a/lms/djangoapps/discussion_api/tests/utils.py b/lms/djangoapps/discussion_api/tests/utils.py index 32a4e9dcfe..4e39639ee7 100644 --- a/lms/djangoapps/discussion_api/tests/utils.py +++ b/lms/djangoapps/discussion_api/tests/utils.py @@ -71,6 +71,20 @@ class CommentsServiceMockMixin(object): status=200 ) + def register_get_threads_search_response(self, threads, rewrite): + """Register a mock response for GET on the CS thread search endpoint""" + httpretty.register_uri( + httpretty.GET, + "http://localhost:4567/api/v1/search/threads", + body=json.dumps({ + "collection": threads, + "page": 1, + "num_pages": 1, + "corrected_text": rewrite, + }), + status=200 + ) + def register_post_thread_response(self, thread_data): """Register a mock response for POST on the CS commentable endpoint""" httpretty.register_uri( diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index 500ae5d570..ba943ebd78 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -104,6 +104,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): multiple topic_id queries to retrieve threads from multiple topics at once. + * text_search: A search string to match. Any thread whose content + (including the bodies of comments in the thread) matches the search + string will be returned. + **POST Parameters**: * course_id (required): The course to create the thread in @@ -133,6 +137,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): * previous: The URL of the previous page (or null if last page) + * text_search_rewrite: The search string to which the text_search + parameter was rewritten in order to match threads (e.g. for spelling + correction) + **POST/PATCH response values**: * id: The id of the thread @@ -184,6 +192,7 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): form.cleaned_data["page"], form.cleaned_data["page_size"], form.cleaned_data["topic_id"], + form.cleaned_data["text_search"], ) )