Add an endpoint to exchange OAuth access tokens
This allows the holder of a third-party access token (e.g. from Google or Facebook) to get a first-party access token for the edX account linked to the given access token.
This commit is contained in:
@@ -26,7 +26,17 @@ from uuid import uuid4
|
||||
|
||||
# import settings from LMS for consistent behavior with CMS
|
||||
# pylint: disable=unused-import
|
||||
from lms.envs.test import (WIKI_ENABLED, PLATFORM_NAME, SITE_NAME, DEFAULT_FILE_STORAGE, MEDIA_ROOT, MEDIA_URL)
|
||||
from lms.envs.test import (
|
||||
WIKI_ENABLED,
|
||||
PLATFORM_NAME,
|
||||
SITE_NAME,
|
||||
DEFAULT_FILE_STORAGE,
|
||||
MEDIA_ROOT,
|
||||
MEDIA_URL,
|
||||
# This is practically unused but needed by the oauth2_provider package, which
|
||||
# some tests in common/ rely on.
|
||||
OAUTH_OIDC_ISSUER,
|
||||
)
|
||||
|
||||
# mongo connection settings
|
||||
MONGO_PORT_NUM = int(os.environ.get('EDXAPP_TEST_MONGO_PORT', '27017'))
|
||||
|
||||
0
common/djangoapps/oauth_exchange/__init__.py
Normal file
0
common/djangoapps/oauth_exchange/__init__.py
Normal file
100
common/djangoapps/oauth_exchange/forms.py
Normal file
100
common/djangoapps/oauth_exchange/forms.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Forms to support third-party to first-party OAuth 2.0 access token exchange
|
||||
"""
|
||||
from django.contrib.auth.models import User
|
||||
from django.forms import CharField
|
||||
from oauth2_provider.constants import SCOPE_NAMES
|
||||
import provider.constants
|
||||
from provider.forms import OAuthForm, OAuthValidationError
|
||||
from provider.oauth2.forms import ScopeChoiceField, ScopeMixin
|
||||
from provider.oauth2.models import Client
|
||||
from requests import HTTPError
|
||||
from social.backends import oauth as social_oauth
|
||||
|
||||
from third_party_auth import pipeline
|
||||
|
||||
|
||||
class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
|
||||
"""Form for access token exchange endpoint"""
|
||||
access_token = CharField(required=False)
|
||||
scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False)
|
||||
client_id = CharField(required=False)
|
||||
|
||||
def __init__(self, request, *args, **kwargs):
|
||||
super(AccessTokenExchangeForm, self).__init__(*args, **kwargs)
|
||||
self.request = request
|
||||
|
||||
def _require_oauth_field(self, field_name):
|
||||
"""
|
||||
Raise an appropriate OAuthValidationError error if the field is missing
|
||||
"""
|
||||
field_val = self.cleaned_data.get(field_name)
|
||||
if not field_val:
|
||||
raise OAuthValidationError(
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "{} is required".format(field_name),
|
||||
}
|
||||
)
|
||||
return field_val
|
||||
|
||||
def clean_access_token(self):
|
||||
return self._require_oauth_field("access_token")
|
||||
|
||||
def clean_client_id(self):
|
||||
return self._require_oauth_field("client_id")
|
||||
|
||||
def clean(self):
|
||||
if self._errors:
|
||||
return {}
|
||||
|
||||
backend = self.request.social_strategy.backend
|
||||
if not isinstance(backend, social_oauth.BaseOAuth2):
|
||||
raise OAuthValidationError(
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "{} is not a supported provider".format(backend.name),
|
||||
}
|
||||
)
|
||||
|
||||
self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API
|
||||
|
||||
client_id = self.cleaned_data["client_id"]
|
||||
try:
|
||||
client = Client.objects.get(client_id=client_id)
|
||||
except Client.DoesNotExist:
|
||||
raise OAuthValidationError(
|
||||
{
|
||||
"error": "invalid_client",
|
||||
"error_description": "{} is not a valid client_id".format(client_id),
|
||||
}
|
||||
)
|
||||
if client.client_type != provider.constants.PUBLIC:
|
||||
raise OAuthValidationError(
|
||||
{
|
||||
# invalid_client isn't really the right code, but this mirrors
|
||||
# https://github.com/edx/django-oauth2-provider/blob/edx/provider/oauth2/forms.py#L331
|
||||
"error": "invalid_client",
|
||||
"error_description": "{} is not a public client".format(client_id),
|
||||
}
|
||||
)
|
||||
self.cleaned_data["client"] = client
|
||||
|
||||
user = None
|
||||
try:
|
||||
user = backend.do_auth(self.cleaned_data.get("access_token"))
|
||||
except HTTPError:
|
||||
pass
|
||||
if user and isinstance(user, User):
|
||||
self.cleaned_data["user"] = user
|
||||
else:
|
||||
# Ensure user does not re-enter the pipeline
|
||||
self.request.social_strategy.clean_partial_pipeline()
|
||||
raise OAuthValidationError(
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "access_token is not valid",
|
||||
}
|
||||
)
|
||||
|
||||
return self.cleaned_data
|
||||
3
common/djangoapps/oauth_exchange/models.py
Normal file
3
common/djangoapps/oauth_exchange/models.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
A models.py is required to make this an app (until we move to Django 1.7)
|
||||
"""
|
||||
0
common/djangoapps/oauth_exchange/tests/__init__.py
Normal file
0
common/djangoapps/oauth_exchange/tests/__init__.py
Normal file
73
common/djangoapps/oauth_exchange/tests/test_forms.py
Normal file
73
common/djangoapps/oauth_exchange/tests/test_forms.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Tests for OAuth token exchange forms
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.middleware import SessionMiddleware
|
||||
from django.test import TestCase
|
||||
from django.test.client import RequestFactory
|
||||
import httpretty
|
||||
from provider import scope
|
||||
import social.apps.django_app.utils as social_utils
|
||||
|
||||
from oauth_exchange.forms import AccessTokenExchangeForm
|
||||
from oauth_exchange.tests.utils import (
|
||||
AccessTokenExchangeTestMixin,
|
||||
AccessTokenExchangeMixinFacebook,
|
||||
AccessTokenExchangeMixinGoogle
|
||||
)
|
||||
|
||||
|
||||
class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
|
||||
"""
|
||||
Mixin that defines test cases for AccessTokenExchangeForm
|
||||
"""
|
||||
def setUp(self):
|
||||
super(AccessTokenExchangeFormTest, self).setUp()
|
||||
self.request = RequestFactory().post("dummy_url")
|
||||
SessionMiddleware().process_request(self.request)
|
||||
self.request.social_strategy = social_utils.load_strategy(self.request, self.BACKEND)
|
||||
|
||||
def _assert_error(self, data, expected_error, expected_error_description):
|
||||
form = AccessTokenExchangeForm(request=self.request, data=data)
|
||||
self.assertEqual(
|
||||
form.errors,
|
||||
{"error": expected_error, "error_description": expected_error_description}
|
||||
)
|
||||
self.assertNotIn("partial_pipeline", self.request.session)
|
||||
|
||||
def _assert_success(self, data, expected_scopes):
|
||||
form = AccessTokenExchangeForm(request=self.request, data=data)
|
||||
self.assertTrue(form.is_valid())
|
||||
self.assertEqual(form.cleaned_data["user"], self.user)
|
||||
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
|
||||
self.assertEqual(scope.to_names(form.cleaned_data["scope"]), expected_scopes)
|
||||
|
||||
|
||||
# This is necessary because cms does not implement third party auth
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
||||
@httpretty.activate
|
||||
class AccessTokenExchangeFormTestFacebook(
|
||||
AccessTokenExchangeFormTest,
|
||||
AccessTokenExchangeMixinFacebook,
|
||||
TestCase
|
||||
):
|
||||
"""
|
||||
Tests for AccessTokenExchangeForm used with Facebook
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# This is necessary because cms does not implement third party auth
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
||||
@httpretty.activate
|
||||
class AccessTokenExchangeFormTestGoogle(
|
||||
AccessTokenExchangeFormTest,
|
||||
AccessTokenExchangeMixinGoogle,
|
||||
TestCase
|
||||
):
|
||||
"""
|
||||
Tests for AccessTokenExchangeForm used with Google
|
||||
"""
|
||||
pass
|
||||
118
common/djangoapps/oauth_exchange/tests/test_views.py
Normal file
118
common/djangoapps/oauth_exchange/tests/test_views.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Tests for OAuth token exchange views
|
||||
"""
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.test import TestCase
|
||||
import httpretty
|
||||
import provider.constants
|
||||
from provider import scope
|
||||
from provider.oauth2.models import AccessToken
|
||||
|
||||
from oauth_exchange.tests.utils import (
|
||||
AccessTokenExchangeTestMixin,
|
||||
AccessTokenExchangeMixinFacebook,
|
||||
AccessTokenExchangeMixinGoogle
|
||||
)
|
||||
|
||||
|
||||
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
|
||||
"""
|
||||
Mixin that defines test cases for AccessTokenExchangeView
|
||||
"""
|
||||
def setUp(self):
|
||||
super(AccessTokenExchangeViewTest, self).setUp()
|
||||
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
|
||||
|
||||
def _assert_error(self, data, expected_error, expected_error_description):
|
||||
response = self.client.post(self.url, data)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response["Content-Type"], "application/json")
|
||||
self.assertEqual(
|
||||
json.loads(response.content),
|
||||
{"error": expected_error, "error_description": expected_error_description}
|
||||
)
|
||||
self.assertNotIn("partial_pipeline", self.client.session)
|
||||
|
||||
def _assert_success(self, data, expected_scopes):
|
||||
response = self.client.post(self.url, data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "application/json")
|
||||
content = json.loads(response.content)
|
||||
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
|
||||
self.assertEqual(content["token_type"], "Bearer")
|
||||
self.assertLessEqual(
|
||||
timedelta(seconds=int(content["expires_in"])),
|
||||
provider.constants.EXPIRE_DELTA_PUBLIC
|
||||
)
|
||||
self.assertEqual(content["scope"], " ".join(expected_scopes))
|
||||
token = AccessToken.objects.get(token=content["access_token"])
|
||||
self.assertEqual(token.user, self.user)
|
||||
self.assertEqual(token.client, self.oauth_client)
|
||||
self.assertEqual(scope.to_names(token.scope), expected_scopes)
|
||||
|
||||
def test_single_access_token(self):
|
||||
def extract_token(response):
|
||||
return json.loads(response.content)["access_token"]
|
||||
|
||||
self._setup_provider_response(success=True)
|
||||
for single_access_token in [True, False]:
|
||||
with mock.patch(
|
||||
"oauth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
|
||||
single_access_token
|
||||
):
|
||||
first_response = self.client.post(self.url, self.data)
|
||||
second_response = self.client.post(self.url, self.data)
|
||||
self.assertEqual(
|
||||
extract_token(first_response) == extract_token(second_response),
|
||||
single_access_token
|
||||
)
|
||||
|
||||
def test_get_method(self):
|
||||
response = self.client.get(self.url, self.data)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
json.loads(response.content),
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST requests allowed.",
|
||||
}
|
||||
)
|
||||
|
||||
def test_invalid_provider(self):
|
||||
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
|
||||
response = self.client.post(url, self.data)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
|
||||
# This is necessary because cms does not implement third party auth
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
||||
@httpretty.activate
|
||||
class AccessTokenExchangeViewTestFacebook(
|
||||
AccessTokenExchangeViewTest,
|
||||
AccessTokenExchangeMixinFacebook,
|
||||
TestCase
|
||||
):
|
||||
"""
|
||||
Tests for AccessTokenExchangeView used with Facebook
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# This is necessary because cms does not implement third party auth
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
||||
@httpretty.activate
|
||||
class AccessTokenExchangeViewTestGoogle(
|
||||
AccessTokenExchangeViewTest,
|
||||
AccessTokenExchangeMixinGoogle,
|
||||
TestCase
|
||||
):
|
||||
"""
|
||||
Tests for AccessTokenExchangeView used with Google
|
||||
"""
|
||||
pass
|
||||
127
common/djangoapps/oauth_exchange/tests/utils.py
Normal file
127
common/djangoapps/oauth_exchange/tests/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Test utilities for OAuth access token exchange
|
||||
"""
|
||||
import json
|
||||
|
||||
import httpretty
|
||||
import provider.constants
|
||||
from provider.oauth2.models import Client
|
||||
from social.apps.django_app.default.models import UserSocialAuth
|
||||
|
||||
from student.tests.factories import UserFactory
|
||||
|
||||
|
||||
class AccessTokenExchangeTestMixin(object):
|
||||
"""
|
||||
A mixin to define test cases for access token exchange. The following
|
||||
methods must be implemented by subclasses:
|
||||
* _assert_error(data, expected_error, expected_error_description)
|
||||
* _assert_success(data, expected_scopes)
|
||||
"""
|
||||
def setUp(self):
|
||||
super(AccessTokenExchangeTestMixin, self).setUp()
|
||||
|
||||
self.client_id = "test_client_id"
|
||||
self.oauth_client = Client.objects.create(
|
||||
client_id=self.client_id,
|
||||
client_type=provider.constants.PUBLIC
|
||||
)
|
||||
self.social_uid = "test_social_uid"
|
||||
self.user = UserFactory()
|
||||
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
|
||||
self.access_token = "test_access_token"
|
||||
# Initialize to minimal data
|
||||
self.data = {
|
||||
"access_token": self.access_token,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
|
||||
def _setup_provider_response(self, success):
|
||||
"""
|
||||
Register a mock response for the third party user information endpoint;
|
||||
success indicates whether the response status code should be 200 or 400
|
||||
"""
|
||||
if success:
|
||||
status = 200
|
||||
body = json.dumps({self.UID_FIELD: self.social_uid})
|
||||
else:
|
||||
status = 400
|
||||
body = json.dumps({})
|
||||
httpretty.register_uri(
|
||||
httpretty.GET,
|
||||
self.USER_URL,
|
||||
body=body,
|
||||
status=status,
|
||||
content_type="application/json"
|
||||
)
|
||||
|
||||
def _assert_error(self, _data, _expected_error, _expected_error_description):
|
||||
"""
|
||||
Given request data, execute a test and check that the expected error
|
||||
was returned (along with any other appropriate assertions).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assert_success(self, data, expected_scopes):
|
||||
"""
|
||||
Given request data, execute a test and check that the expected scopes
|
||||
were returned (along with any other appropriate assertions).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def test_minimal(self):
|
||||
self._setup_provider_response(success=True)
|
||||
self._assert_success(self.data, expected_scopes=[])
|
||||
|
||||
def test_scopes(self):
|
||||
self._setup_provider_response(success=True)
|
||||
self.data["scope"] = "profile email"
|
||||
self._assert_success(self.data, expected_scopes=["profile", "email"])
|
||||
|
||||
def test_missing_fields(self):
|
||||
for field in ["access_token", "client_id"]:
|
||||
data = dict(self.data)
|
||||
del data[field]
|
||||
self._assert_error(data, "invalid_request", "{} is required".format(field))
|
||||
|
||||
def test_invalid_client(self):
|
||||
self.data["client_id"] = "nonexistent_client"
|
||||
self._assert_error(
|
||||
self.data,
|
||||
"invalid_client",
|
||||
"nonexistent_client is not a valid client_id"
|
||||
)
|
||||
|
||||
def test_confidential_client(self):
|
||||
self.oauth_client.client_type = provider.constants.CONFIDENTIAL
|
||||
self.oauth_client.save()
|
||||
self._assert_error(
|
||||
self.data,
|
||||
"invalid_client",
|
||||
"test_client_id is not a public client"
|
||||
)
|
||||
|
||||
def test_invalid_acess_token(self):
|
||||
self._setup_provider_response(success=False)
|
||||
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
|
||||
|
||||
def test_no_linked_user(self):
|
||||
UserSocialAuth.objects.all().delete()
|
||||
self._setup_provider_response(success=True)
|
||||
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
|
||||
|
||||
|
||||
class AccessTokenExchangeMixinFacebook(object):
|
||||
"""Tests access token exchange with the Facebook backend"""
|
||||
BACKEND = "facebook"
|
||||
USER_URL = "https://graph.facebook.com/me"
|
||||
# In facebook responses, the "id" field is used as the user's identifier
|
||||
UID_FIELD = "id"
|
||||
|
||||
|
||||
class AccessTokenExchangeMixinGoogle(object):
|
||||
"""Tests access token exchange with the Google backend"""
|
||||
BACKEND = "google-oauth2"
|
||||
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||
# In google-oauth2 responses, the "email" field is used as the user's identifier
|
||||
UID_FIELD = "email"
|
||||
37
common/djangoapps/oauth_exchange/views.py
Normal file
37
common/djangoapps/oauth_exchange/views.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Views to support third-party to first-party OAuth 2.0 access token exchange
|
||||
"""
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from provider import constants
|
||||
from provider.oauth2.views import AccessTokenView as AccessTokenView
|
||||
import social.apps.django_app.utils as social_utils
|
||||
|
||||
from oauth_exchange.forms import AccessTokenExchangeForm
|
||||
|
||||
|
||||
class AccessTokenExchangeView(AccessTokenView):
|
||||
"""View for access token exchange"""
|
||||
@method_decorator(csrf_exempt)
|
||||
@method_decorator(social_utils.strategy("social:complete"))
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs)
|
||||
|
||||
def get(self, request, _backend):
|
||||
return super(AccessTokenExchangeView, self).get(request)
|
||||
|
||||
def post(self, request, _backend):
|
||||
form = AccessTokenExchangeForm(request=request, data=request.POST)
|
||||
if not form.is_valid():
|
||||
return self.error_response(form.errors)
|
||||
|
||||
user = form.cleaned_data["user"]
|
||||
scope = form.cleaned_data["scope"]
|
||||
client = form.cleaned_data["client"]
|
||||
|
||||
if constants.SINGLE_ACCESS_TOKEN:
|
||||
edx_access_token = self.get_access_token(request, user, scope, client)
|
||||
else:
|
||||
edx_access_token = self.create_access_token(request, user, scope, client)
|
||||
|
||||
return self.access_token_response(edx_access_token)
|
||||
@@ -1548,6 +1548,8 @@ INSTALLED_APPS = (
|
||||
'provider.oauth2',
|
||||
'oauth2_provider',
|
||||
|
||||
'oauth_exchange',
|
||||
|
||||
# For the wiki
|
||||
'wiki', # The new django-wiki from benjaoming
|
||||
'django_notify',
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.conf.urls.static import static
|
||||
|
||||
import django.contrib.auth.views
|
||||
from microsite_configuration import microsite
|
||||
import oauth_exchange.views
|
||||
|
||||
# Uncomment the next two lines to enable the admin:
|
||||
if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'):
|
||||
@@ -585,6 +586,11 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'):
|
||||
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
|
||||
urlpatterns += (
|
||||
url(r'', include('third_party_auth.urls')),
|
||||
url(
|
||||
r'^oauth2/exchange_access_token/(?P<backend>[^/]+)/$',
|
||||
oauth_exchange.views.AccessTokenExchangeView.as_view(),
|
||||
name="exchange_access_token"
|
||||
),
|
||||
url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user