diff --git a/common/djangoapps/entitlements/tasks.py b/common/djangoapps/entitlements/tasks.py index 1b6a0b13e5..1b268e941b 100644 --- a/common/djangoapps/entitlements/tasks.py +++ b/common/djangoapps/entitlements/tasks.py @@ -65,7 +65,7 @@ def expire_old_entitlements(self, start, end, logid='...'): LOGGER.info('Successfully completed the task expire_old_entitlements after examining %d entries [%s]', entitlements.count(), logid) # lint-amnesty, pylint: disable=line-too-long -@shared_task(bind=True, ignore_result=True) +@shared_task(bind=True) @set_code_owner_attribute def expire_and_create_entitlements(self, entitlement_ids, support_username): """ @@ -100,7 +100,7 @@ def expire_and_create_entitlements(self, entitlement_ids, support_username): try: for entitlement_id in entitlement_ids: - entitlement = CourseEntitlement.object.get(_id=entitlement_id) + entitlement = CourseEntitlement.objects.get(id=entitlement_id) LOGGER.info('Started expiring entitlement with id %d', entitlement.id) entitlement.expire_entitlement() LOGGER.info('Expired entitlement with id %d as expiration period has reached', entitlement.id) diff --git a/common/djangoapps/entitlements/tests/test_tasks.py b/common/djangoapps/entitlements/tests/test_tasks.py index 02a1f9468d..fe45055151 100644 --- a/common/djangoapps/entitlements/tests/test_tasks.py +++ b/common/djangoapps/entitlements/tests/test_tasks.py @@ -11,8 +11,9 @@ import pytz from django.test import TestCase from common.djangoapps.entitlements import tasks -from common.djangoapps.entitlements.models import CourseEntitlementPolicy +from common.djangoapps.entitlements.models import CourseEntitlement, CourseEntitlementPolicy from common.djangoapps.entitlements.tests.factories import CourseEntitlementFactory +from common.djangoapps.student.tests.factories import AdminFactory from openedx.core.djangolib.testing.utils import skip_unless_lms @@ -103,3 +104,62 @@ class TestExpireOldEntitlementsTaskIntegration(TestCase): entitlement.refresh_from_db() assert entitlement.expired_at is not None + + +@skip_unless_lms +class TestExpireAndCreateEntitlementsTaskIntegration(TestCase): + """ + Tests for the 'expire_and_create_entitlements' method. + """ + SUPPORT_USERNAME = 'support_username' + + def setUp(self): + """ + Set up user for tests. + """ + # Mock support user + AdminFactory.create(username=self.SUPPORT_USERNAME) + + def test_actually_expired(self): + """ + Integration test with CourseEntitlement to make sure we are calling the + correct API. + """ + entitlement = make_entitlement() + + # Sanity check + assert entitlement.expired_at is None + + # Run enforcement + tasks.expire_and_create_entitlements.delay( + [entitlement.id], + self.SUPPORT_USERNAME, + ).get() + entitlement.refresh_from_db() + + assert entitlement.expired_at is not None + + def test_actually_created(self): + """ + Integration test with CourseEntitlement to make sure we are creating an + entitlement after expiring it. + """ + entitlement = make_entitlement() + + # Sanity check + assert not CourseEntitlement.objects.filter( + course_uuid=entitlement.course_uuid, + expired_at=None + ).exclude(id=entitlement.id).exists() + + # Run enforcement + tasks.expire_and_create_entitlements.delay( + [entitlement.id], + self.SUPPORT_USERNAME, + ).get() + entitlement.refresh_from_db() + + assert CourseEntitlement.objects.filter( + course_uuid=entitlement.course_uuid, + expired_at=None + ).exclude(id=entitlement.id).exists()