From 6052a4e4c63933a4e72fe2b2a65eeb01e3675f47 Mon Sep 17 00:00:00 2001 From: jansenk Date: Thu, 2 May 2019 14:51:33 -0400 Subject: [PATCH] add program course enrollment post endpoint --- .../program_enrollments/api/urls.py | 2 +- .../program_enrollments/api/v1/constants.py | 30 +++ .../program_enrollments/api/v1/serializers.py | 28 ++ .../api/v1/tests/factories.py | 16 +- .../api/v1/tests/test_views.py | 241 +++++++++++++++++- .../program_enrollments/api/v1/urls.py | 12 +- .../program_enrollments/api/v1/views.py | 205 ++++++++++++++- ...add_programcourseenrollment_relatedname.py | 21 ++ lms/djangoapps/program_enrollments/models.py | 43 +++- 9 files changed, 584 insertions(+), 14 deletions(-) create mode 100644 lms/djangoapps/program_enrollments/api/v1/constants.py create mode 100644 lms/djangoapps/program_enrollments/migrations/0004_add_programcourseenrollment_relatedname.py diff --git a/lms/djangoapps/program_enrollments/api/urls.py b/lms/djangoapps/program_enrollments/api/urls.py index 5c7c7075a9..85201e4762 100644 --- a/lms/djangoapps/program_enrollments/api/urls.py +++ b/lms/djangoapps/program_enrollments/api/urls.py @@ -1,5 +1,5 @@ """ -Grades API URLs. +Program Enrollment API URLs. """ from django.conf.urls import include, url diff --git a/lms/djangoapps/program_enrollments/api/v1/constants.py b/lms/djangoapps/program_enrollments/api/v1/constants.py new file mode 100644 index 0000000000..39cf3327dc --- /dev/null +++ b/lms/djangoapps/program_enrollments/api/v1/constants.py @@ -0,0 +1,30 @@ +""" + Constants and strings for the course-enrollment app +""" + +# Captures strings composed of alphanumeric characters a-f and dashes. +PROGRAM_UUID_PATTERN = r'(?P[A-Fa-f0-9-]+)' +MAX_ENROLLMENT_RECORDS = 25 + + +class CourseEnrollmentResponseStatuses(object): + """ + Class to group response statuses returned by the course enrollment endpoint + """ + ACTIVE = "active" + INACTIVE = "inactive" + DUPLICATED = "duplicated" + INVALID_STATUS = "invalid-status" + CONFLICT = "conflict" + ILLEGAL_OPERATION = "illegal-operation" + NOT_IN_PROGRAM = "not-in-program" + INTERNAL_ERROR = "internal-error" + + ERROR_STATUSES = ( + DUPLICATED, + INVALID_STATUS, + CONFLICT, + ILLEGAL_OPERATION, + NOT_IN_PROGRAM, + INTERNAL_ERROR, + ) diff --git a/lms/djangoapps/program_enrollments/api/v1/serializers.py b/lms/djangoapps/program_enrollments/api/v1/serializers.py index 75f2bbbe76..331e556ddd 100644 --- a/lms/djangoapps/program_enrollments/api/v1/serializers.py +++ b/lms/djangoapps/program_enrollments/api/v1/serializers.py @@ -40,3 +40,31 @@ class ProgramEnrollmentListSerializer(serializers.Serializer): def get_account_exists(self, obj): return bool(obj.user) + + +class InvalidStatusMixin(object): + """ + Mixin to provide has_invalid_status method + """ + def has_invalid_status(self): + """ + Returns whether or not this serializer has an invalid error choice on the "status" field + """ + try: + for status_error in self.errors['status']: + if status_error.code == 'invalid_choice': + return True + except KeyError: + pass + return False + + +# pylint: disable=abstract-method +class ProgramCourseEnrollmentRequestSerializer(serializers.Serializer, InvalidStatusMixin): + """ + Serializer for request to create a ProgramCourseEnrollment + """ + STATUS_CHOICES = ['active', 'inactive'] + + student_key = serializers.CharField(allow_blank=False) + status = serializers.ChoiceField(allow_blank=False, choices=STATUS_CHOICES) diff --git a/lms/djangoapps/program_enrollments/api/v1/tests/factories.py b/lms/djangoapps/program_enrollments/api/v1/tests/factories.py index df507355a5..0e77a17331 100644 --- a/lms/djangoapps/program_enrollments/api/v1/tests/factories.py +++ b/lms/djangoapps/program_enrollments/api/v1/tests/factories.py @@ -5,9 +5,10 @@ from uuid import uuid4 import factory from factory.django import DjangoModelFactory +from opaque_keys.edx.keys import CourseKey from lms.djangoapps.program_enrollments import models -from student.tests.factories import UserFactory +from student.tests.factories import UserFactory, CourseEnrollmentFactory class ProgramEnrollmentFactory(DjangoModelFactory): @@ -20,3 +21,16 @@ class ProgramEnrollmentFactory(DjangoModelFactory): program_uuid = uuid4() curriculum_uuid = uuid4() status = 'enrolled' + + +class ProgramCourseEnrollmentFactory(factory.DjangoModelFactory): + """ + Factory for ProgramCourseEnrollment models + """ + class Meta(object): + model = models.ProgramCourseEnrollment + + program_enrollment = factory.SubFactory(ProgramEnrollmentFactory) + course_enrollment = factory.SubFactory(CourseEnrollmentFactory) + course_key = CourseKey.from_string("course-v1:edX+DemoX+Demo_Course") + status = "active" diff --git a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py index e38e085058..bd993aa5c0 100644 --- a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py +++ b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py @@ -2,19 +2,30 @@ Unit tests for ProgramEnrollment views. """ from __future__ import unicode_literals - +from uuid import uuid4 import mock +import ddt +from django.core.cache import cache from django.urls import reverse +from opaque_keys.edx.keys import CourseKey from rest_framework import status from rest_framework.test import APITestCase from six import text_type from lms.djangoapps.courseware.tests.factories import GlobalStaffFactory -from lms.djangoapps.program_enrollments.models import ProgramEnrollment +from lms.djangoapps.program_enrollments.api.v1.constants import CourseEnrollmentResponseStatuses as CourseStatuses +from lms.djangoapps.program_enrollments.models import ProgramEnrollment, ProgramCourseEnrollment from student.tests.factories import UserFactory - -from .factories import ProgramEnrollmentFactory +from openedx.core.djangoapps.catalog.tests.factories import ( + CourseFactory, + OrganizationFactory as CatalogOrganizationFactory, + ProgramFactory, +) +from openedx.core.djangolib.testing.utils import CacheIsolationMixin +from openedx.core.djangoapps.content.course_overviews.tests.factories import CourseOverviewFactory +from openedx.core.djangoapps.catalog.cache import PROGRAM_CACHE_KEY_TPL +from .factories import ProgramEnrollmentFactory, ProgramCourseEnrollmentFactory class ProgramEnrollmentListTest(APITestCase): @@ -171,3 +182,225 @@ class ProgramEnrollmentListTest(APITestCase): assert next_response.data['previous'] is not None assert self.get_url(self.program_uuid) in next_response.data['previous'] assert '?cursor=' in next_response.data['previous'] + + +class ProgramCacheTestCaseMixin(CacheIsolationMixin): + """ + Mixin for using program cache in tests + """ + ENABLED_CACHES = ['default'] + + def setup_catalog_cache(self, program_uuid, organization_key): + """ + helper function to initialize a cached program with an single authoring_organization + """ + catalog_org = CatalogOrganizationFactory.create(key=organization_key) + program = ProgramFactory.create( + uuid=program_uuid, + authoring_organizations=[catalog_org] + ) + cache.set(PROGRAM_CACHE_KEY_TPL.format(uuid=program_uuid), program, None) + return program + + +@ddt.ddt +class CourseEnrollmentPostTests(APITestCase, ProgramCacheTestCaseMixin): + """ Tests for mock course enrollment """ + + @classmethod + def setUpClass(cls): + super(CourseEnrollmentPostTests, cls).setUpClass() + cls.start_cache_isolation() + cls.password = 'password' + cls.student = UserFactory.create(username='student', password=cls.password) + cls.global_staff = GlobalStaffFactory.create(username='global-staff', password=cls.password) + + @classmethod + def tearDownClass(cls): + cls.end_cache_isolation() + super(CourseEnrollmentPostTests, cls).tearDownClass() + + def setUp(self): + super(CourseEnrollmentPostTests, self).setUp() + self.clear_caches() + self.addCleanup(self.clear_caches) + self.program_uuid = uuid4() + self.organization_key = "orgkey" + self.program = self.setup_catalog_cache(self.program_uuid, self.organization_key) + self.course = self.program["courses"][0] + self.course_run = self.course["course_runs"][0] + self.course_key = CourseKey.from_string(self.course_run["key"]) + CourseOverviewFactory(id=self.course_key) + self.course_not_in_program = CourseFactory() + self.course_not_in_program_key = CourseKey.from_string( + self.course_not_in_program["course_runs"][0]["key"] + ) + CourseOverviewFactory(id=self.course_not_in_program_key) + self.default_url = self.get_url(self.program_uuid, self.course_key) + self.client.login(username=self.global_staff, password=self.password) + + def learner_enrollment(self, student_key, enrollment_status="active"): + """ + Convenience method to create a learner enrollment record + """ + return {"student_key": student_key, "status": enrollment_status} + + def get_url(self, program_uuid, course_id): + """ + Convenience method to build a path for a program course enrollment request + """ + return reverse( + 'programs_api:v1:program_course_enrollments', + kwargs={ + 'program_uuid': str(program_uuid), + 'course_id': str(course_id) + } + ) + + def create_program_enrollment(self, external_user_key, user=False): + """ + Creates and returns a ProgramEnrollment for the given external_user_key and + user if specified. + """ + program_enrollment = ProgramEnrollmentFactory.create( + external_user_key=external_user_key, + program_uuid=self.program_uuid, + ) + if user is not False: + program_enrollment.user = user + program_enrollment.save() + return program_enrollment + + def test_enrollments(self): + self.create_program_enrollment('l1') + self.create_program_enrollment('l2') + self.create_program_enrollment('l3', user=None) + self.create_program_enrollment('l4', user=None) + post_data = [ + self.learner_enrollment("l1", "active"), + self.learner_enrollment("l2", "inactive"), + self.learner_enrollment("l3", "active"), + self.learner_enrollment("l4", "inactive"), + ] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(200, response.status_code) + self.assertDictEqual( + { + "l1": "active", + "l2": "inactive", + "l3": "active", + "l4": "inactive", + }, + response.data + ) + self.assert_program_course_enrollment("l1", "active", True) + self.assert_program_course_enrollment("l2", "inactive", True) + self.assert_program_course_enrollment("l3", "active", False) + self.assert_program_course_enrollment("l4", "inactive", False) + + def assert_program_course_enrollment(self, external_user_key, expected_status, has_user): + """ + Convenience method to assert that a ProgramCourseEnrollment has been created, + and potentially that a CourseEnrollment has also been created + """ + enrollment = ProgramCourseEnrollment.objects.get( + program_enrollment__external_user_key=external_user_key, + program_enrollment__program_uuid=self.program_uuid + ) + self.assertEqual(expected_status, enrollment.status) + self.assertEqual(self.course_key, enrollment.course_key) + course_enrollment = enrollment.course_enrollment + if has_user: + self.assertIsNotNone(course_enrollment) + self.assertEqual(expected_status == "active", course_enrollment.is_active) + self.assertEqual(self.course_key, course_enrollment.course_id) + else: + self.assertIsNone(course_enrollment) + + def test_duplicate(self): + post_data = [ + self.learner_enrollment("l1", "active"), + self.learner_enrollment("l1", "active"), + ] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(422, response.status_code) + self.assertDictEqual( + { + "l1": CourseStatuses.DUPLICATED + }, + response.data + ) + + def test_conflict(self): + program_enrollment = self.create_program_enrollment('l1') + ProgramCourseEnrollmentFactory.create( + program_enrollment=program_enrollment, + course_key=self.course_key + ) + post_data = [self.learner_enrollment("l1")] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(422, response.status_code) + self.assertDictEqual({'l1': CourseStatuses.CONFLICT}, response.data) + + def test_user_not_in_program(self): + self.create_program_enrollment('l1') + post_data = [ + self.learner_enrollment("l1"), + self.learner_enrollment("l2"), + ] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(207, response.status_code) + self.assertDictEqual( + { + "l1": "active", + "l2": "not-in-program", + }, + response.data + ) + + def test_401_not_logged_in(self): + self.client.logout() + post_data = [self.learner_enrollment("A")] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(401, response.status_code) + + def test_403_forbidden(self): + self.client.logout() + self.client.login(username=self.student, password=self.password) + post_data = [self.learner_enrollment("A")] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(403, response.status_code) + + def test_413_payload_too_large(self): + post_data = [self.learner_enrollment(str(i)) for i in range(30)] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(413, response.status_code) + + def test_404_not_found_program(self): + paths = [ + self.get_url(uuid4(), self.course_key), + self.get_url(self.program_uuid, CourseKey.from_string("course-v1:fake+fake+fake")), + self.get_url(self.program_uuid, self.course_not_in_program_key), + ] + post_data = [self.learner_enrollment("A")] + for path_404 in paths: + response = self.client.post(path_404, post_data, format="json") + self.assertEqual(404, response.status_code) + + def test_invalid_status(self): + post_data = [self.learner_enrollment('A', 'this-is-not-a-status')] + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(422, response.status_code) + self.assertDictEqual({'A': CourseStatuses.INVALID_STATUS}, response.data) + + @ddt.data( + [{'status': 'active'}], + [{'student_key': '000'}], + ["this isn't even a dict!"], + [{'student_key': '000', 'status': 'active'}, "bad_data"], + "not a list", + ) + def test_422_unprocessable_entity_bad_data(self, post_data): + response = self.client.post(self.default_url, post_data, format="json") + self.assertEqual(response.status_code, 422) + self.assertIn('invalid enrollment record', response.data) diff --git a/lms/djangoapps/program_enrollments/api/v1/urls.py b/lms/djangoapps/program_enrollments/api/v1/urls.py index b055a0bc85..a4ada97693 100644 --- a/lms/djangoapps/program_enrollments/api/v1/urls.py +++ b/lms/djangoapps/program_enrollments/api/v1/urls.py @@ -1,7 +1,9 @@ """ Program Enrollments API v1 URLs. """ from django.conf.urls import url -from lms.djangoapps.program_enrollments.api.v1.views import ProgramEnrollmentsView +from lms.djangoapps.program_enrollments.api.v1.constants import PROGRAM_UUID_PATTERN +from lms.djangoapps.program_enrollments.api.v1.views import ProgramEnrollmentsView, ProgramCourseEnrollmentsView +from openedx.core.constants import COURSE_ID_PATTERN app_name = 'lms.djangoapps.program_enrollments' @@ -12,4 +14,12 @@ urlpatterns = [ ProgramEnrollmentsView.as_view(), name='program_enrollments' ), + url( + r'^programs/{program_uuid}/course/{course_id}/enrollments/'.format( + program_uuid=PROGRAM_UUID_PATTERN, + course_id=COURSE_ID_PATTERN + ), + ProgramCourseEnrollmentsView.as_view(), + name="program_course_enrollments" + ), ] diff --git a/lms/djangoapps/program_enrollments/api/v1/views.py b/lms/djangoapps/program_enrollments/api/v1/views.py index 89d40c0c7d..0036e48e6e 100644 --- a/lms/djangoapps/program_enrollments/api/v1/views.py +++ b/lms/djangoapps/program_enrollments/api/v1/views.py @@ -3,18 +3,27 @@ ProgramEnrollment Views """ from __future__ import unicode_literals - from functools import wraps +from django.http import Http404 +from opaque_keys.edx.keys import CourseKey +from rest_framework import status +from rest_framework.exceptions import ValidationError +from rest_framework.pagination import CursorPagination +from rest_framework.response import Response +from rest_framework.views import APIView + from edx_rest_framework_extensions import permissions from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser -from rest_framework import status -from rest_framework.pagination import CursorPagination - -from lms.djangoapps.program_enrollments.api.v1.serializers import ProgramEnrollmentListSerializer -from lms.djangoapps.program_enrollments.models import ProgramEnrollment +from lms.djangoapps.program_enrollments.api.v1.constants import CourseEnrollmentResponseStatuses, MAX_ENROLLMENT_RECORDS +from lms.djangoapps.program_enrollments.api.v1.serializers import ( + ProgramEnrollmentListSerializer, + ProgramCourseEnrollmentRequestSerializer, +) +from lms.djangoapps.program_enrollments.models import ProgramEnrollment, ProgramCourseEnrollment from openedx.core.djangoapps.catalog.utils import get_programs +from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, PaginatedAPIView @@ -124,3 +133,187 @@ class ProgramEnrollmentsView(DeveloperErrorViewMixin, PaginatedAPIView): paginated_enrollments = self.paginate_queryset(enrollments) serializer = ProgramEnrollmentListSerializer(paginated_enrollments, many=True) return self.get_paginated_response(serializer.data) + + +class ProgramSpecificViewMixin(object): + """ + A mixin for views that operate on or within a specific program. + """ + + @property + def program(self): + """ + The program specified by the `program_uuid` URL parameter. + """ + program = get_programs(uuid=self.kwargs['program_uuid']) + if program is None: + raise Http404() + return program + + +class ProgramCourseRunSpecificViewMixin(ProgramSpecificViewMixin): + """ + A mixin for views that operate on or within a specific course run in a program + """ + + def check_course_existence_and_membership(self): + """ + Attempting to look up the course and program will trigger 404 responses if: + - The program does not exist + - The course run (course_key) does not exist + - The course run is not part of the program + """ + self.course_run # pylint: disable=pointless-statement + + @property + def course_run(self): + """ + The course run specified by the `course_id` URL parameter. + """ + try: + CourseOverview.get_from_id(self.course_key) + except CourseOverview.DoesNotExist: + raise Http404() + for course in self.program["courses"]: + for course_run in course["course_runs"]: + if self.course_key == CourseKey.from_string(course_run["key"]): + return course_run + raise Http404() + + @property + def course_key(self): + """ + The course key for the course run specified by the `course_id` URL parameter. + """ + return CourseKey.from_string(self.kwargs['course_id']) + + +class ProgramCourseEnrollmentsView(ProgramCourseRunSpecificViewMixin, APIView): + """ + A view for enrolling students in a course through a program, + modifying program course enrollments, and listing program course + enrollments + + Path: /api/v1/programs/{program_uuid}/courses/{course_id}/enrollments + + Accepts: [POST] + + ------------------------------------------------------------------------------------ + POST + ------------------------------------------------------------------------------------ + + Returns: + * 200: Returns a map of students and their enrollment status. + * 207: Not all students enrolled. Returns resulting enrollment status. + * 401: User is not authenticated + * 403: User lacks read access organization of specified program. + * 404: Program does not exist, or course does not exist in program + * 422: Invalid request, unable to enroll students. + """ + authentication_classes = ( + JwtAuthentication, + OAuth2AuthenticationAllowInactiveUser, + SessionAuthenticationAllowInactiveUser, + ) + permission_classes = (permissions.JWT_RESTRICTED_APPLICATION_OR_USER_ACCESS,) + pagination_class = ProgramEnrollmentPagination + + def post(self, request, program_uuid=None, course_id=None): + """ + Enroll a list of students in a course in a program + """ + self.check_course_existence_and_membership() + results = {} + seen_student_keys = set() + enrollments = [] + + if not isinstance(request.data, list): + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + if len(request.data) > MAX_ENROLLMENT_RECORDS: + return Response( + 'enrollment limit 25', status.HTTP_413_REQUEST_ENTITY_TOO_LARGE + ) + + try: + for enrollment_request in request.data: + error_status = self.check_enrollment_request(enrollment_request, seen_student_keys) + if error_status: + results[enrollment_request["student_key"]] = error_status + else: + enrollments.append(enrollment_request) + except KeyError: # student_key is not in enrollment_request + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + except TypeError: # enrollment_request isn't a dict + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + except ValidationError: # there was some other error raised by the serializer + return Response('invalid enrollment record', status.HTTP_422_UNPROCESSABLE_ENTITY) + + program_enrollments = self.get_existing_program_enrollments(program_uuid, enrollments) + for enrollment in enrollments: + student_key = enrollment["student_key"] + if student_key in results and results[student_key] == CourseEnrollmentResponseStatuses.DUPLICATED: + continue + results[student_key] = self.enroll_learner_in_course(enrollment, program_enrollments) + + good_count = sum(1 for _, v in results.items() if v not in CourseEnrollmentResponseStatuses.ERROR_STATUSES) + if not good_count: + return Response(results, status.HTTP_422_UNPROCESSABLE_ENTITY) + if good_count != len(results): + return Response(results, status.HTTP_207_MULTI_STATUS) + else: + return Response(results) + + def check_enrollment_request(self, enrollment, seen_student_keys): + """ + Checks that the given enrollment record is valid and hasn't been duplicated + """ + student_key = enrollment['student_key'] + if student_key in seen_student_keys: + return CourseEnrollmentResponseStatuses.DUPLICATED + seen_student_keys.add(student_key) + enrollment_serializer = ProgramCourseEnrollmentRequestSerializer(data=enrollment) + try: + enrollment_serializer.is_valid(raise_exception=True) + except ValidationError as e: + if enrollment_serializer.has_invalid_status(): + return CourseEnrollmentResponseStatuses.INVALID_STATUS + else: + raise e + + def get_existing_program_enrollments(self, program_uuid, enrollments): + """ + Parameters: + - enrollments: A list of enrollment requests + Returns: + - Dictionary mapping all student keys in the enrollment requests + to that user's existing program enrollment in + """ + external_user_keys = [e["student_key"] for e in enrollments] + existing_enrollments = ProgramEnrollment.objects.filter( + external_user_key__in=external_user_keys, + program_uuid=program_uuid, + ) + existing_enrollments = existing_enrollments.prefetch_related('program_course_enrollments') + return {enrollment.external_user_key: enrollment for enrollment in existing_enrollments} + + def enroll_learner_in_course(self, enrollment_request, program_enrollments): + """ + Attempts to enroll the specified user into the course as a part of the + given program enrollment with the given status + + Returns the actual status + """ + student_key = enrollment_request['student_key'] + try: + program_enrollment = program_enrollments[student_key] + except KeyError: + return CourseEnrollmentResponseStatuses.NOT_IN_PROGRAM + if program_enrollment.get_program_course_enrollment(self.course_key): + return CourseEnrollmentResponseStatuses.CONFLICT + + enrollment_status = ProgramCourseEnrollment.enroll( + program_enrollment, + self.course_key, + enrollment_request['status'] + ) + return enrollment_status diff --git a/lms/djangoapps/program_enrollments/migrations/0004_add_programcourseenrollment_relatedname.py b/lms/djangoapps/program_enrollments/migrations/0004_add_programcourseenrollment_relatedname.py new file mode 100644 index 0000000000..67fc74aa1f --- /dev/null +++ b/lms/djangoapps/program_enrollments/migrations/0004_add_programcourseenrollment_relatedname.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.20 on 2019-05-01 21:46 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('program_enrollments', '0003_auto_20190424_1622'), + ] + + operations = [ + migrations.AlterField( + model_name='programcourseenrollment', + name='program_enrollment', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='program_course_enrollments', to='program_enrollments.ProgramEnrollment'), + ), + ] diff --git a/lms/djangoapps/program_enrollments/models.py b/lms/djangoapps/program_enrollments/models.py index 3a3d9802d2..d6d134bbd7 100644 --- a/lms/djangoapps/program_enrollments/models.py +++ b/lms/djangoapps/program_enrollments/models.py @@ -8,6 +8,8 @@ from django.contrib.auth.models import User from django.core.exceptions import ValidationError from django.db import models from django.utils.translation import ugettext_lazy as _ +from course_modes.models import CourseMode +from lms.djangoapps.program_enrollments.api.v1.constants import CourseEnrollmentResponseStatuses from model_utils.models import TimeStampedModel from opaque_keys.edx.django.models import CourseKeyField from simple_history.models import HistoricalRecords @@ -77,6 +79,17 @@ class ProgramEnrollment(TimeStampedModel): # pylint: disable=model-missing-unic enrollments.update(external_user_key=None) return True + def get_program_course_enrollment(self, course_key): + """ + Returns the ProgramCourseEnrollment associated with this ProgramEnrollment and given course, + None if it does not exist + """ + try: + program_course_enrollment = self.program_course_enrollments.get(course_key=course_key) + except ProgramCourseEnrollment.DoesNotExist: + return None + return program_course_enrollment + def __str__(self): return '[ProgramEnrollment id={}]'.format(self.id) @@ -96,7 +109,11 @@ class ProgramCourseEnrollment(TimeStampedModel): # pylint: disable=model-missin class Meta(object): app_label = "program_enrollments" - program_enrollment = models.ForeignKey(ProgramEnrollment, on_delete=models.CASCADE) + program_enrollment = models.ForeignKey( + ProgramEnrollment, + on_delete=models.CASCADE, + related_name="program_course_enrollments" + ) course_enrollment = models.OneToOneField( StudentCourseEnrollment, null=True, @@ -108,3 +125,27 @@ class ProgramCourseEnrollment(TimeStampedModel): # pylint: disable=model-missin def __str__(self): return '[ProgramCourseEnrollment id={}]'.format(self.id) + + @classmethod + def enroll(cls, program_enrollment, course_key, status): + """ + Create ProgramCourseEnrollment for the given course and program enrollment + """ + course_enrollment = None + if program_enrollment.user: + course_enrollment = StudentCourseEnrollment.enroll( + program_enrollment.user, + course_key, + mode=CourseMode.MASTERS, + check_access=True, + ) + if status == CourseEnrollmentResponseStatuses.INACTIVE: + course_enrollment.deactivate() + + program_course_enrollment = ProgramCourseEnrollment.objects.create( + program_enrollment=program_enrollment, + course_enrollment=course_enrollment, + course_key=course_key, + status=status, + ) + return program_course_enrollment.status