Merge pull request #9328 from edx/clintonb/compressed-text-field-fix

Resolved decompression errors during tests
This commit is contained in:
Clinton Blackburn
2015-08-18 13:58:07 -04:00

View File

@@ -4,12 +4,10 @@ import gzip
import logging
from django.db import models
from django.db.models.signals import post_init
from django.utils.text import compress_string
from config_models.models import ConfigurationModel
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -25,27 +23,31 @@ class RateLimitConfiguration(ConfigurationModel):
pass
def uncompress_string(s):
def decompress_string(value):
"""
Helper function to reverse CompressedTextField.get_prep_value.
"""
try:
val = s.encode('utf').decode('base64')
val = value.encode('utf').decode('base64')
zbuf = cStringIO.StringIO(val)
zfile = gzip.GzipFile(fileobj=zbuf)
ret = zfile.read()
zfile.close()
except Exception as e:
logger.error('String decompression failed. There may be corrupted data in the database: %s', e)
ret = s
ret = value
return ret
class CompressedTextField(models.TextField):
"""transparently compress data before hitting the db and uncompress after fetching"""
""" TextField that transparently compresses data when saving to the database, and decompresses the data
when retrieving it from the database. """
__metaclass__ = models.SubfieldBase
def get_prep_value(self, value):
""" Compress the text data. """
if value is not None:
if isinstance(value, unicode):
value = value.encode('utf8')
@@ -53,28 +55,12 @@ class CompressedTextField(models.TextField):
value = value.encode('base64').decode('utf8')
return value
def post_init(self, instance=None, **kwargs): # pylint: disable=unused-argument
value = self._get_val_from_obj(instance)
if value:
setattr(instance, self.attname, value)
def to_python(self, value):
""" Decompresses the value from the database. """
if isinstance(value, unicode):
value = decompress_string(value)
def contribute_to_class(self, cls, name):
super(CompressedTextField, self).contribute_to_class(cls, name)
post_init.connect(self.post_init, sender=cls)
def _get_val_from_obj(self, obj):
if obj:
value = uncompress_string(getattr(obj, self.attname))
if value is not None:
try:
value = value.decode('utf8')
except UnicodeDecodeError:
pass
return value
else:
return self.get_default()
else:
return self.get_default()
return value
def south_field_triple(self):
"""Returns a suitable description of this field for South."""