Pass serializable data to celery
The expire_old_entitlements management command was passing a Paginator object to celery, which it can't serialize when it is running async. Instead, we'll do the pagination manually inside the task code and simply pass it integers.
This commit is contained in:
committed by
Michael Terry
parent
dd80638736
commit
64b7fb88b8
@@ -5,7 +5,6 @@ Management command for expiring old entitlements.
|
||||
import logging
|
||||
|
||||
from django.core.management import BaseCommand
|
||||
from django.core.paginator import Paginator
|
||||
|
||||
from entitlements.models import CourseEntitlement
|
||||
from entitlements.tasks.v1.tasks import expire_old_entitlements
|
||||
@@ -46,24 +45,22 @@ class Command(BaseCommand):
|
||||
def handle(self, *args, **options):
|
||||
logger.info('Looking for entitlements which may be expirable.')
|
||||
|
||||
# This query could be optimized to return a more narrow set, but at a
|
||||
# complexity cost. See bug LEARNER-3451 about improving it.
|
||||
entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True).order_by('id')
|
||||
|
||||
total = CourseEntitlement.objects.count()
|
||||
batch_size = max(1, options.get('batch_size'))
|
||||
entitlements = Paginator(entitlements, batch_size, allow_empty_first_page=False)
|
||||
num_batches = ((total - 1) / batch_size + 1) if total > 0 else 0
|
||||
|
||||
if options.get('commit'):
|
||||
logger.info('Enqueuing entitlement expiration tasks for %d candidates.', entitlements.count)
|
||||
logger.info('Enqueuing %d entitlement expiration tasks.', num_batches)
|
||||
else:
|
||||
logger.info(
|
||||
'Found %d candidates. To enqueue entitlement expiration tasks, pass the -c or --commit flags.',
|
||||
entitlements.count
|
||||
'Found %d batches. To enqueue entitlement expiration tasks, pass the -c or --commit flags.',
|
||||
num_batches
|
||||
)
|
||||
return
|
||||
|
||||
for page_num in entitlements.page_range:
|
||||
page = entitlements.page(page_num)
|
||||
expire_old_entitlements.delay(page, logid=str(page_num))
|
||||
for batch_num in range(num_batches):
|
||||
start = batch_num * batch_size + 1 # ids are 1-based, so add 1
|
||||
end = min(start + batch_size, total + 1)
|
||||
expire_old_entitlements.delay(start, end, logid=str(batch_num))
|
||||
|
||||
logger.info('Done. Successfully enqueued tasks.')
|
||||
logger.info('Done. Successfully enqueued %d tasks.', num_batches)
|
||||
|
||||
@@ -1,24 +1,14 @@
|
||||
"""Test Entitlements models"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import mock
|
||||
import pytz
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
|
||||
from openedx.core.djangolib.testing.utils import skip_unless_lms
|
||||
from entitlements.models import CourseEntitlementPolicy
|
||||
from entitlements.tests.factories import CourseEntitlementFactory
|
||||
|
||||
|
||||
def make_entitlement(expired=False):
|
||||
age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
|
||||
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age)
|
||||
expired_at = past_datetime if expired else None
|
||||
return CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at)
|
||||
|
||||
|
||||
@skip_unless_lms
|
||||
@mock.patch('entitlements.tasks.v1.tasks.expire_old_entitlements.delay')
|
||||
class TestExpireOldEntitlementsCommand(TestCase):
|
||||
@@ -30,7 +20,7 @@ class TestExpireOldEntitlementsCommand(TestCase):
|
||||
"""
|
||||
Verify that relevant tasks are only enqueued when the commit option is passed.
|
||||
"""
|
||||
make_entitlement()
|
||||
CourseEntitlementFactory.create()
|
||||
|
||||
call_command('expire_old_entitlements')
|
||||
self.assertEqual(mock_task.call_count, 0)
|
||||
@@ -40,46 +30,27 @@ class TestExpireOldEntitlementsCommand(TestCase):
|
||||
|
||||
def test_no_tasks_if_no_work(self, mock_task):
|
||||
"""
|
||||
Verify that we never try to spin off a task if there are no matching database rows.
|
||||
Verify that we never try to spin off a task if there are no database rows.
|
||||
"""
|
||||
call_command('expire_old_entitlements', commit=True)
|
||||
self.assertEqual(mock_task.call_count, 0)
|
||||
|
||||
# Now confirm that the above test wasn't a fluke and we will create a task if there is work
|
||||
make_entitlement()
|
||||
CourseEntitlementFactory.create()
|
||||
call_command('expire_old_entitlements', commit=True)
|
||||
self.assertEqual(mock_task.call_count, 1)
|
||||
|
||||
def test_only_unexpired(self, mock_task):
|
||||
"""
|
||||
Verify that only unexpired entitlements are included
|
||||
"""
|
||||
# Create an old expired and an old unexpired entitlement
|
||||
entitlement1 = make_entitlement(expired=True)
|
||||
entitlement2 = make_entitlement()
|
||||
|
||||
# Sanity check
|
||||
self.assertIsNotNone(entitlement1.expired_at)
|
||||
self.assertIsNone(entitlement2.expired_at)
|
||||
|
||||
# Run expiration
|
||||
call_command('expire_old_entitlements', commit=True)
|
||||
|
||||
# Make sure only the unexpired one gets used
|
||||
self.assertEqual(mock_task.call_count, 1)
|
||||
self.assertEqual(list(mock_task.call_args[0][0].object_list), [entitlement2])
|
||||
|
||||
def test_pagination(self, mock_task):
|
||||
"""
|
||||
Verify that we chunk up our requests to celery.
|
||||
"""
|
||||
for _ in range(5):
|
||||
make_entitlement()
|
||||
CourseEntitlementFactory.create()
|
||||
|
||||
call_command('expire_old_entitlements', commit=True, batch_size=2)
|
||||
|
||||
args_list = mock_task.call_args_list
|
||||
self.assertEqual(len(args_list), 3)
|
||||
self.assertEqual(len(args_list[0][0][0].object_list), 2)
|
||||
self.assertEqual(len(args_list[1][0][0].object_list), 2)
|
||||
self.assertEqual(len(args_list[2][0][0].object_list), 1)
|
||||
self.assertEqual(args_list[0][0], (1, 3))
|
||||
self.assertEqual(args_list[1][0], (3, 5))
|
||||
self.assertEqual(args_list[2][0], (5, 6))
|
||||
|
||||
@@ -6,6 +6,8 @@ from celery import task
|
||||
from celery.utils.log import get_task_logger
|
||||
from django.conf import settings
|
||||
|
||||
from entitlements.models import CourseEntitlement
|
||||
|
||||
|
||||
LOGGER = get_task_logger(__name__)
|
||||
# Under cms the following setting is not defined, leading to errors during tests.
|
||||
@@ -18,7 +20,7 @@ MAX_RETRIES = 11
|
||||
|
||||
|
||||
@task(bind=True, ignore_result=True, routing_key=ROUTING_KEY)
|
||||
def expire_old_entitlements(self, entitlements, logid='...'):
|
||||
def expire_old_entitlements(self, start, end, logid='...'):
|
||||
"""
|
||||
This task is designed to be called to process a bundle of entitlements
|
||||
that might be expired and confirm if we can do so. This is useful when
|
||||
@@ -28,14 +30,19 @@ def expire_old_entitlements(self, entitlements, logid='...'):
|
||||
run this task every now and then to clear the backlog.)
|
||||
|
||||
Args:
|
||||
entitlements (list): An iterable set of CourseEntitlements to check
|
||||
start (int): The beginning id in the database to examine
|
||||
end (int): The id in the database to stop examining at (i.e. range is exclusive)
|
||||
logid (str): A string to identify this task in the logs
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
LOGGER.info('Running task expire_old_entitlements [%s]', logid)
|
||||
LOGGER.info('Running task expire_old_entitlements %d:%d [%s]', start, end, logid)
|
||||
|
||||
# This query could be optimized to return a more narrow set, but at a
|
||||
# complexity cost. See bug LEARNER-3451 about improving it.
|
||||
entitlements = CourseEntitlement.objects.filter(expired_at__isnull=True, id__gte=start, id__lt=end)
|
||||
|
||||
countdown = 2 ** self.request.retries
|
||||
|
||||
@@ -53,4 +60,4 @@ def expire_old_entitlements(self, entitlements, logid='...'):
|
||||
# The call above is idempotent, so retry at will
|
||||
raise self.retry(exc=exc, countdown=countdown, max_retries=MAX_RETRIES)
|
||||
|
||||
LOGGER.info('Successfully completed the task expire_old_entitlements [%s]', logid)
|
||||
LOGGER.info('Successfully completed the task expire_old_entitlements after examining %d entries [%s]', entitlements.count(), logid)
|
||||
|
||||
@@ -14,11 +14,12 @@ from entitlements.tests.factories import CourseEntitlementFactory
|
||||
from openedx.core.djangolib.testing.utils import skip_unless_lms
|
||||
|
||||
|
||||
def make_entitlement(**kwargs):
|
||||
m = mock.NonCallableMock()
|
||||
p = mock.PropertyMock(**kwargs)
|
||||
type(m).expired_at_datetime = p
|
||||
return m, p
|
||||
def make_entitlement(expired=False):
|
||||
age = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
|
||||
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=age)
|
||||
expired_at = past_datetime if expired else None
|
||||
entitlement = CourseEntitlementFactory.create(created=past_datetime, expired_at=expired_at)
|
||||
return entitlement
|
||||
|
||||
|
||||
def boom():
|
||||
@@ -26,49 +27,67 @@ def boom():
|
||||
|
||||
|
||||
@skip_unless_lms
|
||||
@mock.patch('entitlements.models.CourseEntitlement.expired_at_datetime', new_callable=mock.PropertyMock)
|
||||
class TestExpireOldEntitlementsTask(TestCase):
|
||||
"""
|
||||
Tests for the 'expire_old_entitlements' method.
|
||||
"""
|
||||
|
||||
def test_checks_expiration(self):
|
||||
def test_checks_expiration(self, mock_datetime):
|
||||
"""
|
||||
Test that we actually do check expiration on each entitlement (happy path)
|
||||
"""
|
||||
entitlement1, prop1 = make_entitlement(return_value=None)
|
||||
entitlement2, prop2 = make_entitlement(return_value='some date')
|
||||
tasks.expire_old_entitlements.delay([entitlement1, entitlement2]).get()
|
||||
make_entitlement()
|
||||
make_entitlement()
|
||||
|
||||
# Test that the expired_at_datetime property was accessed
|
||||
self.assertEqual(prop1.call_count, 1)
|
||||
self.assertEqual(prop2.call_count, 1)
|
||||
tasks.expire_old_entitlements.delay(1, 3).get()
|
||||
|
||||
def test_retry(self):
|
||||
self.assertEqual(mock_datetime.call_count, 2)
|
||||
|
||||
def test_only_unexpired(self, mock_datetime):
|
||||
"""
|
||||
Verify that only unexpired entitlements are included
|
||||
"""
|
||||
# Create an old expired and an old unexpired entitlement
|
||||
make_entitlement(expired=True)
|
||||
make_entitlement()
|
||||
|
||||
# Run expiration
|
||||
tasks.expire_old_entitlements.delay(1, 3).get()
|
||||
|
||||
# Make sure only the unexpired one gets used
|
||||
self.assertEqual(mock_datetime.call_count, 1)
|
||||
|
||||
def test_retry(self, mock_datetime):
|
||||
"""
|
||||
Test that we retry when an exception occurs while checking old
|
||||
entitlements.
|
||||
"""
|
||||
entitlement, prop = make_entitlement(side_effect=boom)
|
||||
task = tasks.expire_old_entitlements.delay([entitlement])
|
||||
mock_datetime.side_effect = boom
|
||||
|
||||
make_entitlement()
|
||||
task = tasks.expire_old_entitlements.delay(1, 2)
|
||||
|
||||
self.assertRaises(Exception, task.get)
|
||||
self.assertEqual(prop.call_count, tasks.MAX_RETRIES + 1)
|
||||
self.assertEqual(mock_datetime.call_count, tasks.MAX_RETRIES + 1)
|
||||
|
||||
|
||||
@skip_unless_lms
|
||||
class TestExpireOldEntitlementsTaskIntegration(TestCase):
|
||||
"""
|
||||
Tests for the 'expire_old_entitlements' method without mocking.
|
||||
"""
|
||||
def test_actually_expired(self):
|
||||
"""
|
||||
Integration test with CourseEntitlement to make sure we are calling the
|
||||
correct API.
|
||||
"""
|
||||
# Create an actual old entitlement
|
||||
past_days = CourseEntitlementPolicy.DEFAULT_EXPIRATION_PERIOD_DAYS
|
||||
past_datetime = datetime.now(tz=pytz.UTC) - timedelta(days=past_days)
|
||||
entitlement = CourseEntitlementFactory.create(created=past_datetime)
|
||||
entitlement = make_entitlement()
|
||||
|
||||
# Sanity check
|
||||
self.assertIsNone(entitlement.expired_at)
|
||||
|
||||
# Run enforcement
|
||||
tasks.expire_old_entitlements.delay([entitlement]).get()
|
||||
tasks.expire_old_entitlements.delay(1, 2).get()
|
||||
entitlement.refresh_from_db()
|
||||
|
||||
self.assertIsNotNone(entitlement.expired_at)
|
||||
|
||||
Reference in New Issue
Block a user