diff --git a/lms/djangoapps/instructor/tests/test_tools.py b/lms/djangoapps/instructor/tests/test_tools.py index aa216625a7..35ca654255 100644 --- a/lms/djangoapps/instructor/tests/test_tools.py +++ b/lms/djangoapps/instructor/tests/test_tools.py @@ -62,6 +62,27 @@ class TestHandleDashboardError(unittest.TestCase): self.assertEqual(view(None, None), "Oh yes!") +class TestRequireStudentIdentifier(unittest.TestCase): + """ + Test require_student_from_identifier() + """ + def setUp(self): + """ + Fixtures + """ + self.student = UserFactory.create() + + def test_valid_student_id(self): + self.assertEqual( + self.student, + tools.require_student_from_identifier(self.student.username) + ) + + def test_invalid_student_id(self): + with self.assertRaises(tools.DashboardError): + tools.require_student_from_identifier("invalid") + + class TestParseDatetime(unittest.TestCase): """ Test date parsing. diff --git a/lms/djangoapps/instructor/views/api.py b/lms/djangoapps/instructor/views/api.py index 72990cbea4..18e08b9138 100644 --- a/lms/djangoapps/instructor/views/api.py +++ b/lms/djangoapps/instructor/views/api.py @@ -66,6 +66,7 @@ from .tools import ( dump_module_extensions, find_unit, get_student_from_identifier, + require_student_from_identifier, handle_dashboard_error, parse_datetime, set_due_date_extension, @@ -1392,7 +1393,7 @@ def change_due_date(request, course_id): Grants a due date extension to a student for a particular unit. """ course = get_course_by_id(SlashSeparatedCourseKey.from_deprecated_string(course_id)) - student = get_student_from_identifier(request.GET.get('student')) + student = require_student_from_identifier(request.GET.get('student')) unit = find_unit(course, request.GET.get('url')) due_date = parse_datetime(request.GET.get('due_datetime')) set_due_date_extension(course, unit, student, due_date) @@ -1413,7 +1414,7 @@ def reset_due_date(request, course_id): Rescinds a due date extension for a student on a particular unit. """ course = get_course_by_id(SlashSeparatedCourseKey.from_deprecated_string(course_id)) - student = get_student_from_identifier(request.GET.get('student')) + student = require_student_from_identifier(request.GET.get('student')) unit = find_unit(course, request.GET.get('url')) set_due_date_extension(course, unit, student, None) if not getattr(unit, "due", None): @@ -1453,7 +1454,7 @@ def show_student_extensions(request, course_id): Shows all of the due date extensions granted to a particular student in a particular course. """ - student = get_student_from_identifier(request.GET.get('student')) + student = require_student_from_identifier(request.GET.get('student')) course = get_course_by_id(SlashSeparatedCourseKey.from_deprecated_string(course_id)) return JsonResponse(dump_student_extensions(course, student)) diff --git a/lms/djangoapps/instructor/views/tools.py b/lms/djangoapps/instructor/views/tools.py index 44919ca064..302b0755fb 100644 --- a/lms/djangoapps/instructor/views/tools.py +++ b/lms/djangoapps/instructor/views/tools.py @@ -90,6 +90,21 @@ def get_student_from_identifier(unique_student_identifier): return student +def require_student_from_identifier(unique_student_identifier): + """ + Same as get_student_from_identifier() but will raise a DashboardError if + the student does not exist. + """ + try: + return get_student_from_identifier(unique_student_identifier) + except User.DoesNotExist: + raise DashboardError( + _("Could not find student matching identifier: {student_identifier}").format( + student_identifier=unique_student_identifier + ) + ) + + def parse_datetime(datestr): """ Convert user input date string into an instance of `datetime.datetime` in