Skip CSRF referer check for cross-domain requests.
This commit extends the workaround in `cors_csrf` middleware to Django Rest Framework's SessionAuthentication, which calls Django's CSRF middleware directly. The workaround checks the cross domain whitelist and skips the CSRF referer check for domains on the whitelist.
This commit is contained in:
29
common/djangoapps/cors_csrf/authentication.py
Normal file
29
common/djangoapps/cors_csrf/authentication.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Django Rest Framework Authentication classes for cross-domain end-points."""
|
||||
from rest_framework import authentication
|
||||
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
|
||||
|
||||
|
||||
class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication):
|
||||
"""Session authentication that skips the referer check over secure connections.
|
||||
|
||||
Django Rest Framework's `SessionAuthentication` class calls Django's
|
||||
CSRF middleware implementation directly, which bypasses the middleware
|
||||
stack.
|
||||
|
||||
This version of `SessionAuthentication` performs the same workaround
|
||||
as `CorsCSRFMiddleware` to skip the referer check for whitelisted
|
||||
domains over a secure connection. See `cors_csrf.middleware` for
|
||||
more information.
|
||||
|
||||
Since this subclass overrides only the `enforce_csrf()` method,
|
||||
it can be mixed in with other `SessionAuthentication` subclasses.
|
||||
|
||||
"""
|
||||
|
||||
def enforce_csrf(self, request):
|
||||
"""Skip the referer check if the cross-domain request is allowed. """
|
||||
if is_cross_domain_request_allowed(request):
|
||||
with skip_cross_domain_referer_check(request):
|
||||
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
|
||||
else:
|
||||
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
|
||||
92
common/djangoapps/cors_csrf/helpers.py
Normal file
92
common/djangoapps/cors_csrf/helpers.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Helper methods for CORS and CSRF checks. """
|
||||
import logging
|
||||
import urlparse
|
||||
import contextlib
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_cross_domain_request_allowed(request):
|
||||
"""Check whether we should allow the cross-domain request.
|
||||
|
||||
We allow a cross-domain request only if:
|
||||
|
||||
1) The request is made securely and the referer has "https://" as the protocol.
|
||||
2) The referer domain has been whitelisted.
|
||||
|
||||
Arguments:
|
||||
request (HttpRequest)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
"""
|
||||
referer = request.META.get('HTTP_REFERER')
|
||||
referer_parts = urlparse.urlparse(referer) if referer else None
|
||||
referer_hostname = referer_parts.hostname if referer_parts is not None else None
|
||||
|
||||
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
|
||||
# it should never be enabled in production.
|
||||
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
|
||||
if not request.is_secure():
|
||||
log.debug(
|
||||
u"Request is not secure, so we cannot send the CSRF token. "
|
||||
u"For testing purposes, you can disable this check by setting "
|
||||
u"`CORS_ALLOW_INSECURE` to True in the settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if not referer:
|
||||
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
|
||||
return False
|
||||
|
||||
if not referer_parts.scheme == 'https':
|
||||
log.debug(u"Referer '%s' must have the scheme 'https'")
|
||||
return False
|
||||
|
||||
domain_is_whitelisted = (
|
||||
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
|
||||
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
|
||||
)
|
||||
if not domain_is_whitelisted:
|
||||
if referer_hostname is None:
|
||||
# If no referer is specified, we can't check if it's a cross-domain
|
||||
# request or not.
|
||||
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
|
||||
elif referer_hostname != request.get_host():
|
||||
log.warning(
|
||||
(
|
||||
u"Domain '%s' is not on the cross domain whitelist. "
|
||||
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
|
||||
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
|
||||
), referer_hostname
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
(
|
||||
u"Domain '%s' is the same as the hostname in the request, "
|
||||
u"so we are not going to treat it as a cross-domain request."
|
||||
), referer_hostname
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_cross_domain_referer_check(request):
|
||||
"""Skip the cross-domain CSRF referer check.
|
||||
|
||||
Django's CSRF middleware performs the referer check
|
||||
only when the request is made over a secure connection.
|
||||
To skip the check, we patch `request.is_secure()` to
|
||||
False.
|
||||
"""
|
||||
is_secure_default = request.is_secure
|
||||
request.is_secure = lambda: False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
request.is_secure = is_secure_default
|
||||
@@ -43,82 +43,16 @@ CSRF cookie.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.middleware.csrf import CsrfViewMiddleware
|
||||
from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured
|
||||
|
||||
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_cross_domain_request_allowed(request):
|
||||
"""Check whether we should allow the cross-domain request.
|
||||
|
||||
We allow a cross-domain request only if:
|
||||
|
||||
1) The request is made securely and the referer has "https://" as the protocol.
|
||||
2) The referer domain has been whitelisted.
|
||||
|
||||
Arguments:
|
||||
request (HttpRequest)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
"""
|
||||
referer = request.META.get('HTTP_REFERER')
|
||||
referer_parts = urlparse.urlparse(referer) if referer else None
|
||||
referer_hostname = referer_parts.hostname if referer_parts is not None else None
|
||||
|
||||
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
|
||||
# it should never be enabled in production.
|
||||
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
|
||||
if not request.is_secure():
|
||||
log.debug(
|
||||
u"Request is not secure, so we cannot send the CSRF token. "
|
||||
u"For testing purposes, you can disable this check by setting "
|
||||
u"`CORS_ALLOW_INSECURE` to True in the settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if not referer:
|
||||
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
|
||||
return False
|
||||
|
||||
if not referer_parts.scheme == 'https':
|
||||
log.debug(u"Referer '%s' must have the scheme 'https'")
|
||||
return False
|
||||
|
||||
domain_is_whitelisted = (
|
||||
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
|
||||
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
|
||||
)
|
||||
if not domain_is_whitelisted:
|
||||
if referer_hostname is None:
|
||||
# If no referer is specified, we can't check if it's a cross-domain
|
||||
# request or not.
|
||||
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
|
||||
elif referer_hostname != request.get_host():
|
||||
log.warning(
|
||||
(
|
||||
u"Domain '%s' is not on the cross domain whitelist. "
|
||||
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
|
||||
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
|
||||
), referer_hostname
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
(
|
||||
u"Domain '%s' is the same as the hostname in the request, "
|
||||
u"so we are not going to treat it as a cross-domain request."
|
||||
), referer_hostname
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class CorsCSRFMiddleware(CsrfViewMiddleware):
|
||||
"""
|
||||
Middleware for handling CSRF checks with CORS requests
|
||||
@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware):
|
||||
log.debug("Could not disable CSRF middleware referer check for cross-domain request.")
|
||||
return
|
||||
|
||||
is_secure_default = request.is_secure
|
||||
|
||||
def is_secure_patched():
|
||||
"""
|
||||
Avoid triggering the additional CSRF middleware checks on the referrer
|
||||
"""
|
||||
return False
|
||||
request.is_secure = is_secure_patched
|
||||
|
||||
res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
|
||||
request.is_secure = is_secure_default
|
||||
return res
|
||||
with skip_cross_domain_referer_check(request):
|
||||
return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
|
||||
|
||||
|
||||
class CsrfCrossDomainCookieMiddleware(object):
|
||||
|
||||
56
common/djangoapps/cors_csrf/tests/test_authentication.py
Normal file
56
common/djangoapps/cors_csrf/tests/test_authentication.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication."""
|
||||
from mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.test.client import RequestFactory
|
||||
from django.conf import settings
|
||||
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
|
||||
|
||||
|
||||
class CrossDomainAuthTest(TestCase):
|
||||
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """
|
||||
|
||||
URL = "/dummy_url"
|
||||
REFERER = "https://www.edx.org"
|
||||
CSRF_TOKEN = 'abcd1234'
|
||||
|
||||
def setUp(self):
|
||||
super(CrossDomainAuthTest, self).setUp()
|
||||
self.auth = SessionAuthenticationCrossDomainCsrf()
|
||||
|
||||
def test_perform_csrf_referer_check(self):
|
||||
request = self._fake_request()
|
||||
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
|
||||
self.auth.enforce_csrf(request)
|
||||
|
||||
@patch.dict(settings.FEATURES, {
|
||||
'ENABLE_CORS_HEADERS': True,
|
||||
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
|
||||
})
|
||||
@override_settings(
|
||||
CORS_ORIGIN_WHITELIST=["www.edx.org"],
|
||||
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
|
||||
)
|
||||
def test_skip_csrf_referer_check(self):
|
||||
request = self._fake_request()
|
||||
result = self.auth.enforce_csrf(request)
|
||||
self.assertIs(result, None)
|
||||
self.assertTrue(request.is_secure())
|
||||
|
||||
def _fake_request(self):
|
||||
"""Construct a fake request with a referer and CSRF token over a secure connection. """
|
||||
factory = RequestFactory()
|
||||
factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN
|
||||
|
||||
request = factory.post(
|
||||
self.URL,
|
||||
HTTP_REFERER=self.REFERER,
|
||||
HTTP_X_CSRFTOKEN=self.CSRF_TOKEN
|
||||
)
|
||||
request.is_secure = lambda: True
|
||||
return request
|
||||
@@ -6,6 +6,8 @@ import json
|
||||
import unittest
|
||||
|
||||
from mock import patch
|
||||
from django.test import Client
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.urlresolvers import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
from rest_framework import status
|
||||
@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
|
||||
url = reverse('courseenrollments')
|
||||
resp = self.client.get(url)
|
||||
return json.loads(resp.content)
|
||||
|
||||
|
||||
def cross_domain_config(func):
|
||||
"""Decorator for configuring a cross-domain request. """
|
||||
feature_flag_decorator = patch.dict(settings.FEATURES, {
|
||||
'ENABLE_CORS_HEADERS': True,
|
||||
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
|
||||
})
|
||||
settings_decorator = override_settings(
|
||||
CORS_ORIGIN_WHITELIST=["www.edx.org"],
|
||||
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
|
||||
)
|
||||
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
|
||||
|
||||
return feature_flag_decorator(
|
||||
settings_decorator(
|
||||
is_secure_decorator(func)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class EnrollmentCrossDomainTest(ModuleStoreTestCase):
|
||||
"""Test cross-domain calls to the enrollment end-points. """
|
||||
|
||||
USERNAME = "Bob"
|
||||
EMAIL = "bob@example.com"
|
||||
PASSWORD = "edx"
|
||||
REFERER = "https://www.edx.org"
|
||||
|
||||
def setUp(self):
|
||||
""" Create a course and user, then log in. """
|
||||
super(EnrollmentCrossDomainTest, self).setUp()
|
||||
self.course = CourseFactory.create()
|
||||
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
self.client = Client(enforce_csrf_checks=True)
|
||||
self.client.login(username=self.USERNAME, password=self.PASSWORD)
|
||||
|
||||
@cross_domain_config
|
||||
def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument
|
||||
csrf_cookie = self._get_csrf_cookie()
|
||||
resp = self._cross_domain_post(csrf_cookie)
|
||||
|
||||
# Expect that the request gets through successfully,
|
||||
# passing the CSRF checks (including the referer check).
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
|
||||
@cross_domain_config
|
||||
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
|
||||
resp = self._cross_domain_post('invalid_csrf_token')
|
||||
self.assertEqual(resp.status_code, 401)
|
||||
|
||||
def _get_csrf_cookie(self):
|
||||
"""Retrieve the cross-domain CSRF cookie. """
|
||||
url = reverse('courseenrollment', kwargs={
|
||||
'course_id': unicode(self.course.id)
|
||||
})
|
||||
resp = self.client.get(url, HTTP_REFERER=self.REFERER)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member
|
||||
return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member
|
||||
|
||||
def _cross_domain_post(self, csrf_cookie):
|
||||
"""Perform a cross-domain POST request. """
|
||||
url = reverse('courseenrollments')
|
||||
params = json.dumps({
|
||||
'course_details': {
|
||||
'course_id': unicode(self.course.id),
|
||||
},
|
||||
'user': self.user.username
|
||||
})
|
||||
return self.client.post(
|
||||
url, params, content_type='application/json',
|
||||
HTTP_REFERER=self.REFERER,
|
||||
HTTP_X_CSRFTOKEN=csrf_cookie
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle
|
||||
from rest_framework.views import APIView
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
from embargo import api as embargo_api
|
||||
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
|
||||
from cors_csrf.decorators import ensure_csrf_cookie_cross_domain
|
||||
from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser
|
||||
from util.disable_rate_limit import can_disable_rate_limit
|
||||
@@ -24,6 +25,11 @@ from enrollment.errors import (
|
||||
)
|
||||
|
||||
|
||||
class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf):
|
||||
"""Session authentication that allows inactive users and cross-domain requests. """
|
||||
pass
|
||||
|
||||
|
||||
class EnrollmentUserThrottle(UserRateThrottle):
|
||||
"""Limit the number of requests users can make to the enrollment API."""
|
||||
# TODO Limit significantly after performance testing. # pylint: disable=fixme
|
||||
@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
|
||||
* user: The ID of the user.
|
||||
"""
|
||||
|
||||
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
|
||||
authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth
|
||||
permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
|
||||
throttle_classes = EnrollmentUserThrottle,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user