Correct team count for private team-sets in Teams tab (#24216)
* Hide private team-sets from users not on a team * Modify add team count to factor in team visibility * Fix bug that broke search w/in private team-sets
This commit is contained in:
@@ -6,11 +6,12 @@ The Python API other app should use to work with Teams feature
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import Count
|
||||
from django.db.models import Count, Q
|
||||
from opaque_keys import InvalidKeyError
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
|
||||
from course_modes.models import CourseMode
|
||||
from lms.djangoapps.courseware.courses import has_access
|
||||
from lms.djangoapps.discussion.django_comment_client.utils import has_discussion_privileges
|
||||
from lms.djangoapps.teams.models import CourseTeam, CourseTeamMembership
|
||||
from openedx.core.lib.teams_config import TeamsetType
|
||||
@@ -251,7 +252,7 @@ def user_protection_status_matches_team(user, team):
|
||||
return OrganizationProtectionStatus.unprotected == protection_status
|
||||
|
||||
|
||||
def get_team_count_query_set(topic_id_set, course_id, organization_protection_status):
|
||||
def _get_team_filter_query(topic_id_set, course_id, organization_protection_status):
|
||||
""" Helper function to get the team count query set based on the filters provided """
|
||||
|
||||
filter_query = {'course_id': course_id}
|
||||
@@ -264,16 +265,34 @@ def get_team_count_query_set(topic_id_set, course_id, organization_protection_st
|
||||
filter_query.update(
|
||||
{'organization_protected': organization_protection_status == OrganizationProtectionStatus.protected}
|
||||
)
|
||||
return CourseTeam.objects.filter(**filter_query)
|
||||
return filter_query
|
||||
|
||||
|
||||
def add_team_count(topics, course_id, organization_protection_status):
|
||||
def get_teams_accessible_by_user(user, topic_id_set, course_id, organization_protection_status):
|
||||
""" Get teams taking for a user, taking into account user visibility privileges """
|
||||
# Filter by topics, course, and protection status
|
||||
filter_query = _get_team_filter_query(topic_id_set, course_id, organization_protection_status)
|
||||
|
||||
# Staff gets unfiltered list of teams
|
||||
if has_access(user, 'staff', course_id):
|
||||
return CourseTeam.objects.filter(**filter_query)
|
||||
|
||||
# Private teams should be hidden unless the student is a member
|
||||
course_module = modulestore().get_course(course_id)
|
||||
private_teamset_ids = [ts.teamset_id for ts in course_module.teamsets if ts.is_private_managed]
|
||||
return CourseTeam.objects.filter(**filter_query).exclude(
|
||||
Q(topic_id__in=private_teamset_ids), ~Q(membership__user=user)
|
||||
)
|
||||
|
||||
|
||||
def add_team_count(user, topics, course_id, organization_protection_status):
|
||||
"""
|
||||
Helper method to add team_count for a list of topics.
|
||||
This allows for a more efficient single query.
|
||||
"""
|
||||
topic_ids = [topic['id'] for topic in topics]
|
||||
teams_query_set = get_team_count_query_set(
|
||||
teams_query_set = get_teams_accessible_by_user(
|
||||
user,
|
||||
topic_ids,
|
||||
course_id,
|
||||
organization_protection_status
|
||||
|
||||
@@ -11,7 +11,7 @@ from django.contrib.auth.models import User
|
||||
from django_countries import countries
|
||||
from rest_framework import serializers
|
||||
|
||||
from lms.djangoapps.teams.api import add_team_count, get_team_count_query_set
|
||||
from lms.djangoapps.teams.api import add_team_count, get_teams_accessible_by_user
|
||||
from lms.djangoapps.teams.models import CourseTeam, CourseTeamMembership
|
||||
from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer
|
||||
from openedx.core.lib.api.fields import ExpandableField
|
||||
@@ -194,7 +194,8 @@ class TopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method
|
||||
if 'team_count' in topic:
|
||||
return topic['team_count']
|
||||
else:
|
||||
return get_team_count_query_set(
|
||||
return get_teams_accessible_by_user(
|
||||
self.context.get('user'),
|
||||
[topic['id']],
|
||||
self.context['course_id'],
|
||||
self.context.get('organization_protection_status')
|
||||
@@ -209,7 +210,12 @@ class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: d
|
||||
def to_representation(self, obj): # pylint: disable=arguments-differ
|
||||
"""Adds team_count to each topic. """
|
||||
data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj)
|
||||
add_team_count(data, self.context['course_id'], self.context.get('organization_protection_status'))
|
||||
add_team_count(
|
||||
self.context['request'].user,
|
||||
data,
|
||||
self.context['course_id'],
|
||||
self.context.get('organization_protection_status')
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -351,7 +351,7 @@ class TeamAccessTests(SharedModuleStoreTestCase):
|
||||
('user_unenrolled', 3),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_team_counter_get_team_count_query_set(self, username, expected_count):
|
||||
def test_team_counter_get_teams_accessible_by_user(self, username, expected_count):
|
||||
user = self.users[username]
|
||||
try:
|
||||
organization_protection_status = teams_api.user_organization_protection_status(
|
||||
@@ -361,7 +361,8 @@ class TeamAccessTests(SharedModuleStoreTestCase):
|
||||
except ValueError:
|
||||
self.assertFalse(CourseEnrollment.is_enrolled(user, COURSE_KEY1))
|
||||
return
|
||||
teams_query_set = teams_api.get_team_count_query_set(
|
||||
teams_query_set = teams_api.get_teams_accessible_by_user(
|
||||
user,
|
||||
[self.topic_id],
|
||||
COURSE_KEY1,
|
||||
organization_protection_status
|
||||
@@ -392,6 +393,7 @@ class TeamAccessTests(SharedModuleStoreTestCase):
|
||||
'id': self.topic_id
|
||||
}
|
||||
teams_api.add_team_count(
|
||||
user,
|
||||
[topic],
|
||||
COURSE_KEY1,
|
||||
organization_protection_status
|
||||
|
||||
@@ -79,9 +79,9 @@ class TopicSerializerTestCase(SerializerTestCase):
|
||||
def test_topic_with_no_team_count(self):
|
||||
"""
|
||||
Verifies that the `TopicSerializer` correctly displays a topic with a
|
||||
team count of 0, and that it only takes one SQL query.
|
||||
team count of 0, and that it takes a known number of SQL queries.
|
||||
"""
|
||||
with self.assertNumQueries(1):
|
||||
with self.assertNumQueries(2):
|
||||
serializer = TopicSerializer(
|
||||
self.course.teamsets[0].cleaned_data,
|
||||
context={'course_id': self.course.id},
|
||||
@@ -101,12 +101,12 @@ class TopicSerializerTestCase(SerializerTestCase):
|
||||
def test_topic_with_team_count(self):
|
||||
"""
|
||||
Verifies that the `TopicSerializer` correctly displays a topic with a
|
||||
positive team count, and that it only takes one SQL query.
|
||||
positive team count, and that it takes a known number of SQL queries.
|
||||
"""
|
||||
CourseTeamFactory.create(
|
||||
course_id=self.course.id, topic_id=self.course.teamsets[0].teamset_id
|
||||
)
|
||||
with self.assertNumQueries(1):
|
||||
with self.assertNumQueries(2):
|
||||
serializer = TopicSerializer(
|
||||
self.course.teamsets[0].cleaned_data,
|
||||
context={'course_id': self.course.id},
|
||||
@@ -134,7 +134,7 @@ class TopicSerializerTestCase(SerializerTestCase):
|
||||
)
|
||||
CourseTeamFactory.create(course_id=self.course.id, topic_id=duplicate_topic[u'id'])
|
||||
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
|
||||
with self.assertNumQueries(1):
|
||||
with self.assertNumQueries(2):
|
||||
serializer = TopicSerializer(
|
||||
self.course.teamsets[0].cleaned_data,
|
||||
context={'course_id': self.course.id},
|
||||
@@ -227,13 +227,18 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
|
||||
NUM_TOPICS = 6
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.user = UserFactory.create()
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
|
||||
|
||||
def test_topics_with_no_team_counts(self):
|
||||
"""
|
||||
Verify that we serialize topics with no team count, making only one SQL
|
||||
query.
|
||||
"""
|
||||
topics = self.setup_topics(teams_per_topic=0)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=0, num_queries=1)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=0, num_queries=2)
|
||||
|
||||
def test_topics_with_team_counts(self):
|
||||
"""
|
||||
@@ -242,7 +247,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
"""
|
||||
teams_per_topic = 10
|
||||
topics = self.setup_topics(teams_per_topic=teams_per_topic)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=2)
|
||||
|
||||
def test_subset_of_topics(self):
|
||||
"""
|
||||
@@ -251,7 +256,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
"""
|
||||
teams_per_topic = 10
|
||||
topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=2)
|
||||
|
||||
def test_scoped_within_course(self):
|
||||
"""Verify that team counts are scoped within a course."""
|
||||
@@ -265,7 +270,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
}),
|
||||
)
|
||||
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
|
||||
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=2)
|
||||
|
||||
def _merge_dicts(self, first, second):
|
||||
"""Convenience method to merge two dicts in a single expression"""
|
||||
@@ -277,8 +282,19 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
"""
|
||||
Verify that the serializer produced the expected topics.
|
||||
"""
|
||||
# Set a request user
|
||||
request = RequestFactory().get('/api/team/v0/topics')
|
||||
request.user = self.user
|
||||
|
||||
with self.assertNumQueries(num_queries):
|
||||
serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True)
|
||||
serializer = self.serializer(
|
||||
topics,
|
||||
context={
|
||||
'course_id': self.course.id,
|
||||
'request': request
|
||||
},
|
||||
many=True
|
||||
)
|
||||
self.assertEqual(
|
||||
serializer.data,
|
||||
[self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics]
|
||||
@@ -290,4 +306,4 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
with no topics.
|
||||
"""
|
||||
self.course.teams_configuration = TeamsConfig({'topics': []})
|
||||
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
|
||||
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=1)
|
||||
|
||||
@@ -242,10 +242,58 @@ class TestDashboard(SharedModuleStoreTestCase):
|
||||
|
||||
expected_has_open = "hasOpenTopic: " + "true" if has_open else "false"
|
||||
expected_has_public = "hasPublicManagedTopic: " + "true" if has_public else "false"
|
||||
expected_has_managed = "hasManagedTopic: " + "true" if has_public or has_private else "false"
|
||||
|
||||
self.assertContains(response, expected_has_open)
|
||||
self.assertContains(response, expected_has_public)
|
||||
|
||||
@ddt.unpack
|
||||
@ddt.data(
|
||||
(True, False, False),
|
||||
(False, True, False),
|
||||
(False, False, True),
|
||||
(True, True, True),
|
||||
(False, True, True),
|
||||
)
|
||||
def test_has_managed_topic(self, has_open, has_private, has_public):
|
||||
topics = []
|
||||
if has_open:
|
||||
topics.append({
|
||||
"name": "test topic 1",
|
||||
"id": 1,
|
||||
"description": "Desc1",
|
||||
"type": "open"
|
||||
})
|
||||
if has_private:
|
||||
topics.append({
|
||||
"name": "test topic 2",
|
||||
"id": 2,
|
||||
"description": "Desc2",
|
||||
"type": "private_managed"
|
||||
})
|
||||
if has_public:
|
||||
topics.append({
|
||||
"name": "test topic 3",
|
||||
"id": 3,
|
||||
"description": "Desc3",
|
||||
"type": "public_managed"
|
||||
})
|
||||
|
||||
# Given a staff user browsing the teams tab
|
||||
course = CourseFactory.create(
|
||||
teams_configuration=TeamsConfig({"topics": topics})
|
||||
)
|
||||
teams_url = reverse('teams_dashboard', args=[course.id])
|
||||
|
||||
staff_user = UserFactory(is_staff=True, password=self.test_password)
|
||||
staff_client = APIClient()
|
||||
staff_client.login(username=staff_user.username, password=self.test_password)
|
||||
|
||||
# When I browse to the team tab
|
||||
response = staff_client.get(teams_url)
|
||||
|
||||
# Then "hasManagedTopic" (which is used to show the "Manage" tab)
|
||||
# is shown if there are managed team-sets
|
||||
expected_has_managed = "hasManagedTopic: " + "true" if has_public or has_private else "false"
|
||||
self.assertContains(response, expected_has_managed)
|
||||
|
||||
|
||||
@@ -1725,7 +1773,7 @@ class TestListTopicsAPI(TeamAPITestCase):
|
||||
data = {'course_id': str(self.test_course_1.id)}
|
||||
if field:
|
||||
data['order_by'] = field
|
||||
topics = self.get_topics_list(status, data)
|
||||
topics = self.get_topics_list(status, data, user='student_enrolled')
|
||||
if status == 200:
|
||||
self.assertEqual(names, [topic['name'] for topic in topics['results']])
|
||||
self.assertEqual(topics['sort_order'], expected_ordering)
|
||||
@@ -1758,28 +1806,37 @@ class TestListTopicsAPI(TeamAPITestCase):
|
||||
)
|
||||
|
||||
# Wind power has the most teams, followed by Solar
|
||||
topics = self.get_topics_list(data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
'page': 1,
|
||||
'order_by': 'team_count'
|
||||
})
|
||||
topics = self.get_topics_list(
|
||||
data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
'page': 1,
|
||||
'order_by': 'team_count'
|
||||
},
|
||||
user='student_enrolled'
|
||||
)
|
||||
self.assertEqual(["Wind Power", u'Sólar power'], [topic['name'] for topic in topics['results']])
|
||||
|
||||
# Coal and Nuclear are tied, so they are alphabetically sorted.
|
||||
topics = self.get_topics_list(data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
'page': 2,
|
||||
'order_by': 'team_count'
|
||||
})
|
||||
topics = self.get_topics_list(
|
||||
data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
'page': 2,
|
||||
'order_by': 'team_count'
|
||||
},
|
||||
user='student_enrolled'
|
||||
)
|
||||
self.assertEqual(["Coal Power", "Nuclear Power"], [topic['name'] for topic in topics['results']])
|
||||
|
||||
def test_pagination(self):
|
||||
response = self.get_topics_list(data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
})
|
||||
response = self.get_topics_list(
|
||||
data={
|
||||
'course_id': str(self.test_course_1.id),
|
||||
'page_size': 2,
|
||||
},
|
||||
user='student_enrolled'
|
||||
)
|
||||
|
||||
self.assertEqual(2, len(response['results']))
|
||||
self.assertIn('next', response)
|
||||
@@ -1793,7 +1850,10 @@ class TestListTopicsAPI(TeamAPITestCase):
|
||||
|
||||
def test_team_count(self):
|
||||
"""Test that team_count is included for each topic"""
|
||||
response = self.get_topics_list(data={'course_id': str(self.test_course_1.id)})
|
||||
response = self.get_topics_list(
|
||||
data={'course_id': str(self.test_course_1.id)},
|
||||
user='student_enrolled'
|
||||
)
|
||||
for topic in response['results']:
|
||||
self.assertIn('team_count', topic)
|
||||
if topic['id'] in ('topic_0', 'topic_1', 'topic_2'):
|
||||
@@ -1827,13 +1887,13 @@ class TestListTopicsAPI(TeamAPITestCase):
|
||||
|
||||
@ddt.unpack
|
||||
@ddt.data(
|
||||
('student_on_team_1_private_set_1', 2),
|
||||
('student_on_team_2_private_set_1', 2),
|
||||
('student_on_team_1_private_set_1', 1),
|
||||
('student_on_team_2_private_set_1', 1),
|
||||
('staff', 2)
|
||||
)
|
||||
def test_private_teamset_team_count(self, requesting_user, expected_team_count):
|
||||
"""
|
||||
TODO: the two students should probably not see that there's another team that they don't see
|
||||
Students should only see teams they are members of in private team-sets
|
||||
"""
|
||||
topics = self.get_topics_list(
|
||||
data={'course_id': str(self.test_course_1.id)},
|
||||
@@ -1887,8 +1947,8 @@ class TestDetailTopicAPI(TeamAPITestCase):
|
||||
@ddt.unpack
|
||||
@ddt.data(
|
||||
('student_enrolled', 404, None),
|
||||
('student_on_team_1_private_set_1', 200, 2),
|
||||
('student_on_team_2_private_set_1', 200, 2),
|
||||
('student_on_team_1_private_set_1', 200, 1),
|
||||
('student_on_team_2_private_set_1', 200, 1),
|
||||
('student_masters', 404, None),
|
||||
('staff', 200, 2)
|
||||
)
|
||||
|
||||
@@ -139,6 +139,7 @@ class TeamsDashboardView(GenericAPIView):
|
||||
# to the serializer so that the paginated results indicate how they were sorted.
|
||||
sort_order = 'name'
|
||||
topics = get_alphabetical_topics(course)
|
||||
topics = _filter_hidden_private_teamsets(user, topics, course)
|
||||
organization_protection_status = user_organization_protection_status(request.user, course_key)
|
||||
|
||||
# We have some frontend logic that needs to know if we have any open, public, or managed teamsets,
|
||||
@@ -443,10 +444,6 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
)
|
||||
return Response(error, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if course_module.teamsets_by_id[topic_id].is_private_managed \
|
||||
and not has_access(request.user, 'staff', course_key):
|
||||
result_filter.update({'membership__user__username': request.user})
|
||||
|
||||
result_filter.update({'topic_id': topic_id})
|
||||
|
||||
organization_protection_status = user_organization_protection_status(
|
||||
@@ -1020,15 +1017,15 @@ class TopicListView(GenericAPIView):
|
||||
# in the case of "team_count".
|
||||
organization_protection_status = user_organization_protection_status(request.user, course_id)
|
||||
topics = get_alphabetical_topics(course_module)
|
||||
topics = self._filter_hidden_private_teamsets(topics, course_module)
|
||||
topics = _filter_hidden_private_teamsets(request.user, topics, course_module)
|
||||
|
||||
if ordering == 'team_count':
|
||||
add_team_count(topics, course_id, organization_protection_status)
|
||||
add_team_count(request.user, topics, course_id, organization_protection_status)
|
||||
topics.sort(key=lambda t: t['team_count'], reverse=True)
|
||||
page = self.paginate_queryset(topics)
|
||||
serializer = TopicSerializer(
|
||||
page,
|
||||
context={'course_id': course_id},
|
||||
context={'course_id': course_id, 'user': request.user},
|
||||
many=True,
|
||||
)
|
||||
else:
|
||||
@@ -1037,6 +1034,7 @@ class TopicListView(GenericAPIView):
|
||||
serializer = BulkTeamCountTopicSerializer(
|
||||
page,
|
||||
context={
|
||||
'request': request,
|
||||
'course_id': course_id,
|
||||
'organization_protection_status': organization_protection_status
|
||||
},
|
||||
@@ -1048,25 +1046,26 @@ class TopicListView(GenericAPIView):
|
||||
|
||||
return response
|
||||
|
||||
def _filter_hidden_private_teamsets(self, teamsets, course_module):
|
||||
"""
|
||||
Return a filtered list of teamsets, removing any private teamsets that a user doesn't have access to.
|
||||
Follows the same logic as `has_specific_teamset_access` but in bulk rather than for one teamset at a time
|
||||
"""
|
||||
if has_course_staff_privileges(self.request.user, course_module.id):
|
||||
return teamsets
|
||||
private_teamset_ids = [teamset.teamset_id for teamset in course_module.teamsets if teamset.is_private_managed]
|
||||
teamset_ids_user_has_access_to = set(
|
||||
CourseTeam.objects.filter(
|
||||
course_id=course_module.id,
|
||||
topic_id__in=private_teamset_ids,
|
||||
membership__user=self.request.user
|
||||
).values_list('topic_id', flat=True)
|
||||
)
|
||||
return [
|
||||
teamset for teamset in teamsets
|
||||
if teamset['type'] != TeamsetType.private_managed.value or teamset['id'] in teamset_ids_user_has_access_to
|
||||
]
|
||||
|
||||
def _filter_hidden_private_teamsets(user, teamsets, course_module):
|
||||
"""
|
||||
Return a filtered list of teamsets, removing any private teamsets that a user doesn't have access to.
|
||||
Follows the same logic as `has_specific_teamset_access` but in bulk rather than for one teamset at a time
|
||||
"""
|
||||
if has_course_staff_privileges(user, course_module.id):
|
||||
return teamsets
|
||||
private_teamset_ids = [teamset.teamset_id for teamset in course_module.teamsets if teamset.is_private_managed]
|
||||
teamset_ids_user_has_access_to = set(
|
||||
CourseTeam.objects.filter(
|
||||
course_id=course_module.id,
|
||||
topic_id__in=private_teamset_ids,
|
||||
membership__user=user
|
||||
).values_list('topic_id', flat=True)
|
||||
)
|
||||
return [
|
||||
teamset for teamset in teamsets
|
||||
if teamset['type'] != TeamsetType.private_managed.value or teamset['id'] in teamset_ids_user_has_access_to
|
||||
]
|
||||
|
||||
|
||||
def get_alphabetical_topics(course_module):
|
||||
@@ -1159,7 +1158,8 @@ class TopicDetailView(APIView):
|
||||
topic.cleaned_data,
|
||||
context={
|
||||
'course_id': course_id,
|
||||
'organization_protection_status': organization_protection_status
|
||||
'organization_protection_status': organization_protection_status,
|
||||
'user': request.user
|
||||
}
|
||||
)
|
||||
return Response(serializer.data)
|
||||
|
||||
Reference in New Issue
Block a user