diff --git a/lms/djangoapps/django_comment_client/tests/utils.py b/lms/djangoapps/django_comment_client/tests/utils.py index 399c8c10e1..e9c0ca966b 100644 --- a/lms/djangoapps/django_comment_client/tests/utils.py +++ b/lms/djangoapps/django_comment_client/tests/utils.py @@ -34,19 +34,19 @@ class CohortedTestCase(SharedModuleStoreTestCase): def setUp(self): super(CohortedTestCase, self).setUp() - self.student_cohort = CohortFactory.create( - name="student_cohort", - course_id=self.course.id - ) - self.moderator_cohort = CohortFactory.create( - name="moderator_cohort", - course_id=self.course.id - ) seed_permissions_roles(self.course.id) self.student = UserFactory.create() self.moderator = UserFactory.create() CourseEnrollmentFactory(user=self.student, course_id=self.course.id) CourseEnrollmentFactory(user=self.moderator, course_id=self.course.id) self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id)) - self.student_cohort.users.add(self.student) - self.moderator_cohort.users.add(self.moderator) + self.student_cohort = CohortFactory.create( + name="student_cohort", + course_id=self.course.id, + users=[self.student] + ) + self.moderator_cohort = CohortFactory.create( + name="moderator_cohort", + course_id=self.course.id, + users=[self.moderator] + ) diff --git a/lms/djangoapps/instructor_analytics/tests/test_basic.py b/lms/djangoapps/instructor_analytics/tests/test_basic.py index daa3b45e7a..efd98560d0 100644 --- a/lms/djangoapps/instructor_analytics/tests/test_basic.py +++ b/lms/djangoapps/instructor_analytics/tests/test_basic.py @@ -133,8 +133,8 @@ class TestAnalyticsBasic(ModuleStoreTestCase): course = CourseFactory.create(org="test", course="course1", display_name="run1") course.cohort_config = {'cohorted': True, 'auto_cohort': True, 'auto_cohort_groups': ['cohort']} self.store.update_item(course, self.instructor.id) - cohort = CohortFactory.create(name='cohort', course_id=course.id) cohorted_students = [UserFactory.create() for _ in xrange(10)] + cohort = CohortFactory.create(name='cohort', course_id=course.id, users=cohorted_students) cohorted_usernames = [student.username for student in cohorted_students] non_cohorted_student = UserFactory.create() for student in cohorted_students: diff --git a/lms/djangoapps/instructor_task/tasks_helper.py b/lms/djangoapps/instructor_task/tasks_helper.py index f277f28e2c..8808db2a5a 100644 --- a/lms/djangoapps/instructor_task/tasks_helper.py +++ b/lms/djangoapps/instructor_task/tasks_helper.py @@ -1505,8 +1505,7 @@ def cohort_students_and_upload(_xmodule_instance_args, _entry_id, course_id, tas continue try: - with outer_atomic(): - add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email) + add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email) cohorts_status[cohort_name]['Students Added'] += 1 task_progress.succeeded += 1 except User.DoesNotExist: diff --git a/lms/djangoapps/mobile_api/video_outlines/tests.py b/lms/djangoapps/mobile_api/video_outlines/tests.py index 3e8f0a2f14..ef4d160653 100644 --- a/lms/djangoapps/mobile_api/video_outlines/tests.py +++ b/lms/djangoapps/mobile_api/video_outlines/tests.py @@ -17,6 +17,7 @@ from xmodule.partitions.partitions import Group, UserPartition from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup +from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, remove_user_from_cohort from ..testutils import MobileAPITestCase, MobileAuthTestMixin, MobileCourseAccessTestMixin @@ -744,7 +745,7 @@ class TestVideoSummaryList( for cohort_index in range(len(cohorts)): # add user to this cohort - cohorts[cohort_index].users.add(self.user) + add_user_to_cohort(cohorts[cohort_index], self.user.username) # should only see video for this cohort video_outline = self.api_response().data @@ -755,7 +756,7 @@ class TestVideoSummaryList( ) # remove user from this cohort - cohorts[cohort_index].users.remove(self.user) + remove_user_from_cohort(cohorts[cohort_index], self.user.username) # un-cohorted user should see no videos video_outline = self.api_response().data diff --git a/lms/djangoapps/notifier_api/tests.py b/lms/djangoapps/notifier_api/tests.py index 4d040d3564..4b173c77f5 100644 --- a/lms/djangoapps/notifier_api/tests.py +++ b/lms/djangoapps/notifier_api/tests.py @@ -5,7 +5,7 @@ from django.conf import settings from django.test.client import RequestFactory from django.test.utils import override_settings -from openedx.core.djangoapps.course_groups.models import CourseUserGroup +from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from django_comment_common.models import Role, Permission from lang_pref import LANGUAGE_KEY from notification_prefs import NOTIFICATION_PREF_KEY @@ -46,12 +46,11 @@ class NotifierUsersViewSetTest(UrlResetMixin, ModuleStoreTestCase): self.courses.append(course) CourseEnrollmentFactory(user=self.user, course_id=course.id) if is_user_cohorted: - cohort = CourseUserGroup.objects.create( + cohort = CohortFactory.create( name="Test Cohort", course_id=course.id, - group_type=CourseUserGroup.COHORT + users=[self.user] ) - cohort.users.add(self.user) self.cohorts.append(cohort) if is_moderator: moderator_perm, _ = Permission.objects.get_or_create(name="see_all_cohorts") diff --git a/openedx/core/djangoapps/course_groups/cohorts.py b/openedx/core/djangoapps/course_groups/cohorts.py index 794c41b65f..cf4827cabb 100644 --- a/openedx/core/djangoapps/course_groups/cohorts.py +++ b/openedx/core/djangoapps/course_groups/cohorts.py @@ -181,13 +181,12 @@ def get_cohort(user, course_key, assign=True, use_cached=False): # If course is cohorted, check if the user already has a cohort. try: - cohort = CourseUserGroup.objects.get( + membership = CohortMembership.objects.get( course_id=course_key, - group_type=CourseUserGroup.COHORT, - users__id=user.id, + user_id=user.id, ) - return request_cache.data.setdefault(cache_key, cohort) - except CourseUserGroup.DoesNotExist: + return request_cache.data.setdefault(cache_key, membership.course_user_group) + except CohortMembership.DoesNotExist: # Didn't find the group. If we do not want to assign, return here. if not assign: # Do not cache the cohort here, because in the next call assign @@ -195,6 +194,17 @@ def get_cohort(user, course_key, assign=True, use_cached=False): return None # Otherwise assign the user a cohort. + membership = CohortMembership.objects.create( + user=user, + course_user_group=_get_default_cohort(course_key) + ) + return request_cache.data.setdefault(cache_key, membership.course_user_group) + + +def _get_default_cohort(course_key): + """ + Helper method to get a default cohort for assignment in get_cohort + """ course = courses.get_course(course_key) cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM) if cohorts: @@ -205,11 +215,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False): course_id=course_key, assignment_type=CourseCohort.RANDOM ).course_user_group - - membership = CohortMembership(course_user_group=cohort, user=user) - membership.save() - - return request_cache.data.setdefault(cache_key, cohort) + return cohort def migrate_cohort_settings(course): @@ -332,6 +338,27 @@ def is_cohort_exists(course_key, name): return CourseUserGroup.objects.filter(course_id=course_key, group_type=CourseUserGroup.COHORT, name=name).exists() +def remove_user_from_cohort(cohort, username_or_email): + """ + Look up the given user, and if successful, remove them from the specified cohort. + + Arguments: + cohort: CourseUserGroup + username_or_email: string. Treated as email if has '@' + + Raises: + User.DoesNotExist if can't find user. + ValueError if user not already present in this cohort. + """ + user = get_user_by_username_or_email(username_or_email) + + try: + membership = CohortMembership.objects.get(course_user_group=cohort, user=user) + membership.delete() + except CohortMembership.DoesNotExist: + raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort)) + + def add_user_to_cohort(cohort, username_or_email): """ Look up the given user, and if successful, add them to the specified cohort. @@ -350,7 +377,7 @@ def add_user_to_cohort(cohort, username_or_email): user = get_user_by_username_or_email(username_or_email) membership = CohortMembership(course_user_group=cohort, user=user) - membership.save() + membership.save() # This will handle both cases, creation and updating, of a CohortMembership for this user. tracker.emit( "edx.cohort.user_add_requested", diff --git a/openedx/core/djangoapps/course_groups/models.py b/openedx/core/djangoapps/course_groups/models.py index 3fa34f0f47..7d1c15f303 100644 --- a/openedx/core/djangoapps/course_groups/models.py +++ b/openedx/core/djangoapps/course_groups/models.py @@ -7,7 +7,10 @@ import logging from django.contrib.auth.models import User from django.db import models, transaction, IntegrityError +from util.db import outer_atomic from django.core.exceptions import ValidationError +from django.db.models.signals import pre_delete +from django.dispatch import receiver from xmodule_django.models import CourseKeyField log = logging.getLogger(__name__) @@ -85,55 +88,63 @@ class CohortMembership(models.Model): raise ValidationError("Non-matching course_ids provided") def save(self, *args, **kwargs): - # Avoid infinite recursion if creating from get_or_create() call below. - if 'force_insert' in kwargs and kwargs['force_insert'] is True: - super(CohortMembership, self).save(*args, **kwargs) - return - self.full_clean(validate_unique=False) - # This loop has been created to allow for optimistic locking, and retrial in case of losing a race condition. - # The limit is 2, since select_for_update ensures atomic updates. Creation is the only possible race condition. - max_retries = 2 - success = False - for __ in range(max_retries): - + # Avoid infinite recursion if creating from get_or_create() call below. + # This block also allows middleware to use CohortMembership.get_or_create without worrying about outer_atomic + if 'force_insert' in kwargs and kwargs['force_insert'] is True: with transaction.atomic(): - - try: - with transaction.atomic(): - saved_membership, created = CohortMembership.objects.select_for_update().get_or_create( - user__id=self.user.id, - course_id=self.course_id, - defaults={ - 'course_user_group': self.course_user_group, - 'user': self.user - } - ) - except IntegrityError: # This can happen if simultaneous requests try to create a membership - continue - - if not created: - if saved_membership.course_user_group == self.course_user_group: - raise ValueError("User {user_name} already present in cohort {cohort_name}".format( - user_name=self.user.username, - cohort_name=self.course_user_group.name - )) - self.previous_cohort = saved_membership.course_user_group - self.previous_cohort_name = saved_membership.course_user_group.name - self.previous_cohort_id = saved_membership.course_user_group.id - self.previous_cohort.users.remove(self.user) - - saved_membership.course_user_group = self.course_user_group self.course_user_group.users.add(self.user) + self.course_user_group.save() + super(CohortMembership, self).save(*args, **kwargs) + return - super(CohortMembership, saved_membership).save(update_fields=['course_user_group']) + # This block will transactionally commit updates to CohortMembership and underlying course_user_groups. + # Note the use of outer_atomic, which guarantees that operations are committed to the database on block exit. + # If called from a view method, that method must be marked with @transaction.non_atomic_requests. + with outer_atomic(read_committed=True): - success = True - break + saved_membership, created = CohortMembership.objects.select_for_update().get_or_create( + user__id=self.user.id, + course_id=self.course_id, + defaults={ + 'course_user_group': self.course_user_group, + 'user': self.user + } + ) - if not success: - raise IntegrityError("Unable to save membership after {} tries, aborting.".format(max_retries)) + # If the membership was newly created, all the validation and course_user_group logic was settled + # with a call to self.save(force_insert=True), which gets handled above. + if created: + return + + if saved_membership.course_user_group == self.course_user_group: + raise ValueError("User {user_name} already present in cohort {cohort_name}".format( + user_name=self.user.username, + cohort_name=self.course_user_group.name + )) + self.previous_cohort = saved_membership.course_user_group + self.previous_cohort_name = saved_membership.course_user_group.name + self.previous_cohort_id = saved_membership.course_user_group.id + self.previous_cohort.users.remove(self.user) + self.previous_cohort.save() + + saved_membership.course_user_group = self.course_user_group + self.course_user_group.users.add(self.user) + self.course_user_group.save() + + super(CohortMembership, saved_membership).save(update_fields=['course_user_group']) + + +# Needs to exist outside class definition in order to use 'sender=CohortMembership' +@receiver(pre_delete, sender=CohortMembership) +def remove_user_from_cohort(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + Ensures that when a CohortMemebrship is deleted, the underlying CourseUserGroup + has its users list updated to reflect the change as well. + """ + instance.course_user_group.users.remove(instance.user) + instance.course_user_group.save() class CourseUserGroupPartitionGroup(models.Model): diff --git a/openedx/core/djangoapps/course_groups/tests/helpers.py b/openedx/core/djangoapps/course_groups/tests/helpers.py index d5c3cf38a1..f5d09dfad2 100644 --- a/openedx/core/djangoapps/course_groups/tests/helpers.py +++ b/openedx/core/djangoapps/course_groups/tests/helpers.py @@ -32,6 +32,11 @@ class CohortFactory(DjangoModelFactory): """ if extracted: self.users.add(*extracted) + for user in self.users.all(): + CohortMembership.objects.create( + user=user, + course_user_group=self, + ) class CourseCohortFactory(DjangoModelFactory): @@ -41,18 +46,6 @@ class CourseCohortFactory(DjangoModelFactory): class Meta(object): model = CourseCohort - @post_generation - def memberships(self, create, extracted, **kwargs): # pylint: disable=unused-argument - """ - Returns the memberships linking users to this cohort. - """ - for user in self.course_user_group.users.all(): # pylint: disable=E1101 - membership = CohortMembership(user=user, course_user_group=self.course_user_group) - membership.save() - - course_user_group = factory.SubFactory(CohortFactory) - assignment_type = 'manual' - class CourseCohortSettingsFactory(DjangoModelFactory): """ diff --git a/openedx/core/djangoapps/course_groups/tests/test_cohorts.py b/openedx/core/djangoapps/course_groups/tests/test_cohorts.py index d72c53a4c0..3771563702 100644 --- a/openedx/core/djangoapps/course_groups/tests/test_cohorts.py +++ b/openedx/core/djangoapps/course_groups/tests/test_cohorts.py @@ -179,8 +179,7 @@ class TestCohorts(ModuleStoreTestCase): self.assertIsNone(cohorts.get_cohort_id(user, course.id)) config_course_cohorts(course, is_cohorted=True) - cohort = CohortFactory(course_id=course.id, name="TestCohort") - cohort.users.add(user) + cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user]) self.assertEqual(cohorts.get_cohort_id(user, course.id), cohort.id) self.assertRaises( @@ -237,8 +236,7 @@ class TestCohorts(ModuleStoreTestCase): self.assertIsNone(cohorts.get_cohort(user, course.id), "No cohort created yet") - cohort = CohortFactory(course_id=course.id, name="TestCohort") - cohort.users.add(user) + cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user]) self.assertIsNone( cohorts.get_cohort(user, course.id), @@ -261,8 +259,8 @@ class TestCohorts(ModuleStoreTestCase): ) @ddt.data( - (True, 2), - (False, 6), + (True, 3), + (False, 9), ) @ddt.unpack def test_get_cohort_sql_queries(self, use_cached, num_sql_queries): @@ -271,10 +269,8 @@ class TestCohorts(ModuleStoreTestCase): """ course = modulestore().get_course(self.toy_course_key) config_course_cohorts(course, is_cohorted=True) - cohort = CohortFactory(course_id=course.id, name="TestCohort") - user = UserFactory(username="test", email="a@b.com") - cohort.users.add(user) + CohortFactory.create(course_id=course.id, name="TestCohort", users=[user]) with self.assertNumQueries(num_sql_queries): for __ in range(3): @@ -314,10 +310,7 @@ class TestCohorts(ModuleStoreTestCase): user1 = UserFactory(username="test", email="a@b.com") user2 = UserFactory(username="test2", email="a2@b.com") - cohort = CohortFactory(course_id=course.id, name="TestCohort") - - # user1 manually added to a cohort - cohort.users.add(user1) + cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user1]) # Add an auto_cohort_group to the course... config_course_cohorts( diff --git a/openedx/core/djangoapps/course_groups/tests/test_partition_scheme.py b/openedx/core/djangoapps/course_groups/tests/test_partition_scheme.py index 093eca653f..f53708e314 100644 --- a/openedx/core/djangoapps/course_groups/tests/test_partition_scheme.py +++ b/openedx/core/djangoapps/course_groups/tests/test_partition_scheme.py @@ -21,7 +21,7 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti from ..partition_scheme import CohortPartitionScheme, get_cohorted_user_partition from ..models import CourseUserGroupPartitionGroup from ..views import link_cohort_to_partition_group, unlink_cohort_partition_group -from ..cohorts import add_user_to_cohort, get_course_cohorts +from ..cohorts import add_user_to_cohort, remove_user_from_cohort, get_course_cohorts from .helpers import CohortFactory, config_course_cohorts @@ -100,7 +100,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase): self.assert_student_in_group(self.groups[1]) # move the student out of the cohort - second_cohort.users.remove(self.student) + remove_user_from_cohort(second_cohort, self.student.username) self.assert_student_in_group(None) def test_cohort_partition_group_assignment(self): diff --git a/openedx/core/djangoapps/course_groups/views.py b/openedx/core/djangoapps/course_groups/views.py index 373991b57e..3e8e0f5ae9 100644 --- a/openedx/core/djangoapps/course_groups/views.py +++ b/openedx/core/djangoapps/course_groups/views.py @@ -10,6 +10,7 @@ from django.core.urlresolvers import reverse from django.http import Http404, HttpResponseBadRequest from django.views.decorators.http import require_http_methods from util.json_request import expect_json, JsonResponse +from django.db import transaction from django.contrib.auth.decorators import login_required from django.utils.translation import ugettext @@ -23,7 +24,7 @@ from edxmako.shortcuts import render_to_response from . import cohorts from lms.djangoapps.django_comment_client.utils import get_discussion_category_map, get_discussion_categories_ids -from .models import CourseUserGroup, CourseUserGroupPartitionGroup +from .models import CourseUserGroup, CourseUserGroupPartitionGroup, CohortMembership log = logging.getLogger(__name__) @@ -299,6 +300,7 @@ def users_in_cohort(request, course_key_string, cohort_id): 'users': user_info}) +@transaction.non_atomic_requests @ensure_csrf_cookie @require_POST def add_users_to_cohort(request, course_key_string, cohort_id): @@ -384,16 +386,22 @@ def remove_user_from_cohort(request, course_key_string, cohort_id): return json_http_response({'success': False, 'msg': 'No username specified'}) - cohort = cohorts.get_cohort_by_id(course_key, cohort_id) try: user = User.objects.get(username=username) - cohort.users.remove(user) - return json_http_response({'success': True}) except User.DoesNotExist: log.debug('no user') return json_http_response({'success': False, 'msg': "No user '{0}'".format(username)}) + try: + membership = CohortMembership.objects.get(user=user, course_id=course_key) + membership.delete() + + except CohortMembership.DoesNotExist: + pass + + return json_http_response({'success': True}) + def debug_cohort_mgmt(request, course_key_string): """