Configure LMS to select oauth2 providing library.

Available backends:

* django-oauth-toolkit (DOT)
* django-oauth2-provider (DOP)

* Use provided client ID to select backend for
  * AccessToken requests
  * third party auth-token exchange
* Create adapters to isolate library-dependent functionality
* Handle django-oauth-toolkit tokens in edX DRF authenticator class

MA-1998
MA-2000
This commit is contained in:
J. Cliff Dyer
2016-02-01 20:36:35 +00:00
parent 88fef8b2a4
commit 1df040228a
29 changed files with 1114 additions and 92 deletions

View File

@@ -8,6 +8,7 @@ import provider.constants
from provider.forms import OAuthForm, OAuthValidationError
from provider.oauth2.forms import ScopeChoiceField, ScopeMixin
from provider.oauth2.models import Client
from oauth2_provider.models import Application
from requests import HTTPError
from social.backends import oauth as social_oauth
from social.exceptions import AuthException
@@ -21,9 +22,10 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False)
client_id = CharField(required=False)
def __init__(self, request, *args, **kwargs):
def __init__(self, request, oauth2_adapter, *args, **kwargs):
super(AccessTokenExchangeForm, self).__init__(*args, **kwargs)
self.request = request
self.oauth2_adapter = oauth2_adapter
def _require_oauth_field(self, field_name):
"""
@@ -68,15 +70,15 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
client_id = self.cleaned_data["client_id"]
try:
client = Client.objects.get(client_id=client_id)
except Client.DoesNotExist:
client = self.oauth2_adapter.get_client(client_id=client_id)
except (Client.DoesNotExist, Application.DoesNotExist):
raise OAuthValidationError(
{
"error": "invalid_client",
"error_description": "{} is not a valid client_id".format(client_id),
}
)
if client.client_type != provider.constants.PUBLIC:
if client.client_type not in [provider.constants.PUBLIC, Application.CLIENT_PUBLIC]:
raise OAuthValidationError(
{
# invalid_client isn't really the right code, but this mirrors

View File

@@ -0,0 +1,111 @@
"""
Mixins to facilitate testing OAuth connections to Django-OAuth-Toolkit or
Django-OAuth2-Provider.
"""
# pylint: disable=protected-access
from unittest import skip, expectedFailure
from django.test.client import RequestFactory
from lms.djangoapps.oauth_dispatch import adapters
from lms.djangoapps.oauth_dispatch.tests.constants import DUMMY_REDIRECT_URL
from ..views import DOTAccessTokenExchangeView
class DOPAdapterMixin(object):
"""
Mixin to rewire existing tests to use django-oauth2-provider (DOP) backend
Overwrites self.client_id, self.access_token, self.oauth2_adapter
"""
client_id = 'dop_test_client_id'
access_token = 'dop_test_access_token'
oauth2_adapter = adapters.DOPAdapter()
def create_public_client(self, user, client_id=None):
"""
Create an oauth client application that is public.
"""
return self.oauth2_adapter.create_public_client(
name='Test Public Client',
user=user,
client_id=client_id,
redirect_uri=DUMMY_REDIRECT_URL,
)
def create_confidential_client(self, user, client_id=None):
"""
Create an oauth client application that is confidential.
"""
return self.oauth2_adapter.create_confidential_client(
name='Test Confidential Client',
user=user,
client_id=client_id,
redirect_uri=DUMMY_REDIRECT_URL,
)
def get_token_response_keys(self):
"""
Return the set of keys provided when requesting an access token
"""
return {'access_token', 'token_type', 'expires_in', 'scope'}
class DOTAdapterMixin(object):
"""
Mixin to rewire existing tests to use django-oauth-toolkit (DOT) backend
Overwrites self.client_id, self.access_token, self.oauth2_adapter
"""
client_id = 'dot_test_client_id'
access_token = 'dot_test_access_token'
oauth2_adapter = adapters.DOTAdapter()
def create_public_client(self, user, client_id=None):
"""
Create an oauth client application that is public.
"""
return self.oauth2_adapter.create_public_client(
name='Test Public Application',
user=user,
client_id=client_id,
redirect_uri=DUMMY_REDIRECT_URL,
)
def create_confidential_client(self, user, client_id=None):
"""
Create an oauth client application that is confidential.
"""
return self.oauth2_adapter.create_confidential_client(
name='Test Confidential Application',
user=user,
client_id=client_id,
redirect_uri=DUMMY_REDIRECT_URL,
)
def get_token_response_keys(self):
"""
Return the set of keys provided when requesting an access token
"""
return {'access_token', 'refresh_token', 'token_type', 'expires_in', 'scope'}
def test_get_method(self):
# Dispatch routes all get methods to DOP, so we test this on the view
request_factory = RequestFactory()
request = request_factory.get('/oauth2/exchange_access_token/')
request.session = {}
view = DOTAccessTokenExchangeView.as_view()
response = view(request, backend='facebook')
self.assertEqual(response.status_code, 400)
@expectedFailure
def test_single_access_token(self):
# TODO: Single access tokens not supported yet for DOT (See MA-2122)
super(DOTAdapterMixin, self).test_single_access_token()
@skip("Not supported yet (See MA-2123)")
def test_scopes(self):
super(DOTAdapterMixin, self).test_scopes()

View File

@@ -12,10 +12,12 @@ import httpretty
from provider import scope
import social.apps.django_app.utils as social_utils
from auth_exchange.forms import AccessTokenExchangeForm
from auth_exchange.tests.utils import AccessTokenExchangeTestMixin
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
from ..forms import AccessTokenExchangeForm
from .utils import AccessTokenExchangeTestMixin
from .mixins import DOPAdapterMixin, DOTAdapterMixin
class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
"""
@@ -31,7 +33,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
self.request.backend = social_utils.load_backend(self.request.social_strategy, self.BACKEND, redirect_uri)
def _assert_error(self, data, expected_error, expected_error_description):
form = AccessTokenExchangeForm(request=self.request, data=data)
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
self.assertEqual(
form.errors,
{"error": expected_error, "error_description": expected_error_description}
@@ -39,7 +41,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
self.assertNotIn("partial_pipeline", self.request.session)
def _assert_success(self, data, expected_scopes):
form = AccessTokenExchangeForm(request=self.request, data=data)
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
self.assertTrue(form.is_valid())
self.assertEqual(form.cleaned_data["user"], self.user)
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
@@ -49,13 +51,15 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestFacebook(
class DOPAccessTokenExchangeFormTestFacebook(
DOPAdapterMixin,
AccessTokenExchangeFormTest,
ThirdPartyOAuthTestMixinFacebook,
TestCase
TestCase,
):
"""
Tests for AccessTokenExchangeForm used with Facebook
Tests for AccessTokenExchangeForm used with Facebook, tested against
django-oauth2-provider (DOP).
"""
pass
@@ -63,12 +67,46 @@ class AccessTokenExchangeFormTestFacebook(
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestGoogle(
class DOTAccessTokenExchangeFormTestFacebook(
DOTAdapterMixin,
AccessTokenExchangeFormTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase
ThirdPartyOAuthTestMixinFacebook,
TestCase,
):
"""
Tests for AccessTokenExchangeForm used with Google
Tests for AccessTokenExchangeForm used with Facebook, tested against
django-oauth-toolkit (DOT).
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class DOPAccessTokenExchangeFormTestGoogle(
DOPAdapterMixin,
AccessTokenExchangeFormTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase,
):
"""
Tests for AccessTokenExchangeForm used with Google, tested against
django-oauth2-provider (DOP).
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class DOTAccessTokenExchangeFormTestGoogle(
DOTAdapterMixin,
AccessTokenExchangeFormTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase,
):
"""
Tests for AccessTokenExchangeForm used with Google, tested against
django-oauth-toolkit (DOT).
"""
pass

View File

@@ -1,25 +1,30 @@
# pylint: disable=no-member
"""
Tests for OAuth token exchange views
"""
# pylint: disable=no-member
from datetime import timedelta
import json
import mock
import unittest
import ddt
from django.conf import settings
from django.core.urlresolvers import reverse
from django.test import TestCase
import httpretty
import provider.constants
from provider import scope
from provider.oauth2.models import AccessToken, Client
from rest_framework.test import APIClient
from auth_exchange.tests.utils import AccessTokenExchangeTestMixin
from student.tests.factories import UserFactory
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
from .mixins import DOPAdapterMixin, DOTAdapterMixin
from .utils import AccessTokenExchangeTestMixin
@ddt.ddt
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeView
@@ -27,33 +32,34 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
def setUp(self):
super(AccessTokenExchangeViewTest, self).setUp()
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
self.csrf_client = APIClient(enforce_csrf_checks=True)
def _assert_error(self, data, expected_error, expected_error_description):
response = self.client.post(self.url, data)
response = self.csrf_client.post(self.url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response["Content-Type"], "application/json")
self.assertEqual(
json.loads(response.content),
{"error": expected_error, "error_description": expected_error_description}
{u"error": expected_error, u"error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_success(self, data, expected_scopes):
response = self.client.post(self.url, data)
response = self.csrf_client.post(self.url, data)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/json")
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
self.assertEqual(set(content.keys()), self.get_token_response_keys())
self.assertEqual(content["token_type"], "Bearer")
self.assertLessEqual(
timedelta(seconds=int(content["expires_in"])),
provider.constants.EXPIRE_DELTA_PUBLIC
)
self.assertEqual(content["scope"], " ".join(expected_scopes))
token = AccessToken.objects.get(token=content["access_token"])
self.assertEqual(content["scope"], self.oauth2_adapter.normalize_scopes(expected_scopes))
token = self.oauth2_adapter.get_access_token(token_string=content["access_token"])
self.assertEqual(token.user, self.user)
self.assertEqual(token.client, self.oauth_client)
self.assertEqual(scope.to_names(token.scope), expected_scopes)
self.assertEqual(self.oauth2_adapter.get_client_for_token(token), self.oauth_client)
self.assertEqual(self.oauth2_adapter.get_token_scope_names(token), expected_scopes)
def test_single_access_token(self):
def extract_token(response):
@@ -64,16 +70,15 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
self._setup_provider_response(success=True)
for single_access_token in [True, False]:
with mock.patch(
"auth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
single_access_token
):
with mock.patch("auth_exchange.views.constants.SINGLE_ACCESS_TOKEN", single_access_token):
first_response = self.client.post(self.url, self.data)
second_response = self.client.post(self.url, self.data)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
self.assertEqual(first_response.status_code, 200)
self.assertEqual(second_response.status_code, 200)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
def test_get_method(self):
response = self.client.get(self.url, self.data)
@@ -95,10 +100,11 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestFacebook(
class DOPAccessTokenExchangeViewTestFacebook(
DOPAdapterMixin,
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinFacebook,
TestCase
TestCase,
):
"""
Tests for AccessTokenExchangeView used with Facebook
@@ -106,16 +112,48 @@ class AccessTokenExchangeViewTestFacebook(
pass
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class DOTAccessTokenExchangeViewTestFacebook(
DOTAdapterMixin,
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinFacebook,
TestCase,
):
"""
Rerun AccessTokenExchangeViewTestFacebook tests against DOT backend
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestGoogle(
class DOPAccessTokenExchangeViewTestGoogle(
DOPAdapterMixin,
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase
TestCase,
):
"""
Tests for AccessTokenExchangeView used with Google
Tests for AccessTokenExchangeView used with Google using
django-oauth2-provider backend.
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class DOTAccessTokenExchangeViewTestGoogle(
DOTAdapterMixin,
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase,
):
"""
Tests for AccessTokenExchangeView used with Google using
django-oauth-toolkit backend.
"""
pass

View File

@@ -1,9 +1,8 @@
"""
Test utilities for OAuth access token exchange
"""
import provider.constants
from social.apps.django_app.default.models import UserSocialAuth
from social.apps.django_app.default.models import UserSocialAuth
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin
@@ -37,6 +36,12 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin):
"""
raise NotImplementedError()
def _create_client(self):
"""
Create an oauth2 client application using class defaults.
"""
return self.create_public_client(self.user, self.client_id)
def test_minimal(self):
self._setup_provider_response(success=True)
self._assert_success(self.data, expected_scopes=[])
@@ -61,12 +66,12 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin):
)
def test_confidential_client(self):
self.oauth_client.client_type = provider.constants.CONFIDENTIAL
self.oauth_client.save()
self.data['client_id'] += '_confidential'
self.oauth_client = self.create_confidential_client(self.user, self.data['client_id'])
self._assert_error(
self.data,
"invalid_client",
"test_client_id is not a public client"
"{}_confidential is not a public client".format(self.client_id),
)
def test_inactive_user(self):

View File

@@ -1,4 +1,3 @@
# pylint: disable=abstract-method
"""
Views to support exchange of authentication credentials.
The following are currently implemented:
@@ -7,36 +6,52 @@ The following are currently implemented:
2. LoginWithAccessTokenView:
1st party (open-edx) OAuth 2.0 access token -> session cookie
"""
# pylint: disable=abstract-method
from django.conf import settings
from django.contrib.auth import login
import django.contrib.auth as auth
from django.http import HttpResponse
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from edx_oauth2_provider.constants import SCOPE_VALUE_DICT
from oauth2_provider.settings import oauth2_settings
from oauth2_provider.views.base import TokenView as DOTAccessTokenView
from oauthlib.oauth2.rfc6749.tokens import BearerToken
from provider import constants
from provider.oauth2.views import AccessTokenView as AccessTokenView
from provider.oauth2.views import AccessTokenView as DOPAccessTokenView
from rest_framework import permissions
from rest_framework.response import Response
from rest_framework.views import APIView
import social.apps.django_app.utils as social_utils
from auth_exchange.forms import AccessTokenExchangeForm
from lms.djangoapps.oauth_dispatch import adapters
from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser
class AccessTokenExchangeView(AccessTokenView):
class AccessTokenExchangeBase(APIView):
"""
View for token exchange from 3rd party OAuth access token to 1st party OAuth access token
View for token exchange from 3rd party OAuth access token to 1st party
OAuth access token.
"""
@method_decorator(csrf_exempt)
@method_decorator(social_utils.strategy("social:complete"))
def dispatch(self, *args, **kwargs):
return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs)
return super(AccessTokenExchangeBase, self).dispatch(*args, **kwargs)
def get(self, request, _backend): # pylint: disable=arguments-differ
return super(AccessTokenExchangeView, self).get(request)
"""
Pass through GET requests without the _backend
"""
return super(AccessTokenExchangeBase, self).get(request)
def post(self, request, _backend): # pylint: disable=arguments-differ
form = AccessTokenExchangeForm(request=request, data=request.POST)
"""
Handle POST requests to get a first-party access token.
"""
form = AccessTokenExchangeForm(request=request, oauth2_adapter=self.oauth2_adapter, data=request.POST) # pylint: disable=no-member
if not form.is_valid():
return self.error_response(form.errors)
@@ -44,12 +59,89 @@ class AccessTokenExchangeView(AccessTokenView):
scope = form.cleaned_data["scope"]
client = form.cleaned_data["client"]
return self.exchange_access_token(request, user, scope, client)
def exchange_access_token(self, request, user, scope, client):
"""
Exchange third party credentials for an edx access token, and return a
serialized access token response.
"""
if constants.SINGLE_ACCESS_TOKEN:
edx_access_token = self.get_access_token(request, user, scope, client)
edx_access_token = self.get_access_token(request, user, scope, client) # pylint: disable=no-member
else:
edx_access_token = self.create_access_token(request, user, scope, client)
return self.access_token_response(edx_access_token) # pylint: disable=no-member
return self.access_token_response(edx_access_token)
class DOPAccessTokenExchangeView(AccessTokenExchangeBase, DOPAccessTokenView):
"""
View for token exchange from 3rd party OAuth access token to 1st party
OAuth access token. Uses django-oauth2-provider (DOP) to manage access
tokens.
"""
oauth2_adapter = adapters.DOPAdapter()
class DOTAccessTokenExchangeView(AccessTokenExchangeBase, DOTAccessTokenView):
"""
View for token exchange from 3rd party OAuth access token to 1st party
OAuth access token. Uses django-oauth-toolkit (DOT) to manage access
tokens.
"""
oauth2_adapter = adapters.DOTAdapter()
def get(self, request, _backend):
return Response(status=400, data={
'error': 'invalid_request',
'error_description': 'Only POST requests allowed.',
})
def get_access_token(self, request, user, scope, client):
"""
TODO: MA-2122: Reusing access tokens is not yet supported for DOT.
Just return a new access token.
"""
return self.create_access_token(request, user, scope, client)
def create_access_token(self, request, user, scope, client):
"""
Create and return a new access token.
"""
_days = 24 * 60 * 60
token_generator = BearerToken(
expires_in=settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * _days,
request_validator=oauth2_settings.OAUTH2_VALIDATOR_CLASS(),
)
self._populate_create_access_token_request(request, user, scope, client)
return token_generator.create_token(request, refresh_token=True)
def access_token_response(self, token):
"""
Wrap an access token in an appropriate response
"""
return Response(data=token)
def _populate_create_access_token_request(self, request, user, scope, client):
"""
django-oauth-toolkit expects certain non-standard attributes to
be present on the request object. This function modifies the
request object to match these expectations
"""
request.user = user
request.scopes = [SCOPE_VALUE_DICT[scope]]
request.client = client
request.state = None
request.refresh_token = None
request.extra_credentials = None
request.grant_type = client.authorization_grant_type
def error_response(self, form_errors):
"""
Return an error response consisting of the errors in the form
"""
return Response(status=400, data=form_errors)
class LoginWithAccessTokenView(APIView):

View File

@@ -22,23 +22,30 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin):
USER_URL: The URL of the endpoint that the backend retrieves user data from
UID_FIELD: The field in the user data that the backend uses as the user id
"""
social_uid = "test_social_uid"
access_token = "test_access_token"
client_id = "test_client_id"
def setUp(self, create_user=True):
super(ThirdPartyOAuthTestMixin, self).setUp()
self.social_uid = "test_social_uid"
self.access_token = "test_access_token"
self.client_id = "test_client_id"
self.oauth_client = Client.objects.create(
client_id=self.client_id,
client_type=PUBLIC
)
if create_user:
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
self.oauth_client = self._create_client()
if self.BACKEND == 'google-oauth2':
self.configure_google_provider(enabled=True)
elif self.BACKEND == 'facebook':
self.configure_facebook_provider(enabled=True)
def _create_client(self):
"""
Create an OAuth2 client application
"""
return Client.objects.create(
client_id=self.client_id,
client_type=PUBLIC,
)
def _setup_provider_response(self, success=False, email=''):
"""
Register a mock response for the third party user information endpoint;
@@ -65,7 +72,7 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin):
self.USER_URL,
body=body,
status=status,
content_type="application/json"
content_type="application/json",
)

Binary file not shown.