diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index 256751eb6d..fc3460628f 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -36,13 +36,15 @@ from importlib import import_module from opaque_keys.edx.locations import SlashSeparatedCourseKey -from course_modes.models import CourseMode import lms.lib.comment_client as cc from util.query import use_read_replica_if_available from xmodule_django.models import CourseKeyField, NoneToEmptyManager from opaque_keys.edx.keys import CourseKey from functools import total_ordering +from certificates.models import GeneratedCertificate +from course_modes.models import CourseMode + unenroll_done = Signal(providing_args=["course_enrollment"]) log = logging.getLogger(__name__) AUDIT_LOG = logging.getLogger("audit") @@ -953,6 +955,11 @@ class CourseEnrollment(models.Model): # (side-effects are bad) if getattr(self, 'can_refund', None) is not None: return True + + # If the student has already been given a certificate they should not be refunded + if GeneratedCertificate.certificate_for_student(self.user, self.course_id) is not None: + return False + course_mode = CourseMode.mode_for_course(self.course_id, 'verified') if course_mode is None: return False diff --git a/common/djangoapps/student/tests/tests.py b/common/djangoapps/student/tests/tests.py index c5845ac8bc..c241585021 100644 --- a/common/djangoapps/student/tests/tests.py +++ b/common/djangoapps/student/tests/tests.py @@ -30,6 +30,8 @@ from student.views import (process_survey_link, _cert_info, change_enrollment, complete_course_mode_info) from student.tests.factories import UserFactory, CourseModeFactory +from certificates.models import CertificateStatuses +from certificates.tests.factories import GeneratedCertificateFactory import shoppingcart log = logging.getLogger(__name__) @@ -216,6 +218,7 @@ class DashboardTest(TestCase): self.assertFalse(course_mode_info['show_upsell']) self.assertIsNone(course_mode_info['days_for_upsell']) + @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') def test_refundable(self): verified_mode = CourseModeFactory.create( course_id=self.course.id, @@ -231,6 +234,26 @@ class DashboardTest(TestCase): verified_mode.save() self.assertFalse(enrollment.refundable()) + @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') + def test_refundable_when_certificate_exists(self): + verified_mode = CourseModeFactory.create( + course_id=self.course.id, + mode_slug='verified', + mode_display_name='Verified', + expiration_datetime=datetime.now(pytz.UTC) + timedelta(days=1) + ) + enrollment = CourseEnrollment.enroll(self.user, self.course.id, mode='verified') + + self.assertTrue(enrollment.refundable()) + + generated_certificate = GeneratedCertificateFactory.create( + user=self.user, + course_id=self.course.id, + status=CertificateStatuses.downloadable, + mode='verified' + ) + + self.assertFalse(enrollment.refundable()) class EnrollInCourseTest(TestCase): diff --git a/lms/djangoapps/certificates/models.py b/lms/djangoapps/certificates/models.py index 71ab9ffcf4..59af2f8ba9 100644 --- a/lms/djangoapps/certificates/models.py +++ b/lms/djangoapps/certificates/models.py @@ -80,6 +80,8 @@ class CertificateWhitelist(models.Model): whitelist = models.BooleanField(default=0) +MODES = Choices('verified', 'honor', 'audit') + class GeneratedCertificate(models.Model): user = models.ForeignKey(User) course_id = CourseKeyField(max_length=255, blank=True, default=None) @@ -90,7 +92,6 @@ class GeneratedCertificate(models.Model): key = models.CharField(max_length=32, blank=True, default='') distinction = models.BooleanField(default=False) status = models.CharField(max_length=32, default='unavailable') - MODES = Choices('verified', 'honor', 'audit') mode = models.CharField(max_length=32, choices=MODES, default=MODES.honor) name = models.CharField(blank=True, max_length=255) created_date = models.DateTimeField( @@ -102,6 +103,18 @@ class GeneratedCertificate(models.Model): class Meta: unique_together = (('user', 'course_id'),) + @classmethod + def certificate_for_student(cls, student, course_id): + """ + This returns the certificate for a student for a particular course + or None if no such certificate exits. + """ + try: + return cls.objects.get(user=student, course_id=course_id) + except cls.DoesNotExist: + pass + + return None def certificate_status_for_student(student, course_id): ''' diff --git a/lms/djangoapps/certificates/tests/__init__.py b/lms/djangoapps/certificates/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lms/djangoapps/certificates/tests/factories.py b/lms/djangoapps/certificates/tests/factories.py new file mode 100644 index 0000000000..e1d824b22a --- /dev/null +++ b/lms/djangoapps/certificates/tests/factories.py @@ -0,0 +1,16 @@ +from factory.django import DjangoModelFactory + +from opaque_keys.edx.locations import SlashSeparatedCourseKey + +from certificates.models import GeneratedCertificate, CertificateStatuses, MODES + +# Factories don't have __init__ methods, and are self documenting +# pylint: disable=W0232 +class GeneratedCertificateFactory(DjangoModelFactory): + + FACTORY_FOR = GeneratedCertificate + + course_id = None + status = CertificateStatuses.unavailable + mode = MODES.honor + name = ''