Merge pull request #19074 from edx/arch/fix-expiration
Fix overriding of token expiration in DOT (ARCH-246)
This commit is contained in:
@@ -7,7 +7,6 @@ from oauthlib.oauth2.rfc6749.errors import OAuth2Error
|
||||
from oauthlib.oauth2.rfc6749.tokens import BearerToken
|
||||
from oauth2_provider.models import AccessToken as dot_access_token
|
||||
from oauth2_provider.models import RefreshToken as dot_refresh_token
|
||||
from oauth2_provider.oauth2_backends import get_oauthlib_core
|
||||
from oauth2_provider.settings import oauth2_settings as dot_settings
|
||||
from provider.oauth2.models import AccessToken as dop_access_token
|
||||
from provider.oauth2.models import RefreshToken as dop_refresh_token
|
||||
@@ -51,8 +50,8 @@ def refresh_dot_access_token(request, client_id, refresh_token, expires_in=None)
|
||||
Create and return a new (persisted) access token, given a previously created
|
||||
refresh_token, possibly returned from create_dot_access_token above.
|
||||
"""
|
||||
auth_core = get_oauthlib_core()
|
||||
expires_in = _get_expires_in_value(expires_in)
|
||||
auth_core = _get_oauthlib_core(expires_in)
|
||||
_populate_refresh_token_request(request, client_id, refresh_token)
|
||||
|
||||
# Note: Unlike create_dot_access_token, we use the top-level auth library
|
||||
@@ -70,13 +69,7 @@ def _get_expires_in_value(expires_in):
|
||||
"""
|
||||
Returns the expires_in value to use for the token.
|
||||
"""
|
||||
# TODO (ARCH-246) Fix expiration configuration as this does not actually
|
||||
# override the token's expiration. Rather, DOT's save_bearer_token method
|
||||
# will always use dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS.
|
||||
if not expires_in:
|
||||
seconds_in_a_day = 24 * 60 * 60
|
||||
expires_in = settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * seconds_in_a_day
|
||||
return expires_in
|
||||
return expires_in or dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
|
||||
|
||||
def _populate_create_access_token_request(request, user, client, scope=None):
|
||||
@@ -105,3 +98,13 @@ def _populate_refresh_token_request(request, client_id, refresh_token):
|
||||
refresh_token=refresh_token,
|
||||
grant_type='refresh_token',
|
||||
)
|
||||
|
||||
|
||||
def _get_oauthlib_core(expires_in):
|
||||
"""
|
||||
Based on oauth2_provider.oauth2_backends.get_oauthlib_core, but allows
|
||||
passing in a value for token_expires_in.
|
||||
"""
|
||||
validator = dot_settings.OAUTH2_VALIDATOR_CLASS()
|
||||
server = dot_settings.OAUTH2_SERVER_CLASS(validator, token_expires_in=expires_in)
|
||||
return dot_settings.OAUTH2_BACKEND_CLASS(server)
|
||||
|
||||
@@ -3,7 +3,7 @@ Classes that override default django-oauth-toolkit behavior
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.db.models.signals import pre_save
|
||||
@@ -82,21 +82,9 @@ class EdxOAuth2Validator(OAuth2Validator):
|
||||
|
||||
super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
|
||||
|
||||
if RestrictedApplication.should_expire_access_token(request.client):
|
||||
# Since RestrictedApplications will override the DOT defined expiry, so that access_tokens
|
||||
# are always expired, we need to re-read the token from the database and then calculate the
|
||||
# expires_in (in seconds) from what we stored in the database. This value should be a negative
|
||||
#value, meaning that it is already expired
|
||||
|
||||
access_token = AccessToken.objects.get(token=token['access_token'])
|
||||
utc_now = datetime.utcnow().replace(tzinfo=utc)
|
||||
expires_in = (access_token.expires - utc_now).total_seconds()
|
||||
|
||||
# assert that RestrictedApplications only issue expired tokens
|
||||
# blow up processing if we see otherwise
|
||||
assert expires_in < 0
|
||||
|
||||
token['expires_in'] = expires_in
|
||||
is_restricted_client = self._update_token_expiry_if_restricted_client(token, request.client)
|
||||
if not is_restricted_client:
|
||||
self._update_token_expiry_if_overridden_in_request(token, request)
|
||||
|
||||
# Restore the original request attributes
|
||||
request.grant_type = grant_type
|
||||
@@ -108,3 +96,43 @@ class EdxOAuth2Validator(OAuth2Validator):
|
||||
"""
|
||||
available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request)
|
||||
return set(scopes).issubset(set(available_scopes))
|
||||
|
||||
def _update_token_expiry_if_restricted_client(self, token, client):
|
||||
"""
|
||||
Update the token's expires_in value if the given client is a
|
||||
RestrictedApplication and return whether the given client is restricted.
|
||||
"""
|
||||
# Since RestrictedApplications override the DOT defined expiry such that
|
||||
# access_tokens are always expired, re-read the token from the database
|
||||
# and calculate expires_in (in seconds) from the database value. This
|
||||
# value should be a negative value, meaning that it is already expired.
|
||||
if RestrictedApplication.should_expire_access_token(client):
|
||||
access_token = AccessToken.objects.get(token=token['access_token'])
|
||||
expires_in = (access_token.expires - _get_utc_now()).total_seconds()
|
||||
assert expires_in < 0
|
||||
token['expires_in'] = expires_in
|
||||
return True
|
||||
|
||||
def _update_token_expiry_if_overridden_in_request(self, token, request):
|
||||
"""
|
||||
Update the token's expires_in value if the request specifies an
|
||||
expiration value and update the expires value on the stored AccessToken
|
||||
object.
|
||||
|
||||
This is needed since DOT's save_bearer_token method always uses
|
||||
the dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS value instead of applying
|
||||
the requesting expiration value.
|
||||
"""
|
||||
expires_in = getattr(request, 'expires_in', None)
|
||||
if expires_in:
|
||||
access_token = AccessToken.objects.get(token=token['access_token'])
|
||||
access_token.expires = _get_utc_now() + timedelta(seconds=expires_in)
|
||||
access_token.save()
|
||||
token['expires_in'] = expires_in
|
||||
|
||||
|
||||
def _get_utc_now():
|
||||
"""
|
||||
Return current time in UTC.
|
||||
"""
|
||||
return datetime.utcnow().replace(tzinfo=utc)
|
||||
|
||||
@@ -30,7 +30,6 @@ class TestOAuthDispatchAPI(TestCase):
|
||||
redirect_uri=DUMMY_REDIRECT_URL,
|
||||
client_id='public-client-id',
|
||||
)
|
||||
self.request = HttpRequest()
|
||||
|
||||
def _assert_stored_token(self, stored_token_value, expected_token_user, expected_client):
|
||||
stored_access_token = AccessToken.objects.get(token=stored_token_value)
|
||||
@@ -39,7 +38,7 @@ class TestOAuthDispatchAPI(TestCase):
|
||||
self.assertEqual(stored_access_token.application.user.id, expected_client.user.id)
|
||||
|
||||
def test_create_token_success(self):
|
||||
token = api.create_dot_access_token(self.request, self.user, self.client)
|
||||
token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
|
||||
self.assertTrue(token['access_token'])
|
||||
self.assertTrue(token['refresh_token'])
|
||||
self.assertDictContainsSubset(
|
||||
@@ -54,20 +53,18 @@ class TestOAuthDispatchAPI(TestCase):
|
||||
|
||||
def test_create_token_another_user(self):
|
||||
another_user = UserFactory()
|
||||
token = api.create_dot_access_token(self.request, another_user, self.client)
|
||||
token = api.create_dot_access_token(HttpRequest(), another_user, self.client)
|
||||
self._assert_stored_token(token['access_token'], another_user, self.client)
|
||||
|
||||
def test_create_token_overrides(self):
|
||||
expires_in = 4800
|
||||
token = api.create_dot_access_token(self.request, self.user, self.client, expires_in=expires_in, scope=2)
|
||||
token = api.create_dot_access_token(HttpRequest(), self.user, self.client, expires_in=expires_in, scope=2)
|
||||
self.assertDictContainsSubset({u'scope': u'profile'}, token)
|
||||
with self.assertRaises(AssertionError): # TODO (ARCH-246) expiration override does not actually work
|
||||
self.assertDictContainsSubset({u'expires_in': expires_in}, token)
|
||||
self.assertDictContainsSubset({u'expires_in': EXPECTED_DEFAULT_EXPIRES_IN}, token)
|
||||
self.assertDictContainsSubset({u'expires_in': expires_in}, token)
|
||||
|
||||
def test_refresh_token_success(self):
|
||||
old_token = api.create_dot_access_token(self.request, self.user, self.client)
|
||||
new_token = api.refresh_dot_access_token(self.request, self.client.client_id, old_token['refresh_token'])
|
||||
old_token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
|
||||
new_token = api.refresh_dot_access_token(HttpRequest(), self.client.client_id, old_token['refresh_token'])
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
u'token_type': u'Bearer',
|
||||
@@ -87,17 +84,17 @@ class TestOAuthDispatchAPI(TestCase):
|
||||
self._assert_stored_token(new_token['access_token'], self.user, self.client)
|
||||
|
||||
def test_refresh_token_invalid_client(self):
|
||||
token = api.create_dot_access_token(self.request, self.user, self.client)
|
||||
token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
|
||||
with self.assertRaises(api.OAuth2Error) as error:
|
||||
api.refresh_dot_access_token(
|
||||
self.request, 'invalid_client_id', token['refresh_token'],
|
||||
HttpRequest(), 'invalid_client_id', token['refresh_token'],
|
||||
)
|
||||
self.assertIn('invalid_client', error.exception.description)
|
||||
|
||||
def test_refresh_token_invalid_token(self):
|
||||
api.create_dot_access_token(self.request, self.user, self.client)
|
||||
api.create_dot_access_token(HttpRequest(), self.user, self.client)
|
||||
with self.assertRaises(api.OAuth2Error) as error:
|
||||
api.refresh_dot_access_token(
|
||||
self.request, self.client.client_id, 'invalid_refresh_token',
|
||||
HttpRequest(), self.client.client_id, 'invalid_refresh_token',
|
||||
)
|
||||
self.assertIn('invalid_grant', error.exception.description)
|
||||
|
||||
Reference in New Issue
Block a user