Merge pull request #7371 from edx/will/django-rest-framework-cors-csrf
Skip CSRF referer check for cross-domain requests.
This commit is contained in:
29
common/djangoapps/cors_csrf/authentication.py
Normal file
29
common/djangoapps/cors_csrf/authentication.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Django Rest Framework Authentication classes for cross-domain end-points."""
|
||||
from rest_framework import authentication
|
||||
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
|
||||
|
||||
|
||||
class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication):
|
||||
"""Session authentication that skips the referer check over secure connections.
|
||||
|
||||
Django Rest Framework's `SessionAuthentication` class calls Django's
|
||||
CSRF middleware implementation directly, which bypasses the middleware
|
||||
stack.
|
||||
|
||||
This version of `SessionAuthentication` performs the same workaround
|
||||
as `CorsCSRFMiddleware` to skip the referer check for whitelisted
|
||||
domains over a secure connection. See `cors_csrf.middleware` for
|
||||
more information.
|
||||
|
||||
Since this subclass overrides only the `enforce_csrf()` method,
|
||||
it can be mixed in with other `SessionAuthentication` subclasses.
|
||||
|
||||
"""
|
||||
|
||||
def enforce_csrf(self, request):
|
||||
"""Skip the referer check if the cross-domain request is allowed. """
|
||||
if is_cross_domain_request_allowed(request):
|
||||
with skip_cross_domain_referer_check(request):
|
||||
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
|
||||
else:
|
||||
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
|
||||
92
common/djangoapps/cors_csrf/helpers.py
Normal file
92
common/djangoapps/cors_csrf/helpers.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Helper methods for CORS and CSRF checks. """
|
||||
import logging
|
||||
import urlparse
|
||||
import contextlib
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_cross_domain_request_allowed(request):
|
||||
"""Check whether we should allow the cross-domain request.
|
||||
|
||||
We allow a cross-domain request only if:
|
||||
|
||||
1) The request is made securely and the referer has "https://" as the protocol.
|
||||
2) The referer domain has been whitelisted.
|
||||
|
||||
Arguments:
|
||||
request (HttpRequest)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
"""
|
||||
referer = request.META.get('HTTP_REFERER')
|
||||
referer_parts = urlparse.urlparse(referer) if referer else None
|
||||
referer_hostname = referer_parts.hostname if referer_parts is not None else None
|
||||
|
||||
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
|
||||
# it should never be enabled in production.
|
||||
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
|
||||
if not request.is_secure():
|
||||
log.debug(
|
||||
u"Request is not secure, so we cannot send the CSRF token. "
|
||||
u"For testing purposes, you can disable this check by setting "
|
||||
u"`CORS_ALLOW_INSECURE` to True in the settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if not referer:
|
||||
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
|
||||
return False
|
||||
|
||||
if not referer_parts.scheme == 'https':
|
||||
log.debug(u"Referer '%s' must have the scheme 'https'")
|
||||
return False
|
||||
|
||||
domain_is_whitelisted = (
|
||||
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
|
||||
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
|
||||
)
|
||||
if not domain_is_whitelisted:
|
||||
if referer_hostname is None:
|
||||
# If no referer is specified, we can't check if it's a cross-domain
|
||||
# request or not.
|
||||
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
|
||||
elif referer_hostname != request.get_host():
|
||||
log.warning(
|
||||
(
|
||||
u"Domain '%s' is not on the cross domain whitelist. "
|
||||
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
|
||||
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
|
||||
), referer_hostname
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
(
|
||||
u"Domain '%s' is the same as the hostname in the request, "
|
||||
u"so we are not going to treat it as a cross-domain request."
|
||||
), referer_hostname
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_cross_domain_referer_check(request):
|
||||
"""Skip the cross-domain CSRF referer check.
|
||||
|
||||
Django's CSRF middleware performs the referer check
|
||||
only when the request is made over a secure connection.
|
||||
To skip the check, we patch `request.is_secure()` to
|
||||
False.
|
||||
"""
|
||||
is_secure_default = request.is_secure
|
||||
request.is_secure = lambda: False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
request.is_secure = is_secure_default
|
||||
@@ -43,82 +43,16 @@ CSRF cookie.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.middleware.csrf import CsrfViewMiddleware
|
||||
from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured
|
||||
|
||||
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_cross_domain_request_allowed(request):
|
||||
"""Check whether we should allow the cross-domain request.
|
||||
|
||||
We allow a cross-domain request only if:
|
||||
|
||||
1) The request is made securely and the referer has "https://" as the protocol.
|
||||
2) The referer domain has been whitelisted.
|
||||
|
||||
Arguments:
|
||||
request (HttpRequest)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
"""
|
||||
referer = request.META.get('HTTP_REFERER')
|
||||
referer_parts = urlparse.urlparse(referer) if referer else None
|
||||
referer_hostname = referer_parts.hostname if referer_parts is not None else None
|
||||
|
||||
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
|
||||
# it should never be enabled in production.
|
||||
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
|
||||
if not request.is_secure():
|
||||
log.debug(
|
||||
u"Request is not secure, so we cannot send the CSRF token. "
|
||||
u"For testing purposes, you can disable this check by setting "
|
||||
u"`CORS_ALLOW_INSECURE` to True in the settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if not referer:
|
||||
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
|
||||
return False
|
||||
|
||||
if not referer_parts.scheme == 'https':
|
||||
log.debug(u"Referer '%s' must have the scheme 'https'")
|
||||
return False
|
||||
|
||||
domain_is_whitelisted = (
|
||||
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
|
||||
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
|
||||
)
|
||||
if not domain_is_whitelisted:
|
||||
if referer_hostname is None:
|
||||
# If no referer is specified, we can't check if it's a cross-domain
|
||||
# request or not.
|
||||
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
|
||||
elif referer_hostname != request.get_host():
|
||||
log.warning(
|
||||
(
|
||||
u"Domain '%s' is not on the cross domain whitelist. "
|
||||
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
|
||||
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
|
||||
), referer_hostname
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
(
|
||||
u"Domain '%s' is the same as the hostname in the request, "
|
||||
u"so we are not going to treat it as a cross-domain request."
|
||||
), referer_hostname
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class CorsCSRFMiddleware(CsrfViewMiddleware):
|
||||
"""
|
||||
Middleware for handling CSRF checks with CORS requests
|
||||
@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware):
|
||||
log.debug("Could not disable CSRF middleware referer check for cross-domain request.")
|
||||
return
|
||||
|
||||
is_secure_default = request.is_secure
|
||||
|
||||
def is_secure_patched():
|
||||
"""
|
||||
Avoid triggering the additional CSRF middleware checks on the referrer
|
||||
"""
|
||||
return False
|
||||
request.is_secure = is_secure_patched
|
||||
|
||||
res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
|
||||
request.is_secure = is_secure_default
|
||||
return res
|
||||
with skip_cross_domain_referer_check(request):
|
||||
return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
|
||||
|
||||
|
||||
class CsrfCrossDomainCookieMiddleware(object):
|
||||
|
||||
56
common/djangoapps/cors_csrf/tests/test_authentication.py
Normal file
56
common/djangoapps/cors_csrf/tests/test_authentication.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication."""
|
||||
from mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.test.client import RequestFactory
|
||||
from django.conf import settings
|
||||
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
|
||||
|
||||
|
||||
class CrossDomainAuthTest(TestCase):
|
||||
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """
|
||||
|
||||
URL = "/dummy_url"
|
||||
REFERER = "https://www.edx.org"
|
||||
CSRF_TOKEN = 'abcd1234'
|
||||
|
||||
def setUp(self):
|
||||
super(CrossDomainAuthTest, self).setUp()
|
||||
self.auth = SessionAuthenticationCrossDomainCsrf()
|
||||
|
||||
def test_perform_csrf_referer_check(self):
|
||||
request = self._fake_request()
|
||||
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
|
||||
self.auth.enforce_csrf(request)
|
||||
|
||||
@patch.dict(settings.FEATURES, {
|
||||
'ENABLE_CORS_HEADERS': True,
|
||||
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
|
||||
})
|
||||
@override_settings(
|
||||
CORS_ORIGIN_WHITELIST=["www.edx.org"],
|
||||
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
|
||||
)
|
||||
def test_skip_csrf_referer_check(self):
|
||||
request = self._fake_request()
|
||||
result = self.auth.enforce_csrf(request)
|
||||
self.assertIs(result, None)
|
||||
self.assertTrue(request.is_secure())
|
||||
|
||||
def _fake_request(self):
|
||||
"""Construct a fake request with a referer and CSRF token over a secure connection. """
|
||||
factory = RequestFactory()
|
||||
factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN
|
||||
|
||||
request = factory.post(
|
||||
self.URL,
|
||||
HTTP_REFERER=self.REFERER,
|
||||
HTTP_X_CSRFTOKEN=self.CSRF_TOKEN
|
||||
)
|
||||
request.is_secure = lambda: True
|
||||
return request
|
||||
@@ -6,6 +6,8 @@ import json
|
||||
import unittest
|
||||
|
||||
from mock import patch
|
||||
from django.test import Client
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.urlresolvers import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
from rest_framework import status
|
||||
@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
|
||||
url = reverse('courseenrollments')
|
||||
resp = self.client.get(url)
|
||||
return json.loads(resp.content)
|
||||
|
||||
|
||||
def cross_domain_config(func):
|
||||
"""Decorator for configuring a cross-domain request. """
|
||||
feature_flag_decorator = patch.dict(settings.FEATURES, {
|
||||
'ENABLE_CORS_HEADERS': True,
|
||||
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
|
||||
})
|
||||
settings_decorator = override_settings(
|
||||
CORS_ORIGIN_WHITELIST=["www.edx.org"],
|
||||
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
|
||||
)
|
||||
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
|
||||
|
||||
return feature_flag_decorator(
|
||||
settings_decorator(
|
||||
is_secure_decorator(func)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class EnrollmentCrossDomainTest(ModuleStoreTestCase):
|
||||
"""Test cross-domain calls to the enrollment end-points. """
|
||||
|
||||
USERNAME = "Bob"
|
||||
EMAIL = "bob@example.com"
|
||||
PASSWORD = "edx"
|
||||
REFERER = "https://www.edx.org"
|
||||
|
||||
def setUp(self):
|
||||
""" Create a course and user, then log in. """
|
||||
super(EnrollmentCrossDomainTest, self).setUp()
|
||||
self.course = CourseFactory.create()
|
||||
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
self.client = Client(enforce_csrf_checks=True)
|
||||
self.client.login(username=self.USERNAME, password=self.PASSWORD)
|
||||
|
||||
@cross_domain_config
|
||||
def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument
|
||||
csrf_cookie = self._get_csrf_cookie()
|
||||
resp = self._cross_domain_post(csrf_cookie)
|
||||
|
||||
# Expect that the request gets through successfully,
|
||||
# passing the CSRF checks (including the referer check).
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
|
||||
@cross_domain_config
|
||||
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
|
||||
resp = self._cross_domain_post('invalid_csrf_token')
|
||||
self.assertEqual(resp.status_code, 401)
|
||||
|
||||
def _get_csrf_cookie(self):
|
||||
"""Retrieve the cross-domain CSRF cookie. """
|
||||
url = reverse('courseenrollment', kwargs={
|
||||
'course_id': unicode(self.course.id)
|
||||
})
|
||||
resp = self.client.get(url, HTTP_REFERER=self.REFERER)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member
|
||||
return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member
|
||||
|
||||
def _cross_domain_post(self, csrf_cookie):
|
||||
"""Perform a cross-domain POST request. """
|
||||
url = reverse('courseenrollments')
|
||||
params = json.dumps({
|
||||
'course_details': {
|
||||
'course_id': unicode(self.course.id),
|
||||
},
|
||||
'user': self.user.username
|
||||
})
|
||||
return self.client.post(
|
||||
url, params, content_type='application/json',
|
||||
HTTP_REFERER=self.REFERER,
|
||||
HTTP_X_CSRFTOKEN=csrf_cookie
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle
|
||||
from rest_framework.views import APIView
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
from embargo import api as embargo_api
|
||||
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
|
||||
from cors_csrf.decorators import ensure_csrf_cookie_cross_domain
|
||||
from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser
|
||||
from util.disable_rate_limit import can_disable_rate_limit
|
||||
@@ -24,6 +25,11 @@ from enrollment.errors import (
|
||||
)
|
||||
|
||||
|
||||
class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf):
|
||||
"""Session authentication that allows inactive users and cross-domain requests. """
|
||||
pass
|
||||
|
||||
|
||||
class EnrollmentUserThrottle(UserRateThrottle):
|
||||
"""Limit the number of requests users can make to the enrollment API."""
|
||||
# TODO Limit significantly after performance testing. # pylint: disable=fixme
|
||||
@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
|
||||
* user: The ID of the user.
|
||||
"""
|
||||
|
||||
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
|
||||
authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth
|
||||
permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
|
||||
throttle_classes = EnrollmentUserThrottle,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user