diff --git a/common/djangoapps/entitlements/api/v1/views.py b/common/djangoapps/entitlements/api/v1/views.py index 8914365c9e..d55bae93b0 100644 --- a/common/djangoapps/entitlements/api/v1/views.py +++ b/common/djangoapps/entitlements/api/v1/views.py @@ -57,13 +57,12 @@ def _process_revoke_and_unenroll_entitlement(course_entitlement, is_refund=False IntegrityError if there is an issue that should reverse the database changes """ if course_entitlement.expired_at is None: - course_entitlement.expired_at = timezone.now() + course_entitlement.expire_entitlement() log.info( 'Set expired_at to [%s] for course entitlement [%s]', course_entitlement.expired_at, course_entitlement.uuid ) - course_entitlement.save() if course_entitlement.enrollment_course_run is not None: course_id = course_entitlement.enrollment_course_run.course_id diff --git a/common/djangoapps/entitlements/models.py b/common/djangoapps/entitlements/models.py index a56d3a00f7..f9b532ea40 100644 --- a/common/djangoapps/entitlements/models.py +++ b/common/djangoapps/entitlements/models.py @@ -210,8 +210,7 @@ class CourseEntitlement(TimeStampedModel): if not self.expired_at: if (self.policy.get_days_until_expiration(self) < 0 or (self.enrollment_course_run and not self.is_entitlement_regainable())): - self.expired_at = now() - self.save() + self.expire_entitlement() def get_days_until_expiration(self): """ @@ -269,6 +268,13 @@ class CourseEntitlement(TimeStampedModel): self.enrollment_course_run = enrollment self.save() + def expire_entitlement(self): + """ + Expire the entitlement. + """ + self.expired_at = now() + self.save() + @classmethod def unexpired_entitlements_for_user(cls, user): return cls.objects.filter(user=user, expired_at=None).select_related('user') @@ -412,11 +418,13 @@ class CourseEntitlement(TimeStampedModel): """ course_uuid = get_course_uuid_for_course(course_enrollment.course_id) course_entitlement = cls.get_entitlement_if_active(course_enrollment.user, course_uuid) - if course_entitlement: + if course_entitlement and course_entitlement.enrollment_course_run == course_enrollment: course_entitlement.set_enrollment(None) if not skip_refund and course_entitlement.is_entitlement_refundable(): course_entitlement.refund() + course_entitlement.expire_entitlement() + def refund(self): """ Initiate refund process for the entitlement. diff --git a/common/djangoapps/entitlements/tests/test_models.py b/common/djangoapps/entitlements/tests/test_models.py index e6fefa9028..7b09a535ce 100644 --- a/common/djangoapps/entitlements/tests/test_models.py +++ b/common/djangoapps/entitlements/tests/test_models.py @@ -300,3 +300,31 @@ class TestModels(TestCase): expired_at_datetime = entitlement.expired_at_datetime assert expired_at_datetime assert entitlement.expired_at + + @patch("entitlements.models.get_course_uuid_for_course") + @patch("entitlements.models.CourseEntitlement.refund") + def test_unenroll_entitlement_with_audit_course_enrollment(self, mock_refund, mock_get_course_uuid): + """ + Test that entitlement is not refunded if un-enroll is called on audit course un-enroll. + """ + self.enrollment.mode = CourseMode.AUDIT + self.enrollment.user = self.user + self.enrollment.save() + entitlement = CourseEntitlementFactory.create(user=self.user) + mock_get_course_uuid.return_value = entitlement.course_uuid + CourseEnrollment.unenroll(self.user, self.course.id) + + assert not mock_refund.called + entitlement.refresh_from_db() + assert entitlement.expired_at is None + + self.enrollment.mode = CourseMode.VERIFIED + self.enrollment.is_active = True + self.enrollment.save() + entitlement.enrollment_course_run = self.enrollment + entitlement.save() + CourseEnrollment.unenroll(self.user, self.course.id) + + assert mock_refund.called + entitlement.refresh_from_db() + assert entitlement.expired_at < now()