diff --git a/common/djangoapps/enrollment/api.py b/common/djangoapps/enrollment/api.py index 25047219f7..2e352a7581 100644 --- a/common/djangoapps/enrollment/api.py +++ b/common/djangoapps/enrollment/api.py @@ -7,6 +7,8 @@ import importlib import logging from django.conf import settings from django.core.cache import cache +from opaque_keys.edx.keys import CourseKey + from course_modes.models import CourseMode from enrollment import errors @@ -372,9 +374,8 @@ def _default_course_mode(course_id): Returns: str """ - course_enrollment_info = _data_api().get_course_enrollment_info(course_id, include_expired=False) - course_modes = course_enrollment_info["course_modes"] - available_modes = [m['slug'] for m in course_modes] + course_modes = CourseMode.modes_for_course(CourseKey.from_string(course_id)) + available_modes = [m.slug for m in course_modes] if CourseMode.DEFAULT_MODE_SLUG in available_modes: return CourseMode.DEFAULT_MODE_SLUG diff --git a/common/djangoapps/enrollment/tests/test_api.py b/common/djangoapps/enrollment/tests/test_api.py index 3d41d24743..95b441d157 100644 --- a/common/djangoapps/enrollment/tests/test_api.py +++ b/common/djangoapps/enrollment/tests/test_api.py @@ -1,6 +1,8 @@ """ Tests for student enrollment. """ +from mock import patch, Mock + import ddt from django.core.cache import cache from nose.tools import raises @@ -67,12 +69,15 @@ class EnrollmentTest(TestCase): def test_enroll_no_mode_success(self, course_modes, expected_mode): # Add a fake course enrollment information to the fake data API fake_data_api.add_course(self.COURSE_ID, course_modes=course_modes) - # Enroll in the course and verify the URL we get sent to - result = api.add_enrollment(self.USERNAME, self.COURSE_ID) - self.assertIsNotNone(result) - self.assertEquals(result['student'], self.USERNAME) - self.assertEquals(result['course']['course_id'], self.COURSE_ID) - self.assertEquals(result['mode'], expected_mode) + with patch('enrollment.api.CourseMode.modes_for_course') as mock_modes_for_course: + mock_course_modes = [Mock(slug=mode) for mode in course_modes] + mock_modes_for_course.return_value = mock_course_modes + # Enroll in the course and verify the URL we get sent to + result = api.add_enrollment(self.USERNAME, self.COURSE_ID) + self.assertIsNotNone(result) + self.assertEquals(result['student'], self.USERNAME) + self.assertEquals(result['course']['course_id'], self.COURSE_ID) + self.assertEquals(result['mode'], expected_mode) @ddt.data( ['professional'], diff --git a/common/djangoapps/enrollment/views.py b/common/djangoapps/enrollment/views.py index 110df7e906..69aea5996c 100644 --- a/common/djangoapps/enrollment/views.py +++ b/common/djangoapps/enrollment/views.py @@ -520,7 +520,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): } ) - mode = request.data.get('mode', CourseMode.DEFAULT_MODE_SLUG) + mode = request.data.get('mode') has_api_key_permissions = self.has_api_key_permissions(request) @@ -532,7 +532,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): # other users, do not let them deduce the existence of an enrollment. return Response(status=status.HTTP_404_NOT_FOUND) - if mode != CourseMode.DEFAULT_MODE_SLUG and not has_api_key_permissions: + if mode not in (CourseMode.AUDIT, CourseMode.HONOR, None) and not has_api_key_permissions: return Response( status=status.HTTP_403_FORBIDDEN, data={ diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index de736b7f9f..d3976288a4 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -47,6 +47,7 @@ from xmodule_django.models import CourseKeyField, NoneToEmptyManager from certificates.models import GeneratedCertificate from course_modes.models import CourseMode +from enrollment.api import _default_course_mode import lms.lib.comment_client as cc from openedx.core.djangoapps.commerce.utils import ecommerce_api_client, ECOMMERCE_DATE_FORMAT from openedx.core.djangoapps.content.course_overviews.models import CourseOverview @@ -1090,7 +1091,7 @@ class CourseEnrollment(models.Model): ) @classmethod - def enroll(cls, user, course_key, mode=CourseMode.DEFAULT_MODE_SLUG, check_access=False): + def enroll(cls, user, course_key, mode=None, check_access=False): """ Enroll a user in a course. This saves immediately. @@ -1124,6 +1125,8 @@ class CourseEnrollment(models.Model): Also emits relevant events for analytics purposes. """ + if mode is None: + mode = _default_course_mode(unicode(course_key)) # All the server-side checks for whether a user is allowed to enroll. try: course = CourseOverview.get_from_id(course_key) @@ -1165,7 +1168,7 @@ class CourseEnrollment(models.Model): return enrollment @classmethod - def enroll_by_email(cls, email, course_id, mode=CourseMode.DEFAULT_MODE_SLUG, ignore_errors=True): + def enroll_by_email(cls, email, course_id, mode=None, ignore_errors=True): """ Enroll a user in a course given their email. This saves immediately. diff --git a/lms/djangoapps/instructor/enrollment.py b/lms/djangoapps/instructor/enrollment.py index e134964fa2..31f0f3ff25 100644 --- a/lms/djangoapps/instructor/enrollment.py +++ b/lms/djangoapps/instructor/enrollment.py @@ -119,7 +119,7 @@ def enroll_email(course_id, student_email, auto_enroll=False, email_students=Fal if CourseMode.is_white_label(course_id): course_mode = CourseMode.DEFAULT_SHOPPINGCART_MODE_SLUG else: - course_mode = CourseMode.DEFAULT_MODE_SLUG + course_mode = None if previous_state.enrollment: course_mode = previous_state.mode