diff --git a/common/djangoapps/course_groups/models.py b/common/djangoapps/course_groups/models.py index 36c9927e18..46859a900e 100644 --- a/common/djangoapps/course_groups/models.py +++ b/common/djangoapps/course_groups/models.py @@ -40,15 +40,28 @@ def get_cohort(user, course_id): Returns: A CourseUserGroup object if the User has a cohort, or None. """ - group_type = CourseUserGroup.COHORT try: - group = CourseUserGroup.objects.get(course_id=course_id, group_type=group_type, - users__id=user.id) + group = CourseUserGroup.objects.get(course_id=course_id, + group_type=CourseUserGroup.COHORT, + users__id=user.id) except CourseUserGroup.DoesNotExist: group = None - + if group: return group # TODO: add auto-cohorting logic here return None + +def get_course_cohorts(course_id): + """ + Get a list of all the cohorts in the given course. + + Arguments: + course_id: string in the format 'org/course/run' + + Returns: + A list of CourseUserGroup objects. Empty if there are no cohorts. + """ + return list(CourseUserGroup.objects.filter(course_id=course_id, + group_type=CourseUserGroup.COHORT)) diff --git a/common/djangoapps/course_groups/tests/tests.py b/common/djangoapps/course_groups/tests/tests.py index 89c77c5b65..676643567d 100644 --- a/common/djangoapps/course_groups/tests/tests.py +++ b/common/djangoapps/course_groups/tests/tests.py @@ -1,21 +1,44 @@ +import django.test from django.contrib.auth.models import User -from nose.tools import assert_equals -from course_groups.models import CourseUserGroup, get_cohort +from course_groups.models import CourseUserGroup, get_cohort, get_course_cohorts -def test_get_cohort(): - course_id = "a/b/c" - cohort = CourseUserGroup.objects.create(name="TestCohort", course_id=course_id, - group_type=CourseUserGroup.COHORT) +class TestCohorts(django.test.TestCase): - user = User.objects.create(username="test", email="a@b.com") - other_user = User.objects.create(username="test2", email="a2@b.com") + def test_get_cohort(self): + course_id = "a/b/c" + cohort = CourseUserGroup.objects.create(name="TestCohort", course_id=course_id, + group_type=CourseUserGroup.COHORT) - cohort.users.add(user) + user = User.objects.create(username="test", email="a@b.com") + other_user = User.objects.create(username="test2", email="a2@b.com") - got = get_cohort(user, course_id) - assert_equals(got.id, cohort.id, "Should find the right cohort") + cohort.users.add(user) - got = get_cohort(other_user, course_id) - assert_equals(got, None, "other_user shouldn't have a cohort") + got = get_cohort(user, course_id) + self.assertEquals(got.id, cohort.id, "Should find the right cohort") + + got = get_cohort(other_user, course_id) + self.assertEquals(got, None, "other_user shouldn't have a cohort") + + + def test_get_course_cohorts(self): + course1_id = "a/b/c" + course2_id = "e/f/g" + + # add some cohorts to course 1 + cohort = CourseUserGroup.objects.create(name="TestCohort", + course_id=course1_id, + group_type=CourseUserGroup.COHORT) + + cohort = CourseUserGroup.objects.create(name="TestCohort2", + course_id=course1_id, + group_type=CourseUserGroup.COHORT) + + + # second course should have no cohorts + self.assertEqual(get_course_cohorts(course2_id), []) + + cohorts = sorted([c.name for c in get_course_cohorts(course1_id)]) + self.assertEqual(cohorts, ['TestCohort', 'TestCohort2'])