diff --git a/common/djangoapps/entitlements/management/__init__.py b/common/djangoapps/entitlements/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/management/commands/__init__.py b/common/djangoapps/entitlements/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py b/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py new file mode 100644 index 0000000000..a5a6e35ac0 --- /dev/null +++ b/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py @@ -0,0 +1,69 @@ +""" +Management command for expiring old entitlements. +""" + +import logging + +from django.core.management import BaseCommand +from django.core.paginator import Paginator + +from entitlements.models import CourseEntitlement +from entitlements.tasks.v1.tasks import expire_old_entitlements + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class Command(BaseCommand): + """ + Management command for expiring old entitlements. + + Most entitlements get expired as the user interacts with the platform, + because the LMS checks as it goes. But if the learner has not logged in + for a while, we still want to reap these old entitlements. So this command + should be run every now and then (probably daily) to expire old + entitlements. + + The command's goal is to pass a narrow subset of entitlements to an + idempotent Celery task for further (parallelized) processing. + """ + help = 'Expire old entitlements.' + + def add_arguments(self, parser): + parser.add_argument( + '-c', '--commit', + action='store_true', + default=False, + help='Submit tasks for processing' + ) + + parser.add_argument( + '--batch-size', + type=int, + default=10000, # arbitrary, should be adjusted if it is found to be inadequate + help='How many entitlements to give each celery task' + ) + + def handle(self, *args, **options): + logger.info('Looking for entitlements which may be expirable.') + + # This query could be optimized to return a more narrow set, but at a + # complexity cost. See bug LEARNER-3451 about improving it. + entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True).order_by('id') + + batch_size = max(1, options.get('batch_size')) + entitlements = Paginator(entitlements, batch_size, allow_empty_first_page=False) + + if options.get('commit'): + logger.info('Enqueuing entitlement expiration tasks for %d candidates.', entitlements.count) + else: + logger.info( + 'Found %d candidates. To enqueue entitlement expiration tasks, pass the -c or --commit flags.', + entitlements.count + ) + return + + for page_num in entitlements.page_range: + page = entitlements.page(page_num) + expire_old_entitlements.delay(page, logid=str(page_num)) + + logger.info('Done. Successfully enqueued tasks.') diff --git a/common/djangoapps/entitlements/management/commands/tests/__init__.py b/common/djangoapps/entitlements/management/commands/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py b/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py new file mode 100644 index 0000000000..46965d9c20 --- /dev/null +++ b/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py @@ -0,0 +1,85 @@ +"""Test Entitlements models""" + +from datetime import datetime, timedelta +import mock +import pytz + +from django.core.management import call_command +from django.test import TestCase + +from openedx.core.djangolib.testing.utils import skip_unless_lms +from entitlements.models import CourseEntitlementPolicy +from entitlements.tests.factories import CourseEntitlementFactory + + +def make_entitlement(expired=False): + age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS + past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age) + expired_at = past_datetime if expired else None + return CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at) + + +@skip_unless_lms +@mock.patch('entitlements.tasks.v1.tasks.expire_old_entitlements.delay') +class TestExpireOldEntitlementsCommand(TestCase): + """ + Test expire_old_entitlement management command. + """ + + def test_no_commit(self, mock_task): + """ + Verify that relevant tasks are only enqueued when the commit option is passed. + """ + make_entitlement() + + call_command('expire_old_entitlements') + self.assertEqual(mock_task.call_count, 0) + + call_command('expire_old_entitlements', commit=True) + self.assertEqual(mock_task.call_count, 1) + + def test_no_tasks_if_no_work(self, mock_task): + """ + Verify that we never try to spin off a task if there are no matching database rows. + """ + call_command('expire_old_entitlements', commit=True) + self.assertEqual(mock_task.call_count, 0) + + # Now confirm that the above test wasn't a fluke and we will create a task if there is work + make_entitlement() + call_command('expire_old_entitlements', commit=True) + self.assertEqual(mock_task.call_count, 1) + + def test_only_unexpired(self, mock_task): + """ + Verify that only unexpired entitlements are included + """ + # Create an old expired and an old unexpired entitlement + entitlement1 = make_entitlement(expired=True) + entitlement2 = make_entitlement() + + # Sanity check + self.assertIsNotNone(entitlement1.expired_at) + self.assertIsNone(entitlement2.expired_at) + + # Run expiration + call_command('expire_old_entitlements', commit=True) + + # Make sure only the unexpired one gets used + self.assertEqual(mock_task.call_count, 1) + self.assertEqual(list(mock_task.call_args[0][0].object_list), [entitlement2]) + + def test_pagination(self, mock_task): + """ + Verify that we chunk up our requests to celery. + """ + for _ in range(5): + make_entitlement() + + call_command('expire_old_entitlements', commit=True, batch_size=2) + + args_list = mock_task.call_args_list + self.assertEqual(len(args_list), 3) + self.assertEqual(len(args_list[0][0][0].object_list), 2) + self.assertEqual(len(args_list[1][0][0].object_list), 2) + self.assertEqual(len(args_list[2][0][0].object_list), 1) diff --git a/common/djangoapps/entitlements/tasks/__init__.py b/common/djangoapps/entitlements/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/tasks/v1/__init__.py b/common/djangoapps/entitlements/tasks/v1/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/tasks/v1/tasks.py b/common/djangoapps/entitlements/tasks/v1/tasks.py new file mode 100644 index 0000000000..bfe849b235 --- /dev/null +++ b/common/djangoapps/entitlements/tasks/v1/tasks.py @@ -0,0 +1,56 @@ +""" +This file contains celery tasks for entitlements-related functionality. +""" + +from celery import task +from celery.utils.log import get_task_logger +from django.conf import settings + + +LOGGER = get_task_logger(__name__) +# Under cms the following setting is not defined, leading to errors during tests. +ROUTING_KEY = getattr(settings, 'ENTITLEMENTS_EXPIRATION_ROUTING_KEY', None) +# Maximum number of retries before giving up on awarding credentials. +# For reference, 11 retries with exponential backoff yields a maximum waiting +# time of 2047 seconds (about 30 minutes). Setting this to None could yield +# unwanted behavior: infinite retries. +MAX_RETRIES = 11 + + +@task(bind=True, ignore_result=True, routing_key=ROUTING_KEY) +def expire_old_entitlements(self, entitlements, logid='...'): + """ + This task is designed to be called to process a bundle of entitlements + that might be expired and confirm if we can do so. This is useful when + checking if an entitlement has just been abandoned by the learner and needs + to be expired. (In the normal course of a learner using the platform, the + entitlement will expire itself. But if a learner doesn't log in... So we + run this task every now and then to clear the backlog.) + + Args: + entitlements (list): An iterable set of CourseEntitlements to check + logid (str): A string to identify this task in the logs + + Returns: + None + + """ + LOGGER.info('Running task expire_old_entitlements [%s]', logid) + + countdown = 2 ** self.request.retries + + try: + for entitlement in entitlements: + # This property request will update the expiration if necessary as + # a side effect. We could manually call update_expired_at(), but + # let's use the same API the rest of the LMS does, to mimic normal + # usage and allow the update call to be an internal detail. + if entitlement.expired_at_datetime: + LOGGER.info('Expired entitlement with id %d [%s]', entitlement.id, logid) + + except Exception as exc: + LOGGER.exception('Failed to expire entitlements [%s]', logid) + # The call above is idempotent, so retry at will + raise self.retry(exc=exc, countdown=countdown, max_retries=MAX_RETRIES) + + LOGGER.info('Successfully completed the task expire_old_entitlements [%s]', logid) diff --git a/common/djangoapps/entitlements/tasks/v1/tests/__init__.py b/common/djangoapps/entitlements/tasks/v1/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py b/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py new file mode 100644 index 0000000000..8caababb8b --- /dev/null +++ b/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py @@ -0,0 +1,74 @@ +""" +Test entitlements tasks +""" + +from datetime import datetime, timedelta +import mock +import pytz + +from django.test import TestCase + +from entitlements.models import CourseEntitlementPolicy +from entitlements.tasks.v1 import tasks +from entitlements.tests.factories import CourseEntitlementFactory +from openedx.core.djangolib.testing.utils import skip_unless_lms + + +def make_entitlement(**kwargs): + m = mock.NonCallableMock() + p = mock.PropertyMock(**kwargs) + type(m).expired_at_datetime = p + return m, p + + +def boom(): + raise Exception('boom') + + +@skip_unless_lms +class TestExpireOldEntitlementsTask(TestCase): + """ + Tests for the 'expire_old_entitlements' method. + """ + + def test_checks_expiration(self): + """ + Test that we actually do check expiration on each entitlement (happy path) + """ + entitlement1, prop1 = make_entitlement(return_value=None) + entitlement2, prop2 = make_entitlement(return_value='some date') + tasks.expire_old_entitlements.delay([entitlement1, entitlement2]).get() + + # Test that the expired_at_datetime property was accessed + self.assertEqual(prop1.call_count, 1) + self.assertEqual(prop2.call_count, 1) + + def test_retry(self): + """ + Test that we retry when an exception occurs while checking old + entitlements. + """ + entitlement, prop = make_entitlement(side_effect=boom) + task = tasks.expire_old_entitlements.delay([entitlement]) + + self.assertRaises(Exception, task.get) + self.assertEqual(prop.call_count, tasks.MAX_RETRIES + 1) + + def test_actually_expired(self): + """ + Integration test with CourseEntitlement to make sure we are calling the + correct API. + """ + # Create an actual old entitlement + past_days = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS + past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=past_days) + entitlement = CourseEntitlementFactory.create(created=past_datetime) + + # Sanity check + self.assertIsNone(entitlement.expired_at) + + # Run enforcement + tasks.expire_old_entitlements.delay([entitlement]).get() + entitlement.refresh_from_db() + + self.assertIsNotNone(entitlement.expired_at) diff --git a/lms/envs/aws.py b/lms/envs/aws.py index 878cc7f894..979be20693 100644 --- a/lms/envs/aws.py +++ b/lms/envs/aws.py @@ -277,6 +277,9 @@ RECALCULATE_GRADES_ROUTING_KEY = ENV_TOKENS.get('RECALCULATE_GRADES_ROUTING_KEY' # Queue to use for updating grades due to grading policy change POLICY_CHANGE_GRADES_ROUTING_KEY = ENV_TOKENS.get('POLICY_CHANGE_GRADES_ROUTING_KEY', LOW_PRIORITY_QUEUE) +# Queue to use for expiring old entitlements +ENTITLEMENTS_EXPIRATION_ROUTING_KEY = ENV_TOKENS.get('ENTITLEMENTS_EXPIRATION_ROUTING_KEY', LOW_PRIORITY_QUEUE) + # Message expiry time in seconds CELERY_EVENT_QUEUE_TTL = ENV_TOKENS.get('CELERY_EVENT_QUEUE_TTL', None)