From 680f62278cd57514649e013f8b549a82a557fd1d Mon Sep 17 00:00:00 2001 From: "David J. Malan" Date: Tue, 1 Oct 2019 15:57:21 -0400 Subject: [PATCH] AccessTokenView: support for X-Token-Type in HTTP header (#21662) * AccessTokenView: support for X-Token-Type in HTTP header --- .../oauth_dispatch/tests/test_views.py | 44 ++++++++++++------- .../core/djangoapps/oauth_dispatch/views.py | 5 ++- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/openedx/core/djangoapps/oauth_dispatch/tests/test_views.py b/openedx/core/djangoapps/oauth_dispatch/tests/test_views.py index 25e4585e28..cf40d9d418 100644 --- a/openedx/core/djangoapps/oauth_dispatch/tests/test_views.py +++ b/openedx/core/djangoapps/oauth_dispatch/tests/test_views.py @@ -123,12 +123,12 @@ class _DispatchingViewTestCase(TestCase): ) models.RestrictedApplication.objects.create(application=self.restricted_dot_app) - def _post_request(self, user, client, token_type=None, scope=None): + def _post_request(self, user, client, token_type=None, scope=None, headers=None): """ Call the view with a POST request object with the appropriate format, returning the response object. """ - return self.client.post(self.url, self._post_body(user, client, token_type, scope)) # pylint: disable=no-member + return self.client.post(self.url, self._post_body(user, client, token_type, scope), **(headers or {})) # pylint: disable=no-member def _post_body(self, user, client, token_type=None, scope=None): """ @@ -186,6 +186,23 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa return serialized_public_keys_json, serialized_keypair_json + def _test_jwt_access_token(self, client_attr, token_type=None, headers=None): + """ + Test response for JWT token. + """ + client = getattr(self, client_attr) + response = self._post_request(self.user, client, token_type=token_type, headers=headers or {}) + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode('utf-8')) + self.assertIn('expires_in', data) + self.assertEqual(data['token_type'], 'JWT') + self.assert_valid_jwt_access_token( + data['access_token'], + self.user, + data['scope'].split(' '), + should_be_restricted=False, + ) + @ddt.data('dop_app', 'dot_app') def test_access_token_fields(self, client_attr): client = getattr(self, client_attr) @@ -217,19 +234,16 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa ) @ddt.data('dop_app', 'dot_app') - def test_jwt_access_token(self, client_attr): - client = getattr(self, client_attr) - response = self._post_request(self.user, client, token_type='jwt') - self.assertEqual(response.status_code, 200) - data = json.loads(response.content.decode('utf-8')) - self.assertIn('expires_in', data) - self.assertEqual(data['token_type'], 'JWT') - self.assert_valid_jwt_access_token( - data['access_token'], - self.user, - data['scope'].split(' '), - should_be_restricted=False, - ) + def test_jwt_access_token_from_parameter(self, client_attr): + self._test_jwt_access_token(client_attr, token_type='jwt') + + @ddt.data('dop_app', 'dot_app') + def test_jwt_access_token_from_header(self, client_attr): + self._test_jwt_access_token(client_attr, headers={'HTTP_X_TOKEN_TYPE': 'jwt'}) + + @ddt.data('dop_app', 'dot_app') + def test_jwt_access_token_from_parameter_not_header(self, client_attr): + self._test_jwt_access_token(client_attr, token_type='jwt', headers={'HTTP_X_TOKEN_TYPE': 'invalid'}) @ddt.data( ('jwt', 'jwt'), diff --git a/openedx/core/djangoapps/oauth_dispatch/views.py b/openedx/core/djangoapps/oauth_dispatch/views.py index d890ab8cae..4d5a99261d 100644 --- a/openedx/core/djangoapps/oauth_dispatch/views.py +++ b/openedx/core/djangoapps/oauth_dispatch/views.py @@ -96,10 +96,11 @@ class AccessTokenView(RatelimitMixin, _DispatchingView): ratelimit_block = True ratelimit_method = ALL - def dispatch(self, request, *args, **kwargs): + def dispatch(self, request, *args, **kwargs): # pylint: disable=arguments-differ response = super(AccessTokenView, self).dispatch(request, *args, **kwargs) - token_type = request.POST.get('token_type', 'no_token_type_supplied').lower() + token_type = request.POST.get('token_type', + request.META.get('HTTP_X_TOKEN_TYPE', 'no_token_type_supplied')).lower() monitoring_utils.set_custom_metric('oauth_token_type', token_type) monitoring_utils.set_custom_metric('oauth_grant_type', request.POST.get('grant_type', ''))