diff --git a/lms/djangoapps/program_enrollments/api/v1/constants.py b/lms/djangoapps/program_enrollments/api/v1/constants.py index 3b17d20364..f336cfa924 100644 --- a/lms/djangoapps/program_enrollments/api/v1/constants.py +++ b/lms/djangoapps/program_enrollments/api/v1/constants.py @@ -10,29 +10,48 @@ MAX_ENROLLMENT_RECORDS = 25 REQUEST_STUDENT_KEY = 'student_key' -class CourseEnrollmentResponseStatuses(object): +class BaseEnrollmentResponseStatuses(object): """ - Class to group response statuses returned by the course enrollment endpoint + Class to group common response statuses """ - ACTIVE = "active" - INACTIVE = "inactive" - DUPLICATED = "duplicated" + DUPLICATED = 'duplicated' INVALID_STATUS = "invalid-status" CONFLICT = "conflict" ILLEGAL_OPERATION = "illegal-operation" NOT_IN_PROGRAM = "not-in-program" - NOT_FOUND = "not-found" INTERNAL_ERROR = "internal-error" - ERROR_STATUSES = ( + ERROR_STATUSES = { DUPLICATED, INVALID_STATUS, CONFLICT, ILLEGAL_OPERATION, NOT_IN_PROGRAM, - NOT_FOUND, INTERNAL_ERROR, - ) + } + + +class CourseEnrollmentResponseStatuses(BaseEnrollmentResponseStatuses): + """ + Class to group response statuses returned by the course enrollment endpoint + """ + ACTIVE = "active" + INACTIVE = "inactive" + NOT_FOUND = "not-found" + + ERROR_STATUSES = BaseEnrollmentResponseStatuses.ERROR_STATUSES | {NOT_FOUND} + + +class ProgramEnrollmentResponseStatuses(BaseEnrollmentResponseStatuses): + """ + Class to group response statuses returned by the program enrollment endpoint + """ + ENROLLED = 'enrolled' + PENDING = 'pending' + SUSPENDED = 'suspended' + CANCELED = 'canceled' + + VALID_STATUSES = [ENROLLED, PENDING, SUSPENDED, CANCELED] class CourseRunProgressStatuses(object): diff --git a/lms/djangoapps/program_enrollments/api/v1/serializers.py b/lms/djangoapps/program_enrollments/api/v1/serializers.py index 8fb167ca2d..fb766e79be 100644 --- a/lms/djangoapps/program_enrollments/api/v1/serializers.py +++ b/lms/djangoapps/program_enrollments/api/v1/serializers.py @@ -7,11 +7,31 @@ from rest_framework import serializers from six import text_type from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment -from lms.djangoapps.program_enrollments.api.v1.constants import CourseRunProgressStatuses +from lms.djangoapps.program_enrollments.api.v1.constants import ( + CourseRunProgressStatuses, + ProgramEnrollmentResponseStatuses +) + + +class InvalidStatusMixin(object): + """ + Mixin to provide has_invalid_status method + """ + def has_invalid_status(self): + """ + Returns whether or not this serializer has an invalid error choice on the "status" field + """ + try: + for status_error in self.errors['status']: + if status_error.code == 'invalid_choice': + return True + except KeyError: + pass + return False # pylint: disable=abstract-method -class ProgramEnrollmentSerializer(serializers.ModelSerializer): +class ProgramEnrollmentSerializer(serializers.ModelSerializer, InvalidStatusMixin): """ Serializer for Program Enrollments """ @@ -37,6 +57,31 @@ class ProgramEnrollmentSerializer(serializers.ModelSerializer): return ProgramEnrollment.objects.create(**validated_data) +class BaseProgramEnrollmentRequestMixin(serializers.Serializer, InvalidStatusMixin): + """ + Base fields for all program enrollment related serializers + """ + student_key = serializers.CharField() + status = serializers.ChoiceField( + allow_blank=False, + choices=ProgramEnrollmentResponseStatuses.VALID_STATUSES + ) + + +class ProgramEnrollmentCreateRequestSerializer(BaseProgramEnrollmentRequestMixin): + """ + Serializer for program enrollment creation requests + """ + curriculum_uuid = serializers.UUIDField() + + +class ProgramEnrollmentModifyRequestSerializer(BaseProgramEnrollmentRequestMixin): + """ + Serializer for program enrollment modification requests + """ + pass + + class ProgramEnrollmentListSerializer(serializers.Serializer): """ Serializer for listing enrollments in a program. @@ -53,23 +98,6 @@ class ProgramEnrollmentListSerializer(serializers.Serializer): return bool(obj.user) -class InvalidStatusMixin(object): - """ - Mixin to provide has_invalid_status method - """ - def has_invalid_status(self): - """ - Returns whether or not this serializer has an invalid error choice on the "status" field - """ - try: - for status_error in self.errors['status']: - if status_error.code == 'invalid_choice': - return True - except KeyError: - pass - return False - - # pylint: disable=abstract-method class ProgramCourseEnrollmentRequestSerializer(serializers.Serializer, InvalidStatusMixin): """ diff --git a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py index 76e6c666e9..475e9cc8aa 100644 --- a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py +++ b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py @@ -28,6 +28,7 @@ from lms.djangoapps.program_enrollments.api.v1.constants import ( CourseEnrollmentResponseStatuses as CourseStatuses, CourseRunProgressStatuses, MAX_ENROLLMENT_RECORDS, + ProgramEnrollmentResponseStatuses as ProgramStatuses, REQUEST_STUDENT_KEY, ) from lms.djangoapps.program_enrollments.tests.factories import ProgramCourseEnrollmentFactory, ProgramEnrollmentFactory @@ -801,12 +802,112 @@ class ProgramCourseEnrollmentListTest(ListViewTestMixin, APITestCase): @ddt.ddt -class ProgramEnrollmentViewPostTests(APITestCase): +class BaseProgramEnrollmentWriteTestsMixin(object): + """ Mixin class that defines common tests for program enrollment write endpoints """ + add_uuid = False + + def student_enrollment(self, enrollment_status, external_user_key=None, prepare_student=False): + """ Convenience method to create a student enrollment record """ + enrollment = { + REQUEST_STUDENT_KEY: external_user_key or str(uuid4().hex[0:10]), + 'status': enrollment_status, + } + if self.add_uuid: + enrollment['curriculum_uuid'] = str(uuid4()) + if prepare_student: + self.prepare_student(enrollment) + return enrollment + + def prepare_student(self, enrollment): + pass + + def get_url(self, program_uuid=None): + if program_uuid is None: + program_uuid = uuid4() + return reverse('programs_api:v1:program_enrollments', args=[program_uuid]) + + def test_unauthenticated(self): + self.client.logout() + request_data = [self.student_enrollment('enrolled')] + response = self.request(self.get_url(), json.dumps(request_data), content_type='application/json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_enrollment_payload_limit(self): + request_data = [self.student_enrollment('enrolled') for _ in range(MAX_ENROLLMENT_RECORDS + 1)] + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.request(self.get_url(), json.dumps(request_data), content_type='application/json') + self.assertEqual(response.status_code, status.HTTP_413_REQUEST_ENTITY_TOO_LARGE) + + def test_duplicate_enrollment(self): + request_data = [ + self.student_enrollment('enrolled', '001'), + self.student_enrollment('enrolled', '001'), + ] + + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.request(self.get_url(), json.dumps(request_data), content_type='application/json') + + self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) + self.assertEqual(response.data, {'001': 'duplicated'}) + + def test_unprocessable_enrollment(self): + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.request( + self.get_url(), + json.dumps([{'status': 'enrolled'}]), + content_type='application/json' + ) + self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) + self.assertEqual(response.data, 'invalid enrollment record') + + def test_program_unauthorized(self): + student = UserFactory.create(password='password') + self.client.login(username=student.username, password='password') + + request_data = [self.student_enrollment('enrolled')] + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.request(self.get_url(), json.dumps(request_data), content_type='application/json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_program_not_found(self): + post_data = [self.student_enrollment('enrolled')] + nonexistant_uuid = uuid4() + response = self.request( + self.get_url(program_uuid=nonexistant_uuid), + json.dumps(post_data), + content_type='application/json' + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + @ddt.data( + [{'status': 'pending'}], + [{'status': 'not-a-status'}], + [{'status': 'pending'}, {'status': 'pending'}], + ) + def test_no_student_key(self, bad_records): + program_uuid = uuid4() + url = reverse('programs_api:v1:program_enrollments', args=[program_uuid]) + enrollments = [self.student_enrollment('enrolled', '001', True)] + enrollments.extend(bad_records) + + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.request(url, json.dumps(enrollments), content_type='application/json') + + self.assertEqual(422, response.status_code) + self.assertEqual('invalid enrollment record', response.data) + + +@ddt.ddt +class ProgramEnrollmentViewPostTests(BaseProgramEnrollmentWriteTestsMixin, APITestCase): """ Tests for the ProgramEnrollment view POST method. """ + add_uuid = True + success_status = status.HTTP_201_CREATED + def setUp(self): super(ProgramEnrollmentViewPostTests, self).setUp() + self.request = self.client.post global_staff = GlobalStaffFactory.create(username='global-staff', password='password') self.client.login(username=global_staff.username, password='password') @@ -814,13 +915,6 @@ class ProgramEnrollmentViewPostTests(APITestCase): super(ProgramEnrollmentViewPostTests, self).tearDown() ProgramEnrollment.objects.all().delete() - def student_enrollment(self, enrollment_status, external_user_key=None): - return { - REQUEST_STUDENT_KEY: external_user_key or str(uuid4().hex[0:10]), - 'status': enrollment_status, - 'curriculum_uuid': str(uuid4()) - } - def test_successful_program_enrollments_no_existing_user(self): program_key = uuid4() statuses = ['pending', 'enrolled', 'pending'] @@ -923,189 +1017,18 @@ class ProgramEnrollmentViewPostTests(APITestCase): self.assertEqual(enrollment.curriculum_uuid, curriculum_uuid) self.assertIsNone(enrollment.user) - def test_enrollment_payload_limit(self): - - post_data = [] - for _ in range(MAX_ENROLLMENT_RECORDS + 1): - post_data += self.student_enrollment('enrolled') - - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(post_data), content_type='application/json') - self.assertEqual(response.status_code, status.HTTP_413_REQUEST_ENTITY_TOO_LARGE) - - def test_duplicate_enrollment(self): - post_data = [ - self.student_enrollment('enrolled', '001'), - self.student_enrollment('enrolled', '002'), - self.student_enrollment('enrolled', '001'), - ] - - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(post_data), content_type='application/json') - - self.assertEqual(response.status_code, status.HTTP_207_MULTI_STATUS) - self.assertEqual(response.data, { - '001': 'duplicated', - '002': 'enrolled', - }) - - def test_unprocessable_enrollment(self): - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post( - url, - json.dumps([{'status': 'enrolled'}]), - content_type='application/json' - ) - - self.assertEqual(response.status_code, 400) - self.assertEqual(response.data, 'invalid enrollment record') - - def test_unauthenticated(self): - self.client.logout() - post_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.post( - url, - json.dumps(post_data), - content_type='application/json' - ) - - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_program_unauthorized(self): - student = UserFactory.create(username='student', password='password') - self.client.login(username=student.username, password='password') - - post_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.post( - url, - json.dumps(post_data), - content_type='application/json' - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_program_not_found(self): - post_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.post( - url, - json.dumps(post_data), - content_type='application/json' - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_partially_valid_enrollment(self): - - post_data = [ - self.student_enrollment('new', '001'), - self.student_enrollment('pending', '003'), - ] - - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(post_data), content_type='application/json') - - self.assertEqual(response.status_code, status.HTTP_207_MULTI_STATUS) - self.assertEqual(response.data, { - '001': 'invalid-status', - '003': 'pending', - }) - - @ddt.data(REQUEST_STUDENT_KEY, 'status', 'curriculum_uuid') - def test_missing_field(self, removed_field): - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - enrollments = [ - self.student_enrollment('enrolled') for _ in range(3) - ] - enrollments[2].pop(removed_field) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(enrollments), content_type='application/json') - - self.assertEqual(400, response.status_code) - self.assertEqual('invalid enrollment record', response.data) - - @ddt.data(REQUEST_STUDENT_KEY, 'status', 'curriculum_uuid') - def test_none_field(self, none_field): - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - enrollments = [ - self.student_enrollment('enrolled') for _ in range(3) - ] - enrollments[2][none_field] = None - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(enrollments), content_type='application/json') - - self.assertEqual(400, response.status_code) - self.assertEqual('invalid enrollment record', response.data) - - @ddt.data( - [{'status': 'pending'}], - [{'status': 'not-a-status'}], - [{'status': 'pending'}, {'status': 'pending'}], - ) - def test_no_student_key(self, bad_records): - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - enrollments = [self.student_enrollment('enrolled')] - enrollments.extend(bad_records) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - with mock.patch( - 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', - autospec=True, - return_value=None - ): - response = self.client.post(url, json.dumps(enrollments), content_type='application/json') - - self.assertEqual(response.status_code, 400) - self.assertEqual('invalid enrollment record', response.data) - @ddt.ddt -class ProgramEnrollmentViewPatchTests(APITestCase): +class ProgramEnrollmentViewPatchTests(BaseProgramEnrollmentWriteTestsMixin, APITestCase): """ Tests for the ProgramEnrollment view PATCH method. """ + add_uuid = False + success_status = status.HTTP_200_OK + def setUp(self): super(ProgramEnrollmentViewPatchTests, self).setUp() + self.request = self.client.patch self.program_uuid = '00000000-1111-2222-3333-444444444444' self.curriculum_uuid = 'aaaaaaaa-1111-2222-3333-444444444444' @@ -1120,11 +1043,14 @@ class ProgramEnrollmentViewPatchTests(APITestCase): self.client.login(username=self.global_staff.username, password=self.password) - def student_enrollment(self, enrollment_status, external_user_key=None): - return { - 'status': enrollment_status, - REQUEST_STUDENT_KEY: external_user_key or str(uuid4().hex[0:10]), - } + def prepare_student(self, enrollment): + ProgramEnrollment.objects.create( + program_uuid=self.program_uuid, + curriculum_uuid=self.curriculum_uuid, + user=None, + status='pending', + external_user_key=enrollment[REQUEST_STUDENT_KEY], + ) def test_successfully_patched_program_enrollment(self): enrollments = {} @@ -1169,71 +1095,7 @@ class ProgramEnrollmentViewPatchTests(APITestCase): assert status.HTTP_200_OK == response.status_code assert expected_response == response.data - def test_enrollment_payload_limit(self): - patch_data = [] - for _ in range(MAX_ENROLLMENT_RECORDS + 1): - patch_data += self.student_enrollment('enrolled') - - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - response = self.client.patch(url, json.dumps(patch_data), content_type='application/json') - - self.assertEqual(response.status_code, status.HTTP_413_REQUEST_ENTITY_TOO_LARGE) - - def test_unauthenticated(self): - self.client.logout() - patch_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.patch( - url, - json.dumps(patch_data), - content_type='application/json' - ) - - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_program_unauthorized(self): - self.client.login(username=self.student.username, password=self.password) - - patch_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.patch( - url, - json.dumps(patch_data), - content_type='application/json' - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_program_not_found(self): - patch_data = [ - self.student_enrollment('enrolled') - ] - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - response = self.client.patch( - url, - json.dumps(patch_data), - content_type='application/json' - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_unprocessable_enrollment(self): - url = reverse('programs_api:v1:program_enrollments', args=[uuid4()]) - - with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - response = self.client.patch( - url, - json.dumps([{'status': 'enrolled'}]), - content_type='application/json' - ) - - self.assertEqual(response.status_code, 400) - self.assertEqual(response.data, 'invalid enrollment record') - - def test_duplicate_enrollment(self): + def test_duplicate_enrollment_record_changed(self): enrollments = {} for i in range(4): user_key = 'user-{}'.format(i) @@ -1274,7 +1136,7 @@ class ProgramEnrollmentViewPatchTests(APITestCase): 'user-2': 'enrolled', }) - def test_partially_valid_enrollment(self): + def test_partially_valid_enrollment_record_changed(self): enrollments = {} for i in range(4): user_key = 'user-{}'.format(i) @@ -1316,26 +1178,89 @@ class ProgramEnrollmentViewPatchTests(APITestCase): 'user-who-is-not-in-program': 'not-in-program', }) - @ddt.data( - [{'status': 'pending'}], - [{'status': 'not-a-status'}], - [{'status': 'pending'}, {'status': 'pending'}], - ) - def test_no_student_key(self, bad_records): - program_uuid = uuid4() - url = reverse('programs_api:v1:program_enrollments', args=[program_uuid]) - enrollments = [self.student_enrollment('enrolled')] - ProgramEnrollmentFactory.create( - external_user_key=enrollments[0]['student_key'], - program_uuid=program_uuid + +@ddt.ddt +class ProgramEnrollmentViewPutTests(BaseProgramEnrollmentWriteTestsMixin, APITestCase): + """ + Tests for the ProgramEnrollment view PATCH method. + """ + add_uuid = True + success_status = status.HTTP_200_OK + + def setUp(self): + super(ProgramEnrollmentViewPutTests, self).setUp() + self.request = self.client.put + + self.program_uuid = '00000000-1111-2222-3333-444444444444' + self.curriculum_uuid = 'aaaaaaaa-1111-2222-3333-444444444444' + + self.global_staff = GlobalStaffFactory.create(username='global-staff', password='password') + self.client.login(username=self.global_staff.username, password='password') + + patch_get_user = mock.patch( + 'lms.djangoapps.program_enrollments.api.v1.views.get_user_by_program_id', + autospec=True, + return_value=None ) - enrollments.extend(bad_records) + self.mock_get_user = patch_get_user.start() + self.addCleanup(patch_get_user.stop) + def prepare_student(self, enrollment): + ProgramEnrollment.objects.create( + program_uuid=self.program_uuid, + curriculum_uuid=self.curriculum_uuid, + user=None, + status='pending', + external_user_key=enrollment[REQUEST_STUDENT_KEY], + ) + + @ddt.data(True, False) + def test_all_create_or_modify(self, create_users): + request_data = [ + self.student_enrollment(ProgramStatuses.ENROLLED) + for _ in range(5) + ] + if create_users: + for enrollment in request_data: + ProgramEnrollmentFactory( + program_uuid=self.program_uuid, + status=ProgramStatuses.PENDING, + external_user_key=enrollment[REQUEST_STUDENT_KEY], + ) + + url = self.get_url(program_uuid=self.program_uuid) with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): - response = self.client.patch(url, json.dumps(enrollments), content_type='application/json') + response = self.client.put(url, json.dumps(request_data), content_type='application/json') + self.assertEqual(self.success_status, response.status_code) + self.assertEqual(5, len(response.data)) + for response_status in response.data.values(): + self.assertEqual(response_status, ProgramStatuses.ENROLLED) - self.assertEqual(400, response.status_code) - self.assertEqual('invalid enrollment record', response.data) + def test_half_create_modify(self): + request_data = [ + self.student_enrollment(ProgramStatuses.ENROLLED, 'learner-01'), + self.student_enrollment(ProgramStatuses.ENROLLED, 'learner-02'), + self.student_enrollment(ProgramStatuses.ENROLLED, 'learner-03'), + self.student_enrollment(ProgramStatuses.ENROLLED, 'learner-04'), + ] + ProgramEnrollmentFactory( + program_uuid=self.program_uuid, + status=ProgramStatuses.PENDING, + external_user_key='learner-03', + ) + ProgramEnrollmentFactory( + program_uuid=self.program_uuid, + status=ProgramStatuses.PENDING, + external_user_key='learner-04', + ) + + url = self.get_url(program_uuid=self.program_uuid) + with mock.patch('lms.djangoapps.program_enrollments.api.v1.views.get_programs', autospec=True): + response = self.client.put(url, json.dumps(request_data), content_type='application/json') + self.assertEqual(self.success_status, response.status_code) + self.assertEqual(4, len(response.data)) + for response_status in response.data.values(): + self.assertEqual(response_status, ProgramStatuses.ENROLLED) @ddt.ddt diff --git a/lms/djangoapps/program_enrollments/api/v1/views.py b/lms/djangoapps/program_enrollments/api/v1/views.py index 31fa4a3528..df21fdd2d8 100644 --- a/lms/djangoapps/program_enrollments/api/v1/views.py +++ b/lms/djangoapps/program_enrollments/api/v1/views.py @@ -5,7 +5,6 @@ ProgramEnrollment Views from __future__ import absolute_import, unicode_literals import logging -from collections import Counter, OrderedDict from datetime import datetime, timedelta from functools import wraps from pytz import UTC @@ -33,14 +32,15 @@ from lms.djangoapps.program_enrollments.api.v1.constants import ( CourseEnrollmentResponseStatuses, CourseRunProgressStatuses, MAX_ENROLLMENT_RECORDS, - REQUEST_STUDENT_KEY, + ProgramEnrollmentResponseStatuses, ) from lms.djangoapps.program_enrollments.api.v1.serializers import ( CourseRunOverviewListSerializer, ProgramCourseEnrollmentListSerializer, ProgramCourseEnrollmentRequestSerializer, + ProgramEnrollmentCreateRequestSerializer, ProgramEnrollmentListSerializer, - ProgramEnrollmentSerializer, + ProgramEnrollmentModifyRequestSerializer, ) from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment from lms.djangoapps.program_enrollments.utils import get_user_by_program_id, ProviderDoesNotExistException @@ -345,150 +345,162 @@ class ProgramEnrollmentsView(DeveloperErrorViewMixin, PaginatedAPIView): """ Create program enrollments for a list of learners """ - if len(request.data) > MAX_ENROLLMENT_RECORDS: - return Response( - status=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content_type='application/json', - ) - - program_uuid = kwargs['program_uuid'] - student_data = self._request_data_by_student_key(request, program_uuid) - if None in student_data: - return Response( - 'invalid enrollment record', - status.HTTP_400_BAD_REQUEST - ) - - response_data = {} - response_data.update(self._remove_duplicate_entries(request, student_data)) - response_data.update(self._remove_existing_entries(program_uuid, student_data)) - - enrollments_to_create = {} - - for student_key, data in student_data.items(): - curriculum_uuid = data['curriculum_uuid'] - - try: - existing_user = get_user_by_program_id(student_key, program_uuid) - if existing_user: - data['user'] = existing_user.id - except ProviderDoesNotExistException: - pass # IDP has not yet been set up, just create waiting enrollments - - serializer = ProgramEnrollmentSerializer(data=data) - if serializer.is_valid(): - enrollments_to_create[(student_key, curriculum_uuid)] = serializer - response_data[student_key] = data.get('status') - else: - if 'status' in serializer.errors and serializer.errors['status'][0].code == 'invalid_choice': - response_data[student_key] = CourseEnrollmentResponseStatuses.INVALID_STATUS - else: - return Response( - 'invalid enrollment record', - status.HTTP_400_BAD_REQUEST - ) - - # TODO: make this a bulk save - https://openedx.atlassian.net/browse/EDUCATOR-4305 - for (student_key, _), enrollment_serializer in enrollments_to_create.items(): - enrollment_serializer.save() - - return self._get_created_or_updated_response(request, enrollments_to_create, response_data) + return self.create_or_modify_enrollments( + request, + kwargs['program_uuid'], + ProgramEnrollmentCreateRequestSerializer, + self.create_program_enrollment, + status.HTTP_201_CREATED, + ) @verify_program_exists def patch(self, request, **kwargs): """ - Modify the program enrollments for a list of learners + Modify program enrollments for a list of learners """ - if len(request.data) > MAX_ENROLLMENT_RECORDS: - return Response( - status=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content_type='application/json', - ) - - program_uuid = kwargs['program_uuid'] - student_data = self._request_data_by_student_key(request, program_uuid) - if None in student_data: - return Response( - 'invalid enrollment record', - status.HTTP_400_BAD_REQUEST - ) - - response_data = {} - response_data.update(self._remove_duplicate_entries(request, student_data)) - - existing_enrollments = { - enrollment.external_user_key: enrollment - for enrollment in - ProgramEnrollment.bulk_read_by_student_key(program_uuid, student_data) - } - - enrollments_to_create = {} - - for external_user_key in student_data.keys(): - if external_user_key not in existing_enrollments: - student_data.pop(external_user_key) - response_data[external_user_key] = CourseEnrollmentResponseStatuses.NOT_IN_PROGRAM - - for external_user_key, enrollment in existing_enrollments.items(): - student = {key: value for key, value in student_data[external_user_key].items() if key == 'status'} - enrollment_serializer = ProgramEnrollmentSerializer(enrollment, data=student, partial=True) - if enrollment_serializer.is_valid(): - enrollments_to_create[(external_user_key, enrollment.curriculum_uuid)] = enrollment_serializer - enrollment_serializer.save() - response_data[external_user_key] = student['status'] - else: - serializer_is_invalid = enrollment_serializer.errors['status'][0].code == 'invalid_choice' - if 'status' in enrollment_serializer.errors and serializer_is_invalid: - response_data[external_user_key] = CourseEnrollmentResponseStatuses.INVALID_STATUS - - return self._get_created_or_updated_response(request, enrollments_to_create, response_data, status.HTTP_200_OK) - - def _remove_duplicate_entries(self, request, student_data): - """ Helper method to remove duplicate entries (based on student key) from request data. """ - result = {} - key_counter = Counter([enrollment.get(REQUEST_STUDENT_KEY) for enrollment in request.data]) - for student_key, count in key_counter.items(): - if count > 1: - result[student_key] = CourseEnrollmentResponseStatuses.DUPLICATED - student_data.pop(student_key) - return result - - def _request_data_by_student_key(self, request, program_uuid): - """ - Helper method that returns an OrderedDict of rows from request.data, - keyed by the `external_user_key`. - """ - return OrderedDict(( - row.get(REQUEST_STUDENT_KEY), - { - 'program_uuid': program_uuid, - 'curriculum_uuid': row.get('curriculum_uuid'), - 'status': row.get('status'), - 'external_user_key': row.get(REQUEST_STUDENT_KEY), - }) - for row in request.data + return self.create_or_modify_enrollments( + request, + kwargs['program_uuid'], + ProgramEnrollmentModifyRequestSerializer, + self.modify_program_enrollment, + status.HTTP_200_OK, ) - def _remove_existing_entries(self, program_uuid, student_data): - """ Helper method to remove entries that have existing ProgramEnrollment records. """ - result = {} - existing_enrollments = ProgramEnrollment.bulk_read_by_student_key(program_uuid, student_data) - for enrollment in existing_enrollments: - result[enrollment.external_user_key] = CourseEnrollmentResponseStatuses.CONFLICT - student_data.pop(enrollment.external_user_key) - return result + @verify_program_exists + def put(self, request, **kwargs): + """ + Create/modify program enrollments for a list of learners + """ + return self.create_or_modify_enrollments( + request, + kwargs['program_uuid'], + ProgramEnrollmentCreateRequestSerializer, + self.create_or_modify_program_enrollment, + status.HTTP_200_OK, + ) - def _get_created_or_updated_response( - self, request, created_or_updated_data, response_data, default_status=status.HTTP_201_CREATED - ): + def validate_enrollment_request(self, enrollment, seen_student_keys, serializer_class): + """ + Validates the given enrollment record and checks that it isn't a duplicate + """ + student_key = enrollment['student_key'] + if student_key in seen_student_keys: + return CourseEnrollmentResponseStatuses.DUPLICATED + seen_student_keys.add(student_key) + enrollment_serializer = serializer_class(data=enrollment) + try: + enrollment_serializer.is_valid(raise_exception=True) + except ValidationError as e: + if enrollment_serializer.has_invalid_status(): + return CourseEnrollmentResponseStatuses.INVALID_STATUS + else: + raise e + + def create_or_modify_enrollments(self, request, program_uuid, serializer_class, operation, success_status): + """ + Process a list of program course enrollment request objects + and create or modify enrollments based on method + """ + results = {} + seen_student_keys = set() + enrollments = [] + + if not isinstance(request.data, list): + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + if len(request.data) > MAX_ENROLLMENT_RECORDS: + return Response( + 'enrollment limit {}'.format(MAX_ENROLLMENT_RECORDS), + status.HTTP_413_REQUEST_ENTITY_TOO_LARGE + ) + + try: + for enrollment_request in request.data: + error_status = self.validate_enrollment_request(enrollment_request, seen_student_keys, serializer_class) + if error_status: + results[enrollment_request["student_key"]] = error_status + else: + enrollments.append(enrollment_request) + except KeyError: # student_key is not in enrollment_request + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + except TypeError: # enrollment_request isn't a dict + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + except ValidationError: # there was some other error raised by the serializer + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + + program_enrollments = self.get_existing_program_enrollments(program_uuid, enrollments) + for enrollment in enrollments: + student_key = enrollment["student_key"] + if student_key in results and results[student_key] == ProgramEnrollmentResponseStatuses.DUPLICATED: + continue + try: + program_enrollment = program_enrollments[student_key] + except KeyError: + program_enrollment = None + results[student_key] = operation(enrollment, program_uuid, program_enrollment) + + return self._get_created_or_updated_response(results, success_status) + + def create_program_enrollment(self, request_data, program_uuid, program_enrollment): + """ + Create new ProgramEnrollment, unless the learner is already enrolled in the program + """ + if program_enrollment: + return ProgramEnrollmentResponseStatuses.CONFLICT + + student_key = request_data.get('student_key') + try: + user = get_user_by_program_id(student_key, program_uuid) + except ProviderDoesNotExistException: + # IDP has not yet been set up, just create waiting enrollments + user = None + + enrollment = ProgramEnrollment.objects.create( + user=user, + external_user_key=student_key, + program_uuid=program_uuid, + curriculum_uuid=request_data.get('curriculum_uuid'), + status=request_data.get('status') + ) + return enrollment.status + + # pylint: disable=unused-argument + def modify_program_enrollment(self, request_data, program_uuid, program_enrollment): + """ + Change the status of an existing program enrollment + """ + if not program_enrollment: + return ProgramEnrollmentResponseStatuses.NOT_IN_PROGRAM + + program_enrollment.status = request_data.get('status') + program_enrollment.save() + return program_enrollment.status + + def create_or_modify_program_enrollment(self, request_data, program_uuid, program_enrollment): + if program_enrollment: + return self.modify_program_enrollment(request_data, program_uuid, program_enrollment) + else: + return self.create_program_enrollment(request_data, program_uuid, program_enrollment) + + def get_existing_program_enrollments(self, program_uuid, student_data): + """ Returns the existing program enrollments for the given students and program """ + student_keys = [data['student_key'] for data in student_data] + return { + e.external_user_key: e + for e in ProgramEnrollment.bulk_read_by_student_key(program_uuid, student_keys) + } + + def _get_created_or_updated_response(self, response_data, default_status=status.HTTP_201_CREATED): """ Helper method to determine an appropirate HTTP response status code. """ response_status = default_status - - if not created_or_updated_data: + good_count = len([ + v for v in response_data.values() + if v not in CourseEnrollmentResponseStatuses.ERROR_STATUSES + ]) + if not good_count: response_status = status.HTTP_422_UNPROCESSABLE_ENTITY - elif len(request.data) != len(created_or_updated_data): + elif good_count != len(response_data): response_status = status.HTTP_207_MULTI_STATUS return Response( diff --git a/lms/djangoapps/program_enrollments/models.py b/lms/djangoapps/program_enrollments/models.py index ee437bc81f..e7ed710ade 100644 --- a/lms/djangoapps/program_enrollments/models.py +++ b/lms/djangoapps/program_enrollments/models.py @@ -67,16 +67,15 @@ class ProgramEnrollment(TimeStampedModel): # pylint: disable=model-missing-unic raise ValidationError(_('One of user or external_user_key must not be null.')) @classmethod - def bulk_read_by_student_key(cls, program_uuid, student_data): + def bulk_read_by_student_key(cls, program_uuid, student_keys): """ args: program_uuid - The UUID of the program to read enrollment data of. - student_data - A dictionary keyed by external_user_key and - valued by a dict containing the curriculum_uuid for the user in the given program. + student_keys - list of student keys """ return cls.objects.filter( program_uuid=program_uuid, - external_user_key__in=list(student_data.keys()), + external_user_key__in=student_keys, ) @classmethod