diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index 0fbe70c0b3..6d1cbb5afb 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -257,8 +257,11 @@ def add_user_to_default_group(user, group): ########################## REPLICATION SIGNALS ################################# @receiver(post_save, sender=User) def replicate_user_save(sender, **kwargs): - user_obj = kwargs['instance'] - return replicate_model(User.save, user_obj, user_obj.id) + user_obj = kwargs['instance'] + if not should_replicate(user_obj): + return + for course_db_name in db_names_to_replicate_to(user_obj.id): + replicate_user(user_obj, course_db_name) @receiver(post_save, sender=CourseEnrollment) def replicate_enrollment_save(sender, **kwargs): @@ -287,8 +290,8 @@ def replicate_enrollment_save(sender, **kwargs): @receiver(post_delete, sender=CourseEnrollment) def replicate_enrollment_delete(sender, **kwargs): - enrollment_obj = kwargs['instance'] - return replicate_model(CourseEnrollment.delete, enrollment_obj, enrollment_obj.user_id) + enrollment_obj = kwargs['instance'] + return replicate_model(CourseEnrollment.delete, enrollment_obj, enrollment_obj.user_id) @receiver(post_save, sender=UserProfile) def replicate_userprofile_save(sender, **kwargs): @@ -311,23 +314,20 @@ def replicate_user(portal_user, course_db_name): overridden. """ try: - # If the user exists in the Course DB, update the appropriate fields and - # save it back out to the Course DB. course_user = User.objects.using(course_db_name).get(id=portal_user.id) - for field in USER_FIELDS_TO_COPY: - setattr(course_user, field, getattr(portal_user, field)) - - mark_handled(course_user) log.debug("User {0} found in Course DB, replicating fields to {1}" .format(course_user, course_db_name)) - course_user.save(using=course_db_name) # Just being explicit. - except User.DoesNotExist: - # Otherwise, just make a straight copy to the Course DB. - mark_handled(portal_user) log.debug("User {0} not found in Course DB, creating copy in {1}" .format(portal_user, course_db_name)) - portal_user.save(using=course_db_name) + course_user = User() + + for field in USER_FIELDS_TO_COPY: + setattr(course_user, field, getattr(portal_user, field)) + + mark_handled(course_user) + course_user.save(using=course_db_name) + unmark(course_user) def replicate_model(model_method, instance, user_id): """ @@ -337,13 +337,14 @@ def replicate_model(model_method, instance, user_id): if not should_replicate(instance): return - mark_handled(instance) course_db_names = db_names_to_replicate_to(user_id) log.debug("Replicating {0} for user {1} to DBs: {2}" .format(model_method, user_id, course_db_names)) + mark_handled(instance) for db_name in course_db_names: model_method(instance, using=db_name) + unmark(instance) ######### Replication Helpers ######### @@ -371,7 +372,7 @@ def db_names_to_replicate_to(user_id): def marked_handled(instance): """Have we marked this instance as being handled to avoid infinite loops caused by saving models in post_save hooks for the same models?""" - return hasattr(instance, '_do_not_copy_to_course_db') + return hasattr(instance, '_do_not_copy_to_course_db') and instance._do_not_copy_to_course_db def mark_handled(instance): """You have to mark your instance with this function or else we'll go into @@ -384,6 +385,11 @@ def mark_handled(instance): """ instance._do_not_copy_to_course_db = True +def unmark(instance): + """If we don't unmark a model after we do replication, then consecutive + save() calls won't be properly replicated.""" + instance._do_not_copy_to_course_db = False + def should_replicate(instance): """Should this instance be replicated? We need to be a Portal server and the instance has to not have been marked_handled.""" @@ -398,9 +404,3 @@ def should_replicate(instance): return False return True - - - - - - diff --git a/common/djangoapps/student/tests.py b/common/djangoapps/student/tests.py index ad7ddb70d1..b33678fbac 100644 --- a/common/djangoapps/student/tests.py +++ b/common/djangoapps/student/tests.py @@ -4,6 +4,7 @@ when you run "manage.py test". Replace this with more appropriate tests for your application. """ +import logging from datetime import datetime from django.test import TestCase @@ -13,6 +14,8 @@ from .models import User, UserProfile, CourseEnrollment, replicate_user, USER_FI COURSE_1 = 'edX/toy/2012_Fall' COURSE_2 = 'edx/full/6.002_Spring_2012' +log = logging.getLogger(__name__) + class ReplicationTest(TestCase): multi_db = True @@ -47,23 +50,18 @@ class ReplicationTest(TestCase): field, portal_user, course_user )) - if hasattr(portal_user, 'seen_response_count'): - # Since it's the first copy over of User data, we should have all of it - self.assertEqual(portal_user.seen_response_count, - course_user.seen_response_count) - - # But if we replicate again, the user already exists in the Course DB, - # so it shouldn't update the seen_response_count (which is Askbot - # controlled). # This hasattr lameness is here because we don't want this test to be # triggered when we're being run by CMS tests (Askbot doesn't exist # there, so the test will fail). + # + # seen_response_count isn't a field we care about, so it shouldn't have + # been copied over. if hasattr(portal_user, 'seen_response_count'): portal_user.seen_response_count = 20 replicate_user(portal_user, COURSE_1) course_user = User.objects.using(COURSE_1).get(id=portal_user.id) self.assertEqual(portal_user.seen_response_count, 20) - self.assertEqual(course_user.seen_response_count, 10) + self.assertEqual(course_user.seen_response_count, 0) # Another replication should work for an email change however, since # it's a field we care about. @@ -123,6 +121,25 @@ class ReplicationTest(TestCase): UserProfile.objects.using(COURSE_2).get, id=portal_user_profile.id) + log.debug("Make sure our seen_response_count is not replicated.") + if hasattr(portal_user, 'seen_response_count'): + portal_user.seen_response_count = 200 + course_user = User.objects.using(COURSE_1).get(id=portal_user.id) + self.assertEqual(portal_user.seen_response_count, 200) + self.assertEqual(course_user.seen_response_count, 0) + portal_user.save() + + course_user = User.objects.using(COURSE_1).get(id=portal_user.id) + self.assertEqual(portal_user.seen_response_count, 200) + self.assertEqual(course_user.seen_response_count, 0) + + portal_user.email = 'jim@edx.org' + portal_user.save() + course_user = User.objects.using(COURSE_1).get(id=portal_user.id) + self.assertEqual(portal_user.email, 'jim@edx.org') + self.assertEqual(course_user.email, 'jim@edx.org') + + def test_enrollment_for_user_info_after_enrollment(self): """Test the effect of modifying User data after you've enrolled."""