From 64b7fb88b889dc4b20000a3f1f730b212c1653bd Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Wed, 3 Jan 2018 16:40:50 -0500 Subject: [PATCH] Pass serializable data to celery The expire_old_entitlements management command was passing a Paginator object to celery, which it can't serialize when it is running async. Instead, we'll do the pagination manually inside the task code and simply pass it integers. --- .../commands/expire_old_entitlements.py | 23 +++---- .../tests/test_expire_old_entitlements.py | 43 +++---------- .../djangoapps/entitlements/tasks/v1/tasks.py | 15 +++-- .../entitlements/tasks/v1/tests/test_tasks.py | 63 ++++++++++++------- 4 files changed, 69 insertions(+), 75 deletions(-) diff --git a/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py b/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py index a5a6e35ac0..0902814f1b 100644 --- a/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py +++ b/common/djangoapps/entitlements/management/commands/expire_old_entitlements.py @@ -5,7 +5,6 @@ Management command for expiring old entitlements. import logging from django.core.management import BaseCommand -from django.core.paginator import Paginator from entitlements.models import CourseEntitlement from entitlements.tasks.v1.tasks import expire_old_entitlements @@ -46,24 +45,22 @@ class Command(BaseCommand): def handle(self, *args, **options): logger.info('Looking for entitlements which may be expirable.') - # This query could be optimized to return a more narrow set, but at a - # complexity cost. See bug LEARNER-3451 about improving it. - entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True).order_by('id') - + total = CourseEntitlement.objects.count() batch_size = max(1, options.get('batch_size')) - entitlements = Paginator(entitlements, batch_size, allow_empty_first_page=False) + num_batches = ((total - 1) / batch_size + 1) if total > 0 else 0 if options.get('commit'): - logger.info('Enqueuing entitlement expiration tasks for %d candidates.', entitlements.count) + logger.info('Enqueuing %d entitlement expiration tasks.', num_batches) else: logger.info( - 'Found %d candidates. To enqueue entitlement expiration tasks, pass the -c or --commit flags.', - entitlements.count + 'Found %d batches. To enqueue entitlement expiration tasks, pass the -c or --commit flags.', + num_batches ) return - for page_num in entitlements.page_range: - page = entitlements.page(page_num) - expire_old_entitlements.delay(page, logid=str(page_num)) + for batch_num in range(num_batches): + start = batch_num * batch_size + 1 # ids are 1-based, so add 1 + end = min(start + batch_size, total + 1) + expire_old_entitlements.delay(start, end, logid=str(batch_num)) - logger.info('Done. Successfully enqueued tasks.') + logger.info('Done. Successfully enqueued %d tasks.', num_batches) diff --git a/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py b/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py index 46965d9c20..2026ce04ed 100644 --- a/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py +++ b/common/djangoapps/entitlements/management/commands/tests/test_expire_old_entitlements.py @@ -1,24 +1,14 @@ """Test Entitlements models""" -from datetime import datetime, timedelta import mock -import pytz from django.core.management import call_command from django.test import TestCase from openedx.core.djangolib.testing.utils import skip_unless_lms -from entitlements.models import CourseEntitlementPolicy from entitlements.tests.factories import CourseEntitlementFactory -def make_entitlement(expired=False): - age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS - past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age) - expired_at = past_datetime if expired else None - return CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at) - - @skip_unless_lms @mock.patch('entitlements.tasks.v1.tasks.expire_old_entitlements.delay') class TestExpireOldEntitlementsCommand(TestCase): @@ -30,7 +20,7 @@ class TestExpireOldEntitlementsCommand(TestCase): """ Verify that relevant tasks are only enqueued when the commit option is passed. """ - make_entitlement() + CourseEntitlementFactory.create() call_command('expire_old_entitlements') self.assertEqual(mock_task.call_count, 0) @@ -40,46 +30,27 @@ class TestExpireOldEntitlementsCommand(TestCase): def test_no_tasks_if_no_work(self, mock_task): """ - Verify that we never try to spin off a task if there are no matching database rows. + Verify that we never try to spin off a task if there are no database rows. """ call_command('expire_old_entitlements', commit=True) self.assertEqual(mock_task.call_count, 0) # Now confirm that the above test wasn't a fluke and we will create a task if there is work - make_entitlement() + CourseEntitlementFactory.create() call_command('expire_old_entitlements', commit=True) self.assertEqual(mock_task.call_count, 1) - def test_only_unexpired(self, mock_task): - """ - Verify that only unexpired entitlements are included - """ - # Create an old expired and an old unexpired entitlement - entitlement1 = make_entitlement(expired=True) - entitlement2 = make_entitlement() - - # Sanity check - self.assertIsNotNone(entitlement1.expired_at) - self.assertIsNone(entitlement2.expired_at) - - # Run expiration - call_command('expire_old_entitlements', commit=True) - - # Make sure only the unexpired one gets used - self.assertEqual(mock_task.call_count, 1) - self.assertEqual(list(mock_task.call_args[0][0].object_list), [entitlement2]) - def test_pagination(self, mock_task): """ Verify that we chunk up our requests to celery. """ for _ in range(5): - make_entitlement() + CourseEntitlementFactory.create() call_command('expire_old_entitlements', commit=True, batch_size=2) args_list = mock_task.call_args_list self.assertEqual(len(args_list), 3) - self.assertEqual(len(args_list[0][0][0].object_list), 2) - self.assertEqual(len(args_list[1][0][0].object_list), 2) - self.assertEqual(len(args_list[2][0][0].object_list), 1) + self.assertEqual(args_list[0][0], (1, 3)) + self.assertEqual(args_list[1][0], (3, 5)) + self.assertEqual(args_list[2][0], (5, 6)) diff --git a/common/djangoapps/entitlements/tasks/v1/tasks.py b/common/djangoapps/entitlements/tasks/v1/tasks.py index bfe849b235..bbef3aa757 100644 --- a/common/djangoapps/entitlements/tasks/v1/tasks.py +++ b/common/djangoapps/entitlements/tasks/v1/tasks.py @@ -6,6 +6,8 @@ from celery import task from celery.utils.log import get_task_logger from django.conf import settings +from entitlements.models import CourseEntitlement + LOGGER = get_task_logger(__name__) # Under cms the following setting is not defined, leading to errors during tests. @@ -18,7 +20,7 @@ MAX_RETRIES = 11 @task(bind=True, ignore_result=True, routing_key=ROUTING_KEY) -def expire_old_entitlements(self, entitlements, logid='...'): +def expire_old_entitlements(self, start, end, logid='...'): """ This task is designed to be called to process a bundle of entitlements that might be expired and confirm if we can do so. This is useful when @@ -28,14 +30,19 @@ def expire_old_entitlements(self, entitlements, logid='...'): run this task every now and then to clear the backlog.) Args: - entitlements (list): An iterable set of CourseEntitlements to check + start (int): The beginning id in the database to examine + end (int): The id in the database to stop examining at (i.e. range is exclusive) logid (str): A string to identify this task in the logs Returns: None """ - LOGGER.info('Running task expire_old_entitlements [%s]', logid) + LOGGER.info('Running task expire_old_entitlements %d:%d [%s]', start, end, logid) + + # This query could be optimized to return a more narrow set, but at a + # complexity cost. See bug LEARNER-3451 about improving it. + entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True, id__gte=start, id__lt=end) countdown = 2 ** self.request.retries @@ -53,4 +60,4 @@ def expire_old_entitlements(self, entitlements, logid='...'): # The call above is idempotent, so retry at will raise self.retry(exc=exc, countdown=countdown, max_retries=MAX_RETRIES) - LOGGER.info('Successfully completed the task expire_old_entitlements [%s]', logid) + LOGGER.info('Successfully completed the task expire_old_entitlements after examining %d entries [%s]', entitlements.count(), logid) diff --git a/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py b/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py index 8caababb8b..0e6240bf5a 100644 --- a/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py +++ b/common/djangoapps/entitlements/tasks/v1/tests/test_tasks.py @@ -14,11 +14,12 @@ from entitlements.tests.factories import CourseEntitlementFactory from openedx.core.djangolib.testing.utils import skip_unless_lms -def make_entitlement(**kwargs): - m = mock.NonCallableMock() - p = mock.PropertyMock(**kwargs) - type(m).expired_at_datetime = p - return m, p +def make_entitlement(expired=False): + age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS + past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age) + expired_at = past_datetime if expired else None + entitlement = CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at) + return entitlement def boom(): @@ -26,49 +27,67 @@ def boom(): @skip_unless_lms +@mock.patch('entitlements.models.CourseEntitlement.expired_at_datetime', new_callable=mock.PropertyMock) class TestExpireOldEntitlementsTask(TestCase): """ Tests for the 'expire_old_entitlements' method. """ - - def test_checks_expiration(self): + def test_checks_expiration(self, mock_datetime): """ Test that we actually do check expiration on each entitlement (happy path) """ - entitlement1, prop1 = make_entitlement(return_value=None) - entitlement2, prop2 = make_entitlement(return_value='some date') - tasks.expire_old_entitlements.delay([entitlement1, entitlement2]).get() + make_entitlement() + make_entitlement() - # Test that the expired_at_datetime property was accessed - self.assertEqual(prop1.call_count, 1) - self.assertEqual(prop2.call_count, 1) + tasks.expire_old_entitlements.delay(1, 3).get() - def test_retry(self): + self.assertEqual(mock_datetime.call_count, 2) + + def test_only_unexpired(self, mock_datetime): + """ + Verify that only unexpired entitlements are included + """ + # Create an old expired and an old unexpired entitlement + make_entitlement(expired=True) + make_entitlement() + + # Run expiration + tasks.expire_old_entitlements.delay(1, 3).get() + + # Make sure only the unexpired one gets used + self.assertEqual(mock_datetime.call_count, 1) + + def test_retry(self, mock_datetime): """ Test that we retry when an exception occurs while checking old entitlements. """ - entitlement, prop = make_entitlement(side_effect=boom) - task = tasks.expire_old_entitlements.delay([entitlement]) + mock_datetime.side_effect = boom + + make_entitlement() + task = tasks.expire_old_entitlements.delay(1, 2) self.assertRaises(Exception, task.get) - self.assertEqual(prop.call_count, tasks.MAX_RETRIES + 1) + self.assertEqual(mock_datetime.call_count, tasks.MAX_RETRIES + 1) + +@skip_unless_lms +class TestExpireOldEntitlementsTaskIntegration(TestCase): + """ + Tests for the 'expire_old_entitlements' method without mocking. + """ def test_actually_expired(self): """ Integration test with CourseEntitlement to make sure we are calling the correct API. """ - # Create an actual old entitlement - past_days = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS - past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=past_days) - entitlement = CourseEntitlementFactory.create(created=past_datetime) + entitlement = make_entitlement() # Sanity check self.assertIsNone(entitlement.expired_at) # Run enforcement - tasks.expire_old_entitlements.delay([entitlement]).get() + tasks.expire_old_entitlements.delay(1, 2).get() entitlement.refresh_from_db() self.assertIsNotNone(entitlement.expired_at)