diff --git a/cms/envs/common.py b/cms/envs/common.py index 954ae13380..2beb6e6925 100644 --- a/cms/envs/common.py +++ b/cms/envs/common.py @@ -864,11 +864,14 @@ INSTALLED_APPS = ( # Self-paced course configuration 'openedx.core.djangoapps.self_paced', - # OAuth2 Provider + # django-oauth2-provider (deprecated) 'provider', 'provider.oauth2', 'edx_oauth2_provider', + # django-oauth-toolkit + 'oauth2_provider', + # These are apps that aren't strictly needed by Studio, but are imported by # other apps that are. Django 1.8 wants to have imported models supported # by installed apps. diff --git a/common/djangoapps/auth_exchange/forms.py b/common/djangoapps/auth_exchange/forms.py index 81cd9da183..8caf61799d 100644 --- a/common/djangoapps/auth_exchange/forms.py +++ b/common/djangoapps/auth_exchange/forms.py @@ -8,6 +8,7 @@ import provider.constants from provider.forms import OAuthForm, OAuthValidationError from provider.oauth2.forms import ScopeChoiceField, ScopeMixin from provider.oauth2.models import Client +from oauth2_provider.models import Application from requests import HTTPError from social.backends import oauth as social_oauth from social.exceptions import AuthException @@ -21,9 +22,10 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm): scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False) client_id = CharField(required=False) - def __init__(self, request, *args, **kwargs): + def __init__(self, request, oauth2_adapter, *args, **kwargs): super(AccessTokenExchangeForm, self).__init__(*args, **kwargs) self.request = request + self.oauth2_adapter = oauth2_adapter def _require_oauth_field(self, field_name): """ @@ -68,15 +70,15 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm): client_id = self.cleaned_data["client_id"] try: - client = Client.objects.get(client_id=client_id) - except Client.DoesNotExist: + client = self.oauth2_adapter.get_client(client_id=client_id) + except (Client.DoesNotExist, Application.DoesNotExist): raise OAuthValidationError( { "error": "invalid_client", "error_description": "{} is not a valid client_id".format(client_id), } ) - if client.client_type != provider.constants.PUBLIC: + if client.client_type not in [provider.constants.PUBLIC, Application.CLIENT_PUBLIC]: raise OAuthValidationError( { # invalid_client isn't really the right code, but this mirrors diff --git a/common/djangoapps/auth_exchange/tests/mixins.py b/common/djangoapps/auth_exchange/tests/mixins.py new file mode 100644 index 0000000000..450f789657 --- /dev/null +++ b/common/djangoapps/auth_exchange/tests/mixins.py @@ -0,0 +1,111 @@ +""" +Mixins to facilitate testing OAuth connections to Django-OAuth-Toolkit or +Django-OAuth2-Provider. +""" + +# pylint: disable=protected-access + +from unittest import skip, expectedFailure +from django.test.client import RequestFactory + +from lms.djangoapps.oauth_dispatch import adapters +from lms.djangoapps.oauth_dispatch.tests.constants import DUMMY_REDIRECT_URL + +from ..views import DOTAccessTokenExchangeView + + +class DOPAdapterMixin(object): + """ + Mixin to rewire existing tests to use django-oauth2-provider (DOP) backend + + Overwrites self.client_id, self.access_token, self.oauth2_adapter + """ + client_id = 'dop_test_client_id' + access_token = 'dop_test_access_token' + oauth2_adapter = adapters.DOPAdapter() + + def create_public_client(self, user, client_id=None): + """ + Create an oauth client application that is public. + """ + return self.oauth2_adapter.create_public_client( + name='Test Public Client', + user=user, + client_id=client_id, + redirect_uri=DUMMY_REDIRECT_URL, + ) + + def create_confidential_client(self, user, client_id=None): + """ + Create an oauth client application that is confidential. + """ + return self.oauth2_adapter.create_confidential_client( + name='Test Confidential Client', + user=user, + client_id=client_id, + redirect_uri=DUMMY_REDIRECT_URL, + ) + + def get_token_response_keys(self): + """ + Return the set of keys provided when requesting an access token + """ + return {'access_token', 'token_type', 'expires_in', 'scope'} + + +class DOTAdapterMixin(object): + """ + Mixin to rewire existing tests to use django-oauth-toolkit (DOT) backend + + Overwrites self.client_id, self.access_token, self.oauth2_adapter + """ + + client_id = 'dot_test_client_id' + access_token = 'dot_test_access_token' + oauth2_adapter = adapters.DOTAdapter() + + def create_public_client(self, user, client_id=None): + """ + Create an oauth client application that is public. + """ + return self.oauth2_adapter.create_public_client( + name='Test Public Application', + user=user, + client_id=client_id, + redirect_uri=DUMMY_REDIRECT_URL, + ) + + def create_confidential_client(self, user, client_id=None): + """ + Create an oauth client application that is confidential. + """ + return self.oauth2_adapter.create_confidential_client( + name='Test Confidential Application', + user=user, + client_id=client_id, + redirect_uri=DUMMY_REDIRECT_URL, + ) + + def get_token_response_keys(self): + """ + Return the set of keys provided when requesting an access token + """ + return {'access_token', 'refresh_token', 'token_type', 'expires_in', 'scope'} + + def test_get_method(self): + # Dispatch routes all get methods to DOP, so we test this on the view + request_factory = RequestFactory() + request = request_factory.get('/oauth2/exchange_access_token/') + request.session = {} + view = DOTAccessTokenExchangeView.as_view() + response = view(request, backend='facebook') + self.assertEqual(response.status_code, 400) + + @expectedFailure + def test_single_access_token(self): + # TODO: Single access tokens not supported yet for DOT (See MA-2122) + super(DOTAdapterMixin, self).test_single_access_token() + + @skip("Not supported yet (See MA-2123)") + def test_scopes(self): + super(DOTAdapterMixin, self).test_scopes() diff --git a/common/djangoapps/auth_exchange/tests/test_forms.py b/common/djangoapps/auth_exchange/tests/test_forms.py index ffeffb4e22..490628f868 100644 --- a/common/djangoapps/auth_exchange/tests/test_forms.py +++ b/common/djangoapps/auth_exchange/tests/test_forms.py @@ -12,10 +12,12 @@ import httpretty from provider import scope import social.apps.django_app.utils as social_utils -from auth_exchange.forms import AccessTokenExchangeForm -from auth_exchange.tests.utils import AccessTokenExchangeTestMixin from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle +from ..forms import AccessTokenExchangeForm +from .utils import AccessTokenExchangeTestMixin +from .mixins import DOPAdapterMixin, DOTAdapterMixin + class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): """ @@ -31,7 +33,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): self.request.backend = social_utils.load_backend(self.request.social_strategy, self.BACKEND, redirect_uri) def _assert_error(self, data, expected_error, expected_error_description): - form = AccessTokenExchangeForm(request=self.request, data=data) + form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data) self.assertEqual( form.errors, {"error": expected_error, "error_description": expected_error_description} @@ -39,7 +41,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): self.assertNotIn("partial_pipeline", self.request.session) def _assert_success(self, data, expected_scopes): - form = AccessTokenExchangeForm(request=self.request, data=data) + form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data) self.assertTrue(form.is_valid()) self.assertEqual(form.cleaned_data["user"], self.user) self.assertEqual(form.cleaned_data["client"], self.oauth_client) @@ -49,13 +51,15 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): # This is necessary because cms does not implement third party auth @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @httpretty.activate -class AccessTokenExchangeFormTestFacebook( +class DOPAccessTokenExchangeFormTestFacebook( + DOPAdapterMixin, AccessTokenExchangeFormTest, ThirdPartyOAuthTestMixinFacebook, - TestCase + TestCase, ): """ - Tests for AccessTokenExchangeForm used with Facebook + Tests for AccessTokenExchangeForm used with Facebook, tested against + django-oauth2-provider (DOP). """ pass @@ -63,12 +67,46 @@ class AccessTokenExchangeFormTestFacebook( # This is necessary because cms does not implement third party auth @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @httpretty.activate -class AccessTokenExchangeFormTestGoogle( +class DOTAccessTokenExchangeFormTestFacebook( + DOTAdapterMixin, AccessTokenExchangeFormTest, - ThirdPartyOAuthTestMixinGoogle, - TestCase + ThirdPartyOAuthTestMixinFacebook, + TestCase, ): """ - Tests for AccessTokenExchangeForm used with Google + Tests for AccessTokenExchangeForm used with Facebook, tested against + django-oauth-toolkit (DOT). + """ + pass + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class DOPAccessTokenExchangeFormTestGoogle( + DOPAdapterMixin, + AccessTokenExchangeFormTest, + ThirdPartyOAuthTestMixinGoogle, + TestCase, +): + """ + Tests for AccessTokenExchangeForm used with Google, tested against + django-oauth2-provider (DOP). + """ + pass + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class DOTAccessTokenExchangeFormTestGoogle( + DOTAdapterMixin, + AccessTokenExchangeFormTest, + ThirdPartyOAuthTestMixinGoogle, + TestCase, +): + """ + Tests for AccessTokenExchangeForm used with Google, tested against + django-oauth-toolkit (DOT). """ pass diff --git a/common/djangoapps/auth_exchange/tests/test_views.py b/common/djangoapps/auth_exchange/tests/test_views.py index 5c0b7503f9..60c687960c 100644 --- a/common/djangoapps/auth_exchange/tests/test_views.py +++ b/common/djangoapps/auth_exchange/tests/test_views.py @@ -1,25 +1,30 @@ -# pylint: disable=no-member """ Tests for OAuth token exchange views """ + +# pylint: disable=no-member + from datetime import timedelta import json import mock import unittest +import ddt from django.conf import settings from django.core.urlresolvers import reverse from django.test import TestCase import httpretty import provider.constants -from provider import scope from provider.oauth2.models import AccessToken, Client +from rest_framework.test import APIClient -from auth_exchange.tests.utils import AccessTokenExchangeTestMixin from student.tests.factories import UserFactory from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle +from .mixins import DOPAdapterMixin, DOTAdapterMixin +from .utils import AccessTokenExchangeTestMixin +@ddt.ddt class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): """ Mixin that defines test cases for AccessTokenExchangeView @@ -27,33 +32,34 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): def setUp(self): super(AccessTokenExchangeViewTest, self).setUp() self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND}) + self.csrf_client = APIClient(enforce_csrf_checks=True) def _assert_error(self, data, expected_error, expected_error_description): - response = self.client.post(self.url, data) + response = self.csrf_client.post(self.url, data) self.assertEqual(response.status_code, 400) self.assertEqual(response["Content-Type"], "application/json") self.assertEqual( json.loads(response.content), - {"error": expected_error, "error_description": expected_error_description} + {u"error": expected_error, u"error_description": expected_error_description} ) self.assertNotIn("partial_pipeline", self.client.session) def _assert_success(self, data, expected_scopes): - response = self.client.post(self.url, data) + response = self.csrf_client.post(self.url, data) self.assertEqual(response.status_code, 200) self.assertEqual(response["Content-Type"], "application/json") content = json.loads(response.content) - self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"}) + self.assertEqual(set(content.keys()), self.get_token_response_keys()) self.assertEqual(content["token_type"], "Bearer") self.assertLessEqual( timedelta(seconds=int(content["expires_in"])), provider.constants.EXPIRE_DELTA_PUBLIC ) - self.assertEqual(content["scope"], " ".join(expected_scopes)) - token = AccessToken.objects.get(token=content["access_token"]) + self.assertEqual(content["scope"], self.oauth2_adapter.normalize_scopes(expected_scopes)) + token = self.oauth2_adapter.get_access_token(token_string=content["access_token"]) self.assertEqual(token.user, self.user) - self.assertEqual(token.client, self.oauth_client) - self.assertEqual(scope.to_names(token.scope), expected_scopes) + self.assertEqual(self.oauth2_adapter.get_client_for_token(token), self.oauth_client) + self.assertEqual(self.oauth2_adapter.get_token_scope_names(token), expected_scopes) def test_single_access_token(self): def extract_token(response): @@ -64,16 +70,15 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): self._setup_provider_response(success=True) for single_access_token in [True, False]: - with mock.patch( - "auth_exchange.views.constants.SINGLE_ACCESS_TOKEN", - single_access_token - ): + with mock.patch("auth_exchange.views.constants.SINGLE_ACCESS_TOKEN", single_access_token): first_response = self.client.post(self.url, self.data) second_response = self.client.post(self.url, self.data) - self.assertEqual( - extract_token(first_response) == extract_token(second_response), - single_access_token - ) + self.assertEqual(first_response.status_code, 200) + self.assertEqual(second_response.status_code, 200) + self.assertEqual( + extract_token(first_response) == extract_token(second_response), + single_access_token + ) def test_get_method(self): response = self.client.get(self.url, self.data) @@ -95,10 +100,11 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): # This is necessary because cms does not implement third party auth @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @httpretty.activate -class AccessTokenExchangeViewTestFacebook( +class DOPAccessTokenExchangeViewTestFacebook( + DOPAdapterMixin, AccessTokenExchangeViewTest, ThirdPartyOAuthTestMixinFacebook, - TestCase + TestCase, ): """ Tests for AccessTokenExchangeView used with Facebook @@ -106,16 +112,48 @@ class AccessTokenExchangeViewTestFacebook( pass +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class DOTAccessTokenExchangeViewTestFacebook( + DOTAdapterMixin, + AccessTokenExchangeViewTest, + ThirdPartyOAuthTestMixinFacebook, + TestCase, +): + """ + Rerun AccessTokenExchangeViewTestFacebook tests against DOT backend + """ + pass + + # This is necessary because cms does not implement third party auth @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @httpretty.activate -class AccessTokenExchangeViewTestGoogle( +class DOPAccessTokenExchangeViewTestGoogle( + DOPAdapterMixin, AccessTokenExchangeViewTest, ThirdPartyOAuthTestMixinGoogle, - TestCase + TestCase, ): """ - Tests for AccessTokenExchangeView used with Google + Tests for AccessTokenExchangeView used with Google using + django-oauth2-provider backend. + """ + pass + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class DOTAccessTokenExchangeViewTestGoogle( + DOTAdapterMixin, + AccessTokenExchangeViewTest, + ThirdPartyOAuthTestMixinGoogle, + TestCase, +): + """ + Tests for AccessTokenExchangeView used with Google using + django-oauth-toolkit backend. """ pass diff --git a/common/djangoapps/auth_exchange/tests/utils.py b/common/djangoapps/auth_exchange/tests/utils.py index 60608f292f..2629557c08 100644 --- a/common/djangoapps/auth_exchange/tests/utils.py +++ b/common/djangoapps/auth_exchange/tests/utils.py @@ -1,9 +1,8 @@ """ Test utilities for OAuth access token exchange """ -import provider.constants -from social.apps.django_app.default.models import UserSocialAuth +from social.apps.django_app.default.models import UserSocialAuth from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin @@ -37,6 +36,12 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin): """ raise NotImplementedError() + def _create_client(self): + """ + Create an oauth2 client application using class defaults. + """ + return self.create_public_client(self.user, self.client_id) + def test_minimal(self): self._setup_provider_response(success=True) self._assert_success(self.data, expected_scopes=[]) @@ -61,12 +66,12 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin): ) def test_confidential_client(self): - self.oauth_client.client_type = provider.constants.CONFIDENTIAL - self.oauth_client.save() + self.data['client_id'] += '_confidential' + self.oauth_client = self.create_confidential_client(self.user, self.data['client_id']) self._assert_error( self.data, "invalid_client", - "test_client_id is not a public client" + "{}_confidential is not a public client".format(self.client_id), ) def test_inactive_user(self): diff --git a/common/djangoapps/auth_exchange/views.py b/common/djangoapps/auth_exchange/views.py index abd31de9ec..6669e489c5 100644 --- a/common/djangoapps/auth_exchange/views.py +++ b/common/djangoapps/auth_exchange/views.py @@ -1,4 +1,3 @@ -# pylint: disable=abstract-method """ Views to support exchange of authentication credentials. The following are currently implemented: @@ -7,36 +6,52 @@ The following are currently implemented: 2. LoginWithAccessTokenView: 1st party (open-edx) OAuth 2.0 access token -> session cookie """ + +# pylint: disable=abstract-method + from django.conf import settings from django.contrib.auth import login import django.contrib.auth as auth from django.http import HttpResponse from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from edx_oauth2_provider.constants import SCOPE_VALUE_DICT +from oauth2_provider.settings import oauth2_settings +from oauth2_provider.views.base import TokenView as DOTAccessTokenView +from oauthlib.oauth2.rfc6749.tokens import BearerToken from provider import constants -from provider.oauth2.views import AccessTokenView as AccessTokenView +from provider.oauth2.views import AccessTokenView as DOPAccessTokenView from rest_framework import permissions +from rest_framework.response import Response from rest_framework.views import APIView import social.apps.django_app.utils as social_utils from auth_exchange.forms import AccessTokenExchangeForm +from lms.djangoapps.oauth_dispatch import adapters from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser -class AccessTokenExchangeView(AccessTokenView): +class AccessTokenExchangeBase(APIView): """ - View for token exchange from 3rd party OAuth access token to 1st party OAuth access token + View for token exchange from 3rd party OAuth access token to 1st party + OAuth access token. """ @method_decorator(csrf_exempt) @method_decorator(social_utils.strategy("social:complete")) def dispatch(self, *args, **kwargs): - return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs) + return super(AccessTokenExchangeBase, self).dispatch(*args, **kwargs) def get(self, request, _backend): # pylint: disable=arguments-differ - return super(AccessTokenExchangeView, self).get(request) + """ + Pass through GET requests without the _backend + """ + return super(AccessTokenExchangeBase, self).get(request) def post(self, request, _backend): # pylint: disable=arguments-differ - form = AccessTokenExchangeForm(request=request, data=request.POST) + """ + Handle POST requests to get a first-party access token. + """ + form = AccessTokenExchangeForm(request=request, oauth2_adapter=self.oauth2_adapter, data=request.POST) # pylint: disable=no-member if not form.is_valid(): return self.error_response(form.errors) @@ -44,12 +59,89 @@ class AccessTokenExchangeView(AccessTokenView): scope = form.cleaned_data["scope"] client = form.cleaned_data["client"] + return self.exchange_access_token(request, user, scope, client) + + def exchange_access_token(self, request, user, scope, client): + """ + Exchange third party credentials for an edx access token, and return a + serialized access token response. + """ if constants.SINGLE_ACCESS_TOKEN: - edx_access_token = self.get_access_token(request, user, scope, client) + edx_access_token = self.get_access_token(request, user, scope, client) # pylint: disable=no-member else: edx_access_token = self.create_access_token(request, user, scope, client) + return self.access_token_response(edx_access_token) # pylint: disable=no-member - return self.access_token_response(edx_access_token) + +class DOPAccessTokenExchangeView(AccessTokenExchangeBase, DOPAccessTokenView): + """ + View for token exchange from 3rd party OAuth access token to 1st party + OAuth access token. Uses django-oauth2-provider (DOP) to manage access + tokens. + """ + + oauth2_adapter = adapters.DOPAdapter() + + +class DOTAccessTokenExchangeView(AccessTokenExchangeBase, DOTAccessTokenView): + """ + View for token exchange from 3rd party OAuth access token to 1st party + OAuth access token. Uses django-oauth-toolkit (DOT) to manage access + tokens. + """ + + oauth2_adapter = adapters.DOTAdapter() + + def get(self, request, _backend): + return Response(status=400, data={ + 'error': 'invalid_request', + 'error_description': 'Only POST requests allowed.', + }) + + def get_access_token(self, request, user, scope, client): + """ + TODO: MA-2122: Reusing access tokens is not yet supported for DOT. + Just return a new access token. + """ + return self.create_access_token(request, user, scope, client) + + def create_access_token(self, request, user, scope, client): + """ + Create and return a new access token. + """ + _days = 24 * 60 * 60 + token_generator = BearerToken( + expires_in=settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * _days, + request_validator=oauth2_settings.OAUTH2_VALIDATOR_CLASS(), + ) + self._populate_create_access_token_request(request, user, scope, client) + return token_generator.create_token(request, refresh_token=True) + + def access_token_response(self, token): + """ + Wrap an access token in an appropriate response + """ + return Response(data=token) + + def _populate_create_access_token_request(self, request, user, scope, client): + """ + django-oauth-toolkit expects certain non-standard attributes to + be present on the request object. This function modifies the + request object to match these expectations + """ + request.user = user + request.scopes = [SCOPE_VALUE_DICT[scope]] + request.client = client + request.state = None + request.refresh_token = None + request.extra_credentials = None + request.grant_type = client.authorization_grant_type + + def error_response(self, form_errors): + """ + Return an error response consisting of the errors in the form + """ + return Response(status=400, data=form_errors) class LoginWithAccessTokenView(APIView): diff --git a/common/djangoapps/third_party_auth/tests/utils.py b/common/djangoapps/third_party_auth/tests/utils.py index cce2edd59b..588d23dce9 100644 --- a/common/djangoapps/third_party_auth/tests/utils.py +++ b/common/djangoapps/third_party_auth/tests/utils.py @@ -22,23 +22,30 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin): USER_URL: The URL of the endpoint that the backend retrieves user data from UID_FIELD: The field in the user data that the backend uses as the user id """ + social_uid = "test_social_uid" + access_token = "test_access_token" + client_id = "test_client_id" + def setUp(self, create_user=True): super(ThirdPartyOAuthTestMixin, self).setUp() - self.social_uid = "test_social_uid" - self.access_token = "test_access_token" - self.client_id = "test_client_id" - self.oauth_client = Client.objects.create( - client_id=self.client_id, - client_type=PUBLIC - ) if create_user: self.user = UserFactory() UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid) + self.oauth_client = self._create_client() if self.BACKEND == 'google-oauth2': self.configure_google_provider(enabled=True) elif self.BACKEND == 'facebook': self.configure_facebook_provider(enabled=True) + def _create_client(self): + """ + Create an OAuth2 client application + """ + return Client.objects.create( + client_id=self.client_id, + client_type=PUBLIC, + ) + def _setup_provider_response(self, success=False, email=''): """ Register a mock response for the third party user information endpoint; @@ -65,7 +72,7 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin): self.USER_URL, body=body, status=status, - content_type="application/json" + content_type="application/json", ) diff --git a/common/test/db_cache/lettuce.db b/common/test/db_cache/lettuce.db index 9a9cf04105..c90f206e86 100644 Binary files a/common/test/db_cache/lettuce.db and b/common/test/db_cache/lettuce.db differ diff --git a/lms/djangoapps/oauth_dispatch/__init__.py b/lms/djangoapps/oauth_dispatch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lms/djangoapps/oauth_dispatch/adapters/__init__.py b/lms/djangoapps/oauth_dispatch/adapters/__init__.py new file mode 100644 index 0000000000..1f6227b50a --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/adapters/__init__.py @@ -0,0 +1,7 @@ +""" +Adapters to provide a common interface to django-oauth2-provider (DOP) and +django-oauth-toolkit (DOT). +""" + +from .dop import DOPAdapter +from .dot import DOTAdapter diff --git a/lms/djangoapps/oauth_dispatch/adapters/dop.py b/lms/djangoapps/oauth_dispatch/adapters/dop.py new file mode 100644 index 0000000000..b12994b96f --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/adapters/dop.py @@ -0,0 +1,70 @@ +""" +Adapter to isolate django-oauth2-provider dependencies +""" + +from provider.oauth2 import models +from provider import constants, scope + + +class DOPAdapter(object): + """ + Standard interface for working with django-oauth2-provider + """ + + backend = object() + + def create_confidential_client(self, name, user, redirect_uri, client_id=None): + """ + Create an oauth client application that is confidential. + """ + return models.Client.objects.create( + name=name, + user=user, + client_id=client_id, + redirect_uri=redirect_uri, + client_type=constants.CONFIDENTIAL, + ) + + def create_public_client(self, name, user, redirect_uri, client_id=None): + """ + Create an oauth client application that is public. + """ + return models.Client.objects.create( + name=name, + user=user, + client_id=client_id, + redirect_uri=redirect_uri, + client_type=constants.PUBLIC, + ) + + def get_client(self, **filters): + """ + Get the oauth client application with the specified filters. + + Wraps django's queryset.get() method. + """ + return models.Client.objects.get(**filters) + + def get_client_for_token(self, token): + """ + Given an AccessToken object, return the associated client application. + """ + return token.client + + def get_access_token(self, token_string): + """ + Given a token string, return the matching AccessToken object. + """ + return models.AccessToken.objects.get(token=token_string) + + def normalize_scopes(self, scopes): + """ + Given a list of scopes, return a space-separated list of those scopes. + """ + return ' '.join(scopes) + + def get_token_scope_names(self, token): + """ + Given an access token object, return its scopes. + """ + return scope.to_names(token.scope) diff --git a/lms/djangoapps/oauth_dispatch/adapters/dot.py b/lms/djangoapps/oauth_dispatch/adapters/dot.py new file mode 100644 index 0000000000..84dcb7ece4 --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/adapters/dot.py @@ -0,0 +1,73 @@ +""" +Adapter to isolate django-oauth-toolkit dependencies +""" + +from oauth2_provider import models + + +class DOTAdapter(object): + """ + Standard interface for working with django-oauth-toolkit + """ + + backend = object() + + def create_confidential_client(self, name, user, redirect_uri, client_id=None): + """ + Create an oauth client application that is confidential. + """ + return models.Application.objects.create( + name=name, + user=user, + client_id=client_id, + client_type=models.Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE, + redirect_uris=redirect_uri, + ) + + def create_public_client(self, name, user, redirect_uri, client_id=None): + """ + Create an oauth client application that is public. + """ + return models.Application.objects.create( + name=name, + user=user, + client_id=client_id, + client_type=models.Application.CLIENT_PUBLIC, + authorization_grant_type=models.Application.GRANT_PASSWORD, + redirect_uris=redirect_uri, + ) + + def get_client(self, **filters): + """ + Get the oauth client application with the specified filters. + + Wraps django's queryset.get() method. + """ + return models.Application.objects.get(**filters) + + def get_client_for_token(self, token): + """ + Given an AccessToken object, return the associated client application. + """ + return token.application + + def get_access_token(self, token_string): + """ + Given a token string, return the matching AccessToken object. + """ + return models.AccessToken.objects.get(token=token_string) + + def normalize_scopes(self, scopes): + """ + Given a list of scopes, return a space-separated list of those scopes. + """ + if not scopes: + scopes = ['default'] + return ' '.join(scopes) + + def get_token_scope_names(self, token): + """ + Given an access token object, return its scopes. + """ + return list(token.scopes) diff --git a/lms/djangoapps/oauth_dispatch/tests/__init__.py b/lms/djangoapps/oauth_dispatch/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lms/djangoapps/oauth_dispatch/tests/constants.py b/lms/djangoapps/oauth_dispatch/tests/constants.py new file mode 100644 index 0000000000..b38868bcfb --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/tests/constants.py @@ -0,0 +1,5 @@ +""" +Constants for testing purposes +""" + +DUMMY_REDIRECT_URL = u'https://example.edx/redirect' diff --git a/lms/djangoapps/oauth_dispatch/tests/mixins.py b/lms/djangoapps/oauth_dispatch/tests/mixins.py new file mode 100644 index 0000000000..8b16c53fc2 --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/tests/mixins.py @@ -0,0 +1,3 @@ +""" +OAuth Dispatch test mixins +""" diff --git a/lms/djangoapps/oauth_dispatch/tests/test_dop_adapter.py b/lms/djangoapps/oauth_dispatch/tests/test_dop_adapter.py new file mode 100644 index 0000000000..6dfe9ac9d4 --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/tests/test_dop_adapter.py @@ -0,0 +1,77 @@ +""" +Tests for DOP Adapter +""" + +from datetime import timedelta + +import ddt +from django.test import TestCase +from django.utils.timezone import now +from provider.oauth2 import models +from provider import constants + +from student.tests.factories import UserFactory + +from ..adapters import DOPAdapter +from .constants import DUMMY_REDIRECT_URL + + +@ddt.ddt +class DOPAdapterTestCase(TestCase): + """ + Test class for DOPAdapter. + """ + + adapter = DOPAdapter() + + def setUp(self): + super(DOPAdapterTestCase, self).setUp() + self.user = UserFactory() + self.public_client = self.adapter.create_public_client( + name='public client', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='public-client-id', + ) + self.confidential_client = self.adapter.create_confidential_client( + name='confidential client', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='confidential-client-id', + ) + + @ddt.data( + ('confidential', constants.CONFIDENTIAL), + ('public', constants.PUBLIC), + ) + @ddt.unpack + def test_create_client(self, client_name, client_type): + client = getattr(self, '{}_client'.format(client_name)) + self.assertIsInstance(client, models.Client) + self.assertEqual(client.client_id, '{}-client-id'.format(client_name)) + self.assertEqual(client.client_type, client_type) + + def test_get_client(self): + client = self.adapter.get_client(client_type=constants.CONFIDENTIAL) + self.assertIsInstance(client, models.Client) + self.assertEqual(client.client_type, constants.CONFIDENTIAL) + + def test_get_client_not_found(self): + with self.assertRaises(models.Client.DoesNotExist): + self.adapter.get_client(client_id='not-found') + + def test_get_client_for_token(self): + token = models.AccessToken( + user=self.user, + client=self.public_client, + ) + self.assertEqual(self.adapter.get_client_for_token(token), self.public_client) + + def test_get_access_token(self): + token = models.AccessToken.objects.create( + token='token-id', + client=self.public_client, + user=self.user, + expires=now() + timedelta(days=30), + ) + self.assertEqual(self.adapter.get_access_token(token_string='token-id'), token) diff --git a/lms/djangoapps/oauth_dispatch/tests/test_dot_adapter.py b/lms/djangoapps/oauth_dispatch/tests/test_dot_adapter.py new file mode 100644 index 0000000000..a623e8cc6d --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/tests/test_dot_adapter.py @@ -0,0 +1,76 @@ +""" +Tests for DOT Adapter +""" + +from datetime import timedelta + +import ddt +from django.test import TestCase +from django.utils.timezone import now +from oauth2_provider import models + +from student.tests.factories import UserFactory + +from ..adapters import DOTAdapter +from .constants import DUMMY_REDIRECT_URL + + +@ddt.ddt +class DOTAdapterTestCase(TestCase): + """ + Test class for DOTAdapter. + """ + + adapter = DOTAdapter() + + def setUp(self): + super(DOTAdapterTestCase, self).setUp() + self.user = UserFactory() + self.public_client = self.adapter.create_public_client( + name='public app', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='public-client-id', + ) + self.confidential_client = self.adapter.create_confidential_client( + name='confidential app', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='confidential-client-id', + ) + + @ddt.data( + ('confidential', models.Application.CLIENT_CONFIDENTIAL), + ('public', models.Application.CLIENT_PUBLIC), + ) + @ddt.unpack + def test_create_client(self, client_name, client_type): + client = getattr(self, '{}_client'.format(client_name)) + self.assertIsInstance(client, models.Application) + self.assertEqual(client.client_id, '{}-client-id'.format(client_name)) + self.assertEqual(client.client_type, client_type) + + def test_get_client(self): + client = self.adapter.get_client(client_type=models.Application.CLIENT_CONFIDENTIAL) + self.assertIsInstance(client, models.Application) + self.assertEqual(client.client_type, models.Application.CLIENT_CONFIDENTIAL) + + def test_get_client_not_found(self): + with self.assertRaises(models.Application.DoesNotExist): + self.adapter.get_client(client_id='not-found') + + def test_get_client_for_token(self): + token = models.AccessToken( + user=self.user, + application=self.public_client, + ) + self.assertEqual(self.adapter.get_client_for_token(token), self.public_client) + + def test_get_access_token(self): + token = models.AccessToken.objects.create( + token='token-id', + application=self.public_client, + user=self.user, + expires=now() + timedelta(days=30), + ) + self.assertEqual(self.adapter.get_access_token(token_string='token-id'), token) diff --git a/lms/djangoapps/oauth_dispatch/tests/test_views.py b/lms/djangoapps/oauth_dispatch/tests/test_views.py new file mode 100644 index 0000000000..b8fb7435ff --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/tests/test_views.py @@ -0,0 +1,251 @@ +""" +Tests for Blocks Views +""" + +import json + +import ddt +from django.test import RequestFactory, TestCase +from django.core.urlresolvers import reverse +import httpretty + +from student.tests.factories import UserFactory +from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle + +from .. import adapters +from .. import views +from .constants import DUMMY_REDIRECT_URL + + +class _DispatchingViewTestCase(TestCase): + """ + Base class for tests that exercise DispatchingViews. + """ + dop_adapter = adapters.DOPAdapter() + dot_adapter = adapters.DOTAdapter() + + view_class = None + url = None + + def setUp(self): + super(_DispatchingViewTestCase, self).setUp() + self.user = UserFactory() + self.dot_app = self.dot_adapter.create_public_client( + name='test dot application', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='dot-app-client-id', + ) + self.dop_client = self.dop_adapter.create_public_client( + name='test dop client', + user=self.user, + redirect_uri=DUMMY_REDIRECT_URL, + client_id='dop-app-client-id', + ) + + def _post_request(self, user, client): + """ + Call the view with a POST request objectwith the appropriate format, + returning the response object. + """ + return self.client.post(self.url, self._post_body(user, client)) + + def _post_body(self, user, client): + """ + Return a dictionary to be used as the body of the POST request + """ + raise NotImplementedError() + + +@ddt.ddt +class TestAccessTokenView(_DispatchingViewTestCase): + """ + Test class for AccessTokenView + """ + + view_class = views.AccessTokenView + url = reverse('access_token') + + def _post_body(self, user, client): + """ + Return a dictionary to be used as the body of the POST request + """ + return { + 'client_id': client.client_id, + 'grant_type': 'password', + 'username': user.username, + 'password': 'test', + } + + @ddt.data('dop_client', 'dot_app') + def test_access_token_fields(self, client_attr): + client = getattr(self, client_attr) + response = self._post_request(self.user, client) + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertIn('access_token', data) + self.assertIn('expires_in', data) + self.assertIn('scope', data) + self.assertIn('token_type', data) + + def test_dot_access_token_provides_refresh_token(self): + response = self._post_request(self.user, self.dot_app) + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertIn('refresh_token', data) + + def test_dop_public_client_access_token(self): + response = self._post_request(self.user, self.dop_client) + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertNotIn('refresh_token', data) + + +@ddt.ddt +@httpretty.activate +class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAuthTestMixin, _DispatchingViewTestCase): + """ + Test class for AccessTokenExchangeView + """ + + view_class = views.AccessTokenExchangeView + url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'}) + + def _post_body(self, user, client): + return { + 'client_id': client.client_id, + 'access_token': self.access_token, + } + + @ddt.data('dop_client', 'dot_app') + def test_access_token_exchange_calls_dispatched_view(self, client_attr): + client = getattr(self, client_attr) + self.oauth_client = client + self._setup_provider_response(success=True) + response = self._post_request(self.user, client) + self.assertEqual(response.status_code, 200) + + +@ddt.ddt +class TestAuthorizationView(TestCase): + """ + Test class for AuthorizationView + """ + + dop_adapter = adapters.DOPAdapter() + + def setUp(self): + super(TestAuthorizationView, self).setUp() + self.user = UserFactory() + self.dop_client = self._create_confidential_client(user=self.user, client_id='dop-app-client-id') + + def _create_confidential_client(self, user, client_id): + """ + Create a confidential client suitable for testing purposes. + """ + return self.dop_adapter.create_confidential_client( + name='test_app', + user=user, + client_id=client_id, + redirect_uri=DUMMY_REDIRECT_URL + ) + + def test_authorization_view(self): + self.client.login(username=self.user.username, password='test') + response = self.client.post( + '/oauth2/authorize/', + { + 'client_id': self.dop_client.client_id, # TODO: DOT is not yet supported (MA-2124) + 'response_type': 'code', + 'state': 'random_state_string', + 'redirect_uri': DUMMY_REDIRECT_URL, + }, + follow=True, + ) + + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + context = response.context # pylint: disable=no-member + self.assertIn('form', context) + self.assertIsNone(context['form']['authorize'].value()) + + self.assertIn('oauth_data', context) + oauth_data = context['oauth_data'] + self.assertEqual(oauth_data['redirect_uri'], DUMMY_REDIRECT_URL) + self.assertEqual(oauth_data['state'], 'random_state_string') + + +class TestViewDispatch(TestCase): + """ + Test that the DispatchingView dispatches the right way. + """ + + dop_adapter = adapters.DOPAdapter() + dot_adapter = adapters.DOTAdapter() + + def setUp(self): + super(TestViewDispatch, self).setUp() + self.user = UserFactory() + self.view = views._DispatchingView() # pylint: disable=protected-access + self.dop_adapter.create_public_client( + name='', + user=self.user, + client_id='dop-id', + redirect_uri=DUMMY_REDIRECT_URL + ) + self.dot_adapter.create_public_client( + name='', + user=self.user, + client_id='dot-id', + redirect_uri=DUMMY_REDIRECT_URL + ) + + def assert_is_view(self, view_candidate): + """ + Assert that a given object is a view. That is, it is callable, and + takes a request argument. Note: while technically, the request argument + could take any name, this assertion requires the argument to be named + `request`. This is good practice. You should do it anyway. + """ + _msg_base = u'{view} is not a view: {reason}' + msg_not_callable = _msg_base.format(view=view_candidate, reason=u'it is not callable') + msg_no_request = _msg_base.format(view=view_candidate, reason=u'it has no request argument') + self.assertTrue(hasattr(view_candidate, '__call__'), msg_not_callable) + args = view_candidate.func_code.co_varnames + self.assertTrue(args, msg_no_request) + self.assertEqual(args[0], 'request') + + def _get_request(self, client_id): + """ + Return a request with the specified client_id in the body + """ + return RequestFactory().post('/', {'client_id': client_id}) + + def test_dispatching_to_dot(self): + request = self._get_request('dot-id') + self.assertEqual(self.view.select_backend(request), self.dot_adapter.backend) + + def test_dispatching_to_dop(self): + request = self._get_request('dop-id') + self.assertEqual(self.view.select_backend(request), self.dop_adapter.backend) + + def test_dispatching_with_no_client(self): + request = self._get_request(None) + self.assertEqual(self.view.select_backend(request), self.dop_adapter.backend) + + def test_dispatching_with_invalid_client(self): + request = self._get_request('abcesdfljh') + self.assertEqual(self.view.select_backend(request), self.dop_adapter.backend) + + def test_get_view_for_dot(self): + view_object = views.AccessTokenView() + self.assert_is_view(view_object.get_view_for_backend(self.dot_adapter.backend)) + + def test_get_view_for_dop(self): + view_object = views.AccessTokenView() + self.assert_is_view(view_object.get_view_for_backend(self.dop_adapter.backend)) + + def test_get_view_for_no_backend(self): + view_object = views.AccessTokenView() + self.assertRaises(KeyError, view_object.get_view_for_backend, None) diff --git a/lms/djangoapps/oauth_dispatch/urls.py b/lms/djangoapps/oauth_dispatch/urls.py new file mode 100644 index 0000000000..59e9068bdb --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/urls.py @@ -0,0 +1,25 @@ +""" +OAuth2 wrapper urls +""" + +from django.conf import settings +from django.conf.urls import patterns, url +from django.views.decorators.csrf import csrf_exempt + +from . import views + + +urlpatterns = patterns( + '', + # TODO: authorize/ URL not yet supported for DOT (MA-2124) + url(r'^access_token/?$', csrf_exempt(views.AccessTokenView.as_view()), name='access_token'), +) + +if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): + urlpatterns += ( + url( + r'^exchange_access_token/(?P[^/]+)/$', + csrf_exempt(views.AccessTokenExchangeView.as_view()), + name='exchange_access_token', + ), + ) diff --git a/lms/djangoapps/oauth_dispatch/views.py b/lms/djangoapps/oauth_dispatch/views.py new file mode 100644 index 0000000000..22652f127b --- /dev/null +++ b/lms/djangoapps/oauth_dispatch/views.py @@ -0,0 +1,89 @@ +""" +Views that dispatch processing of OAuth requests to django-oauth2-provider or +django-oauth-toolkit as appropriate. +""" + +from __future__ import unicode_literals + +from django.views.generic import View +from edx_oauth2_provider import views as dop_views # django-oauth2-provider views +from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit + +from auth_exchange import views as auth_exchange_views + +from . import adapters + + +class _DispatchingView(View): + """ + Base class that route views to the appropriate provider view. The default + behavior routes based on client_id, but this can be overridden by redefining + `select_backend()` if particular views need different behavior. + """ + # pylint: disable=no-member + + dot_adapter = adapters.DOTAdapter() + dop_adapter = adapters.DOPAdapter() + + def dispatch(self, request, *args, **kwargs): + """ + Dispatch the request to the selected backend's view. + """ + backend = self.select_backend(request) + view = self.get_view_for_backend(backend) + return view(request, *args, **kwargs) + + def select_backend(self, request): + """ + Given a request that specifies an oauth `client_id`, return the adapter + for the appropriate OAuth handling library. If the client_id is found + in a django-oauth-toolkit (DOT) Application, use the DOT adapter, + otherwise use the django-oauth2-provider (DOP) adapter, and allow the + calls to fail normally if the client does not exist. + """ + + if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists(): + return self.dot_adapter.backend + else: + return self.dop_adapter.backend + + def get_view_for_backend(self, backend): + """ + Return the appropriate view from the requested backend. + """ + if backend == self.dot_adapter.backend: + return self.dot_view.as_view() + elif backend == self.dop_adapter.backend: + return self.dop_view.as_view() + else: + raise KeyError('Failed to dispatch view. Invalid backend {}'.format(backend)) + + def _get_client_id(self, request): + """ + Return the client_id from the provided request + """ + return request.POST.get('client_id') + + +class AccessTokenView(_DispatchingView): + """ + Handle access token requests. + """ + dot_view = dot_views.TokenView + dop_view = dop_views.AccessTokenView + + +class AuthorizationView(_DispatchingView): + """ + Part of the authorization flow. + """ + dop_view = dop_views.Capture + dot_view = dot_views.AuthorizationView + + +class AccessTokenExchangeView(_DispatchingView): + """ + Exchange a third party auth token. + """ + dop_view = auth_exchange_views.DOPAccessTokenExchangeView + dot_view = auth_exchange_views.DOTAccessTokenExchangeView diff --git a/lms/djangoapps/support/tests/test_programs.py b/lms/djangoapps/support/tests/test_programs.py index 9d2356e40f..2bad288ddf 100644 --- a/lms/djangoapps/support/tests/test_programs.py +++ b/lms/djangoapps/support/tests/test_programs.py @@ -2,7 +2,7 @@ from django.core.urlresolvers import reverse from django.test import TestCase import mock -from oauth2_provider.tests.factories import AccessTokenFactory, ClientFactory +from edx_oauth2_provider.tests.factories import AccessTokenFactory, ClientFactory from openedx.core.djangoapps.programs.tests.mixins import ProgramsApiConfigMixin from student.tests.factories import UserFactory diff --git a/lms/envs/common.py b/lms/envs/common.py index bcb6bb156a..9b1c85f206 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1834,11 +1834,14 @@ INSTALLED_APPS = ( 'external_auth', 'django_openid_auth', - # OAuth2 Provider + # django-oauth2-provider (deprecated) 'provider', 'provider.oauth2', 'edx_oauth2_provider', + # django-oauth-toolkit + 'oauth2_provider', + 'third_party_auth', # We don't use this directly (since we use OAuth2), but we need to install it anyway. diff --git a/lms/urls.py b/lms/urls.py index 964f7a495c..18c5cb80c4 100644 --- a/lms/urls.py +++ b/lms/urls.py @@ -820,10 +820,19 @@ if settings.FEATURES.get('AUTH_USE_OPENID_PROVIDER'): if settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'): urlpatterns += ( + # These URLs dispatch to django-oauth-toolkit or django-oauth2-provider as appropriate. + # Developers should use these routes, to maintain compatibility for existing client code + url(r'^oauth2/', include('lms.djangoapps.oauth_dispatch.urls')), + # These URLs contain the django-oauth2-provider default behavior. It exists to provide + # URLs for django-oauth2-provider to call using reverse() with the oauth2 namespace, and + # also to maintain support for views that have not yet been wrapped in dispatch views. url(r'^oauth2/', include('edx_oauth2_provider.urls', namespace='oauth2')), + # The /_o/ prefix exists to provide a target for code in django-oauth-toolkit that + # uses reverse() with the 'oauth2_provider' namespace. Developers should not access these + # views directly, but should rather use the wrapped views at /oauth2/ + url(r'^_o/', include('oauth2_provider.urls', namespace='oauth2_provider')), ) - if settings.FEATURES.get('ENABLE_LMS_MIGRATION'): urlpatterns += ( url(r'^migrate/modules$', 'lms_migration.migrate.manage_modulestores'), @@ -888,14 +897,6 @@ if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): # OAuth token exchange if settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'): - if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): - urlpatterns += ( - url( - r'^oauth2/exchange_access_token/(?P[^/]+)/$', - auth_exchange.views.AccessTokenExchangeView.as_view(), - name="exchange_access_token" - ), - ) urlpatterns += ( url( r'^oauth2/login/$', diff --git a/openedx/core/djangoapps/programs/tasks/v1/tests/test_tasks.py b/openedx/core/djangoapps/programs/tasks/v1/tests/test_tasks.py index 93cf9a5370..397bd593c1 100644 --- a/openedx/core/djangoapps/programs/tasks/v1/tests/test_tasks.py +++ b/openedx/core/djangoapps/programs/tasks/v1/tests/test_tasks.py @@ -12,8 +12,8 @@ from celery.exceptions import MaxRetriesExceededError from django.conf import settings from django.test import override_settings, TestCase from edx_rest_api_client.client import EdxRestApiClient - from edx_oauth2_provider.tests.factories import ClientFactory + from openedx.core.djangoapps.credentials.tests.mixins import CredentialsApiConfigMixin from openedx.core.djangoapps.programs.tests.mixins import ProgramsApiConfigMixin from openedx.core.djangoapps.programs.tasks.v1 import tasks diff --git a/openedx/core/lib/api/authentication.py b/openedx/core/lib/api/authentication.py index 8e773e2f84..95a296d5b4 100644 --- a/openedx/core/lib/api/authentication.py +++ b/openedx/core/lib/api/authentication.py @@ -2,10 +2,13 @@ import logging +import django.utils.timezone from rest_framework.authentication import SessionAuthentication from rest_framework import exceptions as drf_exceptions from rest_framework_oauth.authentication import OAuth2Authentication -from rest_framework_oauth.compat import oauth2_provider, provider_now + +from provider.oauth2 import models as dop_models +from oauth2_provider import models as dot_models from openedx.core.lib.api.exceptions import AuthenticationFailed @@ -114,21 +117,44 @@ class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication): def authenticate_credentials(self, request, access_token): """ Authenticate the request, given the access token. - Overrides base class implementation to discard failure if user is inactive. + + Overrides base class implementation to discard failure if user is + inactive. """ - token_query = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user') - token = token_query.filter(token=access_token).first() + + token = self.get_access_token(access_token) if not token: raise AuthenticationFailed({ u'error_code': OAUTH2_TOKEN_ERROR_NONEXISTENT, u'developer_message': u'The provided access token does not match any valid tokens.' }) - # provider_now switches to timezone aware datetime when - # the oauth2_provider version supports it. - elif token.expires < provider_now(): + elif token.expires < django.utils.timezone.now(): raise AuthenticationFailed({ u'error_code': OAUTH2_TOKEN_ERROR_EXPIRED, u'developer_message': u'The provided access token has expired and is no longer valid.', }) else: return token.user, token + + def get_access_token(self, access_token): + """ + Return a valid access token that exists in one of our OAuth2 libraries, + or None if no matching token is found. + """ + return self._get_dot_token(access_token) or self._get_dop_token(access_token) + + def _get_dop_token(self, access_token): + """ + Return a valid access token stored by django-oauth2-provider (DOP), or + None if no matching token is found. + """ + token_query = dop_models.AccessToken.objects.select_related('user') + return token_query.filter(token=access_token).first() + + def _get_dot_token(self, access_token): + """ + Return a valid access token stored by django-oauth-toolkit (DOT), or + None if no matching token is found. + """ + token_query = dot_models.AccessToken.objects.select_related('user') + return token_query.filter(token=access_token).first() diff --git a/openedx/core/lib/api/tests/test_authentication.py b/openedx/core/lib/api/tests/test_authentication.py index ea572b841b..ab3f077a27 100644 --- a/openedx/core/lib/api/tests/test_authentication.py +++ b/openedx/core/lib/api/tests/test_authentication.py @@ -19,6 +19,7 @@ from django.utils import unittest from django.utils.http import urlencode from mock import patch from nose.plugins.attrib import attr +from oauth2_provider import models as dot_models from rest_framework import exceptions from rest_framework import status from rest_framework.permissions import IsAuthenticated @@ -28,6 +29,7 @@ from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from rest_framework_jwt.settings import api_settings +from lms.djangoapps.oauth_dispatch import adapters from openedx.core.lib.api import authentication from openedx.core.lib.api.tests.mixins import JwtMixin from provider import constants, scope @@ -84,6 +86,8 @@ class OAuth2Tests(TestCase): def setUp(self): super(OAuth2Tests, self).setUp() + self.dop_adapter = adapters.DOPAdapter() + self.dot_adapter = adapters.DOTAdapter() self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' @@ -95,24 +99,35 @@ class OAuth2Tests(TestCase): self.ACCESS_TOKEN = 'access_token' # pylint: disable=invalid-name self.REFRESH_TOKEN = 'refresh_token' # pylint: disable=invalid-name - self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( - client_id=self.CLIENT_ID, - client_secret=self.CLIENT_SECRET, - redirect_uri='', - client_type=0, + self.dop_oauth2_client = self.dop_adapter.create_public_client( name='example', - user=None, + user=self.user, + client_id=self.CLIENT_ID, + redirect_uri='https://example.edx/redirect', ) self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( token=self.ACCESS_TOKEN, - client=self.oauth2_client, + client=self.dop_oauth2_client, user=self.user, ) self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( user=self.user, access_token=self.access_token, - client=self.oauth2_client + client=self.dop_oauth2_client, + ) + + self.dot_oauth2_client = self.dot_adapter.create_public_client( + name='example', + user=self.user, + client_id='dot-client-id', + redirect_uri='https://example.edx/redirect', + ) + self.dot_access_token = dot_models.AccessToken.objects.create( + user=self.user, + token='dot-access-token', + application=self.dot_oauth2_client, + expires=datetime.now() + timedelta(days=30), ) # This is the a change we've made from the django-rest-framework-oauth version @@ -182,6 +197,10 @@ class OAuth2Tests(TestCase): response = self.get_with_bearer_token('/oauth2-test/') self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_get_form_passing_auth_with_dot(self): + response = self.get_with_bearer_token('/oauth2-test/', token=self.dot_access_token.token) + self.assertEqual(response.status_code, status.HTTP_200_OK) + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth_url_transport(self): """Ensure GETing form over OAuth with correct client credentials in form data succeed""" diff --git a/pavelib/prereqs.py b/pavelib/prereqs.py index da314158dd..6df3eb31f3 100644 --- a/pavelib/prereqs.py +++ b/pavelib/prereqs.py @@ -163,6 +163,7 @@ PACKAGES_TO_UNINSTALL = [ "edxval", # Because it was bork-installed somehow. "django-storages", "django-oauth2-provider", # Because now it's called edx-django-oauth2-provider. + "edx-oauth2-provider", # Because it moved from github to pypi ] @@ -203,7 +204,6 @@ def uninstall_python_packages(): # Uninstall the pacakge sh("pip uninstall --disable-pip-version-check -y {}".format(package_name)) uninstalled = True - if not uninstalled: break else: diff --git a/requirements/edx/base.txt b/requirements/edx/base.txt index eb3d248b0a..d3a60a81aa 100644 --- a/requirements/edx/base.txt +++ b/requirements/edx/base.txt @@ -23,6 +23,7 @@ django-mako==0.1.5pre django-model-utils==2.3.1 django-mptt==0.7.4 django-oauth-plus==2.2.8 +django-oauth-toolkit==0.10.0 django-sekizai==0.8.2 django-ses==0.7.0 django-simple-history==1.6.3 @@ -35,9 +36,9 @@ git+https://github.com/edx/django-rest-framework.git@3c72cb5ee5baebc432894737119 django==1.8.11 djangorestframework-jwt==1.7.2 djangorestframework-oauth==1.1.0 -edx-django-oauth2-provider==0.5.0 edx-lint==0.4.3 -edx-oauth2-provider==0.5.9 +edx-django-oauth2-provider==1.0.1 +edx-oauth2-provider==1.0.0 edx-opaque-keys==0.2.1 edx-organizations==0.4.0 edx-rest-api-client==1.2.1