From 790150a8aae1c0a5a6d2333080824695f149df22 Mon Sep 17 00:00:00 2001 From: ayub-khan Date: Thu, 19 Oct 2017 17:13:59 +0500 Subject: [PATCH] Management Command to Clear DOT expired Tokens LEARNER-717 --- .../oauth_dispatch/management/__init__.py | 0 .../management/commands/__init__.py | 0 .../commands/edx_clear_expired_tokens.py | 71 +++++++++++++++++++ .../management/commands/tests/__init__.py | 0 .../tests/test_clear_expired_tokens.py | 66 +++++++++++++++++ 5 files changed, 137 insertions(+) create mode 100644 openedx/core/djangoapps/oauth_dispatch/management/__init__.py create mode 100644 openedx/core/djangoapps/oauth_dispatch/management/commands/__init__.py create mode 100644 openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py create mode 100644 openedx/core/djangoapps/oauth_dispatch/management/commands/tests/__init__.py create mode 100644 openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py diff --git a/openedx/core/djangoapps/oauth_dispatch/management/__init__.py b/openedx/core/djangoapps/oauth_dispatch/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/__init__.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py new file mode 100644 index 0000000000..0f22cb2530 --- /dev/null +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py @@ -0,0 +1,71 @@ +from __future__ import unicode_literals + +import logging +from datetime import timedelta +from time import sleep + +from django.core.exceptions import ImproperlyConfigured +from django.core.management.base import BaseCommand +from django.db import transaction +from django.utils import timezone +from oauth2_provider.models import AccessToken, Grant, RefreshToken +from oauth2_provider.settings import oauth2_settings + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Clear expired access tokens and refresh tokens for Django OAuth Toolkit" + + def add_arguments(self, parser): + parser.add_argument('--batch_size', + action='store', + dest='batch_size', + type=int, + default=1000, + help='Maximum number of database rows to delete per query. ' + 'This helps avoid locking the database when deleting large amounts of data.') + parser.add_argument('--sleep_time', + action='store', + dest='sleep_time', + type=int, + default=10, + help='Sleep time between deletion of batches') + + def clear_table_data(self, query_set, batch_size, model, sleep_time): + message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__) + logger.info(message) + while query_set.exists(): + qs = query_set[:batch_size] + batch_ids = qs.values_list('id', flat=True) + with transaction.atomic(): + model.objects.filter(pk__in=list(batch_ids)).delete() + + if query_set.exists(): + sleep(sleep_time) + + def get_expiration_time(self, now): + refresh_token_expire_seconds = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS + if not isinstance(refresh_token_expire_seconds, timedelta): + try: + refresh_token_expire_seconds = timedelta(seconds=refresh_token_expire_seconds) + except TypeError: + e = "REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds" + raise ImproperlyConfigured(e) + return now - refresh_token_expire_seconds + + def handle(self, *args, **options): + batch_size = options['batch_size'] + sleep_time = options['sleep_time'] + + now = timezone.now() + refresh_expire_at = self.get_expiration_time(now) + + query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at) + self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time) + + query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now) + self.clear_table_data(query_set, batch_size, AccessToken, sleep_time) + + query_set = Grant.objects.filter(expires__lt=now) + self.clear_table_data(query_set, batch_size, Grant, sleep_time) diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/__init__.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py new file mode 100644 index 0000000000..a77c42607b --- /dev/null +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py @@ -0,0 +1,66 @@ +from datetime import timedelta +import unittest + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured +from django.core.management import call_command +from django.db.models import QuerySet +from django.test import TestCase +from django.utils import timezone +from mock import patch +from oauth2_provider.models import AccessToken +from testfixtures import LogCapture + +from openedx.core.djangoapps.oauth_dispatch.tests import factories +from student.tests.factories import UserFactory + +LOGGER_NAME = 'openedx.core.djangoapps.oauth_dispatch.management.commands.edx_clear_expired_tokens' + + +def counter(fn): + """ + Adds a call counter to the given function. + Source: http://code.activestate.com/recipes/577534-counting-decorator/ + """ + def _counted(*largs, **kargs): + _counted.invocations += 1 + fn(*largs, **kargs) + + _counted.invocations = 0 + return _counted + + +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class EdxClearExpiredTokensTests(TestCase): + + @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz') + def test_invalid_expiration_time(self): + with LogCapture(LOGGER_NAME) as log: + with self.assertRaises(ImproperlyConfigured): + call_command('edx_clear_expired_tokens') + log.check( + ( + LOGGER_NAME, + 'EXCEPTION', + 'REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds' + ) + ) + + @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 3600) + def test_clear_expired_tokens(self): + initial_count = 5 + now = timezone.now() + expires = now - timedelta(days=1) + users = UserFactory.create_batch(initial_count) + for user in users: + application = factories.ApplicationFactory(user=user) + factories.AccessTokenFactory(user=user, application=application, expires=expires) + self.assertEqual( + AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), + initial_count + ) + QuerySet.delete = counter(QuerySet.delete) + + call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0) + self.assertEqual(QuerySet.delete.invocations, initial_count) + self.assertEqual(AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), 0)