diff --git a/common/djangoapps/enrollment/tests/test_views.py b/common/djangoapps/enrollment/tests/test_views.py index a3e961a69f..bc8272569d 100644 --- a/common/djangoapps/enrollment/tests/test_views.py +++ b/common/djangoapps/enrollment/tests/test_views.py @@ -1,10 +1,11 @@ """ Tests for user enrollment. """ -import ddt import json import unittest +import ddt +from django.core.cache import cache from mock import patch from django.test import Client from django.core.handlers.wsgi import WSGIRequest @@ -14,30 +15,81 @@ from rest_framework import status from django.conf import settings from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory +from django.test.utils import override_settings + from course_modes.models import CourseMode +from embargo.models import CountryAccessRule, Country, RestrictedCourse from enrollment.views import EnrollmentUserThrottle from util.models import RateLimitConfiguration from util.testing import UrlResetMixin from enrollment import api from enrollment.errors import CourseEnrollmentError from openedx.core.djangoapps.user_api.models import UserOrgTag -from django.test.utils import override_settings from student.tests.factories import UserFactory, CourseModeFactory from student.models import CourseEnrollment from embargo.test_utils import restrict_course +class EnrollmentTestMixin(object): + """ Mixin with methods useful for testing enrollments. """ + API_KEY = "i am a key" + + def assert_enrollment_status( + self, + course_id=None, + username=None, + expected_status=status.HTTP_200_OK, + email_opt_in=None, + as_server=False, + mode=CourseMode.HONOR, + ): + """ + Enroll in the course and verify the response's status code. If the expected status is 200, also validates + the response content. + + Returns + Response + """ + course_id = course_id or unicode(self.course.id) + username = username or self.user.username + + data = { + 'mode': mode, + 'course_details': { + 'course_id': course_id + }, + 'user': username + } + if email_opt_in is not None: + data['email_opt_in'] = email_opt_in + + extra = {} + if as_server: + extra['HTTP_X_EDX_API_KEY'] = self.API_KEY + + url = reverse('courseenrollments') + response = self.client.post(url, json.dumps(data), content_type='application/json', **extra) + self.assertEqual(response.status_code, expected_status) + + if expected_status in [status.HTTP_200_OK, status.HTTP_200_OK]: + data = json.loads(response.content) + self.assertEqual(course_id, data['course_details']['course_id']) + self.assertEqual(mode, data['mode']) + self.assertTrue(data['is_active']) + + return response + + @override_settings(EDX_API_KEY="i am a key") @ddt.ddt @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') -class EnrollmentTest(ModuleStoreTestCase, APITestCase): +class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase): """ Test user enrollment, especially with different course modes. """ USERNAME = "Bob" EMAIL = "bob@example.com" PASSWORD = "edx" - API_KEY = "i am a key" def setUp(self): """ Create a course and user, then log in. """ @@ -76,7 +128,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Create an enrollment - self._create_enrollment() + self.assert_enrollment_status() self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) @@ -90,9 +142,9 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): mode_display_name=CourseMode.HONOR, ) # Create an enrollment - self._create_enrollment() + self.assert_enrollment_status() resp = self.client.get( - reverse('courseenrollment', kwargs={"user": self.user.username, "course_id": unicode(self.course.id)}) + reverse('courseenrollment', kwargs={'username': self.user.username, "course_id": unicode(self.course.id)}) ) self.assertEqual(resp.status_code, status.HTTP_200_OK) data = json.loads(resp.content) @@ -111,13 +163,14 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): Verify that the email_opt_in parameter sets the underlying flag. And that if the argument is not present, then it does not affect the flag """ + def _assert_no_opt_in_set(): """ Check the tag doesn't exit""" with self.assertRaises(UserOrgTag.DoesNotExist): UserOrgTag.objects.get(user=self.user, org=self.course.id.org, key="email-optin") _assert_no_opt_in_set() - self._create_enrollment(email_opt_in=opt_in) + self.assert_enrollment_status(email_opt_in=opt_in) if opt_in is None: _assert_no_opt_in_set() else: @@ -133,7 +186,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Enroll in the course, this will fail if the mode is not explicitly professional. - resp = self._create_enrollment(expected_status=status.HTTP_400_BAD_REQUEST) + resp = self.assert_enrollment_status(expected_status=status.HTTP_400_BAD_REQUEST) # While the enrollment wrong is invalid, the response content should have # all the valid enrollment modes. @@ -149,7 +202,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): mode_display_name=CourseMode.HONOR, ) # Create an enrollment - self._create_enrollment() + self.assert_enrollment_status() resp = self.client.get( reverse('courseenrollment', kwargs={"course_id": unicode(self.course.id)}) ) @@ -164,7 +217,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.client.logout() # Try to enroll, this should fail. - self._create_enrollment(expected_status=status.HTTP_401_UNAUTHORIZED) + self.assert_enrollment_status(expected_status=status.HTTP_401_UNAUTHORIZED) def test_user_not_activated(self): # Log out the default user, Bob. @@ -187,7 +240,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.user.save() # Enrollment should succeed, even though we haven't authenticated. - self._create_enrollment() + self.assert_enrollment_status() def test_user_does_not_match_url(self): # Try to enroll a user that is not the authenticated user. @@ -196,10 +249,10 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): mode_slug=CourseMode.HONOR, mode_display_name=CourseMode.HONOR, ) - self._create_enrollment(username=self.other_user.username, expected_status=status.HTTP_404_NOT_FOUND) + self.assert_enrollment_status(username=self.other_user.username, expected_status=status.HTTP_404_NOT_FOUND) # Verify that the server still has access to this endpoint. self.client.logout() - self._create_enrollment(username=self.other_user.username, as_server=True) + self.assert_enrollment_status(username=self.other_user.username, as_server=True) def test_user_does_not_match_param_for_list(self): CourseModeFactory.create( @@ -207,12 +260,12 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): mode_slug=CourseMode.HONOR, mode_display_name=CourseMode.HONOR, ) - resp = self.client.get(reverse('courseenrollments'), {"user": self.other_user.username}) + resp = self.client.get(reverse('courseenrollments'), {'user': self.other_user.username}) self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) # Verify that the server still has access to this endpoint. self.client.logout() resp = self.client.get( - reverse('courseenrollments'), {"user": self.other_user.username}, **{'HTTP_X_EDX_API_KEY': self.API_KEY} + reverse('courseenrollments'), {'username': self.other_user.username}, **{'HTTP_X_EDX_API_KEY': self.API_KEY} ) self.assertEqual(resp.status_code, status.HTTP_200_OK) @@ -222,16 +275,15 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): mode_slug=CourseMode.HONOR, mode_display_name=CourseMode.HONOR, ) - resp = self.client.get( - reverse('courseenrollment', kwargs={"user": self.other_user.username, "course_id": unicode(self.course.id)}) - ) + url = reverse('courseenrollment', + kwargs={'username': self.other_user.username, "course_id": unicode(self.course.id)}) + + resp = self.client.get(url) + # Verify that the server still has access to this endpoint. self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) self.client.logout() - resp = self.client.get( - reverse('courseenrollment', kwargs={"user": self.other_user.username, "course_id": unicode(self.course.id)}), - **{'HTTP_X_EDX_API_KEY': self.API_KEY} - ) + resp = self.client.get(url, **{'HTTP_X_EDX_API_KEY': self.API_KEY}) self.assertEqual(resp.status_code, status.HTTP_200_OK) def test_get_course_details(self): @@ -254,7 +306,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.assertEqual(mode['name'], CourseMode.HONOR) def test_with_invalid_course_id(self): - self._create_enrollment(course_id='entirely/fake/course', expected_status=status.HTTP_400_BAD_REQUEST) + self.assert_enrollment_status(course_id='entirely/fake/course', expected_status=status.HTTP_400_BAD_REQUEST) def test_get_enrollment_details_bad_course(self): resp = self.client.get( @@ -266,13 +318,13 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): def test_get_enrollment_internal_error(self, mock_get_enrollment): mock_get_enrollment.side_effect = CourseEnrollmentError("Something bad happened.") resp = self.client.get( - reverse('courseenrollment', kwargs={"user": self.user.username, "course_id": unicode(self.course.id)}) + reverse('courseenrollment', kwargs={'username': self.user.username, "course_id": unicode(self.course.id)}) ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) def test_enrollment_already_enrolled(self): - response = self._create_enrollment() - repeat_response = self._create_enrollment(expected_status=status.HTTP_200_OK) + response = self.assert_enrollment_status() + repeat_response = self.assert_enrollment_status(expected_status=status.HTTP_200_OK) self.assertEqual(json.loads(response.content), json.loads(repeat_response.content)) def test_get_enrollment_with_invalid_key(self): @@ -301,7 +353,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): for attempt in xrange(self.rate_limit + 10): expected_status = status.HTTP_429_TOO_MANY_REQUESTS if attempt >= self.rate_limit else status.HTTP_200_OK - self._create_enrollment(expected_status=expected_status) + self.assert_enrollment_status(expected_status=expected_status) def test_enrollment_throttle_for_service(self): """Make sure a service can call the enrollment API as many times as needed. """ @@ -314,7 +366,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) for attempt in xrange(self.rate_limit + 10): - self._create_enrollment(as_server=True) + self.assert_enrollment_status(as_server=True) def test_create_enrollment_with_mode(self): """With the right API key, create a new enrollment with a mode set other than the default.""" @@ -326,7 +378,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Create an enrollment - self._create_enrollment(as_server=True, mode='professional') + self.assert_enrollment_status(as_server=True, mode='professional') self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) @@ -344,7 +396,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Create an enrollment - self._create_enrollment(as_server=True) + self.assert_enrollment_status(as_server=True) # Check that the enrollment is honor. self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) @@ -353,7 +405,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.assertEqual(course_mode, CourseMode.HONOR) # Check that the enrollment upgraded to verified. - self._create_enrollment(as_server=True, mode=CourseMode.VERIFIED, expected_status=status.HTTP_200_OK) + self.assert_enrollment_status(as_server=True, mode=CourseMode.VERIFIED, expected_status=status.HTTP_200_OK) course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) self.assertTrue(is_active) self.assertEqual(course_mode, CourseMode.VERIFIED) @@ -369,7 +421,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Create a 'verified' enrollment - self._create_enrollment(as_server=True, mode=CourseMode.VERIFIED) + self.assert_enrollment_status(as_server=True, mode=CourseMode.VERIFIED) # Check that the enrollment is verified. self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) @@ -378,7 +430,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.assertEqual(course_mode, CourseMode.VERIFIED) # Check that the enrollment downgraded to honor. - self._create_enrollment(as_server=True, mode=CourseMode.HONOR, expected_status=status.HTTP_200_OK) + self.assert_enrollment_status(as_server=True, mode=CourseMode.HONOR, expected_status=status.HTTP_200_OK) course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) self.assertTrue(is_active) self.assertEqual(course_mode, CourseMode.HONOR) @@ -394,7 +446,7 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ) # Create an enrollment - self._create_enrollment() + self.assert_enrollment_status() # Check that the enrollment is honor. self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) @@ -403,50 +455,23 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): self.assertEqual(course_mode, CourseMode.HONOR) # Get a 403 response when trying to upgrade yourself. - self._create_enrollment(mode=CourseMode.VERIFIED, expected_status=status.HTTP_403_FORBIDDEN) + self.assert_enrollment_status(mode=CourseMode.VERIFIED, expected_status=status.HTTP_403_FORBIDDEN) course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) self.assertTrue(is_active) self.assertEqual(course_mode, CourseMode.HONOR) - def _create_enrollment( - self, - course_id=None, - username=None, - expected_status=status.HTTP_200_OK, - email_opt_in=None, - as_server=False, - mode=CourseMode.HONOR, - ): - """Enroll in the course and verify the URL we are sent to. """ - course_id = unicode(self.course.id) if course_id is None else course_id - username = self.user.username if username is None else username - - params = { - 'mode': mode, - 'course_details': { - 'course_id': course_id - }, - 'user': username - } - if email_opt_in is not None: - params['email_opt_in'] = email_opt_in - if as_server: - resp = self.client.post(reverse('courseenrollments'), params, format='json', **{'HTTP_X_EDX_API_KEY': self.API_KEY}) - else: - resp = self.client.post(reverse('courseenrollments'), params, format='json') - - self.assertEqual(resp.status_code, expected_status) - - if expected_status in [status.HTTP_200_OK, status.HTTP_200_OK]: - data = json.loads(resp.content) - self.assertEqual(course_id, data['course_details']['course_id']) - self.assertEqual(mode, data['mode']) - self.assertTrue(data['is_active']) - return resp + def test_change_mode_invalid_user(self): + """ + Attempts to change an enrollment for a non-existent user should result in an HTTP 404 for non-server users, + and HTTP 406 for server users. + """ + self.assert_enrollment_status(username='fake-user', expected_status=status.HTTP_404_NOT_FOUND, as_server=False) + self.assert_enrollment_status(username='fake-user', expected_status=status.HTTP_406_NOT_ACCEPTABLE, + as_server=True) @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') -class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase): +class EnrollmentEmbargoTest(EnrollmentTestMixin, UrlResetMixin, ModuleStoreTestCase): """Test that enrollment is blocked from embargoed countries. """ USERNAME = "Bob" @@ -460,51 +485,98 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase): self.course = CourseFactory.create() self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) self.client.login(username=self.USERNAME, password=self.PASSWORD) + self.url = reverse('courseenrollments') - @patch.dict(settings.FEATURES, {'EMBARGO': True}) - def test_embargo_change_enrollment_restrict(self): - url = reverse('courseenrollments') - data = json.dumps({ + def _generate_data(self): + return json.dumps({ 'course_details': { 'course_id': unicode(self.course.id) }, 'user': self.user.username }) - # Attempt to enroll from a country embargoed for this course - with restrict_course(self.course.id) as redirect_url: - response = self.client.post(url, data, content_type='application/json') + def assert_access_denied(self, user_message_url): + """ + Verify that the view returns HTTP status 403 and includes a URL in the response, and no enrollment is created. + """ + data = self._generate_data() + response = self.client.post(self.url, data, content_type='application/json') - # Expect an error response - self.assertEqual(response.status_code, 403) + # Expect an error response + self.assertEqual(response.status_code, 403) - # Expect that the redirect URL is included in the response - resp_data = json.loads(response.content) - self.assertEqual(resp_data['user_message_url'], redirect_url) + # Expect that the redirect URL is included in the response + resp_data = json.loads(response.content) + self.assertEqual(resp_data['user_message_url'], user_message_url) # Verify that we were not enrolled self.assertEqual(self._get_enrollments(), []) @patch.dict(settings.FEATURES, {'EMBARGO': True}) - def test_embargo_change_enrollment_allow(self): - url = reverse('courseenrollments') - data = json.dumps({ - 'course_details': { - 'course_id': unicode(self.course.id) - }, - 'user': self.user.username - }) + def test_embargo_change_enrollment_restrict_geoip(self): + """ Validates that enrollment changes are blocked if the request originates from an embargoed country. """ - response = self.client.post(url, data, content_type='application/json') - self.assertEqual(response.status_code, status.HTTP_200_OK) + # Use the helper to setup the embargo and simulate a request from a blocked IP address. + with restrict_course(self.course.id) as redirect_url: + self.assert_access_denied(redirect_url) + + def _setup_embargo(self): + restricted_course = RestrictedCourse.objects.create(course_key=self.course.id) + + restricted_country = Country.objects.create(country='US') + unrestricted_country = Country.objects.create(country='CA') + + CountryAccessRule.objects.create( + rule_type=CountryAccessRule.BLACKLIST_RULE, + restricted_course=restricted_course, + country=restricted_country + ) + + # Clear the cache to remove the effects of previous embargo tests + cache.clear() + + return unrestricted_country, restricted_country + + @override_settings(EDX_API_KEY=EnrollmentTestMixin.API_KEY) + @patch.dict(settings.FEATURES, {'EMBARGO': True}) + def test_embargo_change_enrollment_restrict_user_profile(self): + """ Validates that enrollment changes are blocked if the user's profile is linked to an embargoed country. """ + + __, restricted_country = self._setup_embargo() + + # Update the user's profile, linking the user to the embargoed country. + self.user.profile.country = restricted_country.country + self.user.profile.save() + + user_message_url = reverse('embargo_blocked_message', + kwargs={'access_point': 'enrollment', 'message_key': 'default'}) + self.assert_access_denied(user_message_url) + + @override_settings(EDX_API_KEY=EnrollmentTestMixin.API_KEY) + @patch.dict(settings.FEATURES, {'EMBARGO': True}) + def test_embargo_change_enrollment_allow_user_profile(self): + """ + Validates that enrollment changes are allowed if the user's profile is NOT linked to an embargoed country. + """ + + # Setup the embargo + unrestricted_country, __ = self._setup_embargo() + + # Verify that users without black-listed country codes *can* be enrolled + self.user.profile.country = unrestricted_country.country + self.user.profile.save() + self.assert_enrollment_status() + + @patch.dict(settings.FEATURES, {'EMBARGO': True}) + def test_embargo_change_enrollment_allow(self): + self.assert_enrollment_status() # Verify that we were enrolled self.assertEqual(len(self._get_enrollments()), 1) def _get_enrollments(self): """Retrieve the enrollment list for the current user. """ - url = reverse('courseenrollments') - resp = self.client.get(url) + resp = self.client.get(self.url) return json.loads(resp.content) diff --git a/common/djangoapps/enrollment/urls.py b/common/djangoapps/enrollment/urls.py index 4609b07a68..438ad259df 100644 --- a/common/djangoapps/enrollment/urls.py +++ b/common/djangoapps/enrollment/urls.py @@ -11,12 +11,13 @@ from .views import ( EnrollmentCourseDetailView ) -USER_PATTERN = '(?P[\w.@+-]+)' +USERNAME_PATTERN = '(?P[\w.@+-]+)' urlpatterns = patterns( 'enrollment.views', url( - r'^enrollment/{user},{course_key}$'.format(user=USER_PATTERN, course_key=settings.COURSE_ID_PATTERN), + r'^enrollment/{username},{course_key}$'.format(username=USERNAME_PATTERN, + course_key=settings.COURSE_ID_PATTERN), EnrollmentView.as_view(), name='courseenrollment' ), diff --git a/common/djangoapps/enrollment/views.py b/common/djangoapps/enrollment/views.py index bdd867dd4d..4dd6d6422e 100644 --- a/common/djangoapps/enrollment/views.py +++ b/common/djangoapps/enrollment/views.py @@ -4,6 +4,7 @@ consist primarily of authentication, request validation, and serialization. """ from ipware.ip import get_ip +from django.core.exceptions import ObjectDoesNotExist from django.utils.decorators import method_decorator from opaque_keys import InvalidKeyError from course_modes.models import CourseMode @@ -24,6 +25,7 @@ from enrollment.errors import ( CourseNotFoundError, CourseEnrollmentError, CourseModeNotFoundError, CourseEnrollmentExistsError ) +from student.models import User class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf): @@ -104,11 +106,10 @@ class EnrollmentView(APIView, ApiKeyPermissionMixIn): permission_classes = ApiKeyHeaderPermissionIsAuthenticated, throttle_classes = EnrollmentUserThrottle, - # Since the course about page on the marketing site - # uses this API to auto-enroll users, we need to support - # cross-domain CSRF. + # Since the course about page on the marketing site uses this API to auto-enroll users, + # we need to support cross-domain CSRF. @method_decorator(ensure_csrf_cookie_cross_domain) - def get(self, request, course_id=None, user=None): + def get(self, request, course_id=None, username=None): """Create, read, or update enrollment information for a user. HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and @@ -119,27 +120,29 @@ class EnrollmentView(APIView, ApiKeyPermissionMixIn): information for the current user and the specified course. course_id (str): URI element specifying the course location. Enrollment information will be returned, created, or updated for this particular course. - user (str): The user username associated with this enrollment request. + username (str): The username associated with this enrollment request. Return: A JSON serialized representation of the course enrollment. """ - user = user if user else request.user.username - if request.user.username != user and not self.has_api_key_permissions(request): + username = username or request.user.username + + if request.user.username != username and not self.has_api_key_permissions(request): # Return a 404 instead of a 403 (Unauthorized). If one user is looking up # other users, do not let them deduce the existence of an enrollment. return Response(status=status.HTTP_404_NOT_FOUND) + try: - return Response(api.get_enrollment(user, course_id)) + return Response(api.get_enrollment(username, course_id)) except CourseEnrollmentError: return Response( status=status.HTTP_400_BAD_REQUEST, data={ "message": ( u"An error occurred while retrieving enrollments for user " - u"'{user}' in course '{course_id}'" - ).format(user=user, course_id=course_id) + u"'{username}' in course '{course_id}'" + ).format(username=username, course_id=course_id) } ) @@ -295,20 +298,20 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): """ Gets a list of all course enrollments for the currently logged in user. """ - user = request.GET.get('user', request.user.username) - if request.user.username != user and not self.has_api_key_permissions(request): + username = request.GET.get('user', request.user.username) + if request.user.username != username and not self.has_api_key_permissions(request): # Return a 404 instead of a 403 (Unauthorized). If one user is looking up # other users, do not let them deduce the existence of an enrollment. return Response(status=status.HTTP_404_NOT_FOUND) try: - return Response(api.get_enrollments(user)) + return Response(api.get_enrollments(username)) except CourseEnrollmentError: return Response( status=status.HTTP_400_BAD_REQUEST, data={ "message": ( - u"An error occurred while retrieving enrollments for user '{user}'" - ).format(user=user) + u"An error occurred while retrieving enrollments for user '{username}'" + ).format(username=username) } ) @@ -317,14 +320,15 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): Enrolls the currently logged in user in a course. """ # Get the User, Course ID, and Mode from the request. - user = request.DATA.get('user', request.user.username) + username = request.DATA.get('user', request.user.username) + course_id = request.DATA.get('course_details', {}).get('course_id') - if 'course_details' not in request.DATA or 'course_id' not in request.DATA['course_details']: + if not course_id: return Response( status=status.HTTP_400_BAD_REQUEST, data={"message": u"Course ID must be specified to create a new enrollment."} ) - course_id = request.DATA['course_details']['course_id'] + try: course_id = CourseKey.from_string(course_id) except InvalidKeyError: @@ -340,9 +344,9 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): has_api_key_permissions = self.has_api_key_permissions(request) # Check that the user specified is either the same user, or this is a server-to-server request. - if not user: - user = request.user.username - if user != request.user.username and not has_api_key_permissions: + if not username: + username = request.user.username + if username != request.user.username and not has_api_key_permissions: # Return a 404 instead of a 403 (Unauthorized). If one user is looking up # other users, do not let them deduce the existence of an enrollment. return Response(status=status.HTTP_404_NOT_FOUND) @@ -357,14 +361,22 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): } ) + try: + # Lookup the user, instead of using request.user, since request.user may not match the username POSTed. + user = User.objects.get(username=username) + except ObjectDoesNotExist: + return Response( + status=status.HTTP_406_NOT_ACCEPTABLE, + data={ + 'message': u'The user {} does not exist.'.format(username) + } + ) + # Check whether any country access rules block the user from enrollment # We do this at the view level (rather than the Python API level) # because this check requires information about the HTTP request. redirect_url = embargo_api.redirect_if_blocked( - course_id, user=user, - ip_address=get_ip(request), - url=request.path - ) + course_id, user=user, ip_address=get_ip(request), url=request.path) if redirect_url: return Response( status=status.HTTP_403_FORBIDDEN, @@ -384,11 +396,11 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): # Only server-to-server calls will currently be allowed to modify the mode for existing enrollments. All # other requests will go through add_enrollment(), which will allow creating of new enrollments, and # re-activating enrollments - enrollment = api.get_enrollment(user, unicode(course_id)) + enrollment = api.get_enrollment(username, unicode(course_id)) if has_api_key_permissions and enrollment and enrollment['mode'] != mode: - response = api.update_enrollment(user, unicode(course_id), mode=mode) + response = api.update_enrollment(username, unicode(course_id), mode=mode) else: - response = api.add_enrollment(user, unicode(course_id), mode=mode) + response = api.add_enrollment(username, unicode(course_id), mode=mode) email_opt_in = request.DATA.get('email_opt_in', None) if email_opt_in is not None: org = course_id.org @@ -418,7 +430,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): data={ "message": ( u"An error occurred while creating the new course enrollment for user " - u"'{user}' in course '{course_id}'" - ).format(user=user, course_id=course_id) + u"'{username}' in course '{course_id}'" + ).format(username=username, course_id=course_id) } )