Cache cohort info during requests to reduce SQL queries.
TNL-1258
This commit is contained in:
@@ -110,12 +110,12 @@ def is_course_cohorted(course_key):
|
||||
return get_course_cohort_settings(course_key).is_cohorted
|
||||
|
||||
|
||||
def get_cohort_id(user, course_key):
|
||||
def get_cohort_id(user, course_key, use_cached=False):
|
||||
"""
|
||||
Given a course key and a user, return the id of the cohort that user is
|
||||
assigned to in that course. If they don't have a cohort, return None.
|
||||
"""
|
||||
cohort = get_cohort(user, course_key)
|
||||
cohort = get_cohort(user, course_key, use_cached=use_cached)
|
||||
return None if cohort is None else cohort.id
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ def get_cohorted_commentables(course_key):
|
||||
|
||||
|
||||
@transaction.commit_on_success
|
||||
def get_cohort(user, course_key, assign=True):
|
||||
def get_cohort(user, course_key, assign=True, use_cached=False):
|
||||
"""
|
||||
Given a Django user and a CourseKey, return the user's cohort in that
|
||||
cohort.
|
||||
@@ -181,6 +181,7 @@ def get_cohort(user, course_key, assign=True):
|
||||
user: a Django User object.
|
||||
course_key: CourseKey
|
||||
assign (bool): if False then we don't assign a group to user
|
||||
use_cached (bool): Whether to use the cached value or fetch from database.
|
||||
|
||||
Returns:
|
||||
A CourseUserGroup object if the course is cohorted and the User has a
|
||||
@@ -189,23 +190,33 @@ def get_cohort(user, course_key, assign=True):
|
||||
Raises:
|
||||
ValueError if the CourseKey doesn't exist.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
# We cache the cohort on the user object so that we do not have to repeatedly
|
||||
# query the database during a request. If the cached value exists, just return it.
|
||||
if use_cached and hasattr(user, '_cohort'):
|
||||
return user._cohort
|
||||
|
||||
# First check whether the course is cohorted (users shouldn't be in a cohort
|
||||
# in non-cohorted courses, but settings can change after course starts)
|
||||
course = courses.get_course(course_key)
|
||||
course_cohort_settings = get_course_cohort_settings(course.id)
|
||||
|
||||
if not course_cohort_settings.is_cohorted:
|
||||
return None
|
||||
user._cohort = None
|
||||
return user._cohort
|
||||
|
||||
try:
|
||||
return CourseUserGroup.objects.get(
|
||||
user._cohort = CourseUserGroup.objects.get(
|
||||
course_id=course_key,
|
||||
group_type=CourseUserGroup.COHORT,
|
||||
users__id=user.id,
|
||||
)
|
||||
return user._cohort
|
||||
except CourseUserGroup.DoesNotExist:
|
||||
# Didn't find the group. We'll go on to create one if needed.
|
||||
if not assign:
|
||||
# Do not cache the cohort here, because in the next call assign
|
||||
# may be True, and we will have to assign the user a cohort.
|
||||
return None
|
||||
|
||||
cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM)
|
||||
@@ -220,7 +231,8 @@ def get_cohort(user, course_key, assign=True):
|
||||
|
||||
user.course_groups.add(cohort)
|
||||
|
||||
return cohort
|
||||
user._cohort = cohort
|
||||
return user._cohort
|
||||
|
||||
|
||||
def migrate_cohort_settings(course):
|
||||
@@ -387,7 +399,7 @@ def add_user_to_cohort(cohort, username_or_email):
|
||||
return (user, previous_cohort_name)
|
||||
|
||||
|
||||
def get_group_info_for_cohort(cohort):
|
||||
def get_group_info_for_cohort(cohort, use_cached=False):
|
||||
"""
|
||||
Get the ids of the group and partition to which this cohort has been linked
|
||||
as a tuple of (int, int).
|
||||
@@ -395,9 +407,21 @@ def get_group_info_for_cohort(cohort):
|
||||
If the cohort has not been linked to any group/partition, both values in the
|
||||
tuple will be None.
|
||||
"""
|
||||
res = CourseUserGroupPartitionGroup.objects.filter(course_user_group=cohort)
|
||||
if len(res):
|
||||
return res[0].group_id, res[0].partition_id
|
||||
# pylint: disable=protected-access
|
||||
# We cache the partition group on the cohort object so that we do not have to repeatedly
|
||||
# query the database during a request.
|
||||
if not use_cached and hasattr(cohort, '_partition_group'):
|
||||
delattr(cohort, '_partition_group')
|
||||
|
||||
if not hasattr(cohort, '_partition_group'):
|
||||
try:
|
||||
cohort._partition_group = CourseUserGroupPartitionGroup.objects.get(course_user_group=cohort)
|
||||
except CourseUserGroupPartitionGroup.DoesNotExist:
|
||||
cohort._partition_group = None
|
||||
|
||||
if cohort._partition_group:
|
||||
return cohort._partition_group.group_id, cohort._partition_group.partition_id
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class CohortPartitionScheme(object):
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@classmethod
|
||||
def get_group_for_user(cls, course_key, user, user_partition, track_function=None):
|
||||
def get_group_for_user(cls, course_key, user, user_partition, track_function=None, use_cached=True):
|
||||
"""
|
||||
Returns the Group from the specified user partition to which the user
|
||||
is assigned, via their cohort membership and any mappings from cohorts
|
||||
@@ -48,12 +48,12 @@ class CohortPartitionScheme(object):
|
||||
return None
|
||||
return None
|
||||
|
||||
cohort = get_cohort(user, course_key)
|
||||
cohort = get_cohort(user, course_key, use_cached=use_cached)
|
||||
if cohort is None:
|
||||
# student doesn't have a cohort
|
||||
return None
|
||||
|
||||
group_id, partition_id = get_group_info_for_cohort(cohort)
|
||||
group_id, partition_id = get_group_info_for_cohort(cohort, use_cached=use_cached)
|
||||
if partition_id is None:
|
||||
# cohort isn't mapped to any partition group.
|
||||
return None
|
||||
|
||||
@@ -10,8 +10,7 @@ from opaque_keys.edx.locations import SlashSeparatedCourseKey
|
||||
from xmodule.modulestore.django import modulestore
|
||||
from xmodule.modulestore import ModuleStoreEnum
|
||||
|
||||
import json
|
||||
from ..cohorts import get_course_cohort_settings, set_course_cohort_settings
|
||||
from ..cohorts import set_course_cohort_settings
|
||||
from ..models import CourseUserGroup, CourseCohort, CourseCohortsSettings
|
||||
|
||||
|
||||
@@ -126,6 +125,7 @@ def config_course_cohorts_legacy(
|
||||
pass
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def config_course_cohorts(
|
||||
course,
|
||||
is_cohorted,
|
||||
@@ -154,13 +154,14 @@ def config_course_cohorts(
|
||||
Nothing -- modifies course in place.
|
||||
"""
|
||||
def to_id(name):
|
||||
"""Convert name to id."""
|
||||
return topic_name_to_id(course, name)
|
||||
|
||||
set_course_cohort_settings(
|
||||
course.id,
|
||||
is_cohorted = is_cohorted,
|
||||
cohorted_discussions = [to_id(name) for name in cohorted_discussions],
|
||||
always_cohort_inline_discussions = always_cohort_inline_discussions
|
||||
is_cohorted=is_cohorted,
|
||||
cohorted_discussions=[to_id(name) for name in cohorted_discussions],
|
||||
always_cohort_inline_discussions=always_cohort_inline_discussions
|
||||
)
|
||||
|
||||
for cohort_name in auto_cohorts:
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
Tests for cohorts
|
||||
"""
|
||||
# pylint: disable=no-member
|
||||
import ddt
|
||||
from mock import call, patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import IntegrityError
|
||||
from django.http import Http404
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
from mock import call, patch
|
||||
|
||||
from opaque_keys.edx.locations import SlashSeparatedCourseKey
|
||||
from student.models import CourseEnrollment
|
||||
@@ -121,6 +122,7 @@ class TestCohortSignals(TestCase):
|
||||
self.assertFalse(mock_tracker.emit.called)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class TestCohorts(ModuleStoreTestCase):
|
||||
"""
|
||||
Test the cohorts feature
|
||||
@@ -243,12 +245,33 @@ class TestCohorts(ModuleStoreTestCase):
|
||||
cohort.id,
|
||||
"user should be assigned to the correct cohort"
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
cohorts.get_cohort(other_user, course.id).id,
|
||||
cohorts.get_cohort_by_name(course.id, cohorts.DEFAULT_COHORT_NAME).id,
|
||||
"other_user should be assigned to the default cohort"
|
||||
)
|
||||
|
||||
@ddt.data(
|
||||
(True, 2),
|
||||
(False, 6),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
|
||||
"""
|
||||
Test number of queries by cohorts.get_cohort() with and without caching.
|
||||
"""
|
||||
course = modulestore().get_course(self.toy_course_key)
|
||||
config_course_cohorts(course, is_cohorted=True)
|
||||
cohort = CohortFactory(course_id=course.id, name="TestCohort")
|
||||
|
||||
user = UserFactory(username="test", email="a@b.com")
|
||||
cohort.users.add(user)
|
||||
|
||||
with self.assertNumQueries(num_sql_queries):
|
||||
for __ in range(3):
|
||||
cohorts.get_cohort(user, course.id, use_cached=use_cached)
|
||||
|
||||
def test_get_cohort_with_assign(self):
|
||||
"""
|
||||
Make sure cohorts.get_cohort() returns None if no group is already
|
||||
@@ -473,7 +496,7 @@ class TestCohorts(ModuleStoreTestCase):
|
||||
config_course_cohorts(
|
||||
course,
|
||||
is_cohorted=True,
|
||||
discussion_topics= ["General", "Feedback"],
|
||||
discussion_topics=["General", "Feedback"],
|
||||
cohorted_discussions=["Feedback"]
|
||||
)
|
||||
|
||||
@@ -497,7 +520,7 @@ class TestCohorts(ModuleStoreTestCase):
|
||||
config_course_cohorts(
|
||||
course,
|
||||
is_cohorted=True,
|
||||
discussion_topics =["General", "Feedback"],
|
||||
discussion_topics=["General", "Feedback"],
|
||||
cohorted_discussions=["Feedback", "random_inline"]
|
||||
)
|
||||
self.assertTrue(
|
||||
@@ -741,6 +764,7 @@ class TestCohorts(ModuleStoreTestCase):
|
||||
)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
|
||||
"""
|
||||
Test Cohorts and Partitions Groups.
|
||||
@@ -803,6 +827,25 @@ class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
|
||||
(None, None),
|
||||
)
|
||||
|
||||
@ddt.data(
|
||||
(True, 1),
|
||||
(False, 3),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_get_group_info_for_cohort_queries(self, use_cached, num_sql_queries):
|
||||
"""
|
||||
Basic test of the partition_group_info accessor function
|
||||
"""
|
||||
# create a link for the cohort in the db
|
||||
self._link_cohort_partition_group(
|
||||
self.first_cohort,
|
||||
self.partition_id,
|
||||
self.group1_id
|
||||
)
|
||||
with self.assertNumQueries(num_sql_queries):
|
||||
for __ in range(3):
|
||||
self.assertIsNotNone(cohorts.get_group_info_for_cohort(self.first_cohort, use_cached=use_cached))
|
||||
|
||||
def test_multiple_cohorts(self):
|
||||
"""
|
||||
Test that multiple cohorts can be linked to the same partition group
|
||||
|
||||
@@ -63,6 +63,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
|
||||
self.course_key,
|
||||
self.student,
|
||||
partition or self.user_partition,
|
||||
use_cached=False
|
||||
),
|
||||
group
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user