diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py index 0f22cb2530..bac0f7e34b 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py @@ -31,6 +31,12 @@ class Command(BaseCommand): type=int, default=10, help='Sleep time between deletion of batches') + parser.add_argument('--excluded-application-ids', + action='store', + dest='excluded-application-ids', + type=str, + default='', + help='Comma-separated list of application IDs for which tokens will NOT be removed') def clear_table_data(self, query_set, batch_size, model, sleep_time): message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__) @@ -57,11 +63,16 @@ class Command(BaseCommand): def handle(self, *args, **options): batch_size = options['batch_size'] sleep_time = options['sleep_time'] + if options['excluded-application-ids']: + excluded_application_ids = [int(x) for x in options['excluded-application-ids'].split(',')] + else: + excluded_application_ids = [] now = timezone.now() refresh_expire_at = self.get_expiration_time(now) - query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at) + query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at).exclude( + application_id__in=excluded_application_ids) self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time) query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now) diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py index a77c42607b..8ba8bc8776 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py @@ -1,14 +1,15 @@ -from datetime import timedelta import unittest +from datetime import timedelta from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.management import call_command from django.db.models import QuerySet from django.test import TestCase +from django.test.utils import override_settings from django.utils import timezone from mock import patch -from oauth2_provider.models import AccessToken +from oauth2_provider.models import AccessToken, RefreshToken from testfixtures import LogCapture from openedx.core.djangoapps.oauth_dispatch.tests import factories @@ -33,6 +34,7 @@ def counter(fn): @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') class EdxClearExpiredTokensTests(TestCase): + # patching REFRESH_TOKEN_EXPIRE_SECONDS because override_settings not working. @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz') def test_invalid_expiration_time(self): with LogCapture(LOGGER_NAME) as log: @@ -46,8 +48,37 @@ class EdxClearExpiredTokensTests(TestCase): ) ) - @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 3600) + @override_settings() + def test_excluded_application_ids(self): + settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600 + expires = timezone.now() - timedelta(days=1) + application = factories.ApplicationFactory() + access_token = factories.AccessTokenFactory(user=application.user, application=application, expires=expires) + factories.RefreshTokenFactory(user=application.user, application=application, access_token=access_token) + with LogCapture(LOGGER_NAME) as log: + call_command('edx_clear_expired_tokens', sleep_time=0, excluded_application_ids=str(application.id)) + log.check( + ( + LOGGER_NAME, + 'INFO', + 'Cleaning {} rows from {} table'.format(0, RefreshToken.__name__) + ), + ( + LOGGER_NAME, + 'INFO', + 'Cleaning {} rows from {} table'.format(0, AccessToken.__name__), + ), + ( + LOGGER_NAME, + 'INFO', + 'Cleaning 0 rows from Grant table', + ) + ) + self.assertTrue(RefreshToken.objects.filter(application=application).exists()) + + @override_settings() def test_clear_expired_tokens(self): + settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600 initial_count = 5 now = timezone.now() expires = now - timedelta(days=1)