diff --git a/cms/djangoapps/contentstore/tasks.py b/cms/djangoapps/contentstore/tasks.py index bbc260d968..55341cb22d 100644 --- a/cms/djangoapps/contentstore/tasks.py +++ b/cms/djangoapps/contentstore/tasks.py @@ -13,6 +13,7 @@ from tempfile import NamedTemporaryFile, mkdtemp from celery.task import task from celery.utils.log import get_task_logger +from organizations.models import OrganizationCourse from path import Path as path from pytz import UTC from six import iteritems, text_type @@ -37,6 +38,7 @@ from course_action_state.models import CourseRerunState from models.settings.course_metadata import CourseMetadata from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.locator import LibraryLocator +from openedx.core.djangoapps.embargo.models import RestrictedCourse, CountryAccessRule from openedx.core.lib.extract_tar import safetar_extractall from student.auth import has_course_author_access from xmodule.contentstore.django import contentstore @@ -54,6 +56,28 @@ FILE_READ_CHUNK = 1024 # bytes FULL_COURSE_REINDEX_THRESHOLD = 1 +def clone_instance(instance, field_values): + """ Clones a Django model instance. + + The specified fields are replaced with new values. + + Arguments: + instance (Model): Instance of a Django model. + field_values (dict): Map of field names to new values. + + Returns: + Model: New instance. + """ + instance.pk = None + + for field, value in iteritems(field_values): + setattr(instance, field, value) + + instance.save() + + return instance + + @task() def rerun_course(source_course_key_string, destination_course_key_string, user_id, fields=None): """ @@ -83,6 +107,21 @@ def rerun_course(source_course_key_string, destination_course_key_string, user_i # call edxval to attach videos to the rerun copy_course_videos(source_course_key, destination_course_key) + # Copy OrganizationCourse + organization_course = OrganizationCourse.objects.filter(course_id=source_course_key_string).first() + + if organization_course: + clone_instance(organization_course, {'course_id': destination_course_key_string}) + + # Copy RestrictedCourse + restricted_course = RestrictedCourse.objects.filter(course_key=source_course_key).first() + + if restricted_course: + country_access_rules = CountryAccessRule.objects.filter(restricted_course=restricted_course) + new_restricted_course = clone_instance(restricted_course, {'course_key': destination_course_key}) + for country_access_rule in country_access_rules: + clone_instance(country_access_rule, {'restricted_course': new_restricted_course}) + return "succeeded" except DuplicateCourseError: diff --git a/cms/djangoapps/contentstore/tests/test_tasks.py b/cms/djangoapps/contentstore/tests/test_tasks.py index 86eab6efb5..f2c7cd7056 100644 --- a/cms/djangoapps/contentstore/tests/test_tasks.py +++ b/cms/djangoapps/contentstore/tests/test_tasks.py @@ -5,18 +5,23 @@ from __future__ import absolute_import, division, print_function import copy import json -import mock from uuid import uuid4 +import mock from django.conf import settings from django.contrib.auth.models import User from django.test.utils import override_settings - +from opaque_keys.edx.locator import CourseLocator +from organizations.models import OrganizationCourse +from organizations.tests.factories import OrganizationFactory from user_tasks.models import UserTaskArtifact, UserTaskStatus +from xmodule.modulestore.django import modulestore -from contentstore.tasks import export_olx +from contentstore.tasks import export_olx, rerun_course from contentstore.tests.test_libraries import LibraryTestCase from contentstore.tests.utils import CourseTestCase +from course_action_state.models import CourseRerunState +from openedx.core.djangoapps.embargo.models import RestrictedCourse, CountryAccessRule, Country TEST_DATA_CONTENTSTORE = copy.deepcopy(settings.CONTENTSTORE) TEST_DATA_CONTENTSTORE['DOC_STORE_CONFIG']['db'] = 'test_xcontent_%s' % uuid4().hex @@ -99,10 +104,60 @@ class ExportLibraryTestCase(LibraryTestCase): Verify that a routine library export task succeeds """ key = str(self.lib_key) - result = export_olx.delay(self.user.id, key, u'en') # pylint: disable=no-member + result = export_olx.delay(self.user.id, key, u'en') # pylint: disable=no-member status = UserTaskStatus.objects.get(task_id=result.id) self.assertEqual(status.state, UserTaskStatus.SUCCEEDED) artifacts = UserTaskArtifact.objects.filter(status=status) self.assertEqual(len(artifacts), 1) output = artifacts[0] self.assertEqual(output.name, 'Output') + + +@override_settings(CONTENTSTORE=TEST_DATA_CONTENTSTORE) +class RerunCourseTaskTestCase(CourseTestCase): + def _rerun_course(self, old_course_key, new_course_key): + CourseRerunState.objects.initiated(old_course_key, new_course_key, self.user, 'Test Re-run') + rerun_course(str(old_course_key), str(new_course_key), self.user.id) + + def test_success(self): + """ The task should clone the OrganizationCourse and RestrictedCourse data. """ + old_course_key = self.course.id + new_course_key = CourseLocator(org=old_course_key.org, course=old_course_key.course, run='rerun') + + old_course_id = str(old_course_key) + new_course_id = str(new_course_key) + + organization = OrganizationFactory() + OrganizationCourse.objects.create(course_id=old_course_id, organization=organization) + + restricted_course = RestrictedCourse.objects.create(course_key=self.course.id) + restricted_country = Country.objects.create(country='US') + + CountryAccessRule.objects.create( + rule_type=CountryAccessRule.BLACKLIST_RULE, + restricted_course=restricted_course, + country=restricted_country + ) + + # Run the task! + self._rerun_course(old_course_key, new_course_key) + + # Verify the new course run exists + course = modulestore().get_course(new_course_key) + self.assertIsNotNone(course) + + # Verify the OrganizationCourse is cloned + self.assertEqual(OrganizationCourse.objects.count(), 2) + # This will raise an error if the OrganizationCourse object was not cloned + OrganizationCourse.objects.get(course_id=new_course_id, organization=organization) + + # Verify the RestrictedCourse and related objects are cloned + self.assertEqual(RestrictedCourse.objects.count(), 2) + restricted_course = RestrictedCourse.objects.get(course_key=new_course_key) + + self.assertEqual(CountryAccessRule.objects.count(), 2) + CountryAccessRule.objects.get( + rule_type=CountryAccessRule.BLACKLIST_RULE, + restricted_course=restricted_course, + country=restricted_country + )