changed oauth2 authentication class in bookmarks app (#22908)

* Added new authentication class(meets drf standards)

The new class replaces the deprecated oauth2authetnication class from rest_framework_auth library(repo django-rest-framework-oauth).
Majority of the code is combination of copy-pasta from old oauth2authentication class and Oauth2AuthenticationAllowInactiveUser class

* Added ability to switch to new authentication class in bookmarks app

* Changed error type reported by Outh class. It now outputs a json rather than a string.
This commit is contained in:
Manjinder Singh
2020-02-04 08:49:26 -05:00
committed by GitHub
parent f09e9fdc57
commit e0981025b2
3 changed files with 189 additions and 28 deletions

View File

@@ -22,8 +22,8 @@ from rest_framework.authentication import SessionAuthentication
from rest_framework.generics import ListCreateAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
from openedx.core.lib.api.authentication import OAuth2AuthenticationDeprecated
from openedx.core.lib.api.authentication import OAuth2AuthenticationDeprecated, OAuth2Authentication
from openedx.core.djangoapps.bookmarks.api import BookmarksLimitReachedError
from openedx.core.lib.api.permissions import IsUserInUrl
from openedx.core.lib.url_utils import unquote_slashes
@@ -34,6 +34,22 @@ from .serializers import BookmarkSerializer
log = logging.getLogger(__name__)
# .. toggle_name: BOOKMARKS_USE_NEW_OAUTH2_CLASS
# .. toggle_implementation: DjangoSetting
# .. toggle_default: False
# .. toggle_description: Toggle for replacing OAuth2AuthenticationDeprecated with OAuth2Authentication for bookmarks.
# .. toggle_category: n/a
# .. toggle_use_cases: Monitored Rollout
# .. toggle_creation_date: 2020-01-31
# .. toggle_expiration_date: 2020-02-28
# .. toggle_warnings: None
# .. toggle_tickets: BOM-1037
# .. toggle_status: supported
if getattr(settings, "BOOKMARKS_USE_NEW_OAUTH2_CLASS", False):
_bookmarks_configured_authentication_classes = (OAuth2Authentication, SessionAuthentication)
else:
_bookmarks_configured_authentication_classes = (OAuth2AuthenticationDeprecated, SessionAuthentication)
# Default error message for user
DEFAULT_USER_MESSAGE = ugettext_noop(u'An error has occurred. Please try again.')
@@ -99,7 +115,7 @@ class BookmarksViewMixin(object):
class BookmarksListView(ListCreateAPIView, BookmarksViewMixin):
"""REST endpoints for lists of bookmarks."""
authentication_classes = (OAuth2AuthenticationDeprecated, SessionAuthentication)
authentication_classes = _bookmarks_configured_authentication_classes
pagination_class = BookmarksPagination
permission_classes = (permissions.IsAuthenticated,)
serializer_class = BookmarkSerializer
@@ -290,7 +306,8 @@ class BookmarksDetailView(APIView, BookmarksViewMixin):
to a requesting user's bookmark a 404 is returned. 404 will also be returned
if the bookmark does not exist.
"""
authentication_classes = (OAuth2AuthenticationDeprecated, SessionAuthentication)
authentication_classes = _bookmarks_configured_authentication_classes
permission_classes = (permissions.IsAuthenticated, IsUserInUrl)
serializer_class = BookmarkSerializer

View File

@@ -7,21 +7,22 @@ import django.utils.timezone
from oauth2_provider import models as dot_models
from provider.oauth2 import models as dop_models
from rest_framework.exceptions import AuthenticationFailed
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework_oauth.authentication import OAuth2Authentication as OAuth2AuthenticationDeprecatedBase
from rest_framework.authentication import BaseAuthentication, get_authorization_header
from edx_django_utils.monitoring import set_custom_metric
OAUTH2_TOKEN_ERROR = u'token_error'
OAUTH2_TOKEN_ERROR_EXPIRED = u'token_expired'
OAUTH2_TOKEN_ERROR_MALFORMED = u'token_malformed'
OAUTH2_TOKEN_ERROR_NONEXISTENT = u'token_nonexistent'
OAUTH2_TOKEN_ERROR_NOT_PROVIDED = u'token_not_provided'
OAUTH2_TOKEN_ERROR = 'token_error'
OAUTH2_TOKEN_ERROR_EXPIRED = 'token_expired'
OAUTH2_TOKEN_ERROR_MALFORMED = 'token_malformed'
OAUTH2_TOKEN_ERROR_NONEXISTENT = 'token_nonexistent'
OAUTH2_TOKEN_ERROR_NOT_PROVIDED = 'token_not_provided'
OAUTH2_USER_NOT_ACTIVE_ERROR = 'user_not_active'
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class OAuth2AuthenticationDeprecated(OAuth2Authentication):
class OAuth2AuthenticationDeprecated(OAuth2AuthenticationDeprecatedBase):
"""
This child class was added to add new_relic metrics to OAuth2Authentication. This should be very temporary.
"""
@@ -125,3 +126,128 @@ class OAuth2AuthenticationAllowInactiveUser(OAuth2AuthenticationDeprecated):
"""
token_query = dot_models.AccessToken.objects.select_related('user')
return token_query.filter(token=access_token).first()
class OAuth2Authentication(BaseAuthentication):
"""
OAuth 2 authentication backend using either `django-oauth2-provider` or 'django-oauth-toolkit'
"""
www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns tuple (user, token) if access token authentication succeeds,
returns None if the user did not try to authenticate using an access
token, or raises an AuthenticationFailed (HTTP 401) if authentication
fails.
"""
set_custom_metric("OAuth2Authentication", "Failed") # default value
auth = get_authorization_header(request).split()
if len(auth) == 1:
raise AuthenticationFailed({
'error_code': OAUTH2_TOKEN_ERROR_NOT_PROVIDED,
'developer_message': 'Invalid token header. No credentials provided.'})
elif len(auth) > 2:
raise AuthenticationFailed({
'error_code': OAUTH2_TOKEN_ERROR_MALFORMED,
'developer_message': 'Invalid token header. Token string should not contain spaces.'})
if auth and auth[0].lower() == b'bearer':
access_token = auth[1].decode('utf8')
set_custom_metric('OAuth2Authentication_token_location', 'bearer-in-header')
elif 'access_token' in request.POST:
access_token = request.POST['access_token']
set_custom_metric('OAuth2Authentication_token_location', 'post-token')
else:
set_custom_metric("OAuth2Authentication", "None")
return None
user, token = self.authenticate_credentials(access_token)
set_custom_metric("OAuth2Authentication", "Success")
return user, token
def authenticate_credentials(self, access_token):
"""
Authenticate the request, given the access token.
Overrides base class implementation to discard failure if user is
inactive.
"""
try:
token = self.get_access_token(access_token)
except AuthenticationFailed as exc:
raise AuthenticationFailed({
u'error_code': OAUTH2_TOKEN_ERROR,
u'developer_message': exc.detail
})
if not token:
raise AuthenticationFailed({
'error_code': OAUTH2_TOKEN_ERROR_NONEXISTENT,
'developer_message': 'The provided access token does not match any valid tokens.'
})
elif token.expires < django.utils.timezone.now():
raise AuthenticationFailed({
'error_code': OAUTH2_TOKEN_ERROR_EXPIRED,
'developer_message': 'The provided access token has expired and is no longer valid.',
})
else:
user = token.user
# Check to make sure the users have activated their account(by confirming their email)
if not user.is_active:
set_custom_metric("OAuth2Authentication_user_active", False)
msg = 'User inactive or deleted: %s' % user.get_username()
raise AuthenticationFailed({
'error_code': OAUTH2_USER_NOT_ACTIVE_ERROR,
'developer_message': msg})
else:
set_custom_metric("OAuth2Authentication_user_active", True)
return user, token
def get_access_token(self, access_token):
"""
Return a valid access token that exists in one of our OAuth2 libraries,
or None if no matching token is found.
"""
dot_token_return = self._get_dot_token(access_token)
if dot_token_return is not None:
set_custom_metric('OAuth2Authentication_token_type', 'dot')
return dot_token_return
dop_token_return = self._get_dop_token(access_token)
if dop_token_return is not None:
set_custom_metric('OAuth2Authentication_token_type', 'dop')
return dop_token_return
set_custom_metric('OAuth2Authentication_token_type', 'None')
return None
def _get_dop_token(self, access_token):
"""
Return a valid access token stored by django-oauth2-provider (DOP), or
None if no matching token is found.
"""
token_query = dop_models.AccessToken.objects.select_related('user')
return token_query.filter(token=access_token).first()
def _get_dot_token(self, access_token):
"""
Return a valid access token stored by django-oauth-toolkit (DOT), or
None if no matching token is found.
"""
token_query = dot_models.AccessToken.objects.select_related('user')
return token_query.filter(token=access_token).first()
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response
"""
return 'Bearer realm="%s"' % self.www_authenticate_realm

View File

@@ -57,9 +57,14 @@ class OAuth2AuthenticationDebug(authentication.OAuth2AuthenticationAllowInactive
urlpatterns = [
url(r'^oauth2/', include(('provider.oauth2.urls', 'oauth2'), namespace='oauth2')),
url(
r'^oauth2-test/$',
r'^oauth2-deprecated-test/$',
MockView.as_view(authentication_classes=[authentication.OAuth2AuthenticationAllowInactiveUser])
),
url(
r'^oauth2-test/$',
MockView.as_view(authentication_classes=[authentication.OAuth2Authentication])
),
# TODO(jinder): remove url when OAuth2AuthenticationDeprecated is fully removed
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(
r'^oauth2-with-scope-test/$',
@@ -71,14 +76,15 @@ urlpatterns = [
]
@ddt.ddt
@ddt.ddt # pylint: disable=missing-docstring
@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled")
@override_settings(ROOT_URLCONF=__name__)
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
class OAuth2AllowInActiveUsersTests(TestCase):
OAUTH2_BASE_TESTING_URL = '/oauth2-deprecated-test/'
def setUp(self):
super(OAuth2Tests, self).setUp()
super(OAuth2AllowInActiveUsersTests, self).setUp()
self.dop_adapter = adapters.DOPAdapter()
self.dot_adapter = adapters.DOTAdapter()
self.csrf_client = APIClient(enforce_csrf_checks=True)
@@ -172,7 +178,7 @@ class OAuth2Tests(TestCase):
def test_get_form_with_wrong_authorization_header_token_type_failing(self, params):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
response = self.csrf_client.get(
'/oauth2-test/',
self.OAUTH2_BASE_TESTING_URL,
params,
HTTP_AUTHORIZATION='Wrong token-type-obviously'
)
@@ -187,22 +193,23 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
response = self.get_with_bearer_token('/oauth2-test/')
response = self.get_with_bearer_token(self.OAUTH2_BASE_TESTING_URL)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_get_form_passing_auth_with_dot(self):
response = self.get_with_bearer_token('/oauth2-test/', token=self.dot_access_token.token)
response = self.get_with_bearer_token(self.OAUTH2_BASE_TESTING_URL, token=self.dot_access_token.token)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post(
'/oauth2-test/',
self.OAUTH2_BASE_TESTING_URL,
data={'access_token': self.access_token.token}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# TODO(jinder): remove test when OAuth2AuthenticationDeprecated is fully removed
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
@@ -214,7 +221,7 @@ class OAuth2Tests(TestCase):
def test_get_form_failing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test/?%s' % query)
response = self.csrf_client.get(self.OAUTH2_BASE_TESTING_URL + '?%s' % query)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# This case is handled directly by DRF so no error_code is provided (yet).
@@ -223,14 +230,14 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
response = self.post_with_bearer_token('/oauth2-test/')
response = self.post_with_bearer_token(self.OAUTH2_BASE_TESTING_URL)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_token_removed_failing_auth(self):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
response = self.post_with_bearer_token('/oauth2-test/')
response = self.post_with_bearer_token(self.OAUTH2_BASE_TESTING_URL)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -240,7 +247,7 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
response = self.post_with_bearer_token('/oauth2-test/', token=self.refresh_token.token)
response = self.post_with_bearer_token(self.OAUTH2_BASE_TESTING_URL, token=self.refresh_token.token)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -252,7 +259,7 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing with expired access token fails with a 'token_expired' error"""
self.access_token.expires = now() - timedelta(seconds=10) # 10 seconds late
self.access_token.save()
response = self.post_with_bearer_token('/oauth2-test/')
response = self.post_with_bearer_token(self.OAUTH2_BASE_TESTING_URL)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -274,7 +281,7 @@ class OAuth2Tests(TestCase):
@ddt.unpack
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_response_for_get_request_with_bad_auth_token(self, http_params, token_error):
response = self.get_with_bearer_token('/oauth2-test/', http_params, token=token_error.token)
response = self.get_with_bearer_token(self.OAUTH2_BASE_TESTING_URL, http_params, token=token_error.token)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -287,7 +294,7 @@ class OAuth2Tests(TestCase):
TokenErrorDDT('', authentication.OAUTH2_TOKEN_ERROR_NOT_PROVIDED),
)
def test_response_for_post_request_with_bad_auth_token(self, token_error):
response = self.post_with_bearer_token('/oauth2-test/', token=token_error.token)
response = self.post_with_bearer_token(self.OAUTH2_BASE_TESTING_URL, token=token_error.token)
self.check_error_codes(response, status_code=status.HTTP_401_UNAUTHORIZED, error_code=token_error.error_code)
ScopeStatusDDT = namedtuple('ScopeStatusDDT', ['scope', 'read_status', 'write_status'])
@@ -304,3 +311,14 @@ class OAuth2Tests(TestCase):
self.assertEqual(response.status_code, scope_statuses.read_status)
response = self.post_with_bearer_token('/oauth2-with-scope-test/', token=self.access_token.token)
self.assertEqual(response.status_code, scope_statuses.write_status)
class OAuth2AuthenticationTests(OAuth2AllowInActiveUsersTests): # pylint: disable=test-inherits-tests
OAUTH2_BASE_TESTING_URL = '/oauth2-test/'
def setUp(self):
super(OAuth2AuthenticationTests, self).setUp()
# Since this is testing back to previous version, user should be set to true
self.user.is_active = True
self.user.save()