diff --git a/common/djangoapps/student/models.py b/common/djangoapps/student/models.py index 49d3381303..50175b2ac8 100644 --- a/common/djangoapps/student/models.py +++ b/common/djangoapps/student/models.py @@ -10,14 +10,21 @@ file and check it in at the same time as your model changes. To do that, """ from datetime import datetime import json +import logging import uuid -from django.db import models +from django.conf import settings from django.contrib.auth.models import User +from django.db import models +from django.db.models.signals import post_delete, post_save +from django.dispatch import receiver from django_countries import CountryField +from xmodule.modulestore.django import modulestore + #from cache_toolbox import cache_model, cache_relation +log = logging.getLogger(__name__) class UserProfile(models.Model): class Meta: @@ -203,3 +210,59 @@ def add_user_to_default_group(user, group): utg.save() utg.users.add(User.objects.get(username=user)) utg.save() + + + + + +################################# SIGNALS ###################################### + +def is_valid_course_id(course_id): + """We check to both make sure that it's a valid course_id (and not + 'default', or some other non-course DB name) and that we have a mapping + for what database it belongs to.""" + course_ids = set(course.id for course in modulestore().get_courses()) + return (course_id in course_ids) and (course_id in settings.DATABASES) + +def is_portal(): + """Are we in the portal pool? (in which case we'll have to replicate user + updates). Right now, that means we have more than one database defined.""" + return len(settings.DATABASES) > 1 + +def replicate_enrollment(instance_method, **kwargs): + log.debug("########## Enrollment replication called ############") + instance = kwargs['instance'] + + if not is_portal(): + log.debug("replicate_enrollment triggered, but we're not a portal so " + + "we're not propogating") + return + + if not is_valid_course_id(instance.course_id): + log.error("Don't know where to replicate to for course_id: {0}" + .format(instance.course_id)) + return + + # We create a _replicated attribute to differentiate the first save of this + # model vs. the duplicate save we force on to the course database. + if hasattr(instance, '_replicated'): + log.debug("We've already replicated this -- stopping so we don't go " + + "into an infinite loop.") + return + instance._replicated = True + + # instance_method is either CourseEnrollment.save or CourseEnrollment.delete + # using is the entry in DATABASES we push to (we use course_ids for names) + instance_method(instance, using=instance.course_id) + +@receiver(post_save, sender=CourseEnrollment) +def replicate_enrollment_save(sender, **kwargs): + return replicate_enrollment(CourseEnrollment.save, **kwargs) + +@receiver(post_delete, sender=CourseEnrollment) +def replicate_enrollment_delete(sender, **kwargs): + return replicate_enrollment(CourseEnrollment.delete, **kwargs) + + + +