diff --git a/openedx/core/djangoapps/content/course_overviews/models.py b/openedx/core/djangoapps/content/course_overviews/models.py index bf0747f5d2..ac4363a2f7 100644 --- a/openedx/core/djangoapps/content/course_overviews/models.py +++ b/openedx/core/djangoapps/content/course_overviews/models.py @@ -680,8 +680,9 @@ class CourseOverviewImageSet(TimeStampedModel): # an error or the course has no source course_image), our url fields # just keep their blank defaults. try: - image_set.save() - course_overview.image_set = image_set + with transaction.atomic(): + image_set.save() + course_overview.image_set = image_set except (IntegrityError, ValueError): # In the event of a race condition that tries to save two image sets # to the same CourseOverview, we'll just silently pass on the one diff --git a/openedx/core/djangoapps/content/course_overviews/tests.py b/openedx/core/djangoapps/content/course_overviews/tests.py index 02a426c37f..a87518ff2f 100644 --- a/openedx/core/djangoapps/content/course_overviews/tests.py +++ b/openedx/core/djangoapps/content/course_overviews/tests.py @@ -32,7 +32,7 @@ from xmodule.modulestore.django import modulestore from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory, check_mongo_calls, check_mongo_calls_range -from .models import CourseOverview, CourseOverviewImageConfig +from .models import CourseOverview, CourseOverviewImageSet, CourseOverviewImageConfig @ddt.ddt @@ -494,6 +494,27 @@ class CourseOverviewTestCase(ModuleStoreTestCase): {c.id for c in org_courses[1]}, ) + def test_get_all_courses_by_mobile_available(self): + non_mobile_course = CourseFactory.create(emit_signals=True) + mobile_course = CourseFactory.create(mobile_available=True, emit_signals=True) + + test_cases = ( + (None, {non_mobile_course.id, mobile_course.id}), + (dict(mobile_available=True), {mobile_course.id}), + (dict(mobile_available=False), {non_mobile_course.id}), + ) + + for filter_, expected_courses in test_cases: + self.assertEqual( + { + course_overview.id + for course_overview in + CourseOverview.get_all_courses(filter_=filter_) + }, + expected_courses, + "testing CourseOverview.get_all_courses with filter_={}".format(filter_), + ) + @ddt.ddt class CourseOverviewImageSetTestCase(ModuleStoreTestCase): @@ -797,6 +818,43 @@ class CourseOverviewImageSetTestCase(ModuleStoreTestCase): self.assertEqual(src_x, image_x) self.assertEqual(src_y, image_y) + def test_image_creation_race_condition(self): + """ + Test for race condition in CourseOverviewImageSet creation. + + CourseOverviewTestCase already tests for race conditions with + CourseOverview as a whole, but we still need to test the case where a + CourseOverview already exists and we have a race condition purely in the + part that adds a new CourseOverviewImageSet. + """ + # Set config to False so that we don't create the image yet + self.set_config(False) + course = CourseFactory.create() + + # First create our CourseOverview + overview = CourseOverview.get_from_id(course.id) + self.assertFalse(hasattr(overview, 'image_set')) + + # Now create an ImageSet by hand... + CourseOverviewImageSet.objects.create(course_overview=overview) + + # Now do it the normal way -- this will cause an IntegrityError to be + # thrown and suppressed in create_for_course() + self.set_config(True) + CourseOverviewImageSet.create_for_course(overview) + self.assertTrue(hasattr(overview, 'image_set')) + + # The following is actually very important for this test because + # set_config() does a model insert after create_for_course() has caught + # and supressed an IntegrityError above. If create_for_course() properly + # wraps that operation in a transaction.atomic() block, the following + # will execute fine. If create_for_course() doesn't use an atomic block, + # the following line will cause a TransactionManagementError because + # Django will detect that something has already been rolled back in this + # transaction. So we don't really care about setting the config -- it's + # just a convenient way to cause a database write operation to happen. + self.set_config(False) + def _assert_image_urls_all_default(self, modulestore_type, raw_course_image_name, expected_url=None): """ Helper for asserting that all image_urls are defaulting to a particular value. @@ -826,24 +884,3 @@ class CourseOverviewImageSetTestCase(ModuleStoreTestCase): } ) return course_overview - - def test_get_all_courses_by_mobile_available(self): - non_mobile_course = CourseFactory.create(emit_signals=True) - mobile_course = CourseFactory.create(mobile_available=True, emit_signals=True) - - test_cases = ( - (None, {non_mobile_course.id, mobile_course.id}), - (dict(mobile_available=True), {mobile_course.id}), - (dict(mobile_available=False), {non_mobile_course.id}), - ) - - for filter_, expected_courses in test_cases: - self.assertEqual( - { - course_overview.id - for course_overview in - CourseOverview.get_all_courses(filter_=filter_) - }, - expected_courses, - "testing CourseOverview.get_all_courses with filter_={}".format(filter_), - )