From 6941fcd76684613be2f0336efdf31d345a7bfbbd Mon Sep 17 00:00:00 2001 From: Clinton Blackburn Date: Tue, 19 Apr 2016 18:10:39 -0400 Subject: [PATCH] Updated access token view to return a JWT as an access token The JWT includes the user email and username, along with details pulled from the original access token (e.g. scope, expiration). ECOM-4221 --- lms/djangoapps/ccx/api/v0/tests/test_views.py | 6 +- lms/djangoapps/ccx/api/v0/views.py | 5 +- lms/djangoapps/oauth_dispatch/adapters/dot.py | 9 +- .../oauth_dispatch/dot_overrides.py | 23 +++++ lms/djangoapps/oauth_dispatch/tests/mixins.py | 41 ++++++++ .../tests/test_client_credentials.py | 42 ++++++-- .../oauth_dispatch/tests/test_views.py | 32 +++++-- lms/djangoapps/oauth_dispatch/views.py | 95 +++++++++++++++++-- lms/envs/common.py | 8 +- 9 files changed, 233 insertions(+), 28 deletions(-) diff --git a/lms/djangoapps/ccx/api/v0/tests/test_views.py b/lms/djangoapps/ccx/api/v0/tests/test_views.py index 43d7e9b03a..843085397e 100644 --- a/lms/djangoapps/ccx/api/v0/tests/test_views.py +++ b/lms/djangoapps/ccx/api/v0/tests/test_views.py @@ -111,8 +111,10 @@ class CcxRestApiTest(CcxTestCase, APITestCase): token_resp = self.client.post('/oauth2/access_token/', data=token_data) self.assertEqual(token_resp.status_code, status.HTTP_200_OK) token_resp_json = json.loads(token_resp.content) - self.assertIn('access_token', token_resp_json) - return 'Bearer {0}'.format(token_resp_json.get('access_token')) + return '{token_type} {token}'.format( + token_type=token_resp_json['token_type'], + token=token_resp_json['access_token'] + ) def expect_error(self, http_code, error_code_str, resp_obj): """ diff --git a/lms/djangoapps/ccx/api/v0/views.py b/lms/djangoapps/ccx/api/v0/views.py index 1e3d14de47..fa01df1f2c 100644 --- a/lms/djangoapps/ccx/api/v0/views.py +++ b/lms/djangoapps/ccx/api/v0/views.py @@ -17,6 +17,7 @@ from rest_framework_oauth.authentication import OAuth2Authentication from ccx_keys.locator import CCXLocator from courseware import courses +from edx_rest_framework_extensions.authentication import JwtAuthentication from instructor.enrollment import ( enroll_email, get_email_params, @@ -361,7 +362,7 @@ class CCXListView(GenericAPIView): ] } """ - authentication_classes = (OAuth2Authentication, SessionAuthentication,) + authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,) permission_classes = (IsAuthenticated, permissions.IsMasterCourseStaffInstructor) serializer_class = CCXCourseSerializer pagination_class = CCXAPIPagination @@ -599,7 +600,7 @@ class CCXDetailView(GenericAPIView): response is returned. """ - authentication_classes = (OAuth2Authentication, SessionAuthentication,) + authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,) permission_classes = (IsAuthenticated, permissions.IsCourseStaffInstructor) serializer_class = CCXCourseSerializer diff --git a/lms/djangoapps/oauth_dispatch/adapters/dot.py b/lms/djangoapps/oauth_dispatch/adapters/dot.py index 84dcb7ece4..e1ac9705db 100644 --- a/lms/djangoapps/oauth_dispatch/adapters/dot.py +++ b/lms/djangoapps/oauth_dispatch/adapters/dot.py @@ -12,7 +12,12 @@ class DOTAdapter(object): backend = object() - def create_confidential_client(self, name, user, redirect_uri, client_id=None): + def create_confidential_client(self, + name, + user, + redirect_uri, + client_id=None, + authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE): """ Create an oauth client application that is confidential. """ @@ -21,7 +26,7 @@ class DOTAdapter(object): user=user, client_id=client_id, client_type=models.Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE, + authorization_grant_type=authorization_grant_type, redirect_uris=redirect_uri, ) diff --git a/lms/djangoapps/oauth_dispatch/dot_overrides.py b/lms/djangoapps/oauth_dispatch/dot_overrides.py index 98203d6cde..2e71bc316b 100644 --- a/lms/djangoapps/oauth_dispatch/dot_overrides.py +++ b/lms/djangoapps/oauth_dispatch/dot_overrides.py @@ -1,6 +1,7 @@ """ Classes that override default django-oauth-toolkit behavior """ +from __future__ import unicode_literals from django.contrib.auth import authenticate, get_user_model from oauth2_provider.oauth2_validators import OAuth2Validator @@ -41,3 +42,25 @@ class EdxOAuth2Validator(OAuth2Validator): else: authenticated_user = authenticate(username=email_user.username, password=password) return authenticated_user + + def save_bearer_token(self, token, request, *args, **kwargs): + """ + Ensure that access tokens issued via client credentials grant are associated with the owner of the + ``Application``. + """ + grant_type = request.grant_type + user = request.user + + if grant_type == 'client_credentials': + # Temporarily remove the grant type to avoid triggering the super method's code that removes request.user. + request.grant_type = None + + # Ensure the tokens get associated with the correct user since DOT does not normally + # associate access tokens issued with the client_credentials grant to users. + request.user = request.client.user + + super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs) + + # Restore the original request attributes + request.grant_type = grant_type + request.user = user diff --git a/lms/djangoapps/oauth_dispatch/tests/mixins.py b/lms/djangoapps/oauth_dispatch/tests/mixins.py index 8b16c53fc2..3c2acfe149 100644 --- a/lms/djangoapps/oauth_dispatch/tests/mixins.py +++ b/lms/djangoapps/oauth_dispatch/tests/mixins.py @@ -1,3 +1,44 @@ """ OAuth Dispatch test mixins """ +import jwt +from django.conf import settings + + +class AccessTokenMixin(object): + """ Mixin for tests dealing with OAuth 2 access tokens. """ + + def assert_valid_jwt_access_token(self, access_token, user, scopes=None): + """ + Verify the specified JWT access token is valid, and belongs to the specified user. + + Args: + access_token (str): JWT + user (User): User whose information is contained in the JWT payload. + + Returns: + dict: Decoded JWT payload + """ + scopes = scopes or [] + audience = settings.JWT_AUTH['JWT_AUDIENCE'] + issuer = settings.JWT_AUTH['JWT_ISSUER'] + payload = jwt.decode( + access_token, + settings.JWT_AUTH['JWT_SECRET_KEY'], + algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']], + audience=audience, + issuer=issuer + ) + + expected = { + 'aud': audience, + 'iss': issuer, + 'preferred_username': user.username, + } + + if 'email' in scopes: + expected['email'] = user.email + + self.assertDictContainsSubset(expected, payload) + + return payload diff --git a/lms/djangoapps/oauth_dispatch/tests/test_client_credentials.py b/lms/djangoapps/oauth_dispatch/tests/test_client_credentials.py index 749f712c0c..027c4eebfa 100644 --- a/lms/djangoapps/oauth_dispatch/tests/test_client_credentials.py +++ b/lms/djangoapps/oauth_dispatch/tests/test_client_credentials.py @@ -4,33 +4,37 @@ import json from django.core.urlresolvers import reverse from django.test import TestCase from edx_oauth2_provider.tests.factories import ClientFactory +from oauth2_provider.models import Application from provider.oauth2.models import AccessToken from student.tests.factories import UserFactory +from . import mixins +from .constants import DUMMY_REDIRECT_URL +from ..adapters import DOTAdapter -class ClientCredentialsTest(TestCase): + +class ClientCredentialsTest(mixins.AccessTokenMixin, TestCase): """ Tests validating the client credentials grant behavior. """ def setUp(self): super(ClientCredentialsTest, self).setUp() - self.user = UserFactory() - self.oauth_client = ClientFactory(user=self.user) def test_access_token(self): """ Verify the client credentials grant can be used to obtain an access token whose default scopes allow access to the user info endpoint. """ + oauth_client = ClientFactory(user=self.user) data = { 'grant_type': 'client_credentials', - 'client_id': self.oauth_client.client_id, - 'client_secret': self.oauth_client.client_secret + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret } response = self.client.post(reverse('oauth2:access_token'), data) self.assertEqual(response.status_code, 200) access_token = json.loads(response.content)['access_token'] - expected = AccessToken.objects.filter(client=self.oauth_client, user=self.user).first().token + expected = AccessToken.objects.filter(client=oauth_client, user=self.user).first().token self.assertEqual(access_token, expected) headers = { @@ -38,3 +42,29 @@ class ClientCredentialsTest(TestCase): } response = self.client.get(reverse('oauth2:user_info'), **headers) self.assertEqual(response.status_code, 200) + + def test_jwt_access_token(self): + """ Verify the client credentials grant can be used to obtain a JWT access token. """ + application = DOTAdapter().create_confidential_client( + name='test dot application', + user=self.user, + authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='dot-app-client-id', + ) + scopes = ('read', 'write', 'email') + data = { + 'grant_type': 'client_credentials', + 'client_id': application.client_id, + 'client_secret': application.client_secret, + 'scope': ' '.join(scopes), + 'token_type': 'jwt' + } + + response = self.client.post(reverse('access_token'), data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content) + access_token = content['access_token'] + self.assertEqual(content['scope'], data['scope']) + self.assert_valid_jwt_access_token(access_token, self.user, scopes) diff --git a/lms/djangoapps/oauth_dispatch/tests/test_views.py b/lms/djangoapps/oauth_dispatch/tests/test_views.py index b8fb7435ff..6edd6cda35 100644 --- a/lms/djangoapps/oauth_dispatch/tests/test_views.py +++ b/lms/djangoapps/oauth_dispatch/tests/test_views.py @@ -12,9 +12,10 @@ import httpretty from student.tests.factories import UserFactory from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle +from .constants import DUMMY_REDIRECT_URL from .. import adapters from .. import views -from .constants import DUMMY_REDIRECT_URL +from . import mixins class _DispatchingViewTestCase(TestCase): @@ -43,14 +44,14 @@ class _DispatchingViewTestCase(TestCase): client_id='dop-app-client-id', ) - def _post_request(self, user, client): + def _post_request(self, user, client, token_type=None): """ Call the view with a POST request objectwith the appropriate format, returning the response object. """ - return self.client.post(self.url, self._post_body(user, client)) + return self.client.post(self.url, self._post_body(user, client, token_type)) - def _post_body(self, user, client): + def _post_body(self, user, client, token_type=None): """ Return a dictionary to be used as the body of the POST request """ @@ -58,7 +59,7 @@ class _DispatchingViewTestCase(TestCase): @ddt.ddt -class TestAccessTokenView(_DispatchingViewTestCase): +class TestAccessTokenView(mixins.AccessTokenMixin, _DispatchingViewTestCase): """ Test class for AccessTokenView """ @@ -66,17 +67,22 @@ class TestAccessTokenView(_DispatchingViewTestCase): view_class = views.AccessTokenView url = reverse('access_token') - def _post_body(self, user, client): + def _post_body(self, user, client, token_type=None): """ Return a dictionary to be used as the body of the POST request """ - return { + body = { 'client_id': client.client_id, 'grant_type': 'password', 'username': user.username, 'password': 'test', } + if token_type: + body['token_type'] = token_type + + return body + @ddt.data('dop_client', 'dot_app') def test_access_token_fields(self, client_attr): client = getattr(self, client_attr) @@ -88,6 +94,16 @@ class TestAccessTokenView(_DispatchingViewTestCase): self.assertIn('scope', data) self.assertIn('token_type', data) + @ddt.data('dop_client', 'dot_app') + def test_jwt_access_token(self, client_attr): + client = getattr(self, client_attr) + response = self._post_request(self.user, client, token_type='jwt') + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertIn('expires_in', data) + self.assertEqual(data['token_type'], 'JWT') + self.assert_valid_jwt_access_token(data['access_token'], self.user, data['scope'].split(' ')) + def test_dot_access_token_provides_refresh_token(self): response = self._post_request(self.user, self.dot_app) self.assertEqual(response.status_code, 200) @@ -111,7 +127,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut view_class = views.AccessTokenExchangeView url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'}) - def _post_body(self, user, client): + def _post_body(self, user, client, token_type=None): return { 'client_id': client.client_id, 'access_token': self.access_token, diff --git a/lms/djangoapps/oauth_dispatch/views.py b/lms/djangoapps/oauth_dispatch/views.py index 22652f127b..0a62a0c80f 100644 --- a/lms/djangoapps/oauth_dispatch/views.py +++ b/lms/djangoapps/oauth_dispatch/views.py @@ -5,12 +5,17 @@ django-oauth-toolkit as appropriate. from __future__ import unicode_literals +import json +from time import time + +import jwt +from auth_exchange import views as auth_exchange_views +from django.conf import settings +from django.utils.functional import cached_property from django.views.generic import View from edx_oauth2_provider import views as dop_views # django-oauth2-provider views from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit -from auth_exchange import views as auth_exchange_views - from . import adapters @@ -25,6 +30,15 @@ class _DispatchingView(View): dot_adapter = adapters.DOTAdapter() dop_adapter = adapters.DOPAdapter() + def get_adapter(self, request): + """ + Returns the appropriate adapter based on the OAuth client linked to the request. + """ + if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists(): + return self.dot_adapter + else: + return self.dop_adapter + def dispatch(self, request, *args, **kwargs): """ Dispatch the request to the selected backend's view. @@ -41,11 +55,7 @@ class _DispatchingView(View): otherwise use the django-oauth2-provider (DOP) adapter, and allow the calls to fail normally if the client does not exist. """ - - if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists(): - return self.dot_adapter.backend - else: - return self.dop_adapter.backend + return self.get_adapter(request).backend def get_view_for_backend(self, backend): """ @@ -72,6 +82,77 @@ class AccessTokenView(_DispatchingView): dot_view = dot_views.TokenView dop_view = dop_views.AccessTokenView + @cached_property + def claim_handlers(self): + """ Returns a dictionary mapping scopes to methods that will add claims to the JWT payload. """ + + return { + 'email': self._attach_email_claim, + 'profile': self._attach_profile_claim + } + + def dispatch(self, request, *args, **kwargs): + response = super(AccessTokenView, self).dispatch(request, *args, **kwargs) + + if response.status_code == 200 and request.POST.get('token_type', '').lower() == 'jwt': + expires_in, scopes, user = self._decompose_access_token_response(request, response) + + content = { + 'access_token': self._generate_jwt(user, scopes, expires_in), + 'expires_in': expires_in, + 'token_type': 'JWT', + 'scope': ' '.join(scopes), + } + response.content = json.dumps(content) + + return response + + def _decompose_access_token_response(self, request, response): + """ Decomposes the access token in the request to an expiration date, scopes, and User. """ + content = json.loads(response.content) + access_token = content['access_token'] + scope = content['scope'] + access_token_obj = self.get_adapter(request).get_access_token(access_token) + user = access_token_obj.user + scopes = scope.split(' ') + expires_in = content['expires_in'] + return expires_in, scopes, user + + def _generate_jwt(self, user, scopes, expires_in): + """ Returns a JWT access token. """ + now = int(time()) + + payload = { + 'iss': settings.JWT_AUTH['JWT_ISSUER'], + 'aud': settings.JWT_AUTH['JWT_AUDIENCE'], + 'exp': now + expires_in, + 'iat': now, + 'preferred_username': user.username, + } + + for scope in scopes: + handler = self.claim_handlers.get(scope) + + if handler: + handler(payload, user) + + secret = settings.JWT_AUTH['JWT_SECRET_KEY'] + token = jwt.encode(payload, secret, algorithm=settings.JWT_AUTH['JWT_ALGORITHM']) + + return token + + def _attach_email_claim(self, payload, user): + """ Add the email claim details to the JWT payload. """ + payload['email'] = user.email + + def _attach_profile_claim(self, payload, user): + """ Add the profile claim details to the JWT payload. """ + payload.update({ + 'family_name': user.last_name, + 'name': user.get_full_name(), + 'given_name': user.first_name, + }) + class AuthorizationView(_DispatchingView): """ diff --git a/lms/envs/common.py b/lms/envs/common.py index e6a1d7889f..fae4573e92 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -458,7 +458,13 @@ OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS = 30 ################################## DJANGO OAUTH TOOLKIT ####################################### OAUTH2_PROVIDER = { - 'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator' + 'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator', + 'SCOPES': { + 'read': 'Read scope', + 'write': 'Write scope', + 'email': 'Email scope', + 'profile': 'Profile scope', + } } ################################## TEMPLATE CONFIGURATION #####################################