diff --git a/common/djangoapps/oauth_exchange/__init__.py b/common/djangoapps/auth_exchange/__init__.py similarity index 100% rename from common/djangoapps/oauth_exchange/__init__.py rename to common/djangoapps/auth_exchange/__init__.py diff --git a/common/djangoapps/oauth_exchange/forms.py b/common/djangoapps/auth_exchange/forms.py similarity index 95% rename from common/djangoapps/oauth_exchange/forms.py rename to common/djangoapps/auth_exchange/forms.py index 772de6d156..2391486f99 100644 --- a/common/djangoapps/oauth_exchange/forms.py +++ b/common/djangoapps/auth_exchange/forms.py @@ -39,9 +39,15 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm): return field_val def clean_access_token(self): + """ + Validates and returns the "access_token" field. + """ return self._require_oauth_field("access_token") def clean_client_id(self): + """ + Validates and returns the "client_id" field. + """ return self._require_oauth_field("client_id") def clean(self): diff --git a/common/djangoapps/oauth_exchange/models.py b/common/djangoapps/auth_exchange/models.py similarity index 100% rename from common/djangoapps/oauth_exchange/models.py rename to common/djangoapps/auth_exchange/models.py diff --git a/common/djangoapps/oauth_exchange/tests/__init__.py b/common/djangoapps/auth_exchange/tests/__init__.py similarity index 100% rename from common/djangoapps/oauth_exchange/tests/__init__.py rename to common/djangoapps/auth_exchange/tests/__init__.py diff --git a/common/djangoapps/oauth_exchange/tests/test_forms.py b/common/djangoapps/auth_exchange/tests/test_forms.py similarity index 94% rename from common/djangoapps/oauth_exchange/tests/test_forms.py rename to common/djangoapps/auth_exchange/tests/test_forms.py index 392935c6ab..c89fec9d5f 100644 --- a/common/djangoapps/oauth_exchange/tests/test_forms.py +++ b/common/djangoapps/auth_exchange/tests/test_forms.py @@ -1,3 +1,4 @@ +# pylint: disable=no-member """ Tests for OAuth token exchange forms """ @@ -11,8 +12,8 @@ import httpretty from provider import scope import social.apps.django_app.utils as social_utils -from oauth_exchange.forms import AccessTokenExchangeForm -from oauth_exchange.tests.utils import AccessTokenExchangeTestMixin +from auth_exchange.forms import AccessTokenExchangeForm +from auth_exchange.tests.utils import AccessTokenExchangeTestMixin from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle diff --git a/common/djangoapps/oauth_exchange/tests/test_views.py b/common/djangoapps/auth_exchange/tests/test_views.py similarity index 68% rename from common/djangoapps/oauth_exchange/tests/test_views.py rename to common/djangoapps/auth_exchange/tests/test_views.py index d58aa6edc4..7b7ca5a3a1 100644 --- a/common/djangoapps/oauth_exchange/tests/test_views.py +++ b/common/djangoapps/auth_exchange/tests/test_views.py @@ -1,3 +1,4 @@ +# pylint: disable=no-member """ Tests for OAuth token exchange views """ @@ -12,9 +13,10 @@ from django.test import TestCase import httpretty import provider.constants from provider import scope -from provider.oauth2.models import AccessToken +from provider.oauth2.models import AccessToken, Client -from oauth_exchange.tests.utils import AccessTokenExchangeTestMixin +from auth_exchange.tests.utils import AccessTokenExchangeTestMixin +from student.tests.factories import UserFactory from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle @@ -55,12 +57,15 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): def test_single_access_token(self): def extract_token(response): + """ + Returns the access token from the response payload. + """ return json.loads(response.content)["access_token"] self._setup_provider_response(success=True) for single_access_token in [True, False]: with mock.patch( - "oauth_exchange.views.constants.SINGLE_ACCESS_TOKEN", + "auth_exchange.views.constants.SINGLE_ACCESS_TOKEN", single_access_token ): first_response = self.client.post(self.url, self.data) @@ -113,3 +118,38 @@ class AccessTokenExchangeViewTestGoogle( Tests for AccessTokenExchangeView used with Google """ pass + + +@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled") +class TestLoginWithAccessTokenView(TestCase): + """ + Tests for LoginWithAccessTokenView + """ + def setUp(self): + super(TestLoginWithAccessTokenView, self).setUp() + self.user = UserFactory() + self.oauth2_client = Client.objects.create(client_type=provider.constants.CONFIDENTIAL) + + def _verify_response(self, access_token, expected_status_code, expected_num_cookies): + """ + Calls the login_with_access_token endpoint and verifies the response given the expected values. + """ + url = reverse("login_with_access_token") + response = self.client.post(url, HTTP_AUTHORIZATION="Bearer {0}".format(access_token)) + self.assertEqual(response.status_code, expected_status_code) + self.assertEqual(len(response.cookies), expected_num_cookies) + + def test_success(self): + access_token = AccessToken.objects.create( + token="test_access_token", + client=self.oauth2_client, + user=self.user, + ) + self._verify_response(access_token, expected_status_code=204, expected_num_cookies=1) + self.assertEqual(len(self.client.cookies), 1) + self.assertEqual(self.client.session['_auth_user_id'], self.user.id) + + def test_unauthenticated(self): + self._verify_response("invalid_token", expected_status_code=401, expected_num_cookies=0) + self.assertEqual(len(self.client.cookies), 0) + self.assertNotIn("session_key", self.client.session) diff --git a/common/djangoapps/oauth_exchange/tests/utils.py b/common/djangoapps/auth_exchange/tests/utils.py similarity index 98% rename from common/djangoapps/oauth_exchange/tests/utils.py rename to common/djangoapps/auth_exchange/tests/utils.py index 74cb02604b..60608f292f 100644 --- a/common/djangoapps/oauth_exchange/tests/utils.py +++ b/common/djangoapps/auth_exchange/tests/utils.py @@ -14,7 +14,7 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin): * _assert_error(data, expected_error, expected_error_description) * _assert_success(data, expected_scopes) """ - def setUp(self): + def setUp(self): # pylint: disable=arguments-differ super(AccessTokenExchangeTestMixin, self).setUp() # Initialize to minimal data diff --git a/common/djangoapps/auth_exchange/views.py b/common/djangoapps/auth_exchange/views.py new file mode 100644 index 0000000000..abd31de9ec --- /dev/null +++ b/common/djangoapps/auth_exchange/views.py @@ -0,0 +1,85 @@ +# pylint: disable=abstract-method +""" +Views to support exchange of authentication credentials. +The following are currently implemented: + 1. AccessTokenExchangeView: + 3rd party (social-auth) OAuth 2.0 access token -> 1st party (open-edx) OAuth 2.0 access token + 2. LoginWithAccessTokenView: + 1st party (open-edx) OAuth 2.0 access token -> session cookie +""" +from django.conf import settings +from django.contrib.auth import login +import django.contrib.auth as auth +from django.http import HttpResponse +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt +from provider import constants +from provider.oauth2.views import AccessTokenView as AccessTokenView +from rest_framework import permissions +from rest_framework.views import APIView +import social.apps.django_app.utils as social_utils + +from auth_exchange.forms import AccessTokenExchangeForm +from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser + + +class AccessTokenExchangeView(AccessTokenView): + """ + View for token exchange from 3rd party OAuth access token to 1st party OAuth access token + """ + @method_decorator(csrf_exempt) + @method_decorator(social_utils.strategy("social:complete")) + def dispatch(self, *args, **kwargs): + return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs) + + def get(self, request, _backend): # pylint: disable=arguments-differ + return super(AccessTokenExchangeView, self).get(request) + + def post(self, request, _backend): # pylint: disable=arguments-differ + form = AccessTokenExchangeForm(request=request, data=request.POST) + if not form.is_valid(): + return self.error_response(form.errors) + + user = form.cleaned_data["user"] + scope = form.cleaned_data["scope"] + client = form.cleaned_data["client"] + + if constants.SINGLE_ACCESS_TOKEN: + edx_access_token = self.get_access_token(request, user, scope, client) + else: + edx_access_token = self.create_access_token(request, user, scope, client) + + return self.access_token_response(edx_access_token) + + +class LoginWithAccessTokenView(APIView): + """ + View for exchanging an access token for session cookies + """ + authentication_classes = (OAuth2AuthenticationAllowInactiveUser,) + permission_classes = (permissions.IsAuthenticated,) + + @staticmethod + def _get_path_of_arbitrary_backend_for_user(user): + """ + Return the path to the first found authentication backend that recognizes the given user. + """ + for backend_path in settings.AUTHENTICATION_BACKENDS: + backend = auth.load_backend(backend_path) + if backend.get_user(user.id): + return backend_path + + @method_decorator(csrf_exempt) + def post(self, request): + """ + Handler for the POST method to this view. + """ + # The django login method stores the user's id in request.session[SESSION_KEY] and the + # path to the user's authentication backend in request.session[BACKEND_SESSION_KEY]. + # The login method assumes the backend path had been previously stored in request.user.backend + # in the 'authenticate' call. However, not all authentication providers do so. + # So we explicitly populate the request.user.backend field here. + if not hasattr(request.user, 'backend'): + request.user.backend = self._get_path_of_arbitrary_backend_for_user(request.user) + login(request, request.user) # login generates and stores the user's cookies in the session + return HttpResponse(status=204) # cookies stored in the session are returned with the response diff --git a/common/djangoapps/oauth_exchange/views.py b/common/djangoapps/oauth_exchange/views.py deleted file mode 100644 index aae6d2b12c..0000000000 --- a/common/djangoapps/oauth_exchange/views.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Views to support third-party to first-party OAuth 2.0 access token exchange -""" -from django.utils.decorators import method_decorator -from django.views.decorators.csrf import csrf_exempt -from provider import constants -from provider.oauth2.views import AccessTokenView as AccessTokenView -import social.apps.django_app.utils as social_utils - -from oauth_exchange.forms import AccessTokenExchangeForm - - -class AccessTokenExchangeView(AccessTokenView): - """View for access token exchange""" - @method_decorator(csrf_exempt) - @method_decorator(social_utils.strategy("social:complete")) - def dispatch(self, *args, **kwargs): - return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs) - - def get(self, request, _backend): - return super(AccessTokenExchangeView, self).get(request) - - def post(self, request, _backend): - form = AccessTokenExchangeForm(request=request, data=request.POST) - if not form.is_valid(): - return self.error_response(form.errors) - - user = form.cleaned_data["user"] - scope = form.cleaned_data["scope"] - client = form.cleaned_data["client"] - - if constants.SINGLE_ACCESS_TOKEN: - edx_access_token = self.get_access_token(request, user, scope, client) - else: - edx_access_token = self.create_access_token(request, user, scope, client) - - return self.access_token_response(edx_access_token) diff --git a/lms/envs/common.py b/lms/envs/common.py index d43aa24411..01492633f0 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1665,7 +1665,7 @@ INSTALLED_APPS = ( 'provider.oauth2', 'oauth2_provider', - 'oauth_exchange', + 'auth_exchange', # For the wiki 'wiki', # The new django-wiki from benjaoming diff --git a/lms/urls.py b/lms/urls.py index 99120f1834..4e026d30cc 100644 --- a/lms/urls.py +++ b/lms/urls.py @@ -5,7 +5,7 @@ from django.conf.urls.static import static import django.contrib.auth.views from microsite_configuration import microsite -import oauth_exchange.views +import auth_exchange.views # Uncomment the next two lines to enable the admin: if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'): @@ -611,12 +611,20 @@ if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): ) # OAuth token exchange -if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH') and settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'): +if settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'): + if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): + urlpatterns += ( + url( + r'^oauth2/exchange_access_token/(?P[^/]+)/$', + auth_exchange.views.AccessTokenExchangeView.as_view(), + name="exchange_access_token" + ), + ) urlpatterns += ( url( - r'^oauth2/exchange_access_token/(?P[^/]+)/$', - oauth_exchange.views.AccessTokenExchangeView.as_view(), - name="exchange_access_token" + r'^oauth2/login/$', + auth_exchange.views.LoginWithAccessTokenView.as_view(), + name="login_with_access_token" ), )