diff --git a/common/djangoapps/util/models.py b/common/djangoapps/util/models.py index 1c1ade043e..ff353a8039 100644 --- a/common/djangoapps/util/models.py +++ b/common/djangoapps/util/models.py @@ -4,12 +4,10 @@ import gzip import logging from django.db import models -from django.db.models.signals import post_init from django.utils.text import compress_string from config_models.models import ConfigurationModel - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -25,27 +23,31 @@ class RateLimitConfiguration(ConfigurationModel): pass -def uncompress_string(s): +def decompress_string(value): """ Helper function to reverse CompressedTextField.get_prep_value. """ try: - val = s.encode('utf').decode('base64') + val = value.encode('utf').decode('base64') zbuf = cStringIO.StringIO(val) zfile = gzip.GzipFile(fileobj=zbuf) ret = zfile.read() zfile.close() except Exception as e: logger.error('String decompression failed. There may be corrupted data in the database: %s', e) - ret = s + ret = value return ret class CompressedTextField(models.TextField): - """transparently compress data before hitting the db and uncompress after fetching""" + """ TextField that transparently compresses data when saving to the database, and decompresses the data + when retrieving it from the database. """ + + __metaclass__ = models.SubfieldBase def get_prep_value(self, value): + """ Compress the text data. """ if value is not None: if isinstance(value, unicode): value = value.encode('utf8') @@ -53,28 +55,12 @@ class CompressedTextField(models.TextField): value = value.encode('base64').decode('utf8') return value - def post_init(self, instance=None, **kwargs): # pylint: disable=unused-argument - value = self._get_val_from_obj(instance) - if value: - setattr(instance, self.attname, value) + def to_python(self, value): + """ Decompresses the value from the database. """ + if isinstance(value, unicode): + value = decompress_string(value) - def contribute_to_class(self, cls, name): - super(CompressedTextField, self).contribute_to_class(cls, name) - post_init.connect(self.post_init, sender=cls) - - def _get_val_from_obj(self, obj): - if obj: - value = uncompress_string(getattr(obj, self.attname)) - if value is not None: - try: - value = value.decode('utf8') - except UnicodeDecodeError: - pass - return value - else: - return self.get_default() - else: - return self.get_default() + return value def south_field_triple(self): """Returns a suitable description of this field for South."""