diff --git a/lms/djangoapps/courseware/model_data.py b/lms/djangoapps/courseware/model_data.py index 8a3bcb7007..6f5c844fb7 100644 --- a/lms/djangoapps/courseware/model_data.py +++ b/lms/djangoapps/courseware/model_data.py @@ -42,6 +42,32 @@ def chunks(items, chunk_size): return (items[i:i + chunk_size] for i in xrange(0, len(items), chunk_size)) +def _query(model_class, select_for_update, **kwargs): + """ + Queries model_class with **kwargs, optionally adding select_for_update if + `select_for_update` is True. + """ + query = model_class.objects + if select_for_update: + query = query.select_for_update() + query = query.filter(**kwargs) + return query + +def _chunked_query(model_class, select_for_update, chunk_field, items, chunk_size=500, **kwargs): + """ + Queries model_class with `chunk_field` set to chunks of size `chunk_size`, + and all other parameters from `**kwargs`. + + This works around a limitation in sqlite3 on the number of parameters + that can be put into a single query. + """ + res = chain.from_iterable( + _query(model_class, select_for_update, **dict([(chunk_field, chunk)] + kwargs.items())) + for chunk in chunks(items, chunk_size) + ) + return res + + class FieldDataCache(object): """ A cache of django model objects needed to supply the data @@ -142,30 +168,6 @@ class FieldDataCache(object): cache.add_descriptor_descendents(descriptor, depth, descriptor_filter) return cache - def _query(self, model_class, **kwargs): - """ - Queries model_class with **kwargs, optionally adding select_for_update if - self.select_for_update is set - """ - query = model_class.objects - if self.select_for_update: - query = query.select_for_update() - query = query.filter(**kwargs) - return query - - def _chunked_query(self, model_class, chunk_field, items, chunk_size=500, **kwargs): - """ - Queries model_class with `chunk_field` set to chunks of size `chunk_size`, - and all other parameters from `**kwargs` - - This works around a limitation in sqlite3 on the number of parameters - that can be put into a single query - """ - res = chain.from_iterable( - self._query(model_class, **dict([(chunk_field, chunk)] + kwargs.items())) - for chunk in chunks(items, chunk_size) - ) - return res def _all_usage_ids(self, descriptors): """ @@ -199,31 +201,35 @@ class FieldDataCache(object): Queries the database for all of the fields in the specified scope """ if scope == Scope.user_state: - return self._chunked_query( + return _chunked_query( StudentModule, + self.select_for_update, 'module_state_key__in', self._all_usage_ids(descriptors), course_id=self.course_id, student=self.user.pk, ) elif scope == Scope.user_state_summary: - return self._chunked_query( + return _chunked_query( XModuleUserStateSummaryField, + self.select_for_update, 'usage_id__in', self._all_usage_ids(descriptors), field_name__in=set(field.name for field in fields), ) elif scope == Scope.preferences: - return self._chunked_query( + return _chunked_query( XModuleStudentPrefsField, + self.select_for_update, 'module_type__in', self._all_block_types(descriptors), student=self.user.pk, field_name__in=set(field.name for field in fields), ) elif scope == Scope.user_info: - return self._query( + return _query( XModuleStudentInfoField, + self.select_for_update, student=self.user.pk, field_name__in=set(field.name for field in fields), )