diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py index 9b820ea1d4..518b208f4d 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py @@ -42,16 +42,23 @@ class Command(BaseCommand): # lint-amnesty, pylint: disable=missing-class-docst help='Comma-separated list of application IDs for which tokens will NOT be removed') def clear_table_data(self, query_set, batch_size, model, sleep_time): # lint-amnesty, pylint: disable=missing-function-docstring - message = f'Cleaning {query_set.count()} rows from {model.__name__} table' - logger.info(message) - while query_set.exists(): + total_deletions = 0 + deletion_count = 1 # we expect to delete at least one row to start out with + while deletion_count: + deletion_count = 0 qs = query_set[:batch_size] + if isinstance(qs, list) and not qs: + # if we are seeing an empty list here then that there's nothing more to delete + break batch_ids = qs.values_list('id', flat=True) + batch_id_list = list(batch_ids) with transaction.atomic(): - model.objects.filter(pk__in=list(batch_ids)).delete() - - if query_set.exists(): - sleep(sleep_time) + deletions = model.objects.filter(pk__in=batch_id_list).delete() + deletion_count = deletions[0] + total_deletions += deletion_count + sleep(sleep_time) + message = f'Cleaned {total_deletions} rows from {model.__name__} table' + logger.info(message) def get_expiration_time(self, now): # lint-amnesty, pylint: disable=missing-function-docstring refresh_token_expire_seconds = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py index b94470c641..d8436eb6e0 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py @@ -5,7 +5,9 @@ Tests the ``edx_clear_expired_tokens`` management command. from datetime import timedelta from unittest.mock import patch +import math import pytest +import ddt from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.management import call_command @@ -13,7 +15,7 @@ from django.db.models import QuerySet from django.test import TestCase from django.test.utils import override_settings from django.utils import timezone -from oauth2_provider.models import AccessToken, RefreshToken +from oauth2_provider.models import AccessToken, RefreshToken, Grant from testfixtures import LogCapture from openedx.core.djangoapps.oauth_dispatch.tests import factories @@ -36,6 +38,7 @@ def counter(fn): return _counted +@ddt.ddt @skip_unless_lms class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=missing-class-docstring @@ -66,17 +69,17 @@ class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=mis ( LOGGER_NAME, 'INFO', - f'Cleaning {0} rows from {RefreshToken.__name__} table' + f'Cleaned {0} rows from {RefreshToken.__name__} table' ), ( LOGGER_NAME, 'INFO', - f'Cleaning {0} rows from {AccessToken.__name__} table', + f'Cleaned {0} rows from {AccessToken.__name__} table', ), ( LOGGER_NAME, 'INFO', - 'Cleaning 0 rows from Grant table', + f'Cleaned 0 rows from {Grant.__name__} table', ) ) assert RefreshToken.objects.filter(application=application).exists() @@ -96,7 +99,36 @@ class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=mis QuerySet.delete = counter(QuerySet.delete) try: call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0) - assert not QuerySet.delete.invocations != initial_count # pylint: disable=no-member + # three being the number of tables we'll end up unnecessarily calling .delete on once + assert QuerySet.delete.invocations == initial_count + 3 # pylint: disable=no-member assert AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count() == 0 finally: QuerySet.delete = original_delete + + @override_settings() + @ddt.unpack + @ddt.data( + (5, 1), + (500, 1), + (7, 5), + (500, 50), + ) + def test_clear_expired_refreshtokens(self, initial_count, batch_size): + settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600 + now = timezone.now() + expires = now - timedelta(days=1) + refresh_expires = now - timedelta(seconds=3600) + users = UserFactory.create_batch(initial_count) + for user in users: + application = factories.ApplicationFactory(user=user) + access_token = factories.AccessTokenFactory(user=user, application=application, expires=expires) + factories.RefreshTokenFactory(access_token=access_token, application=application, user=user) + assert RefreshToken.objects.filter(access_token__expires__lt=refresh_expires).count() == initial_count + original_delete = QuerySet.delete + QuerySet.delete = counter(QuerySet.delete) + try: + call_command('edx_clear_expired_tokens', batch_size=batch_size, sleep_time=0) + assert QuerySet.delete.invocations == (math.ceil(initial_count / batch_size) * 2 + 3) + assert RefreshToken.objects.filter(access_token__expires__lt=refresh_expires).count() == 0 + finally: + QuerySet.delete = original_delete