AccessTokenView: support for X-Token-Type in HTTP header (#21662)

* AccessTokenView: support for X-Token-Type in HTTP header
This commit is contained in:
David J. Malan
2019-10-01 15:57:21 -04:00
committed by Nimisha Asthagiri
parent 62300b922c
commit 680f62278c
2 changed files with 32 additions and 17 deletions

View File

@@ -123,12 +123,12 @@ class _DispatchingViewTestCase(TestCase):
)
models.RestrictedApplication.objects.create(application=self.restricted_dot_app)
def _post_request(self, user, client, token_type=None, scope=None):
def _post_request(self, user, client, token_type=None, scope=None, headers=None):
"""
Call the view with a POST request object with the appropriate format,
returning the response object.
"""
return self.client.post(self.url, self._post_body(user, client, token_type, scope)) # pylint: disable=no-member
return self.client.post(self.url, self._post_body(user, client, token_type, scope), **(headers or {})) # pylint: disable=no-member
def _post_body(self, user, client, token_type=None, scope=None):
"""
@@ -186,6 +186,23 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
return serialized_public_keys_json, serialized_keypair_json
def _test_jwt_access_token(self, client_attr, token_type=None, headers=None):
"""
Test response for JWT token.
"""
client = getattr(self, client_attr)
response = self._post_request(self.user, client, token_type=token_type, headers=headers or {})
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode('utf-8'))
self.assertIn('expires_in', data)
self.assertEqual(data['token_type'], 'JWT')
self.assert_valid_jwt_access_token(
data['access_token'],
self.user,
data['scope'].split(' '),
should_be_restricted=False,
)
@ddt.data('dop_app', 'dot_app')
def test_access_token_fields(self, client_attr):
client = getattr(self, client_attr)
@@ -217,19 +234,16 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
)
@ddt.data('dop_app', 'dot_app')
def test_jwt_access_token(self, client_attr):
client = getattr(self, client_attr)
response = self._post_request(self.user, client, token_type='jwt')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode('utf-8'))
self.assertIn('expires_in', data)
self.assertEqual(data['token_type'], 'JWT')
self.assert_valid_jwt_access_token(
data['access_token'],
self.user,
data['scope'].split(' '),
should_be_restricted=False,
)
def test_jwt_access_token_from_parameter(self, client_attr):
self._test_jwt_access_token(client_attr, token_type='jwt')
@ddt.data('dop_app', 'dot_app')
def test_jwt_access_token_from_header(self, client_attr):
self._test_jwt_access_token(client_attr, headers={'HTTP_X_TOKEN_TYPE': 'jwt'})
@ddt.data('dop_app', 'dot_app')
def test_jwt_access_token_from_parameter_not_header(self, client_attr):
self._test_jwt_access_token(client_attr, token_type='jwt', headers={'HTTP_X_TOKEN_TYPE': 'invalid'})
@ddt.data(
('jwt', 'jwt'),

View File

@@ -96,10 +96,11 @@ class AccessTokenView(RatelimitMixin, _DispatchingView):
ratelimit_block = True
ratelimit_method = ALL
def dispatch(self, request, *args, **kwargs):
def dispatch(self, request, *args, **kwargs): # pylint: disable=arguments-differ
response = super(AccessTokenView, self).dispatch(request, *args, **kwargs)
token_type = request.POST.get('token_type', 'no_token_type_supplied').lower()
token_type = request.POST.get('token_type',
request.META.get('HTTP_X_TOKEN_TYPE', 'no_token_type_supplied')).lower()
monitoring_utils.set_custom_metric('oauth_token_type', token_type)
monitoring_utils.set_custom_metric('oauth_grant_type', request.POST.get('grant_type', ''))