diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py index 2f9aaf8b52..7fea5e92c5 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/edx_clear_expired_tokens.py @@ -48,6 +48,7 @@ class Command(BaseCommand): # lint-amnesty, pylint: disable=missing-class-docst self._add_boolean_flag(parser, 'refresh-tokens', True) self._add_boolean_flag(parser, 'access-tokens', True) self._add_boolean_flag(parser, 'grants', True) + self._add_boolean_flag(parser, 'revoked-tokens', True) def clear_table_data(self, query_set, batch_size, model, sleep_time): # lint-amnesty, pylint: disable=missing-function-docstring total_deletions = 0 @@ -89,6 +90,12 @@ class Command(BaseCommand): # lint-amnesty, pylint: disable=missing-class-docst now = timezone.now() refresh_expire_at = self.get_expiration_time(now) + if options['revoked-tokens']: + # remove revoked, as opposed to expired, RefreshTokens + revoked = RefreshToken.objects.filter(revoked__lt=refresh_expire_at).exclude( + application_id__in=excluded_application_ids) + self.clear_table_data(revoked, batch_size, RefreshToken, sleep_time) + if options['refresh-tokens']: query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at).exclude( application_id__in=excluded_application_ids) diff --git a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py index d8436eb6e0..c9486c5cc8 100644 --- a/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py +++ b/openedx/core/djangoapps/oauth_dispatch/management/commands/tests/test_clear_expired_tokens.py @@ -38,6 +38,13 @@ def counter(fn): return _counted +def create_factory_refresh_token_for_user(user, expires, revoked=None): + application = factories.ApplicationFactory(user=user) + access_token = factories.AccessTokenFactory(user=user, application=application, expires=expires) + return factories.RefreshTokenFactory(access_token=access_token, application=application, user=user, + revoked=revoked) + + @ddt.ddt @skip_unless_lms class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=missing-class-docstring @@ -71,6 +78,11 @@ class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=mis 'INFO', f'Cleaned {0} rows from {RefreshToken.__name__} table' ), + ( + LOGGER_NAME, + 'INFO', + f'Cleaned {0} rows from {RefreshToken.__name__} table' + ), ( LOGGER_NAME, 'INFO', @@ -99,12 +111,40 @@ class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=mis QuerySet.delete = counter(QuerySet.delete) try: call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0) - # three being the number of tables we'll end up unnecessarily calling .delete on once - assert QuerySet.delete.invocations == initial_count + 3 # pylint: disable=no-member + # four being the number of tables we'll end up unnecessarily calling .delete on once + assert QuerySet.delete.invocations == initial_count + 4 # pylint: disable=no-member assert AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count() == 0 finally: QuerySet.delete = original_delete + @override_settings() + def test_clear_revoked_refresh_tokens(self): + settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600 + now = timezone.now() + # expiry date in the future because we only want to check revoked tokens, not expired ones + expires = now + timedelta(days=1) + refresh_expires = now - timedelta(seconds=3600) + user_keep = UserFactory() + user_revoke = UserFactory() + keep_token = create_factory_refresh_token_for_user(user_keep, expires=expires) + revoke_token = create_factory_refresh_token_for_user(user_revoke, expires=expires, + revoked=refresh_expires - timedelta(seconds=1)) + original_delete = QuerySet.delete + QuerySet.delete = counter(QuerySet.delete) + try: + call_command('edx_clear_expired_tokens', sleep_time=0, access_tokens=False, refresh_tokens=False, + grants=False) + # 1 overhead call, 1 real call + assert QuerySet.delete.invocations == 2 + assert RefreshToken.objects.filter(revoked__lt=refresh_expires).count() == 0 + # revoked token has been deleted + with self.assertRaises(RefreshToken.DoesNotExist): + RefreshToken.objects.get(token=revoke_token.token) + # normal token is still there + assert RefreshToken.objects.get(token=keep_token.token) == keep_token + finally: + QuerySet.delete = original_delete + @override_settings() @ddt.unpack @ddt.data( @@ -128,7 +168,7 @@ class EdxClearExpiredTokensTests(TestCase): # lint-amnesty, pylint: disable=mis QuerySet.delete = counter(QuerySet.delete) try: call_command('edx_clear_expired_tokens', batch_size=batch_size, sleep_time=0) - assert QuerySet.delete.invocations == (math.ceil(initial_count / batch_size) * 2 + 3) + assert QuerySet.delete.invocations == (math.ceil(initial_count / batch_size) * 2 + 4) assert RefreshToken.objects.filter(access_token__expires__lt=refresh_expires).count() == 0 finally: QuerySet.delete = original_delete