diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index ebc9f4028c..f97085c97f 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -713,7 +713,7 @@ class CourseEnrollment(models.Model): ).format(self.user, self.course_id, self.created, self.is_active) @classmethod - def create_enrollment(cls, user, course_id): + def get_or_create_enrollment(cls, user, course_id): """ Create an enrollment for a user in a class. By default *this enrollment is not active*. This is useful for when an enrollment needs to go @@ -761,11 +761,15 @@ class CourseEnrollment(models.Model): This saves immediately. """ activation_changed = False + # if is_active is None, then the call to update_enrollment didn't specify + # any value, so just leave is_active as it is if self.is_active != is_active and is_active is not None: self.is_active = is_active activation_changed = True mode_changed = False + # if mode is None, the call to update_enrollment didn't specify a new + # mode, so leave as-is if self.mode != mode and mode is not None: self.mode = mode mode_changed = True @@ -819,7 +823,7 @@ class CourseEnrollment(models.Model): It is expected that this method is called from a method which has already verified the user authentication and access. """ - enrollment = cls.create_enrollment(user, course_id) + enrollment = cls.get_or_create_enrollment(user, course_id) enrollment.update_enrollment(is_active=True) return enrollment diff --git a/common/djangoapps/student/tests/tests.py b/common/djangoapps/student/tests/tests.py index 634996fd55..7e19ee359e 100644 --- a/common/djangoapps/student/tests/tests.py +++ b/common/djangoapps/student/tests/tests.py @@ -443,7 +443,7 @@ class EnrollInCourseTest(TestCase): # Creating an enrollment doesn't actually enroll a student # (calling CourseEnrollment.enroll() would have) - enrollment = CourseEnrollment.create_enrollment(user, course_id) + enrollment = CourseEnrollment.get_or_create_enrollment(user, course_id) self.assertFalse(CourseEnrollment.is_enrolled(user, course_id)) self.assert_no_events_were_emitted() diff --git a/lms/djangoapps/shoppingcart/models.py b/lms/djangoapps/shoppingcart/models.py index ceac6c87e0..e072a79764 100644 --- a/lms/djangoapps/shoppingcart/models.py +++ b/lms/djangoapps/shoppingcart/models.py @@ -464,11 +464,8 @@ class CertificateItem(OrderItem): """ super(CertificateItem, cls).add_to_order(order, course_id, cost, currency=currency) - try: - course_enrollment = CourseEnrollment.objects.get(user=order.user, course_id=course_id) - except ObjectDoesNotExist: - course_enrollment = CourseEnrollment.create_enrollment(order.user, course_id) - course_enrollment.update_enrollment(mode=mode) + course_enrollment = CourseEnrollment.get_or_create_enrollment(order.user, course_id) + course_enrollment.update_enrollment(mode=mode) # do some validation on the enrollment mode valid_modes = CourseMode.modes_for_course_dict(course_id)