diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index e784b91abb..e874b190f6 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -1030,16 +1030,32 @@ class CourseEnrollment(models.Model): if user.id is None: user.save() - enrollment, created = cls.objects.get_or_create( - user=user, - course_id=course_key, - ) + try: + enrollment, created = cls.objects.get_or_create( + user=user, + course_id=course_key, + ) - # If we *did* just create a new enrollment, set some defaults - if created: - enrollment.mode = CourseMode.DEFAULT_MODE_SLUG - enrollment.is_active = False - enrollment.save() + # If we *did* just create a new enrollment, set some defaults + if created: + enrollment.mode = CourseMode.DEFAULT_MODE_SLUG + enrollment.is_active = False + enrollment.save() + + except IntegrityError: + log.info( + ( + "An integrity error occurred while getting-or-creating the enrollment" + "for course key %s and student %s. This can occur if two processes try to get-or-create " + "the enrollment at the same time and the database is set to REPEATABLE READ. We will try " + "committing the transaction and retrying." + ), + course_key, user + ) + enrollment = cls.objects.get( + user=user, + course_id=course_key, + ) return enrollment diff --git a/common/djangoapps/student/tests/test_enrollment.py b/common/djangoapps/student/tests/test_enrollment.py index acabe344c3..9090c40ecf 100644 --- a/common/djangoapps/student/tests/test_enrollment.py +++ b/common/djangoapps/student/tests/test_enrollment.py @@ -8,12 +8,13 @@ from nose.plugins.attrib import attr from django.conf import settings from django.core.urlresolvers import reverse +from django.db import IntegrityError from course_modes.models import CourseMode from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory from util.testing import UrlResetMixin from embargo.test_utils import restrict_course -from student.tests.factories import UserFactory, CourseModeFactory +from student.tests.factories import UserFactory, CourseModeFactory, CourseEnrollmentFactory from student.models import CourseEnrollment, CourseFullError from student.roles import ( CourseInstructorRole, @@ -281,3 +282,18 @@ class EnrollmentTest(UrlResetMixin, SharedModuleStoreTestCase): params['email_opt_in'] = email_opt_in return self.client.post(reverse('change_enrollment'), params) + + def test_get_or_create_integrity_error(self): + """Verify that get_or_create_enrollment handles IntegrityError.""" + + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) + + with patch.object(CourseEnrollment.objects, "get_or_create") as mock_get_or_create: + mock_get_or_create.side_effect = IntegrityError + enrollment = CourseEnrollment.get_or_create_enrollment( + self.user, + self.course.id + ) + + self.assertEqual(enrollment.user, self.user) + self.assertEqual(enrollment.course.id, self.course.id)