Pass serializable data to celery

The expire_old_entitlements management command was passing a
Paginator object to celery, which it can't serialize when it is
running async. Instead, we'll do the pagination manually inside the
task code and simply pass it integers.
This commit is contained in:
Michael Terry
2018-01-03 16:40:50 -05:00
committed by Michael Terry
parent dd80638736
commit 64b7fb88b8
4 changed files with 69 additions and 75 deletions

View File

@@ -5,7 +5,6 @@ Management command for expiring old entitlements.
import logging
from django.core.management import BaseCommand
from django.core.paginator import Paginator
from entitlements.models import CourseEntitlement
from entitlements.tasks.v1.tasks import expire_old_entitlements
@@ -46,24 +45,22 @@ class Command(BaseCommand):
def handle(self, *args, **options):
logger.info('Looking for entitlements which may be expirable.')
# This query could be optimized to return a more narrow set, but at a
# complexity cost. See bug LEARNER-3451 about improving it.
entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True).order_by('id')
total = CourseEntitlement.objects.count()
batch_size = max(1, options.get('batch_size'))
entitlements = Paginator(entitlements, batch_size, allow_empty_first_page=False)
num_batches = ((total - 1) / batch_size + 1) if total > 0 else 0
if options.get('commit'):
logger.info('Enqueuing entitlement expiration tasks for %d candidates.', entitlements.count)
logger.info('Enqueuing %d entitlement expiration tasks.', num_batches)
else:
logger.info(
'Found %d candidates. To enqueue entitlement expiration tasks, pass the -c or --commit flags.',
entitlements.count
'Found %d batches. To enqueue entitlement expiration tasks, pass the -c or --commit flags.',
num_batches
)
return
for page_num in entitlements.page_range:
page = entitlements.page(page_num)
expire_old_entitlements.delay(page, logid=str(page_num))
for batch_num in range(num_batches):
start = batch_num * batch_size + 1 # ids are 1-based, so add 1
end = min(start + batch_size, total + 1)
expire_old_entitlements.delay(start, end, logid=str(batch_num))
logger.info('Done. Successfully enqueued tasks.')
logger.info('Done. Successfully enqueued %d tasks.', num_batches)

View File

@@ -1,24 +1,14 @@
"""Test Entitlements models"""
from datetime import datetime, timedelta
import mock
import pytz
from django.core.management import call_command
from django.test import TestCase
from openedx.core.djangolib.testing.utils import skip_unless_lms
from entitlements.models import CourseEntitlementPolicy
from entitlements.tests.factories import CourseEntitlementFactory
def make_entitlement(expired=False):
age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age)
expired_at = past_datetime if expired else None
return CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at)
@skip_unless_lms
@mock.patch('entitlements.tasks.v1.tasks.expire_old_entitlements.delay')
class TestExpireOldEntitlementsCommand(TestCase):
@@ -30,7 +20,7 @@ class TestExpireOldEntitlementsCommand(TestCase):
"""
Verify that relevant tasks are only enqueued when the commit option is passed.
"""
make_entitlement()
CourseEntitlementFactory.create()
call_command('expire_old_entitlements')
self.assertEqual(mock_task.call_count, 0)
@@ -40,46 +30,27 @@ class TestExpireOldEntitlementsCommand(TestCase):
def test_no_tasks_if_no_work(self, mock_task):
"""
Verify that we never try to spin off a task if there are no matching database rows.
Verify that we never try to spin off a task if there are no database rows.
"""
call_command('expire_old_entitlements', commit=True)
self.assertEqual(mock_task.call_count, 0)
# Now confirm that the above test wasn't a fluke and we will create a task if there is work
make_entitlement()
CourseEntitlementFactory.create()
call_command('expire_old_entitlements', commit=True)
self.assertEqual(mock_task.call_count, 1)
def test_only_unexpired(self, mock_task):
"""
Verify that only unexpired entitlements are included
"""
# Create an old expired and an old unexpired entitlement
entitlement1 = make_entitlement(expired=True)
entitlement2 = make_entitlement()
# Sanity check
self.assertIsNotNone(entitlement1.expired_at)
self.assertIsNone(entitlement2.expired_at)
# Run expiration
call_command('expire_old_entitlements', commit=True)
# Make sure only the unexpired one gets used
self.assertEqual(mock_task.call_count, 1)
self.assertEqual(list(mock_task.call_args[0][0].object_list), [entitlement2])
def test_pagination(self, mock_task):
"""
Verify that we chunk up our requests to celery.
"""
for _ in range(5):
make_entitlement()
CourseEntitlementFactory.create()
call_command('expire_old_entitlements', commit=True, batch_size=2)
args_list = mock_task.call_args_list
self.assertEqual(len(args_list), 3)
self.assertEqual(len(args_list[0][0][0].object_list), 2)
self.assertEqual(len(args_list[1][0][0].object_list), 2)
self.assertEqual(len(args_list[2][0][0].object_list), 1)
self.assertEqual(args_list[0][0], (1, 3))
self.assertEqual(args_list[1][0], (3, 5))
self.assertEqual(args_list[2][0], (5, 6))

