From 4ca2692f5da346982268665ebec298aa466282fe Mon Sep 17 00:00:00 2001 From: Toby Lawrence Date: Thu, 3 Mar 2016 14:00:12 -0500 Subject: [PATCH 1/2] Rename 'clean_headers' to 'header_control'. This is part of adding the ability to forcefully set headers through the middleware in addition to removing specific headers. --- common/djangoapps/clean_headers/__init__.py | 15 ----- common/djangoapps/clean_headers/decorators.py | 36 ---------- common/djangoapps/clean_headers/middleware.py | 25 ------- .../clean_headers/tests/test_decorators.py | 20 ------ .../clean_headers/tests/test_middleware.py | 34 ---------- common/djangoapps/header_control/__init__.py | 21 ++++++ .../djangoapps/header_control/decorators.py | 67 +++++++++++++++++++ .../djangoapps/header_control/middleware.py | 34 ++++++++++ .../header_control/tests/test_decorators.py | 32 +++++++++ .../header_control/tests/test_middleware.py | 47 +++++++++++++ lms/envs/common.py | 2 +- 11 files changed, 202 insertions(+), 131 deletions(-) delete mode 100644 common/djangoapps/clean_headers/__init__.py delete mode 100644 common/djangoapps/clean_headers/decorators.py delete mode 100644 common/djangoapps/clean_headers/middleware.py delete mode 100644 common/djangoapps/clean_headers/tests/test_decorators.py delete mode 100644 common/djangoapps/clean_headers/tests/test_middleware.py create mode 100644 common/djangoapps/header_control/__init__.py create mode 100644 common/djangoapps/header_control/decorators.py create mode 100644 common/djangoapps/header_control/middleware.py create mode 100644 common/djangoapps/header_control/tests/test_decorators.py create mode 100644 common/djangoapps/header_control/tests/test_middleware.py diff --git a/common/djangoapps/clean_headers/__init__.py b/common/djangoapps/clean_headers/__init__.py deleted file mode 100644 index 0c718ec7ad..0000000000 --- a/common/djangoapps/clean_headers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -This middleware is used for cleaning headers from a response before it is sent to the end user. - -Due to the nature of how middleware runs, a piece of middleware high in the chain cannot ensure -that response headers won't be present on the final response body, as middleware further down -the chain could be adding them. - -This middleware is intended to sit as close as possible to the top of the list, so that it has -a chance on the reponse going out to strip the intended headers. -""" - - -def remove_headers_from_response(response, *headers): - """Removes the given headers from the response using the clean_headers middleware.""" - response.clean_headers = headers diff --git a/common/djangoapps/clean_headers/decorators.py b/common/djangoapps/clean_headers/decorators.py deleted file mode 100644 index dadba59619..0000000000 --- a/common/djangoapps/clean_headers/decorators.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Middleware decorator for removing headers. -""" - -from functools import wraps - - -def clean_headers(*headers): - """ - Decorator that removes any headers specified from the response. - Usage: - @clean_headers("Vary") - def myview(request): - ... - - The CleanHeadersMiddleware must be used and placed as closely as possible to the top - of the middleware chain, ideally after any caching middleware but before everything else. - - This decorator is not safe for multiple uses: each call will overwrite any previously set values. - """ - def _decorator(func): - """ - Decorates the given function. - """ - @wraps(func) - def _inner(*args, **kwargs): - """ - Alters the response. - """ - response = func(*args, **kwargs) - response.clean_headers = headers - return response - - return _inner - - return _decorator diff --git a/common/djangoapps/clean_headers/middleware.py b/common/djangoapps/clean_headers/middleware.py deleted file mode 100644 index d3f8e0be6a..0000000000 --- a/common/djangoapps/clean_headers/middleware.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Middleware used for cleaning headers from a response before it is sent to the end user. -""" - - -class CleanHeadersMiddleware(object): - """ - Middleware that can drop headers present in a response. - - This can be used, for example, to remove headers i.e. drop any Vary headers to improve cache performance. - """ - - def process_response(self, _request, response): - """ - Processes the given response, potentially stripping out any unwanted headers. - """ - - if len(getattr(response, 'clean_headers', [])) > 0: - for header in response.clean_headers: - try: - del response[header] - except KeyError: - pass - - return response diff --git a/common/djangoapps/clean_headers/tests/test_decorators.py b/common/djangoapps/clean_headers/tests/test_decorators.py deleted file mode 100644 index d9f0642405..0000000000 --- a/common/djangoapps/clean_headers/tests/test_decorators.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Tests for clean_headers decorator. """ -from django.http import HttpResponse, HttpRequest -from django.test import TestCase -from clean_headers.decorators import clean_headers - - -def fake_view(_request): - """Fake view that returns an empty response.""" - return HttpResponse() - - -class TestCleanHeaders(TestCase): - """Test the `clean_headers` decorator.""" - - def test_clean_headers(self): - request = HttpRequest() - wrapper = clean_headers('Vary', 'Accept-Encoding') - wrapped_view = wrapper(fake_view) - response = wrapped_view(request) - self.assertEqual(len(response.clean_headers), 2) diff --git a/common/djangoapps/clean_headers/tests/test_middleware.py b/common/djangoapps/clean_headers/tests/test_middleware.py deleted file mode 100644 index 3be79a6f11..0000000000 --- a/common/djangoapps/clean_headers/tests/test_middleware.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Tests for clean_headers middleware.""" -from django.http import HttpResponse, HttpRequest -from django.test import TestCase -from clean_headers.middleware import CleanHeadersMiddleware - - -class TestCleanHeadersMiddlewareProcessResponse(TestCase): - """Test the `clean_headers` middleware. """ - def setUp(self): - super(TestCleanHeadersMiddlewareProcessResponse, self).setUp() - self.middleware = CleanHeadersMiddleware() - - def test_cleans_intended_headers(self): - fake_request = HttpRequest() - - fake_response = HttpResponse() - fake_response['Vary'] = 'Cookie' - fake_response['Accept-Encoding'] = 'gzip' - fake_response.clean_headers = ['Vary'] - - result = self.middleware.process_response(fake_request, fake_response) - self.assertNotIn('Vary', result) - self.assertEquals('gzip', result['Accept-Encoding']) - - def test_does_not_mangle_undecorated_response(self): - fake_request = HttpRequest() - - fake_response = HttpResponse() - fake_response['Vary'] = 'Cookie' - fake_response['Accept-Encoding'] = 'gzip' - - result = self.middleware.process_response(fake_request, fake_response) - self.assertEquals('Cookie', result['Vary']) - self.assertEquals('gzip', result['Accept-Encoding']) diff --git a/common/djangoapps/header_control/__init__.py b/common/djangoapps/header_control/__init__.py new file mode 100644 index 0000000000..a33b8d2de9 --- /dev/null +++ b/common/djangoapps/header_control/__init__.py @@ -0,0 +1,21 @@ +""" +This middleware is used for adjusting the headers in a response before it is sent to the end user. + +This middleware is intended to sit as close as possible to the top of the middleare list as possible, +so that it is one of the last pieces of middleware to touch the response, and thus can most accurately +adjust/control the headers of the response. +""" + + +def remove_headers_from_response(response, *headers): + """Removes the given headers from the response using the header_control middleware.""" + response.remove_headers = headers + +def force_header_for_response(response, header, value): + """Forces the given header for the given response using the header_control middleware.""" + force_headers = {} + if hasattr(response, 'force_headers'): + force_headers = response.force_headers + force_headers[header] = value + + response.force_headers = force_headers diff --git a/common/djangoapps/header_control/decorators.py b/common/djangoapps/header_control/decorators.py new file mode 100644 index 0000000000..6a3ca9c221 --- /dev/null +++ b/common/djangoapps/header_control/decorators.py @@ -0,0 +1,67 @@ +""" +Middleware decorator for removing headers. +""" + +from functools import wraps +from header_control import remove_headers_from_response, force_header_for_response + +def remove_headers(*headers): + """ + Decorator that removes specific headers from the response. + Usage: + @remove_headers("Vary") + def myview(request): + ... + + The HeaderControlMiddleware must be used and placed as closely as possible to the top + of the middleware chain, ideally after any caching middleware but before everything else. + + This decorator is not safe for multiple uses: each call will overwrite any previously set values. + """ + def _decorator(func): + """ + Decorates the given function. + """ + @wraps(func) + def _inner(*args, **kwargs): + """ + Alters the response. + """ + response = func(*args, **kwargs) + remove_headers_from_response(response, *headers) + return response + + return _inner + + return _decorator + + +def force_header(header, value): + """ + Decorator that forces a header in the response to have a specific value. + Usage: + @force_header("Vary", "Origin") + def myview(request): + ... + + The HeaderControlMiddleware must be used and placed as closely as possible to the top + of the middleware chain, ideally after any caching middleware but before everything else. + + This decorator is not safe for multiple uses: each call will overwrite any previously set values. + """ + def _decorator(func): + """ + Decorates the given function. + """ + @wraps(func) + def _inner(*args, **kwargs): + """ + Alters the response. + """ + response = func(*args, **kwargs) + force_header_for_response(response, header, value) + return response + + return _inner + + return _decorator diff --git a/common/djangoapps/header_control/middleware.py b/common/djangoapps/header_control/middleware.py new file mode 100644 index 0000000000..c0f118fe81 --- /dev/null +++ b/common/djangoapps/header_control/middleware.py @@ -0,0 +1,34 @@ +""" +Middleware used for adjusting headers in a response before it is sent to the end user. +""" + + +class HeaderControlMiddleware(object): + """ + Middleware that can modify/remove headers in a response. + + This can be used, for example, to remove headers i.e. drop any Vary headers to improve cache performance. + """ + + def process_response(self, _request, response): + """ + Processes the given response, potentially remove or modifying headers. + """ + + if len(getattr(response, 'remove_headers', [])) > 0: + for header in response.remove_headers: + try: + del response[header] + except KeyError: + pass + + if len(getattr(response, 'force_headers', {})) > 0: + for header, value in response.force_headers.iteritems(): + try: + del response[header] + except KeyError: + pass + + response[header] = value + + return response diff --git a/common/djangoapps/header_control/tests/test_decorators.py b/common/djangoapps/header_control/tests/test_decorators.py new file mode 100644 index 0000000000..65b754ca1d --- /dev/null +++ b/common/djangoapps/header_control/tests/test_decorators.py @@ -0,0 +1,32 @@ +"""Tests for remove_headers and force_header decorator. """ +from django.http import HttpResponse, HttpRequest +from django.test import TestCase +from header_control.decorators import remove_headers, force_header + + +def fake_view(_request): + """Fake view that returns an empty response.""" + return HttpResponse() + + +class TestRemoveHeaders(TestCase): + """Test the `remove_headers` decorator.""" + + def test_remove_headers(self): + request = HttpRequest() + wrapper = remove_headers('Vary', 'Accept-Encoding') + wrapped_view = wrapper(fake_view) + response = wrapped_view(request) + self.assertEqual(len(response.remove_headers), 2) + + +class TestForceHeader(TestCase): + """Test the `force_header` decorator.""" + + def test_force_header(self): + request = HttpRequest() + wrapper = force_header('Vary', 'Origin') + wrapped_view = wrapper(fake_view) + response = wrapped_view(request) + self.assertEqual(len(response.force_headers), 1) + self.assertEqual(response.force_headers['Vary'], 'Origin') \ No newline at end of file diff --git a/common/djangoapps/header_control/tests/test_middleware.py b/common/djangoapps/header_control/tests/test_middleware.py new file mode 100644 index 0000000000..db3b77749f --- /dev/null +++ b/common/djangoapps/header_control/tests/test_middleware.py @@ -0,0 +1,47 @@ +"""Tests for header_control middleware.""" +from django.http import HttpResponse, HttpRequest +from django.test import TestCase +from header_control import remove_headers_from_response, force_header_for_response +from header_control.middleware import HeaderControlMiddleware + + +class TestHeaderControlMiddlewareProcessResponse(TestCase): + """Test the `header_control` middleware. """ + def setUp(self): + super(TestHeaderControlMiddlewareProcessResponse, self).setUp() + self.middleware = HeaderControlMiddleware() + + def test_removes_intended_headers(self): + fake_request = HttpRequest() + + fake_response = HttpResponse() + fake_response['Vary'] = 'Cookie' + fake_response['Accept-Encoding'] = 'gzip' + remove_headers_from_response(fake_response, 'Vary') + + result = self.middleware.process_response(fake_request, fake_response) + self.assertNotIn('Vary', result) + self.assertEquals('gzip', result['Accept-Encoding']) + + def test_forces_intended_header(self): + fake_request = HttpRequest() + + fake_response = HttpResponse() + fake_response['Vary'] = 'Cookie' + fake_response['Accept-Encoding'] = 'gzip' + force_header_for_response(fake_response, 'Vary', 'Origin') + + result = self.middleware.process_response(fake_request, fake_response) + self.assertEquals('Origin', result['Vary']) + self.assertEquals('gzip', result['Accept-Encoding']) + + def test_does_not_mangle_undecorated_response(self): + fake_request = HttpRequest() + + fake_response = HttpResponse() + fake_response['Vary'] = 'Cookie' + fake_response['Accept-Encoding'] = 'gzip' + + result = self.middleware.process_response(fake_request, fake_response) + self.assertEquals('Cookie', result['Vary']) + self.assertEquals('gzip', result['Accept-Encoding']) diff --git a/lms/envs/common.py b/lms/envs/common.py index c73eaae451..f14d9d406a 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1087,7 +1087,7 @@ simplefilter('ignore') MIDDLEWARE_CLASSES = ( 'request_cache.middleware.RequestCache', - 'clean_headers.middleware.CleanHeadersMiddleware', + 'header_control.middleware.HeaderControlMiddleware', 'microsite_configuration.middleware.MicrositeMiddleware', 'django_comment_client.middleware.AjaxExceptionMiddleware', 'django.middleware.common.CommonMiddleware', From 1a3464152f29933336edf8f1b4bd61e821360bad Mon Sep 17 00:00:00 2001 From: Toby Lawrence Date: Thu, 3 Mar 2016 14:25:29 -0500 Subject: [PATCH 2/2] Switch to header_control in contentserver. --- cms/envs/common.py | 2 +- common/djangoapps/contentserver/middleware.py | 7 +- .../contentserver/test/test_contentserver.py | 87 ++++++++++--------- common/djangoapps/header_control/__init__.py | 1 + .../djangoapps/header_control/decorators.py | 1 + .../djangoapps/header_control/middleware.py | 18 +--- .../header_control/tests/test_decorators.py | 2 +- .../header_control/tests/test_middleware.py | 23 +++++ 8 files changed, 80 insertions(+), 61 deletions(-) diff --git a/cms/envs/common.py b/cms/envs/common.py index 40dd6c5ac0..fcb455643b 100644 --- a/cms/envs/common.py +++ b/cms/envs/common.py @@ -311,7 +311,7 @@ simplefilter('ignore') MIDDLEWARE_CLASSES = ( 'request_cache.middleware.RequestCache', - 'clean_headers.middleware.CleanHeadersMiddleware', + 'header_control.middleware.HeaderControlMiddleware', 'django.middleware.cache.UpdateCacheMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', diff --git a/common/djangoapps/contentserver/middleware.py b/common/djangoapps/contentserver/middleware.py index 61528c214b..433967e1cb 100644 --- a/common/djangoapps/contentserver/middleware.py +++ b/common/djangoapps/contentserver/middleware.py @@ -11,7 +11,7 @@ from django.http import ( from student.models import CourseEnrollment from contentserver.models import CourseAssetCacheTtlConfig -from clean_headers import remove_headers_from_response +from header_control import force_header_for_response from xmodule.assetstore.assetmgr import AssetManager from xmodule.contentstore.content import StaticContent, XASSET_LOCATION_TAG from xmodule.modulestore import InvalidLocationError @@ -153,7 +153,10 @@ class StaticContentServer(object): response['Last-Modified'] = content.last_modified_at.strftime(HTTP_DATE_FORMAT) - remove_headers_from_response(response, "Vary") + # Force the Vary header to only vary responses on Origin, so that XHR and browser requests get cached + # separately and don't screw over one another. i.e. a browser request that doesn't send Origin, and + # caches a version of the response without CORS headers, in turn breaking XHR requests. + force_header_for_response(response, 'Vary', 'Origin') @staticmethod def get_expiration_value(now, cache_ttl): diff --git a/common/djangoapps/contentserver/test/test_contentserver.py b/common/djangoapps/contentserver/test/test_contentserver.py index 576640759f..d34e714b1f 100644 --- a/common/djangoapps/contentserver/test/test_contentserver.py +++ b/common/djangoapps/contentserver/test/test_contentserver.py @@ -16,12 +16,13 @@ from mock import patch from xmodule.contentstore.django import contentstore from xmodule.modulestore.django import modulestore -from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase +from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore.xml_importer import import_course_from_xml from contentserver.middleware import parse_range_header, HTTP_DATE_FORMAT, StaticContentServer from student.models import CourseEnrollment +from student.tests.factories import UserFactory, AdminFactory log = logging.getLogger(__name__) @@ -33,39 +34,44 @@ TEST_DATA_DIR = settings.COMMON_TEST_DATA_ROOT @ddt.ddt @override_settings(CONTENTSTORE=TEST_DATA_CONTENTSTORE) -class ContentStoreToyCourseTest(ModuleStoreTestCase): +class ContentStoreToyCourseTest(SharedModuleStoreTestCase): """ Tests that use the toy course. """ + @classmethod + def setUpClass(cls): + super(ContentStoreToyCourseTest, cls).setUpClass() + + cls.contentstore = contentstore() + cls.modulestore = modulestore() + + cls.course_key = cls.modulestore.make_course_key('edX', 'toy', '2012_Fall') + + import_course_from_xml( + cls.modulestore, 1, TEST_DATA_DIR, ['toy'], + static_content_store=cls.contentstore, verbose=True + ) + + # A locked asset + cls.locked_asset = cls.course_key.make_asset_key('asset', 'sample_static.txt') + cls.url_locked = unicode(cls.locked_asset) + cls.contentstore.set_attr(cls.locked_asset, 'locked', True) + + # An unlocked asset + cls.unlocked_asset = cls.course_key.make_asset_key('asset', 'another_static.txt') + cls.url_unlocked = unicode(cls.unlocked_asset) + cls.length_unlocked = cls.contentstore.get_attr(cls.unlocked_asset, 'length') + def setUp(self): """ Create user and login. """ - self.staff_pwd = super(ContentStoreToyCourseTest, self).setUp() - self.staff_usr = self.user - self.non_staff_usr, self.non_staff_pwd = self.create_non_staff_user() + super(ContentStoreToyCourseTest, self).setUp() + self.staff_usr = AdminFactory.create() + self.non_staff_usr = UserFactory.create() self.client = Client() - self.contentstore = contentstore() - store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo) # pylint: disable=protected-access - - self.course_key = store.make_course_key('edX', 'toy', '2012_Fall') - - import_course_from_xml( - store, self.user.id, TEST_DATA_DIR, ['toy'], - static_content_store=self.contentstore, verbose=True - ) - - # A locked asset - self.locked_asset = self.course_key.make_asset_key('asset', 'sample_static.txt') - self.url_locked = unicode(self.locked_asset) - self.contentstore.set_attr(self.locked_asset, 'locked', True) - - # An unlocked asset - self.unlocked_asset = self.course_key.make_asset_key('asset', 'another_static.txt') - self.url_unlocked = unicode(self.unlocked_asset) - self.length_unlocked = self.contentstore.get_attr(self.unlocked_asset, 'length') def test_unlocked_asset(self): """ @@ -89,7 +95,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): Test that locked assets behave appropriately in case user is logged in in but not registered for the course. """ - self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) + self.client.login(username=self.non_staff_usr, password='test') resp = self.client.get(self.url_locked) self.assertEqual(resp.status_code, 403) @@ -101,7 +107,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): CourseEnrollment.enroll(self.non_staff_usr, self.course_key) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) - self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) + self.client.login(username=self.non_staff_usr, password='test') resp = self.client.get(self.url_locked) self.assertEqual(resp.status_code, 200) @@ -109,7 +115,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): """ Test that locked assets behave appropriately in case user is staff. """ - self.client.login(username=self.staff_usr, password=self.staff_pwd) + self.client.login(username=self.staff_usr, password='test') resp = self.client.get(self.url_locked) self.assertEqual(resp.status_code, 200) @@ -191,6 +197,15 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): first=(self.length_unlocked), last=(self.length_unlocked))) self.assertEqual(resp.status_code, 416) + def test_vary_header_sent(self): + """ + Tests that we're properly setting the Vary header to ensure browser requests don't get + cached in a way that breaks XHR requests to the same asset. + """ + resp = self.client.get(self.url_unlocked) + self.assertEqual(resp.status_code, 200) + self.assertEquals('Origin', resp['Vary']) + @patch('contentserver.models.CourseAssetCacheTtlConfig.get_cache_ttl') def test_cache_headers_with_ttl_unlocked(self, mock_get_cache_ttl): """ @@ -215,7 +230,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): CourseEnrollment.enroll(self.non_staff_usr, self.course_key) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) - self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) + self.client.login(username=self.non_staff_usr, password='test') resp = self.client.get(self.url_locked) self.assertEqual(resp.status_code, 200) self.assertNotIn('Expires', resp) @@ -245,7 +260,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): CourseEnrollment.enroll(self.non_staff_usr, self.course_key) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) - self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) + self.client.login(username=self.non_staff_usr, password='test') resp = self.client.get(self.url_locked) self.assertEqual(resp.status_code, 200) self.assertNotIn('Expires', resp) @@ -256,20 +271,6 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): near_expire_dt = StaticContentServer.get_expiration_value(start_dt, 55) self.assertEqual("Thu, 01 Dec 1983 20:00:55 GMT", near_expire_dt) - def test_response_no_vary_header_unlocked(self): - resp = self.client.get(self.url_unlocked) - self.assertEqual(resp.status_code, 200) - self.assertNotIn('Vary', resp) - - def test_response_no_vary_header_locked(self): - CourseEnrollment.enroll(self.non_staff_usr, self.course_key) - self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) - - self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) - resp = self.client.get(self.url_locked) - self.assertEqual(resp.status_code, 200) - self.assertNotIn('Vary', resp) - @ddt.ddt class ParseRangeHeaderTestCase(unittest.TestCase): diff --git a/common/djangoapps/header_control/__init__.py b/common/djangoapps/header_control/__init__.py index a33b8d2de9..ad2208736d 100644 --- a/common/djangoapps/header_control/__init__.py +++ b/common/djangoapps/header_control/__init__.py @@ -11,6 +11,7 @@ def remove_headers_from_response(response, *headers): """Removes the given headers from the response using the header_control middleware.""" response.remove_headers = headers + def force_header_for_response(response, header, value): """Forces the given header for the given response using the header_control middleware.""" force_headers = {} diff --git a/common/djangoapps/header_control/decorators.py b/common/djangoapps/header_control/decorators.py index 6a3ca9c221..cb0d148ad6 100644 --- a/common/djangoapps/header_control/decorators.py +++ b/common/djangoapps/header_control/decorators.py @@ -5,6 +5,7 @@ Middleware decorator for removing headers. from functools import wraps from header_control import remove_headers_from_response, force_header_for_response + def remove_headers(*headers): """ Decorator that removes specific headers from the response. diff --git a/common/djangoapps/header_control/middleware.py b/common/djangoapps/header_control/middleware.py index c0f118fe81..9798c846e2 100644 --- a/common/djangoapps/header_control/middleware.py +++ b/common/djangoapps/header_control/middleware.py @@ -15,20 +15,10 @@ class HeaderControlMiddleware(object): Processes the given response, potentially remove or modifying headers. """ - if len(getattr(response, 'remove_headers', [])) > 0: - for header in response.remove_headers: - try: - del response[header] - except KeyError: - pass + for header in getattr(response, 'remove_headers', []): + del response[header] - if len(getattr(response, 'force_headers', {})) > 0: - for header, value in response.force_headers.iteritems(): - try: - del response[header] - except KeyError: - pass - - response[header] = value + for header, value in getattr(response, 'force_headers', {}).iteritems(): + response[header] = value return response diff --git a/common/djangoapps/header_control/tests/test_decorators.py b/common/djangoapps/header_control/tests/test_decorators.py index 65b754ca1d..cf5c1afbb0 100644 --- a/common/djangoapps/header_control/tests/test_decorators.py +++ b/common/djangoapps/header_control/tests/test_decorators.py @@ -29,4 +29,4 @@ class TestForceHeader(TestCase): wrapped_view = wrapper(fake_view) response = wrapped_view(request) self.assertEqual(len(response.force_headers), 1) - self.assertEqual(response.force_headers['Vary'], 'Origin') \ No newline at end of file + self.assertEqual(response.force_headers['Vary'], 'Origin') diff --git a/common/djangoapps/header_control/tests/test_middleware.py b/common/djangoapps/header_control/tests/test_middleware.py index db3b77749f..34e5e503c5 100644 --- a/common/djangoapps/header_control/tests/test_middleware.py +++ b/common/djangoapps/header_control/tests/test_middleware.py @@ -11,6 +11,29 @@ class TestHeaderControlMiddlewareProcessResponse(TestCase): super(TestHeaderControlMiddlewareProcessResponse, self).setUp() self.middleware = HeaderControlMiddleware() + def test_doesnt_barf_if_not_modifying_anything(self): + fake_request = HttpRequest() + + fake_response = HttpResponse() + fake_response['Vary'] = 'Cookie' + fake_response['Accept-Encoding'] = 'gzip' + + result = self.middleware.process_response(fake_request, fake_response) + self.assertEquals('Cookie', result['Vary']) + self.assertEquals('gzip', result['Accept-Encoding']) + + def test_doesnt_barf_removing_nonexistent_headers(self): + fake_request = HttpRequest() + + fake_response = HttpResponse() + fake_response['Vary'] = 'Cookie' + fake_response['Accept-Encoding'] = 'gzip' + remove_headers_from_response(fake_response, 'Vary', 'FakeHeaderWeeee') + + result = self.middleware.process_response(fake_request, fake_response) + self.assertNotIn('Vary', result) + self.assertEquals('gzip', result['Accept-Encoding']) + def test_removes_intended_headers(self): fake_request = HttpRequest()