Merge pull request #16279 from edx/LEARNER-717-2
Management Command to clear DOT expired data
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from time import sleep
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.utils import timezone
|
||||
from oauth2_provider.models import AccessToken, Grant, RefreshToken
|
||||
from oauth2_provider.settings import oauth2_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Clear expired access tokens and refresh tokens for Django OAuth Toolkit"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--batch_size',
|
||||
action='store',
|
||||
dest='batch_size',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Maximum number of database rows to delete per query. '
|
||||
'This helps avoid locking the database when deleting large amounts of data.')
|
||||
parser.add_argument('--sleep_time',
|
||||
action='store',
|
||||
dest='sleep_time',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Sleep time between deletion of batches')
|
||||
|
||||
def clear_table_data(self, query_set, batch_size, model, sleep_time):
|
||||
message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__)
|
||||
logger.info(message)
|
||||
while query_set.exists():
|
||||
qs = query_set[:batch_size]
|
||||
batch_ids = qs.values_list('id', flat=True)
|
||||
with transaction.atomic():
|
||||
model.objects.filter(pk__in=list(batch_ids)).delete()
|
||||
|
||||
if query_set.exists():
|
||||
sleep(sleep_time)
|
||||
|
||||
def get_expiration_time(self, now):
|
||||
refresh_token_expire_seconds = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS
|
||||
if not isinstance(refresh_token_expire_seconds, timedelta):
|
||||
try:
|
||||
refresh_token_expire_seconds = timedelta(seconds=refresh_token_expire_seconds)
|
||||
except TypeError:
|
||||
e = "REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds"
|
||||
raise ImproperlyConfigured(e)
|
||||
return now - refresh_token_expire_seconds
|
||||
|
||||
def handle(self, *args, **options):
|
||||
batch_size = options['batch_size']
|
||||
sleep_time = options['sleep_time']
|
||||
|
||||
now = timezone.now()
|
||||
refresh_expire_at = self.get_expiration_time(now)
|
||||
|
||||
query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at)
|
||||
self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time)
|
||||
|
||||
query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now)
|
||||
self.clear_table_data(query_set, batch_size, AccessToken, sleep_time)
|
||||
|
||||
query_set = Grant.objects.filter(expires__lt=now)
|
||||
self.clear_table_data(query_set, batch_size, Grant, sleep_time)
|
||||
@@ -0,0 +1,66 @@
|
||||
from datetime import timedelta
|
||||
import unittest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.management import call_command
|
||||
from django.db.models import QuerySet
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from mock import patch
|
||||
from oauth2_provider.models import AccessToken
|
||||
from testfixtures import LogCapture
|
||||
|
||||
from openedx.core.djangoapps.oauth_dispatch.tests import factories
|
||||
from student.tests.factories import UserFactory
|
||||
|
||||
LOGGER_NAME = 'openedx.core.djangoapps.oauth_dispatch.management.commands.edx_clear_expired_tokens'
|
||||
|
||||
|
||||
def counter(fn):
|
||||
"""
|
||||
Adds a call counter to the given function.
|
||||
Source: http://code.activestate.com/recipes/577534-counting-decorator/
|
||||
"""
|
||||
def _counted(*largs, **kargs):
|
||||
_counted.invocations += 1
|
||||
fn(*largs, **kargs)
|
||||
|
||||
_counted.invocations = 0
|
||||
return _counted
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class EdxClearExpiredTokensTests(TestCase):
|
||||
|
||||
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz')
|
||||
def test_invalid_expiration_time(self):
|
||||
with LogCapture(LOGGER_NAME) as log:
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
call_command('edx_clear_expired_tokens')
|
||||
log.check(
|
||||
(
|
||||
LOGGER_NAME,
|
||||
'EXCEPTION',
|
||||
'REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds'
|
||||
)
|
||||
)
|
||||
|
||||
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 3600)
|
||||
def test_clear_expired_tokens(self):
|
||||
initial_count = 5
|
||||
now = timezone.now()
|
||||
expires = now - timedelta(days=1)
|
||||
users = UserFactory.create_batch(initial_count)
|
||||
for user in users:
|
||||
application = factories.ApplicationFactory(user=user)
|
||||
factories.AccessTokenFactory(user=user, application=application, expires=expires)
|
||||
self.assertEqual(
|
||||
AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(),
|
||||
initial_count
|
||||
)
|
||||
QuerySet.delete = counter(QuerySet.delete)
|
||||
|
||||
call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0)
|
||||
self.assertEqual(QuerySet.delete.invocations, initial_count)
|
||||
self.assertEqual(AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), 0)
|
||||
Reference in New Issue
Block a user