diff --git a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py index 8ac0644a12..ec25886b45 100644 --- a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py +++ b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py @@ -477,6 +477,19 @@ class BaseCourseEnrollmentTestsMixin(ProgramCacheTestCaseMixin): self.assertEqual(response.status_code, 422) self.assertIn('invalid enrollment record', response.data) + @ddt.data( + [{'status': 'pending'}], + [{'status': 'not-a-status'}], + [{'status': 'pending'}, {'status': 'pending'}], + ) + def test_no_student_key(self, bad_records): + request_data = [self.learner_enrollment('learner-1')] + request_data.extend(bad_records) + response = self.request(self.default_url, request_data) + + self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) + self.assertIn('invalid enrollment record', response.data) + class CourseEnrollmentPostTests(BaseCourseEnrollmentTestsMixin, APITestCase): """ Tests for course enrollment POST """ @@ -750,6 +763,7 @@ class ProgramCourseEnrollmentListTest(ListViewTestMixin, APITestCase): assert '?cursor=' in next_response.data['previous'] +@ddt.ddt class ProgramEnrollmentViewPostTests(APITestCase): """ Tests for the ProgramEnrollment view POST method. @@ -990,7 +1004,64 @@ class ProgramEnrollmentViewPostTests(APITestCase): '003': 'pending', }) + @ddt.data(REQUEST_STUDENT_KEY, 'status', 'curriculum_uuid') + def test_missing_field(self, removed_field): + url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) + enrollments = [ + self.student_enrollment('enrolled') for _ in range(3) + ] + enrollments[2].pop(removed_field) + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + with mock.patch( + 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', + autospec=True, + return_value=None + ): + response = self.client.post(url, json.dumps(enrollments), content_type='application/json') + self.assertEqual(422, response.status_code) + self.assertEqual('invalid enrollment record', response.data) + + @ddt.data(REQUEST_STUDENT_KEY, 'status', 'curriculum_uuid') + def test_none_field(self, none_field): + url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) + enrollments = [ + self.student_enrollment('enrolled') for _ in range(3) + ] + enrollments[2][none_field] = None + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + with mock.patch( + 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', + autospec=True, + return_value=None + ): + response = self.client.post(url, json.dumps(enrollments), content_type='application/json') + + self.assertEqual(422, response.status_code) + self.assertEqual('invalid enrollment record', response.data) + + @ddt.data( + [{'status': 'pending'}], + [{'status': 'not-a-status'}], + [{'status': 'pending'}, {'status': 'pending'}], + ) + def test_no_student_key(self, bad_records): + url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) + enrollments = [self.student_enrollment('enrolled')] + enrollments.extend(bad_records) + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + with mock.patch( + 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', + autospec=True, + return_value=None + ): + response = self.client.post(url, json.dumps(enrollments), content_type='application/json') + + self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) + self.assertEqual('invalid enrollment record', response.data) + + +@ddt.ddt class ProgramEnrollmentViewPatchTests(APITestCase): """ Tests for the ProgramEnrollment view PATCH method. @@ -1206,6 +1277,27 @@ class ProgramEnrollmentViewPatchTests(APITestCase): 'user-who-is-not-in-program': 'not-in-program', }) + @ddt.data( + [{'status': 'pending'}], + [{'status': 'not-a-status'}], + [{'status': 'pending'}, {'status': 'pending'}], + ) + def test_no_student_key(self, bad_records): + program_uuid = uuid4() + url = reverse('programs_api:v1:program_enrollments', args=[program_uuid]) + enrollments = [self.student_enrollment('enrolled')] + ProgramEnrollmentFactory.create( + external_user_key=enrollments[0]['student_key'], + program_uuid=program_uuid + ) + enrollments.extend(bad_records) + + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.client.patch(url, json.dumps(enrollments), content_type='application/json') + + self.assertEqual(422, response.status_code) + self.assertEqual('invalid enrollment record', response.data) + @ddt.ddt class ProgramCourseEnrollmentOverviewViewTests(ProgramCacheTestCaseMixin, SharedModuleStoreTestCase, APITestCase): diff --git a/lms/djangoapps/program_enrollments/api/v1/views.py b/lms/djangoapps/program_enrollments/api/v1/views.py index 515f7c957c..2ca6e3f57a 100644 --- a/lms/djangoapps/program_enrollments/api/v1/views.py +++ b/lms/djangoapps/program_enrollments/api/v1/views.py @@ -353,6 +353,11 @@ class ProgramEnrollmentsView(DeveloperErrorViewMixin, PaginatedAPIView): program_uuid = kwargs['program_uuid'] student_data = self._request_data_by_student_key(request, program_uuid) + if None in student_data: + return Response( + 'invalid enrollment record', + status.HTTP_422_UNPROCESSABLE_ENTITY + ) response_data = {} response_data.update(self._remove_duplicate_entries(request, student_data)) @@ -402,6 +407,11 @@ class ProgramEnrollmentsView(DeveloperErrorViewMixin, PaginatedAPIView): program_uuid = kwargs['program_uuid'] student_data = self._request_data_by_student_key(request, program_uuid) + if None in student_data: + return Response( + 'invalid enrollment record', + status.HTTP_422_UNPROCESSABLE_ENTITY + ) response_data = {} response_data.update(self._remove_duplicate_entries(request, student_data))