diff --git a/cms/envs/test.py b/cms/envs/test.py index 4ea673698f..d16355de26 100644 --- a/cms/envs/test.py +++ b/cms/envs/test.py @@ -26,7 +26,17 @@ from uuid import uuid4 # import settings from LMS for consistent behavior with CMS # pylint: disable=unused-import -from lms.envs.test import (WIKI_ENABLED, PLATFORM_NAME, SITE_NAME, DEFAULT_FILE_STORAGE, MEDIA_ROOT, MEDIA_URL) +from lms.envs.test import ( + WIKI_ENABLED, + PLATFORM_NAME, + SITE_NAME, + DEFAULT_FILE_STORAGE, + MEDIA_ROOT, + MEDIA_URL, + # This is practically unused but needed by the oauth2_provider package, which + # some tests in common/ rely on. + OAUTH_OIDC_ISSUER, +) # mongo connection settings MONGO_PORT_NUM = int(os.environ.get('EDXAPP_TEST_MONGO_PORT', '27017')) diff --git a/common/djangoapps/oauth_exchange/__init__.py b/common/djangoapps/oauth_exchange/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/oauth_exchange/forms.py b/common/djangoapps/oauth_exchange/forms.py new file mode 100644 index 0000000000..36051684e8 --- /dev/null +++ b/common/djangoapps/oauth_exchange/forms.py @@ -0,0 +1,100 @@ +""" +Forms to support third-party to first-party OAuth 2.0 access token exchange +""" +from django.contrib.auth.models import User +from django.forms import CharField +from oauth2_provider.constants import SCOPE_NAMES +import provider.constants +from provider.forms import OAuthForm, OAuthValidationError +from provider.oauth2.forms import ScopeChoiceField, ScopeMixin +from provider.oauth2.models import Client +from requests import HTTPError +from social.backends import oauth as social_oauth + +from third_party_auth import pipeline + + +class AccessTokenExchangeForm(ScopeMixin, OAuthForm): + """Form for access token exchange endpoint""" + access_token = CharField(required=False) + scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False) + client_id = CharField(required=False) + + def __init__(self, request, *args, **kwargs): + super(AccessTokenExchangeForm, self).__init__(*args, **kwargs) + self.request = request + + def _require_oauth_field(self, field_name): + """ + Raise an appropriate OAuthValidationError error if the field is missing + """ + field_val = self.cleaned_data.get(field_name) + if not field_val: + raise OAuthValidationError( + { + "error": "invalid_request", + "error_description": "{} is required".format(field_name), + } + ) + return field_val + + def clean_access_token(self): + return self._require_oauth_field("access_token") + + def clean_client_id(self): + return self._require_oauth_field("client_id") + + def clean(self): + if self._errors: + return {} + + backend = self.request.social_strategy.backend + if not isinstance(backend, social_oauth.BaseOAuth2): + raise OAuthValidationError( + { + "error": "invalid_request", + "error_description": "{} is not a supported provider".format(backend.name), + } + ) + + self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API + + client_id = self.cleaned_data["client_id"] + try: + client = Client.objects.get(client_id=client_id) + except Client.DoesNotExist: + raise OAuthValidationError( + { + "error": "invalid_client", + "error_description": "{} is not a valid client_id".format(client_id), + } + ) + if client.client_type != provider.constants.PUBLIC: + raise OAuthValidationError( + { + # invalid_client isn't really the right code, but this mirrors + # https://github.com/edx/django-oauth2-provider/blob/edx/provider/oauth2/forms.py#L331 + "error": "invalid_client", + "error_description": "{} is not a public client".format(client_id), + } + ) + self.cleaned_data["client"] = client + + user = None + try: + user = backend.do_auth(self.cleaned_data.get("access_token")) + except HTTPError: + pass + if user and isinstance(user, User): + self.cleaned_data["user"] = user + else: + # Ensure user does not re-enter the pipeline + self.request.social_strategy.clean_partial_pipeline() + raise OAuthValidationError( + { + "error": "invalid_grant", + "error_description": "access_token is not valid", + } + ) + + return self.cleaned_data diff --git a/common/djangoapps/oauth_exchange/models.py b/common/djangoapps/oauth_exchange/models.py new file mode 100644 index 0000000000..d2e8572729 --- /dev/null +++ b/common/djangoapps/oauth_exchange/models.py @@ -0,0 +1,3 @@ +""" +A models.py is required to make this an app (until we move to Django 1.7) +""" diff --git a/common/djangoapps/oauth_exchange/tests/__init__.py b/common/djangoapps/oauth_exchange/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/oauth_exchange/tests/test_forms.py b/common/djangoapps/oauth_exchange/tests/test_forms.py new file mode 100644 index 0000000000..8bf3b6d427 --- /dev/null +++ b/common/djangoapps/oauth_exchange/tests/test_forms.py @@ -0,0 +1,73 @@ +""" +Tests for OAuth token exchange forms +""" +import unittest + +from django.conf import settings +from django.contrib.sessions.middleware import SessionMiddleware +from django.test import TestCase +from django.test.client import RequestFactory +import httpretty +from provider import scope +import social.apps.django_app.utils as social_utils + +from oauth_exchange.forms import AccessTokenExchangeForm +from oauth_exchange.tests.utils import ( + AccessTokenExchangeTestMixin, + AccessTokenExchangeMixinFacebook, + AccessTokenExchangeMixinGoogle +) + + +class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): + """ + Mixin that defines test cases for AccessTokenExchangeForm + """ + def setUp(self): + super(AccessTokenExchangeFormTest, self).setUp() + self.request = RequestFactory().post("dummy_url") + SessionMiddleware().process_request(self.request) + self.request.social_strategy = social_utils.load_strategy(self.request, self.BACKEND) + + def _assert_error(self, data, expected_error, expected_error_description): + form = AccessTokenExchangeForm(request=self.request, data=data) + self.assertEqual( + form.errors, + {"error": expected_error, "error_description": expected_error_description} + ) + self.assertNotIn("partial_pipeline", self.request.session) + + def _assert_success(self, data, expected_scopes): + form = AccessTokenExchangeForm(request=self.request, data=data) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data["user"], self.user) + self.assertEqual(form.cleaned_data["client"], self.oauth_client) + self.assertEqual(scope.to_names(form.cleaned_data["scope"]), expected_scopes) + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class AccessTokenExchangeFormTestFacebook( + AccessTokenExchangeFormTest, + AccessTokenExchangeMixinFacebook, + TestCase +): + """ + Tests for AccessTokenExchangeForm used with Facebook + """ + pass + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class AccessTokenExchangeFormTestGoogle( + AccessTokenExchangeFormTest, + AccessTokenExchangeMixinGoogle, + TestCase +): + """ + Tests for AccessTokenExchangeForm used with Google + """ + pass diff --git a/common/djangoapps/oauth_exchange/tests/test_views.py b/common/djangoapps/oauth_exchange/tests/test_views.py new file mode 100644 index 0000000000..14e49c1a8b --- /dev/null +++ b/common/djangoapps/oauth_exchange/tests/test_views.py @@ -0,0 +1,118 @@ +""" +Tests for OAuth token exchange views +""" +from datetime import timedelta +import json +import mock +import unittest + +from django.conf import settings +from django.core.urlresolvers import reverse +from django.test import TestCase +import httpretty +import provider.constants +from provider import scope +from provider.oauth2.models import AccessToken + +from oauth_exchange.tests.utils import ( + AccessTokenExchangeTestMixin, + AccessTokenExchangeMixinFacebook, + AccessTokenExchangeMixinGoogle +) + + +class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): + """ + Mixin that defines test cases for AccessTokenExchangeView + """ + def setUp(self): + super(AccessTokenExchangeViewTest, self).setUp() + self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND}) + + def _assert_error(self, data, expected_error, expected_error_description): + response = self.client.post(self.url, data) + self.assertEqual(response.status_code, 400) + self.assertEqual(response["Content-Type"], "application/json") + self.assertEqual( + json.loads(response.content), + {"error": expected_error, "error_description": expected_error_description} + ) + self.assertNotIn("partial_pipeline", self.client.session) + + def _assert_success(self, data, expected_scopes): + response = self.client.post(self.url, data) + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Content-Type"], "application/json") + content = json.loads(response.content) + self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"}) + self.assertEqual(content["token_type"], "Bearer") + self.assertLessEqual( + timedelta(seconds=int(content["expires_in"])), + provider.constants.EXPIRE_DELTA_PUBLIC + ) + self.assertEqual(content["scope"], " ".join(expected_scopes)) + token = AccessToken.objects.get(token=content["access_token"]) + self.assertEqual(token.user, self.user) + self.assertEqual(token.client, self.oauth_client) + self.assertEqual(scope.to_names(token.scope), expected_scopes) + + def test_single_access_token(self): + def extract_token(response): + return json.loads(response.content)["access_token"] + + self._setup_provider_response(success=True) + for single_access_token in [True, False]: + with mock.patch( + "oauth_exchange.views.constants.SINGLE_ACCESS_TOKEN", + single_access_token + ): + first_response = self.client.post(self.url, self.data) + second_response = self.client.post(self.url, self.data) + self.assertEqual( + extract_token(first_response) == extract_token(second_response), + single_access_token + ) + + def test_get_method(self): + response = self.client.get(self.url, self.data) + self.assertEqual(response.status_code, 400) + self.assertEqual( + json.loads(response.content), + { + "error": "invalid_request", + "error_description": "Only POST requests allowed.", + } + ) + + def test_invalid_provider(self): + url = reverse("exchange_access_token", kwargs={"backend": "invalid"}) + response = self.client.post(url, self.data) + self.assertEqual(response.status_code, 404) + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class AccessTokenExchangeViewTestFacebook( + AccessTokenExchangeViewTest, + AccessTokenExchangeMixinFacebook, + TestCase +): + """ + Tests for AccessTokenExchangeView used with Facebook + """ + pass + + +# This is necessary because cms does not implement third party auth +@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") +@httpretty.activate +class AccessTokenExchangeViewTestGoogle( + AccessTokenExchangeViewTest, + AccessTokenExchangeMixinGoogle, + TestCase +): + """ + Tests for AccessTokenExchangeView used with Google + """ + pass diff --git a/common/djangoapps/oauth_exchange/tests/utils.py b/common/djangoapps/oauth_exchange/tests/utils.py new file mode 100644 index 0000000000..3ac278e956 --- /dev/null +++ b/common/djangoapps/oauth_exchange/tests/utils.py @@ -0,0 +1,127 @@ +""" +Test utilities for OAuth access token exchange +""" +import json + +import httpretty +import provider.constants +from provider.oauth2.models import Client +from social.apps.django_app.default.models import UserSocialAuth + +from student.tests.factories import UserFactory + + +class AccessTokenExchangeTestMixin(object): + """ + A mixin to define test cases for access token exchange. The following + methods must be implemented by subclasses: + * _assert_error(data, expected_error, expected_error_description) + * _assert_success(data, expected_scopes) + """ + def setUp(self): + super(AccessTokenExchangeTestMixin, self).setUp() + + self.client_id = "test_client_id" + self.oauth_client = Client.objects.create( + client_id=self.client_id, + client_type=provider.constants.PUBLIC + ) + self.social_uid = "test_social_uid" + self.user = UserFactory() + UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid) + self.access_token = "test_access_token" + # Initialize to minimal data + self.data = { + "access_token": self.access_token, + "client_id": self.client_id, + } + + def _setup_provider_response(self, success): + """ + Register a mock response for the third party user information endpoint; + success indicates whether the response status code should be 200 or 400 + """ + if success: + status = 200 + body = json.dumps({self.UID_FIELD: self.social_uid}) + else: + status = 400 + body = json.dumps({}) + httpretty.register_uri( + httpretty.GET, + self.USER_URL, + body=body, + status=status, + content_type="application/json" + ) + + def _assert_error(self, _data, _expected_error, _expected_error_description): + """ + Given request data, execute a test and check that the expected error + was returned (along with any other appropriate assertions). + """ + raise NotImplementedError() + + def _assert_success(self, data, expected_scopes): + """ + Given request data, execute a test and check that the expected scopes + were returned (along with any other appropriate assertions). + """ + raise NotImplementedError() + + def test_minimal(self): + self._setup_provider_response(success=True) + self._assert_success(self.data, expected_scopes=[]) + + def test_scopes(self): + self._setup_provider_response(success=True) + self.data["scope"] = "profile email" + self._assert_success(self.data, expected_scopes=["profile", "email"]) + + def test_missing_fields(self): + for field in ["access_token", "client_id"]: + data = dict(self.data) + del data[field] + self._assert_error(data, "invalid_request", "{} is required".format(field)) + + def test_invalid_client(self): + self.data["client_id"] = "nonexistent_client" + self._assert_error( + self.data, + "invalid_client", + "nonexistent_client is not a valid client_id" + ) + + def test_confidential_client(self): + self.oauth_client.client_type = provider.constants.CONFIDENTIAL + self.oauth_client.save() + self._assert_error( + self.data, + "invalid_client", + "test_client_id is not a public client" + ) + + def test_invalid_acess_token(self): + self._setup_provider_response(success=False) + self._assert_error(self.data, "invalid_grant", "access_token is not valid") + + def test_no_linked_user(self): + UserSocialAuth.objects.all().delete() + self._setup_provider_response(success=True) + self._assert_error(self.data, "invalid_grant", "access_token is not valid") + + +class AccessTokenExchangeMixinFacebook(object): + """Tests access token exchange with the Facebook backend""" + BACKEND = "facebook" + USER_URL = "https://graph.facebook.com/me" + # In facebook responses, the "id" field is used as the user's identifier + UID_FIELD = "id" + + +class AccessTokenExchangeMixinGoogle(object): + """Tests access token exchange with the Google backend""" + BACKEND = "google-oauth2" + USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo" + # In google-oauth2 responses, the "email" field is used as the user's identifier + UID_FIELD = "email" diff --git a/common/djangoapps/oauth_exchange/views.py b/common/djangoapps/oauth_exchange/views.py new file mode 100644 index 0000000000..aae6d2b12c --- /dev/null +++ b/common/djangoapps/oauth_exchange/views.py @@ -0,0 +1,37 @@ +""" +Views to support third-party to first-party OAuth 2.0 access token exchange +""" +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt +from provider import constants +from provider.oauth2.views import AccessTokenView as AccessTokenView +import social.apps.django_app.utils as social_utils + +from oauth_exchange.forms import AccessTokenExchangeForm + + +class AccessTokenExchangeView(AccessTokenView): + """View for access token exchange""" + @method_decorator(csrf_exempt) + @method_decorator(social_utils.strategy("social:complete")) + def dispatch(self, *args, **kwargs): + return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs) + + def get(self, request, _backend): + return super(AccessTokenExchangeView, self).get(request) + + def post(self, request, _backend): + form = AccessTokenExchangeForm(request=request, data=request.POST) + if not form.is_valid(): + return self.error_response(form.errors) + + user = form.cleaned_data["user"] + scope = form.cleaned_data["scope"] + client = form.cleaned_data["client"] + + if constants.SINGLE_ACCESS_TOKEN: + edx_access_token = self.get_access_token(request, user, scope, client) + else: + edx_access_token = self.create_access_token(request, user, scope, client) + + return self.access_token_response(edx_access_token) diff --git a/lms/envs/common.py b/lms/envs/common.py index c48ea4724e..20412451ca 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1548,6 +1548,8 @@ INSTALLED_APPS = ( 'provider.oauth2', 'oauth2_provider', + 'oauth_exchange', + # For the wiki 'wiki', # The new django-wiki from benjaoming 'django_notify', diff --git a/lms/urls.py b/lms/urls.py index a7948e3212..1004f78b81 100644 --- a/lms/urls.py +++ b/lms/urls.py @@ -5,6 +5,7 @@ from django.conf.urls.static import static import django.contrib.auth.views from microsite_configuration import microsite +import oauth_exchange.views # Uncomment the next two lines to enable the admin: if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'): @@ -585,6 +586,11 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'): if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): urlpatterns += ( url(r'', include('third_party_auth.urls')), + url( + r'^oauth2/exchange_access_token/(?P[^/]+)/$', + oauth_exchange.views.AccessTokenExchangeView.as_view(), + name="exchange_access_token" + ), url(r'^login_oauth_token/(?P[^/]+)/$', 'student.views.login_oauth_token'), )