Updated access token view to return a JWT as an access token
The JWT includes the user email and username, along with details pulled from the original access token (e.g. scope, expiration). ECOM-4221
This commit is contained in:
committed by
Clinton Blackburn
parent
5adf6fec66
commit
6941fcd766
@@ -111,8 +111,10 @@ class CcxRestApiTest(CcxTestCase, APITestCase):
|
||||
token_resp = self.client.post('/oauth2/access_token/', data=token_data)
|
||||
self.assertEqual(token_resp.status_code, status.HTTP_200_OK)
|
||||
token_resp_json = json.loads(token_resp.content)
|
||||
self.assertIn('access_token', token_resp_json)
|
||||
return 'Bearer {0}'.format(token_resp_json.get('access_token'))
|
||||
return '{token_type} {token}'.format(
|
||||
token_type=token_resp_json['token_type'],
|
||||
token=token_resp_json['access_token']
|
||||
)
|
||||
|
||||
def expect_error(self, http_code, error_code_str, resp_obj):
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,7 @@ from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
|
||||
from ccx_keys.locator import CCXLocator
|
||||
from courseware import courses
|
||||
from edx_rest_framework_extensions.authentication import JwtAuthentication
|
||||
from instructor.enrollment import (
|
||||
enroll_email,
|
||||
get_email_params,
|
||||
@@ -361,7 +362,7 @@ class CCXListView(GenericAPIView):
|
||||
]
|
||||
}
|
||||
"""
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
|
||||
authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
|
||||
permission_classes = (IsAuthenticated, permissions.IsMasterCourseStaffInstructor)
|
||||
serializer_class = CCXCourseSerializer
|
||||
pagination_class = CCXAPIPagination
|
||||
@@ -599,7 +600,7 @@ class CCXDetailView(GenericAPIView):
|
||||
response is returned.
|
||||
"""
|
||||
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
|
||||
authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
|
||||
permission_classes = (IsAuthenticated, permissions.IsCourseStaffInstructor)
|
||||
serializer_class = CCXCourseSerializer
|
||||
|
||||
|
||||
@@ -12,7 +12,12 @@ class DOTAdapter(object):
|
||||
|
||||
backend = object()
|
||||
|
||||
def create_confidential_client(self, name, user, redirect_uri, client_id=None):
|
||||
def create_confidential_client(self,
|
||||
name,
|
||||
user,
|
||||
redirect_uri,
|
||||
client_id=None,
|
||||
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE):
|
||||
"""
|
||||
Create an oauth client application that is confidential.
|
||||
"""
|
||||
@@ -21,7 +26,7 @@ class DOTAdapter(object):
|
||||
user=user,
|
||||
client_id=client_id,
|
||||
client_type=models.Application.CLIENT_CONFIDENTIAL,
|
||||
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
|
||||
authorization_grant_type=authorization_grant_type,
|
||||
redirect_uris=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Classes that override default django-oauth-toolkit behavior
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from oauth2_provider.oauth2_validators import OAuth2Validator
|
||||
@@ -41,3 +42,25 @@ class EdxOAuth2Validator(OAuth2Validator):
|
||||
else:
|
||||
authenticated_user = authenticate(username=email_user.username, password=password)
|
||||
return authenticated_user
|
||||
|
||||
def save_bearer_token(self, token, request, *args, **kwargs):
|
||||
"""
|
||||
Ensure that access tokens issued via client credentials grant are associated with the owner of the
|
||||
``Application``.
|
||||
"""
|
||||
grant_type = request.grant_type
|
||||
user = request.user
|
||||
|
||||
if grant_type == 'client_credentials':
|
||||
# Temporarily remove the grant type to avoid triggering the super method's code that removes request.user.
|
||||
request.grant_type = None
|
||||
|
||||
# Ensure the tokens get associated with the correct user since DOT does not normally
|
||||
# associate access tokens issued with the client_credentials grant to users.
|
||||
request.user = request.client.user
|
||||
|
||||
super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
|
||||
|
||||
# Restore the original request attributes
|
||||
request.grant_type = grant_type
|
||||
request.user = user
|
||||
|
||||
@@ -1,3 +1,44 @@
|
||||
"""
|
||||
OAuth Dispatch test mixins
|
||||
"""
|
||||
import jwt
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class AccessTokenMixin(object):
|
||||
""" Mixin for tests dealing with OAuth 2 access tokens. """
|
||||
|
||||
def assert_valid_jwt_access_token(self, access_token, user, scopes=None):
|
||||
"""
|
||||
Verify the specified JWT access token is valid, and belongs to the specified user.
|
||||
|
||||
Args:
|
||||
access_token (str): JWT
|
||||
user (User): User whose information is contained in the JWT payload.
|
||||
|
||||
Returns:
|
||||
dict: Decoded JWT payload
|
||||
"""
|
||||
scopes = scopes or []
|
||||
audience = settings.JWT_AUTH['JWT_AUDIENCE']
|
||||
issuer = settings.JWT_AUTH['JWT_ISSUER']
|
||||
payload = jwt.decode(
|
||||
access_token,
|
||||
settings.JWT_AUTH['JWT_SECRET_KEY'],
|
||||
algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']],
|
||||
audience=audience,
|
||||
issuer=issuer
|
||||
)
|
||||
|
||||
expected = {
|
||||
'aud': audience,
|
||||
'iss': issuer,
|
||||
'preferred_username': user.username,
|
||||
}
|
||||
|
||||
if 'email' in scopes:
|
||||
expected['email'] = user.email
|
||||
|
||||
self.assertDictContainsSubset(expected, payload)
|
||||
|
||||
return payload
|
||||
|
||||
@@ -4,33 +4,37 @@ import json
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.test import TestCase
|
||||
from edx_oauth2_provider.tests.factories import ClientFactory
|
||||
from oauth2_provider.models import Application
|
||||
from provider.oauth2.models import AccessToken
|
||||
from student.tests.factories import UserFactory
|
||||
|
||||
from . import mixins
|
||||
from .constants import DUMMY_REDIRECT_URL
|
||||
from ..adapters import DOTAdapter
|
||||
|
||||
class ClientCredentialsTest(TestCase):
|
||||
|
||||
class ClientCredentialsTest(mixins.AccessTokenMixin, TestCase):
|
||||
""" Tests validating the client credentials grant behavior. """
|
||||
|
||||
def setUp(self):
|
||||
super(ClientCredentialsTest, self).setUp()
|
||||
|
||||
self.user = UserFactory()
|
||||
self.oauth_client = ClientFactory(user=self.user)
|
||||
|
||||
def test_access_token(self):
|
||||
""" Verify the client credentials grant can be used to obtain an access token whose default scopes allow access
|
||||
to the user info endpoint.
|
||||
"""
|
||||
oauth_client = ClientFactory(user=self.user)
|
||||
data = {
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.oauth_client.client_id,
|
||||
'client_secret': self.oauth_client.client_secret
|
||||
'client_id': oauth_client.client_id,
|
||||
'client_secret': oauth_client.client_secret
|
||||
}
|
||||
response = self.client.post(reverse('oauth2:access_token'), data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
access_token = json.loads(response.content)['access_token']
|
||||
expected = AccessToken.objects.filter(client=self.oauth_client, user=self.user).first().token
|
||||
expected = AccessToken.objects.filter(client=oauth_client, user=self.user).first().token
|
||||
self.assertEqual(access_token, expected)
|
||||
|
||||
headers = {
|
||||
@@ -38,3 +42,29 @@ class ClientCredentialsTest(TestCase):
|
||||
}
|
||||
response = self.client.get(reverse('oauth2:user_info'), **headers)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_jwt_access_token(self):
|
||||
""" Verify the client credentials grant can be used to obtain a JWT access token. """
|
||||
application = DOTAdapter().create_confidential_client(
|
||||
name='test dot application',
|
||||
user=self.user,
|
||||
authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS,
|
||||
redirect_uri=DUMMY_REDIRECT_URL,
|
||||
client_id='dot-app-client-id',
|
||||
)
|
||||
scopes = ('read', 'write', 'email')
|
||||
data = {
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': application.client_id,
|
||||
'client_secret': application.client_secret,
|
||||
'scope': ' '.join(scopes),
|
||||
'token_type': 'jwt'
|
||||
}
|
||||
|
||||
response = self.client.post(reverse('access_token'), data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
content = json.loads(response.content)
|
||||
access_token = content['access_token']
|
||||
self.assertEqual(content['scope'], data['scope'])
|
||||
self.assert_valid_jwt_access_token(access_token, self.user, scopes)
|
||||
|
||||
@@ -12,9 +12,10 @@ import httpretty
|
||||
from student.tests.factories import UserFactory
|
||||
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle
|
||||
|
||||
from .constants import DUMMY_REDIRECT_URL
|
||||
from .. import adapters
|
||||
from .. import views
|
||||
from .constants import DUMMY_REDIRECT_URL
|
||||
from . import mixins
|
||||
|
||||
|
||||
class _DispatchingViewTestCase(TestCase):
|
||||
@@ -43,14 +44,14 @@ class _DispatchingViewTestCase(TestCase):
|
||||
client_id='dop-app-client-id',
|
||||
)
|
||||
|
||||
def _post_request(self, user, client):
|
||||
def _post_request(self, user, client, token_type=None):
|
||||
"""
|
||||
Call the view with a POST request objectwith the appropriate format,
|
||||
returning the response object.
|
||||
"""
|
||||
return self.client.post(self.url, self._post_body(user, client))
|
||||
return self.client.post(self.url, self._post_body(user, client, token_type))
|
||||
|
||||
def _post_body(self, user, client):
|
||||
def _post_body(self, user, client, token_type=None):
|
||||
"""
|
||||
Return a dictionary to be used as the body of the POST request
|
||||
"""
|
||||
@@ -58,7 +59,7 @@ class _DispatchingViewTestCase(TestCase):
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class TestAccessTokenView(_DispatchingViewTestCase):
|
||||
class TestAccessTokenView(mixins.AccessTokenMixin, _DispatchingViewTestCase):
|
||||
"""
|
||||
Test class for AccessTokenView
|
||||
"""
|
||||
@@ -66,17 +67,22 @@ class TestAccessTokenView(_DispatchingViewTestCase):
|
||||
view_class = views.AccessTokenView
|
||||
url = reverse('access_token')
|
||||
|
||||
def _post_body(self, user, client):
|
||||
def _post_body(self, user, client, token_type=None):
|
||||
"""
|
||||
Return a dictionary to be used as the body of the POST request
|
||||
"""
|
||||
return {
|
||||
body = {
|
||||
'client_id': client.client_id,
|
||||
'grant_type': 'password',
|
||||
'username': user.username,
|
||||
'password': 'test',
|
||||
}
|
||||
|
||||
if token_type:
|
||||
body['token_type'] = token_type
|
||||
|
||||
return body
|
||||
|
||||
@ddt.data('dop_client', 'dot_app')
|
||||
def test_access_token_fields(self, client_attr):
|
||||
client = getattr(self, client_attr)
|
||||
@@ -88,6 +94,16 @@ class TestAccessTokenView(_DispatchingViewTestCase):
|
||||
self.assertIn('scope', data)
|
||||
self.assertIn('token_type', data)
|
||||
|
||||
@ddt.data('dop_client', 'dot_app')
|
||||
def test_jwt_access_token(self, client_attr):
|
||||
client = getattr(self, client_attr)
|
||||
response = self._post_request(self.user, client, token_type='jwt')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = json.loads(response.content)
|
||||
self.assertIn('expires_in', data)
|
||||
self.assertEqual(data['token_type'], 'JWT')
|
||||
self.assert_valid_jwt_access_token(data['access_token'], self.user, data['scope'].split(' '))
|
||||
|
||||
def test_dot_access_token_provides_refresh_token(self):
|
||||
response = self._post_request(self.user, self.dot_app)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
@@ -111,7 +127,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut
|
||||
view_class = views.AccessTokenExchangeView
|
||||
url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'})
|
||||
|
||||
def _post_body(self, user, client):
|
||||
def _post_body(self, user, client, token_type=None):
|
||||
return {
|
||||
'client_id': client.client_id,
|
||||
'access_token': self.access_token,
|
||||
|
||||
@@ -5,12 +5,17 @@ django-oauth-toolkit as appropriate.
|
||||
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
from time import time
|
||||
|
||||
import jwt
|
||||
from auth_exchange import views as auth_exchange_views
|
||||
from django.conf import settings
|
||||
from django.utils.functional import cached_property
|
||||
from django.views.generic import View
|
||||
from edx_oauth2_provider import views as dop_views # django-oauth2-provider views
|
||||
from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit
|
||||
|
||||
from auth_exchange import views as auth_exchange_views
|
||||
|
||||
from . import adapters
|
||||
|
||||
|
||||
@@ -25,6 +30,15 @@ class _DispatchingView(View):
|
||||
dot_adapter = adapters.DOTAdapter()
|
||||
dop_adapter = adapters.DOPAdapter()
|
||||
|
||||
def get_adapter(self, request):
|
||||
"""
|
||||
Returns the appropriate adapter based on the OAuth client linked to the request.
|
||||
"""
|
||||
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
|
||||
return self.dot_adapter
|
||||
else:
|
||||
return self.dop_adapter
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
"""
|
||||
Dispatch the request to the selected backend's view.
|
||||
@@ -41,11 +55,7 @@ class _DispatchingView(View):
|
||||
otherwise use the django-oauth2-provider (DOP) adapter, and allow the
|
||||
calls to fail normally if the client does not exist.
|
||||
"""
|
||||
|
||||
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
|
||||
return self.dot_adapter.backend
|
||||
else:
|
||||
return self.dop_adapter.backend
|
||||
return self.get_adapter(request).backend
|
||||
|
||||
def get_view_for_backend(self, backend):
|
||||
"""
|
||||
@@ -72,6 +82,77 @@ class AccessTokenView(_DispatchingView):
|
||||
dot_view = dot_views.TokenView
|
||||
dop_view = dop_views.AccessTokenView
|
||||
|
||||
@cached_property
|
||||
def claim_handlers(self):
|
||||
""" Returns a dictionary mapping scopes to methods that will add claims to the JWT payload. """
|
||||
|
||||
return {
|
||||
'email': self._attach_email_claim,
|
||||
'profile': self._attach_profile_claim
|
||||
}
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
response = super(AccessTokenView, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
if response.status_code == 200 and request.POST.get('token_type', '').lower() == 'jwt':
|
||||
expires_in, scopes, user = self._decompose_access_token_response(request, response)
|
||||
|
||||
content = {
|
||||
'access_token': self._generate_jwt(user, scopes, expires_in),
|
||||
'expires_in': expires_in,
|
||||
'token_type': 'JWT',
|
||||
'scope': ' '.join(scopes),
|
||||
}
|
||||
response.content = json.dumps(content)
|
||||
|
||||
return response
|
||||
|
||||
def _decompose_access_token_response(self, request, response):
|
||||
""" Decomposes the access token in the request to an expiration date, scopes, and User. """
|
||||
content = json.loads(response.content)
|
||||
access_token = content['access_token']
|
||||
scope = content['scope']
|
||||
access_token_obj = self.get_adapter(request).get_access_token(access_token)
|
||||
user = access_token_obj.user
|
||||
scopes = scope.split(' ')
|
||||
expires_in = content['expires_in']
|
||||
return expires_in, scopes, user
|
||||
|
||||
def _generate_jwt(self, user, scopes, expires_in):
|
||||
""" Returns a JWT access token. """
|
||||
now = int(time())
|
||||
|
||||
payload = {
|
||||
'iss': settings.JWT_AUTH['JWT_ISSUER'],
|
||||
'aud': settings.JWT_AUTH['JWT_AUDIENCE'],
|
||||
'exp': now + expires_in,
|
||||
'iat': now,
|
||||
'preferred_username': user.username,
|
||||
}
|
||||
|
||||
for scope in scopes:
|
||||
handler = self.claim_handlers.get(scope)
|
||||
|
||||
if handler:
|
||||
handler(payload, user)
|
||||
|
||||
secret = settings.JWT_AUTH['JWT_SECRET_KEY']
|
||||
token = jwt.encode(payload, secret, algorithm=settings.JWT_AUTH['JWT_ALGORITHM'])
|
||||
|
||||
return token
|
||||
|
||||
def _attach_email_claim(self, payload, user):
|
||||
""" Add the email claim details to the JWT payload. """
|
||||
payload['email'] = user.email
|
||||
|
||||
def _attach_profile_claim(self, payload, user):
|
||||
""" Add the profile claim details to the JWT payload. """
|
||||
payload.update({
|
||||
'family_name': user.last_name,
|
||||
'name': user.get_full_name(),
|
||||
'given_name': user.first_name,
|
||||
})
|
||||
|
||||
|
||||
class AuthorizationView(_DispatchingView):
|
||||
"""
|
||||
|
||||
@@ -458,7 +458,13 @@ OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS = 30
|
||||
################################## DJANGO OAUTH TOOLKIT #######################################
|
||||
|
||||
OAUTH2_PROVIDER = {
|
||||
'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator'
|
||||
'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator',
|
||||
'SCOPES': {
|
||||
'read': 'Read scope',
|
||||
'write': 'Write scope',
|
||||
'email': 'Email scope',
|
||||
'profile': 'Profile scope',
|
||||
}
|
||||
}
|
||||
|
||||
################################## TEMPLATE CONFIGURATION #####################################
|
||||
|
||||
Reference in New Issue
Block a user