diff --git a/lms/djangoapps/courseware/tests/test_lti_integration.py b/lms/djangoapps/courseware/tests/test_lti_integration.py index e634e460ba..d22050d3a1 100644 --- a/lms/djangoapps/courseware/tests/test_lti_integration.py +++ b/lms/djangoapps/courseware/tests/test_lti_integration.py @@ -178,11 +178,10 @@ class TestLTIModuleListing(ModuleStoreTestCase): def test_lti_rest_bad_course(self): """Tests what happens when the lti listing rest endpoint gets a bad course_id""" - bad_ids = [u"sf", u"dne/dne/dne", u"fo/ey/\u5305"] - request = mock.Mock() - request.method = 'GET' + bad_ids = [u"sf", u"dne/dne/dne", u"fo/ey/\\u5305"] for bad_course_id in bad_ids: - response = get_course_lti_endpoints(request, bad_course_id) + lti_rest_endpoints_url = 'courses/{}/lti_rest_endpoints/'.format(bad_course_id) + response = self.client.get(lti_rest_endpoints_url) self.assertEqual(404, response.status_code) def test_lti_rest_listing(self): diff --git a/lms/djangoapps/courseware/tests/test_views.py b/lms/djangoapps/courseware/tests/test_views.py index 540e725801..d373f9e805 100644 --- a/lms/djangoapps/courseware/tests/test_views.py +++ b/lms/djangoapps/courseware/tests/test_views.py @@ -5,7 +5,7 @@ Tests courseware views.py import unittest from datetime import datetime -from mock import MagicMock, patch +from mock import MagicMock, patch, create_autospec from pytz import UTC from django.test import TestCase @@ -152,6 +152,10 @@ class ViewsTestCase(TestCase): response = self.client.get('/courses/MITx/3.091X/') self.assertEqual(response.status_code, 404) + def test_incomplete_course_id(self): + response = self.client.get('/courses/MITx/') + self.assertEqual(response.status_code, 404) + def test_index_invalid_position(self): request_url = '/'.join([ '/courses', @@ -562,3 +566,26 @@ class ProgressPageTests(ModuleStoreTestCase): def test_non_asci_grade_cutoffs(self): resp = views.progress(self.request, self.course.id.to_deprecated_string()) self.assertEqual(resp.status_code, 200) + + +class TestVerifyCourseIdDecorator(TestCase): + """ + Tests for the verify_course_id decorator. + """ + + def setUp(self): + self.request = RequestFactory().get("foo") + self.valid_course_id = "edX/test/1" + self.invalid_course_id = "edX/" + + def test_decorator_with_valid_course_id(self): + mocked_view = create_autospec(views.course_about) + view_function = views.verify_course_id(mocked_view) + view_function(self.request, self.valid_course_id) + self.assertTrue(mocked_view.called) + + def test_decorator_with_invalid_course_id(self): + mocked_view = create_autospec(views.course_about) + view_function = views.verify_course_id(mocked_view) + self.assertRaises(Http404, view_function, self.request, self.invalid_course_id) + self.assertFalse(mocked_view.called) diff --git a/lms/djangoapps/courseware/views.py b/lms/djangoapps/courseware/views.py index a048309ce3..350485f2e3 100644 --- a/lms/djangoapps/courseware/views.py +++ b/lms/djangoapps/courseware/views.py @@ -23,6 +23,7 @@ from edxmako.shortcuts import render_to_response, render_to_string from django_future.csrf import ensure_csrf_cookie from django.views.decorators.cache import cache_control from django.db import transaction +from functools import wraps from markupsafe import escape from courseware import grades @@ -81,6 +82,24 @@ def user_groups(user): return group_names +def verify_course_id(view_func): + """ + This decorator should only be used with views whose second argument is course_id. + If course_id is not valid raise 404. + """ + + @wraps(view_func) + def _decorated(request, course_id, *args, **kwargs): + try: + SlashSeparatedCourseKey.from_deprecated_string(course_id) + except InvalidKeyError: + raise Http404 + response = view_func(request, course_id, *args, **kwargs) + return response + + return _decorated + + @ensure_csrf_cookie @cache_if_anonymous def courses(request): @@ -242,6 +261,7 @@ def chat_settings(course, user): @login_required @ensure_csrf_cookie @cache_control(no_cache=True, no_store=True, must_revalidate=True) +@verify_course_id def index(request, course_id, chapter=None, section=None, position=None): """ @@ -266,7 +286,9 @@ def index(request, course_id, chapter=None, section=None, - HTTPresponse """ + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + user = User.objects.prefetch_related("groups").get(id=request.user.id) request.user = user # keep just one instance of User course = get_course_with_access(user, 'load', course_key, depth=2) @@ -458,6 +480,7 @@ def index(request, course_id, chapter=None, section=None, @ensure_csrf_cookie +@verify_course_id def jump_to_id(request, course_id, module_id): """ This entry point allows for a shorter version of a jump to where just the id of the element is @@ -516,13 +539,16 @@ def jump_to(request, course_id, location): @ensure_csrf_cookie +@verify_course_id def course_info(request, course_id): """ Display the course's info.html, or 404 if there is no such course. Assumes the course_id is in a valid format. """ + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + course = get_course_with_access(request.user, 'load', course_key) staff_access = has_access(request.user, 'staff', course) masq = setup_masquerade(request, staff_access) # allow staff to toggle masquerade on info page @@ -544,16 +570,15 @@ def course_info(request, course_id): @ensure_csrf_cookie +@verify_course_id def static_tab(request, course_id, tab_slug): """ Display the courses tab with the given name. Assumes the course_id is in a valid format. """ - try: - course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) - except InvalidKeyError: - raise Http404 + + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) course = get_course_with_access(request.user, 'load', course_key) @@ -579,13 +604,16 @@ def static_tab(request, course_id, tab_slug): @ensure_csrf_cookie +@verify_course_id def syllabus(request, course_id): """ Display the course's syllabus.html, or 404 if there is no such course. Assumes the course_id is in a valid format. """ + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + course = get_course_with_access(request.user, 'load', course_key) staff_access = has_access(request.user, 'staff', course) @@ -621,7 +649,9 @@ def course_about(request, course_id): settings.FEATURES.get('ENABLE_MKTG_SITE', False) ): raise Http404 + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + course = get_course_with_access(request.user, 'see_exists', course_key) registered = registered_for_course(course, request.user) staff_access = has_access(request.user, 'staff', course) @@ -683,12 +713,14 @@ def course_about(request, course_id): @ensure_csrf_cookie @cache_if_anonymous +@verify_course_id def mktg_course_about(request, course_id): """ This is the button that gets put into an iframe on the Drupal site """ course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + try: course = get_course_with_access(request.user, 'see_exists', course_key) except (ValueError, Http404) as e: @@ -739,13 +771,17 @@ def mktg_course_about(request, course_id): @login_required @cache_control(no_cache=True, no_store=True, must_revalidate=True) @transaction.commit_manually +@verify_course_id def progress(request, course_id, student_id=None): """ Wraps "_progress" with the manual_transaction context manager just in case there are unanticipated errors. """ + + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) + with grades.manual_transaction(): - return _progress(request, SlashSeparatedCourseKey.from_deprecated_string(course_id), student_id) + return _progress(request, course_key, student_id) def _progress(request, course_key, student_id): @@ -816,16 +852,15 @@ def fetch_reverify_banner_info(request, course_key): @login_required +@verify_course_id def submission_history(request, course_id, student_username, location): """Render an HTML fragment (meant for inclusion elsewhere) that renders a history of all state changes made by this user for this problem location. Right now this only works for problems because that's all StudentModuleHistory records. """ - try: - course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) - except (InvalidKeyError, AssertionError): - return HttpResponse(escape(_(u'Invalid course id.'))) + + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) try: usage_key = course_key.make_usage_key_from_deprecated_string(location) @@ -925,6 +960,7 @@ def get_static_tab_contents(request, course, tab): @require_GET +@verify_course_id def get_course_lti_endpoints(request, course_id): """ View that, given a course_id, returns the a JSON object that enumerates all of the LTI endpoints for that course. @@ -941,10 +977,8 @@ def get_course_lti_endpoints(request, course_id): Returns: (django response object): HTTP response. 404 if course is not found, otherwise 200 with JSON body. """ - try: - course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) - except InvalidKeyError: - return HttpResponse(status=404) + + course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) try: course = get_course(course_key, depth=2)