diff --git a/openedx/core/djangoapps/auth_exchange/tests/test_views.py b/openedx/core/djangoapps/auth_exchange/tests/test_views.py index 845614feaf..6448c7ffd9 100644 --- a/openedx/core/djangoapps/auth_exchange/tests/test_views.py +++ b/openedx/core/djangoapps/auth_exchange/tests/test_views.py @@ -19,6 +19,7 @@ from provider.oauth2.models import AccessToken, Client from rest_framework.test import APIClient from social_django.models import Partial +from openedx.core.djangoapps.oauth_dispatch.tests import factories as dot_factories from student.tests.factories import UserFactory from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle from .mixins import DOPAdapterMixin, DOTAdapterMixin @@ -185,15 +186,37 @@ class TestLoginWithAccessTokenView(TestCase): if expected_cookie_name: self.assertIn(expected_cookie_name, response.cookies) - def test_success(self): - access_token = AccessToken.objects.create( + def _create_dot_access_token(self, grant_type='Client credentials'): + """ + Create dot based access token + """ + dot_application = dot_factories.ApplicationFactory(user=self.user, authorization_grant_type=grant_type) + return dot_factories.AccessTokenFactory(user=self.user, application=dot_application) + + def _create_dop_access_token(self): + """ + Create dop based access token + """ + return AccessToken.objects.create( token="test_access_token", client=self.oauth2_client, user=self.user, ) + + def test_dop_unsupported(self): + access_token = self._create_dop_access_token() + self._verify_response(access_token, expected_status_code=401) + + def test_invalid_token(self): + self._verify_response("invalid_token", expected_status_code=401) + self.assertNotIn("session_key", self.client.session) + + def test_dot_password_grant_supported(self): + access_token = self._create_dot_access_token(grant_type='password') + self._verify_response(access_token, expected_status_code=204, expected_cookie_name='sessionid') self.assertEqual(int(self.client.session['_auth_user_id']), self.user.id) - def test_unauthenticated(self): - self._verify_response("invalid_token", expected_status_code=401) - self.assertNotIn("session_key", self.client.session) + def test_dot_client_credentials_unsupported(self): + access_token = self._create_dot_access_token() + self._verify_response(access_token, expected_status_code=401) diff --git a/openedx/core/djangoapps/auth_exchange/views.py b/openedx/core/djangoapps/auth_exchange/views.py index e9d6f58a1a..38bf6292fb 100644 --- a/openedx/core/djangoapps/auth_exchange/views.py +++ b/openedx/core/djangoapps/auth_exchange/views.py @@ -17,12 +17,14 @@ from django.http import HttpResponse from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from edx_oauth2_provider.constants import SCOPE_VALUE_DICT +from oauth2_provider import models as dot_models from oauth2_provider.settings import oauth2_settings from oauth2_provider.views.base import TokenView as DOTAccessTokenView from oauthlib.oauth2.rfc6749.tokens import BearerToken from provider import constants from provider.oauth2.views import AccessTokenView as DOPAccessTokenView from rest_framework import permissions +from rest_framework.exceptions import AuthenticationFailed from rest_framework.response import Response from rest_framework.views import APIView @@ -161,6 +163,18 @@ class LoginWithAccessTokenView(APIView): if backend.get_user(user.id): return backend_path + @staticmethod + def _is_grant_password(access_token): + """ + Check if the access token provided is DOT based and has password type grant. + """ + token_query = dot_models.AccessToken.objects.select_related('user') + dot_token = token_query.filter(token=access_token).first() + if dot_token and dot_token.application.authorization_grant_type == dot_models.Application.GRANT_PASSWORD: + return True + + return False + @method_decorator(csrf_exempt) def post(self, request): """ @@ -171,7 +185,15 @@ class LoginWithAccessTokenView(APIView): # The login method assumes the backend path had been previously stored in request.user.backend # in the 'authenticate' call. However, not all authentication providers do so. # So we explicitly populate the request.user.backend field here. + if not hasattr(request.user, 'backend'): request.user.backend = self._get_path_of_arbitrary_backend_for_user(request.user) + + if not self._is_grant_password(request.auth): + raise AuthenticationFailed({ + u'error_code': u'non_supported_token', + u'developer_message': u'Only support DOT type access token with grant type password. ' + }) + login(request, request.user) # login generates and stores the user's cookies in the session return HttpResponse(status=204) # cookies stored in the session are returned with the response