add tests
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
import urllib
|
||||
import unittest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.test.utils import override_settings
|
||||
from mock import patch
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from experiments.factories import ExperimentDataFactory, ExperimentKeyValueFactory
|
||||
@@ -9,7 +14,32 @@ from experiments.serializers import ExperimentDataSerializer
|
||||
from student.tests.factories import UserFactory
|
||||
|
||||
|
||||
CROSS_DOMAIN_REFERER = 'https://ecommerce.edx.org'
|
||||
|
||||
|
||||
def cross_domain_config(func):
|
||||
"""Decorator for configuring a cross-domain request. """
|
||||
feature_flag_decorator = patch.dict(settings.FEATURES, {
|
||||
'ENABLE_CORS_HEADERS': True,
|
||||
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
|
||||
})
|
||||
settings_decorator = override_settings(
|
||||
CORS_ORIGIN_WHITELIST=['ecommerce.edx.org'],
|
||||
CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
|
||||
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
|
||||
)
|
||||
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
|
||||
|
||||
return feature_flag_decorator(
|
||||
settings_decorator(
|
||||
is_secure_decorator(func)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExperimentDataViewSetTests(APITestCase):
|
||||
|
||||
def assert_data_created_for_user(self, user, method='post', status=201):
|
||||
url = reverse('api_experiments:v0:data-list')
|
||||
data = {
|
||||
@@ -210,6 +240,79 @@ class ExperimentDataViewSetTests(APITestCase):
|
||||
ExperimentData.objects.get(user=other_user, **kwargs)
|
||||
|
||||
|
||||
class ExperimentCrossDomainTests(APITestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ExperimentCrossDomainTests, self).setUp()
|
||||
self.client = self.client_class(enforce_csrf_checks=True)
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@cross_domain_config
|
||||
def test_cross_domain_create(self, *args): # pylint: disable=unused-argument
|
||||
user = UserFactory()
|
||||
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
|
||||
csrf_cookie = self._get_csrf_cookie()
|
||||
data = {
|
||||
'experiment_id': 1,
|
||||
'key': 'foo',
|
||||
'value': 'bar',
|
||||
}
|
||||
resp = self._cross_domain_post(csrf_cookie, data)
|
||||
|
||||
# Expect that the request gets through successfully,
|
||||
# passing the CSRF checks (including the referer check).
|
||||
self.assertEqual(resp.status_code, 201)
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@cross_domain_config
|
||||
def test_cross_domain_invalid_csrf_header(self, *args): # pylint: disable=unused-argument
|
||||
user = UserFactory()
|
||||
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
|
||||
self._get_csrf_cookie()
|
||||
data = {
|
||||
'experiment_id': 1,
|
||||
'key': 'foo',
|
||||
'value': 'bar',
|
||||
}
|
||||
resp = self._cross_domain_post('invalid_csrf_token', data)
|
||||
self.assertEqual(resp.status_code, 403)
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@cross_domain_config
|
||||
def test_cross_domain_not_in_whitelist(self, *args): # pylint: disable=unused-argument
|
||||
user = UserFactory()
|
||||
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
|
||||
csrf_cookie = self._get_csrf_cookie()
|
||||
data = {
|
||||
'experiment_id': 1,
|
||||
'key': 'foo',
|
||||
'value': 'bar',
|
||||
}
|
||||
resp = self._cross_domain_post(csrf_cookie, data, referer='www.example.com')
|
||||
self.assertEqual(resp.status_code, 403)
|
||||
|
||||
def _get_csrf_cookie(self):
|
||||
"""Retrieve the cross-domain CSRF cookie. """
|
||||
url = reverse('courseenrollments')
|
||||
resp = self.client.get(url, HTTP_REFERER=CROSS_DOMAIN_REFERER)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn(settings.CSRF_COOKIE_NAME, resp.cookies) # pylint: disable=no-member
|
||||
return resp.cookies[settings.CSRF_COOKIE_NAME].value # pylint: disable=no-member
|
||||
|
||||
def _cross_domain_post(self, csrf_token, data, referer=CROSS_DOMAIN_REFERER):
|
||||
"""Perform a cross-domain POST request. """
|
||||
url = reverse('api_experiments:v0:data-list')
|
||||
kwargs = {
|
||||
'HTTP_REFERER': referer,
|
||||
settings.CSRF_HEADER_NAME: csrf_token,
|
||||
}
|
||||
return self.client.post(
|
||||
url,
|
||||
data,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class ExperimentKeyValueViewSetTests(APITestCase):
|
||||
def test_permissions(self):
|
||||
""" Staff access is required for write operations. """
|
||||
|
||||
Reference in New Issue
Block a user