View File

@@ -6,6 +6,8 @@ from celery import task
from celery.utils.log import get_task_logger
from django.conf import settings
from entitlements.models import CourseEntitlement
LOGGER = get_task_logger(__name__)
# Under cms the following setting is not defined, leading to errors during tests.
@@ -18,7 +20,7 @@ MAX_RETRIES = 11
@task(bind=True, ignore_result=True, routing_key=ROUTING_KEY)
def expire_old_entitlements(self, entitlements, logid='...'):
def expire_old_entitlements(self, start, end, logid='...'):
"""
This task is designed to be called to process a bundle of entitlements
that might be expired and confirm if we can do so. This is useful when
@@ -28,14 +30,19 @@ def expire_old_entitlements(self, entitlements, logid='...'):
run this task every now and then to clear the backlog.)
Args:
entitlements (list): An iterable set of CourseEntitlements to check
start (int): The beginning id in the database to examine
end (int): The id in the database to stop examining at (i.e. range is exclusive)
logid (str): A string to identify this task in the logs
Returns:
None
"""
LOGGER.info('Running task expire_old_entitlements [%s]', logid)
LOGGER.info('Running task expire_old_entitlements %d:%d [%s]', start, end, logid)
# This query could be optimized to return a more narrow set, but at a
# complexity cost. See bug LEARNER-3451 about improving it.
entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True, id__gte=start, id__lt=end)
countdown = 2 ** self.request.retries
@@ -53,4 +60,4 @@ def expire_old_entitlements(self, entitlements, logid='...'):
# The call above is idempotent, so retry at will
raise self.retry(exc=exc, countdown=countdown, max_retries=MAX_RETRIES)
LOGGER.info('Successfully completed the task expire_old_entitlements [%s]', logid)
LOGGER.info('Successfully completed the task expire_old_entitlements after examining %d entries [%s]', entitlements.count(), logid)

View File

@@ -14,11 +14,12 @@ from entitlements.tests.factories import CourseEntitlementFactory
from openedx.core.djangolib.testing.utils import skip_unless_lms
def make_entitlement(**kwargs):
m = mock.NonCallableMock()
p = mock.PropertyMock(**kwargs)
type(m).expired_at_datetime = p
return m, p
def make_entitlement(expired=False):
age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age)
expired_at = past_datetime if expired else None
entitlement = CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at)
return entitlement
def boom():
@@ -26,49 +27,67 @@ def boom():
@skip_unless_lms
@mock.patch('entitlements.models.CourseEntitlement.expired_at_datetime', new_callable=mock.PropertyMock)
class TestExpireOldEntitlementsTask(TestCase):
"""
Tests for the 'expire_old_entitlements' method.
"""
def test_checks_expiration(self):
def test_checks_expiration(self, mock_datetime):
"""
Test that we actually do check expiration on each entitlement (happy path)
"""
entitlement1, prop1 = make_entitlement(return_value=None)
entitlement2, prop2 = make_entitlement(return_value='some date')
tasks.expire_old_entitlements.delay([entitlement1, entitlement2]).get()
make_entitlement()
make_entitlement()
# Test that the expired_at_datetime property was accessed
self.assertEqual(prop1.call_count, 1)
self.assertEqual(prop2.call_count, 1)
tasks.expire_old_entitlements.delay(1, 3).get()
def test_retry(self):
self.assertEqual(mock_datetime.call_count, 2)
def test_only_unexpired(self, mock_datetime):
"""
Verify that only unexpired entitlements are included
"""
# Create an old expired and an old unexpired entitlement
make_entitlement(expired=True)
make_entitlement()
# Run expiration
tasks.expire_old_entitlements.delay(1, 3).get()
# Make sure only the unexpired one gets used
self.assertEqual(mock_datetime.call_count, 1)
def test_retry(self, mock_datetime):
"""
Test that we retry when an exception occurs while checking old
entitlements.
"""
entitlement, prop = make_entitlement(side_effect=boom)
task = tasks.expire_old_entitlements.delay([entitlement])
mock_datetime.side_effect = boom
make_entitlement()
task = tasks.expire_old_entitlements.delay(1, 2)
self.assertRaises(Exception, task.get)
self.assertEqual(prop.call_count, tasks.MAX_RETRIES + 1)
self.assertEqual(mock_datetime.call_count, tasks.MAX_RETRIES + 1)
@skip_unless_lms
class TestExpireOldEntitlementsTaskIntegration(TestCase):
"""
Tests for the 'expire_old_entitlements' method without mocking.
"""
def test_actually_expired(self):
"""
Integration test with CourseEntitlement to make sure we are calling the
correct API.
"""
# Create an actual old entitlement
past_days = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=past_days)
entitlement = CourseEntitlementFactory.create(created=past_datetime)
entitlement = make_entitlement()
# Sanity check
self.assertIsNone(entitlement.expired_at)
# Run enforcement
tasks.expire_old_entitlements.delay([entitlement]).get()
tasks.expire_old_entitlements.delay(1, 2).get()
entitlement.refresh_from_db()
self.assertIsNotNone(entitlement.expired_at)