diff --git a/common/djangoapps/cors_csrf/authentication.py b/common/djangoapps/cors_csrf/authentication.py new file mode 100644 index 0000000000..723ec8eed1 --- /dev/null +++ b/common/djangoapps/cors_csrf/authentication.py @@ -0,0 +1,29 @@ +"""Django Rest Framework Authentication classes for cross-domain end-points.""" +from rest_framework import authentication +from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check + + +class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication): + """Session authentication that skips the referer check over secure connections. + + Django Rest Framework's `SessionAuthentication` class calls Django's + CSRF middleware implementation directly, which bypasses the middleware + stack. + + This version of `SessionAuthentication` performs the same workaround + as `CorsCSRFMiddleware` to skip the referer check for whitelisted + domains over a secure connection. See `cors_csrf.middleware` for + more information. + + Since this subclass overrides only the `enforce_csrf()` method, + it can be mixed in with other `SessionAuthentication` subclasses. + + """ + + def enforce_csrf(self, request): + """Skip the referer check if the cross-domain request is allowed. """ + if is_cross_domain_request_allowed(request): + with skip_cross_domain_referer_check(request): + return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request) + else: + return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request) diff --git a/common/djangoapps/cors_csrf/helpers.py b/common/djangoapps/cors_csrf/helpers.py new file mode 100644 index 0000000000..b04db29b25 --- /dev/null +++ b/common/djangoapps/cors_csrf/helpers.py @@ -0,0 +1,92 @@ +"""Helper methods for CORS and CSRF checks. """ +import logging +import urlparse +import contextlib + +from django.conf import settings + +log = logging.getLogger(__name__) + + +def is_cross_domain_request_allowed(request): + """Check whether we should allow the cross-domain request. + + We allow a cross-domain request only if: + + 1) The request is made securely and the referer has "https://" as the protocol. + 2) The referer domain has been whitelisted. + + Arguments: + request (HttpRequest) + + Returns: + bool + + """ + referer = request.META.get('HTTP_REFERER') + referer_parts = urlparse.urlparse(referer) if referer else None + referer_hostname = referer_parts.hostname if referer_parts is not None else None + + # Use CORS_ALLOW_INSECURE *only* for development and testing environments; + # it should never be enabled in production. + if not getattr(settings, 'CORS_ALLOW_INSECURE', False): + if not request.is_secure(): + log.debug( + u"Request is not secure, so we cannot send the CSRF token. " + u"For testing purposes, you can disable this check by setting " + u"`CORS_ALLOW_INSECURE` to True in the settings" + ) + return False + + if not referer: + log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.") + return False + + if not referer_parts.scheme == 'https': + log.debug(u"Referer '%s' must have the scheme 'https'") + return False + + domain_is_whitelisted = ( + getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or + referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', []) + ) + if not domain_is_whitelisted: + if referer_hostname is None: + # If no referer is specified, we can't check if it's a cross-domain + # request or not. + log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.") + elif referer_hostname != request.get_host(): + log.warning( + ( + u"Domain '%s' is not on the cross domain whitelist. " + u"Add the domain to `CORS_ORIGIN_WHITELIST` or set " + u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings." + ), referer_hostname + ) + else: + log.debug( + ( + u"Domain '%s' is the same as the hostname in the request, " + u"so we are not going to treat it as a cross-domain request." + ), referer_hostname + ) + return False + + return True + + +@contextlib.contextmanager +def skip_cross_domain_referer_check(request): + """Skip the cross-domain CSRF referer check. + + Django's CSRF middleware performs the referer check + only when the request is made over a secure connection. + To skip the check, we patch `request.is_secure()` to + False. + """ + is_secure_default = request.is_secure + request.is_secure = lambda: False + try: + yield + finally: + request.is_secure = is_secure_default diff --git a/common/djangoapps/cors_csrf/middleware.py b/common/djangoapps/cors_csrf/middleware.py index 920acaeea9..dfd554d9dc 100644 --- a/common/djangoapps/cors_csrf/middleware.py +++ b/common/djangoapps/cors_csrf/middleware.py @@ -43,82 +43,16 @@ CSRF cookie. """ import logging -import urlparse from django.conf import settings from django.middleware.csrf import CsrfViewMiddleware from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured +from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check + log = logging.getLogger(__name__) -def is_cross_domain_request_allowed(request): - """Check whether we should allow the cross-domain request. - - We allow a cross-domain request only if: - - 1) The request is made securely and the referer has "https://" as the protocol. - 2) The referer domain has been whitelisted. - - Arguments: - request (HttpRequest) - - Returns: - bool - - """ - referer = request.META.get('HTTP_REFERER') - referer_parts = urlparse.urlparse(referer) if referer else None - referer_hostname = referer_parts.hostname if referer_parts is not None else None - - # Use CORS_ALLOW_INSECURE *only* for development and testing environments; - # it should never be enabled in production. - if not getattr(settings, 'CORS_ALLOW_INSECURE', False): - if not request.is_secure(): - log.debug( - u"Request is not secure, so we cannot send the CSRF token. " - u"For testing purposes, you can disable this check by setting " - u"`CORS_ALLOW_INSECURE` to True in the settings" - ) - return False - - if not referer: - log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.") - return False - - if not referer_parts.scheme == 'https': - log.debug(u"Referer '%s' must have the scheme 'https'") - return False - - domain_is_whitelisted = ( - getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or - referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', []) - ) - if not domain_is_whitelisted: - if referer_hostname is None: - # If no referer is specified, we can't check if it's a cross-domain - # request or not. - log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.") - elif referer_hostname != request.get_host(): - log.warning( - ( - u"Domain '%s' is not on the cross domain whitelist. " - u"Add the domain to `CORS_ORIGIN_WHITELIST` or set " - u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings." - ), referer_hostname - ) - else: - log.debug( - ( - u"Domain '%s' is the same as the hostname in the request, " - u"so we are not going to treat it as a cross-domain request." - ), referer_hostname - ) - return False - - return True - - class CorsCSRFMiddleware(CsrfViewMiddleware): """ Middleware for handling CSRF checks with CORS requests @@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware): log.debug("Could not disable CSRF middleware referer check for cross-domain request.") return - is_secure_default = request.is_secure - - def is_secure_patched(): - """ - Avoid triggering the additional CSRF middleware checks on the referrer - """ - return False - request.is_secure = is_secure_patched - - res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs) - request.is_secure = is_secure_default - return res + with skip_cross_domain_referer_check(request): + return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs) class CsrfCrossDomainCookieMiddleware(object): diff --git a/common/djangoapps/cors_csrf/tests/test_authentication.py b/common/djangoapps/cors_csrf/tests/test_authentication.py new file mode 100644 index 0000000000..7f2d78fd8a --- /dev/null +++ b/common/djangoapps/cors_csrf/tests/test_authentication.py @@ -0,0 +1,56 @@ +"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication.""" +from mock import patch + +from django.test import TestCase +from django.test.utils import override_settings +from django.test.client import RequestFactory +from django.conf import settings + +from rest_framework.exceptions import AuthenticationFailed + +from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf + + +class CrossDomainAuthTest(TestCase): + """Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """ + + URL = "/dummy_url" + REFERER = "https://www.edx.org" + CSRF_TOKEN = 'abcd1234' + + def setUp(self): + super(CrossDomainAuthTest, self).setUp() + self.auth = SessionAuthenticationCrossDomainCsrf() + + def test_perform_csrf_referer_check(self): + request = self._fake_request() + with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'): + self.auth.enforce_csrf(request) + + @patch.dict(settings.FEATURES, { + 'ENABLE_CORS_HEADERS': True, + 'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True + }) + @override_settings( + CORS_ORIGIN_WHITELIST=["www.edx.org"], + CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken", + CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org" + ) + def test_skip_csrf_referer_check(self): + request = self._fake_request() + result = self.auth.enforce_csrf(request) + self.assertIs(result, None) + self.assertTrue(request.is_secure()) + + def _fake_request(self): + """Construct a fake request with a referer and CSRF token over a secure connection. """ + factory = RequestFactory() + factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN + + request = factory.post( + self.URL, + HTTP_REFERER=self.REFERER, + HTTP_X_CSRFTOKEN=self.CSRF_TOKEN + ) + request.is_secure = lambda: True + return request diff --git a/common/djangoapps/enrollment/tests/test_views.py b/common/djangoapps/enrollment/tests/test_views.py index 948b3e1cc4..a052c94522 100644 --- a/common/djangoapps/enrollment/tests/test_views.py +++ b/common/djangoapps/enrollment/tests/test_views.py @@ -6,6 +6,8 @@ import json import unittest from mock import patch +from django.test import Client +from django.core.handlers.wsgi import WSGIRequest from django.core.urlresolvers import reverse from rest_framework.test import APITestCase from rest_framework import status @@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase): url = reverse('courseenrollments') resp = self.client.get(url) return json.loads(resp.content) + + +def cross_domain_config(func): + """Decorator for configuring a cross-domain request. """ + feature_flag_decorator = patch.dict(settings.FEATURES, { + 'ENABLE_CORS_HEADERS': True, + 'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True + }) + settings_decorator = override_settings( + CORS_ORIGIN_WHITELIST=["www.edx.org"], + CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken", + CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org" + ) + is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True) + + return feature_flag_decorator( + settings_decorator( + is_secure_decorator(func) + ) + ) + + +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class EnrollmentCrossDomainTest(ModuleStoreTestCase): + """Test cross-domain calls to the enrollment end-points. """ + + USERNAME = "Bob" + EMAIL = "bob@example.com" + PASSWORD = "edx" + REFERER = "https://www.edx.org" + + def setUp(self): + """ Create a course and user, then log in. """ + super(EnrollmentCrossDomainTest, self).setUp() + self.course = CourseFactory.create() + self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) + + self.client = Client(enforce_csrf_checks=True) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + + @cross_domain_config + def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument + csrf_cookie = self._get_csrf_cookie() + resp = self._cross_domain_post(csrf_cookie) + + # Expect that the request gets through successfully, + # passing the CSRF checks (including the referer check). + self.assertEqual(resp.status_code, 200) + + @cross_domain_config + def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument + resp = self._cross_domain_post('invalid_csrf_token') + self.assertEqual(resp.status_code, 401) + + def _get_csrf_cookie(self): + """Retrieve the cross-domain CSRF cookie. """ + url = reverse('courseenrollment', kwargs={ + 'course_id': unicode(self.course.id) + }) + resp = self.client.get(url, HTTP_REFERER=self.REFERER) + self.assertEqual(resp.status_code, 200) + self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member + return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member + + def _cross_domain_post(self, csrf_cookie): + """Perform a cross-domain POST request. """ + url = reverse('courseenrollments') + params = json.dumps({ + 'course_details': { + 'course_id': unicode(self.course.id), + }, + 'user': self.user.username + }) + return self.client.post( + url, params, content_type='application/json', + HTTP_REFERER=self.REFERER, + HTTP_X_CSRFTOKEN=csrf_cookie + ) diff --git a/common/djangoapps/enrollment/views.py b/common/djangoapps/enrollment/views.py index e2bb848d2c..4143a935be 100644 --- a/common/djangoapps/enrollment/views.py +++ b/common/djangoapps/enrollment/views.py @@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle from rest_framework.views import APIView from opaque_keys.edx.keys import CourseKey from embargo import api as embargo_api +from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from cors_csrf.decorators import ensure_csrf_cookie_cross_domain from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser from util.disable_rate_limit import can_disable_rate_limit @@ -24,6 +25,11 @@ from enrollment.errors import ( ) +class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf): + """Session authentication that allows inactive users and cross-domain requests. """ + pass + + class EnrollmentUserThrottle(UserRateThrottle): """Limit the number of requests users can make to the enrollment API.""" # TODO Limit significantly after performance testing. # pylint: disable=fixme @@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): * user: The ID of the user. """ - authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser + authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth permission_classes = ApiKeyHeaderPermissionIsAuthenticated, throttle_classes = EnrollmentUserThrottle,