Correct team count for private team-sets in Teams tab (#24216)

* Hide private team-sets from users not on a team

* Modify add team count to factor in team visibility

* Fix bug that broke search w/in private team-sets
This commit is contained in:
Nathan Sprenkle
2020-06-18 15:21:58 -04:00
committed by GitHub
parent c51cc3705b
commit 5cedc64f41
6 changed files with 175 additions and 72 deletions

View File

@@ -6,11 +6,12 @@ The Python API other app should use to work with Teams feature
import logging
from enum import Enum
from django.db.models import Count
from django.db.models import Count, Q
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
from course_modes.models import CourseMode
from lms.djangoapps.courseware.courses import has_access
from lms.djangoapps.discussion.django_comment_client.utils import has_discussion_privileges
from lms.djangoapps.teams.models import CourseTeam, CourseTeamMembership
from openedx.core.lib.teams_config import TeamsetType
@@ -251,7 +252,7 @@ def user_protection_status_matches_team(user, team):
return OrganizationProtectionStatus.unprotected == protection_status
def get_team_count_query_set(topic_id_set, course_id, organization_protection_status):
def _get_team_filter_query(topic_id_set, course_id, organization_protection_status):
""" Helper function to get the team count query set based on the filters provided """
filter_query = {'course_id': course_id}
@@ -264,16 +265,34 @@ def get_team_count_query_set(topic_id_set, course_id, organization_protection_st
filter_query.update(
{'organization_protected': organization_protection_status == OrganizationProtectionStatus.protected}
)
return CourseTeam.objects.filter(**filter_query)
return filter_query
def add_team_count(topics, course_id, organization_protection_status):
def get_teams_accessible_by_user(user, topic_id_set, course_id, organization_protection_status):
""" Get teams taking for a user, taking into account user visibility privileges """
# Filter by topics, course, and protection status
filter_query = _get_team_filter_query(topic_id_set, course_id, organization_protection_status)
# Staff gets unfiltered list of teams
if has_access(user, 'staff', course_id):
return CourseTeam.objects.filter(**filter_query)
# Private teams should be hidden unless the student is a member
course_module = modulestore().get_course(course_id)
private_teamset_ids = [ts.teamset_id for ts in course_module.teamsets if ts.is_private_managed]
return CourseTeam.objects.filter(**filter_query).exclude(
Q(topic_id__in=private_teamset_ids), ~Q(membership__user=user)
)
def add_team_count(user, topics, course_id, organization_protection_status):
"""
Helper method to add team_count for a list of topics.
This allows for a more efficient single query.
"""
topic_ids = [topic['id'] for topic in topics]
teams_query_set = get_team_count_query_set(
teams_query_set = get_teams_accessible_by_user(
user,
topic_ids,
course_id,
organization_protection_status

View File

@@ -11,7 +11,7 @@ from django.contrib.auth.models import User
from django_countries import countries
from rest_framework import serializers
from lms.djangoapps.teams.api import add_team_count, get_team_count_query_set
from lms.djangoapps.teams.api import add_team_count, get_teams_accessible_by_user
from lms.djangoapps.teams.models import CourseTeam, CourseTeamMembership
from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer
from openedx.core.lib.api.fields import ExpandableField
@@ -194,7 +194,8 @@ class TopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method
if 'team_count' in topic:
return topic['team_count']
else:
return get_team_count_query_set(
return get_teams_accessible_by_user(
self.context.get('user'),
[topic['id']],
self.context['course_id'],
self.context.get('organization_protection_status')
@@ -209,7 +210,12 @@ class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: d
def to_representation(self, obj): # pylint: disable=arguments-differ
"""Adds team_count to each topic. """
data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj)
add_team_count(data, self.context['course_id'], self.context.get('organization_protection_status'))
add_team_count(
self.context['request'].user,
data,
self.context['course_id'],
self.context.get('organization_protection_status')
)
return data

View File

@@ -351,7 +351,7 @@ class TeamAccessTests(SharedModuleStoreTestCase):
('user_unenrolled', 3),
)
@ddt.unpack
def test_team_counter_get_team_count_query_set(self, username, expected_count):
def test_team_counter_get_teams_accessible_by_user(self, username, expected_count):
user = self.users[username]
try:
organization_protection_status = teams_api.user_organization_protection_status(
@@ -361,7 +361,8 @@ class TeamAccessTests(SharedModuleStoreTestCase):
except ValueError:
self.assertFalse(CourseEnrollment.is_enrolled(user, COURSE_KEY1))
return
teams_query_set = teams_api.get_team_count_query_set(
teams_query_set = teams_api.get_teams_accessible_by_user(
user,
[self.topic_id],
COURSE_KEY1,
organization_protection_status
@@ -392,6 +393,7 @@ class TeamAccessTests(SharedModuleStoreTestCase):
'id': self.topic_id
}
teams_api.add_team_count(
user,
[topic],
COURSE_KEY1,
organization_protection_status

View File

@@ -79,9 +79,9 @@ class TopicSerializerTestCase(SerializerTestCase):
def test_topic_with_no_team_count(self):
"""
Verifies that the `TopicSerializer` correctly displays a topic with a
team count of 0, and that it only takes one SQL query.
team count of 0, and that it takes a known number of SQL queries.
"""
with self.assertNumQueries(1):
with self.assertNumQueries(2):
serializer = TopicSerializer(
self.course.teamsets[0].cleaned_data,
context={'course_id': self.course.id},
@@ -101,12 +101,12 @@ class TopicSerializerTestCase(SerializerTestCase):
def test_topic_with_team_count(self):
"""
Verifies that the `TopicSerializer` correctly displays a topic with a
positive team count, and that it only takes one SQL query.
positive team count, and that it takes a known number of SQL queries.
"""
CourseTeamFactory.create(
course_id=self.course.id, topic_id=self.course.teamsets[0].teamset_id
)
with self.assertNumQueries(1):
with self.assertNumQueries(2):
serializer = TopicSerializer(
self.course.teamsets[0].cleaned_data,
context={'course_id': self.course.id},
@@ -134,7 +134,7 @@ class TopicSerializerTestCase(SerializerTestCase):
)
CourseTeamFactory.create(course_id=self.course.id, topic_id=duplicate_topic[u'id'])
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
with self.assertNumQueries(1):
with self.assertNumQueries(2):
serializer = TopicSerializer(
self.course.teamsets[0].cleaned_data,
context={'course_id': self.course.id},
@@ -227,13 +227,18 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
NUM_TOPICS = 6
def setUp(self):
super().setUp()
self.user = UserFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
def test_topics_with_no_team_counts(self):
"""
Verify that we serialize topics with no team count, making only one SQL
query.
"""
topics = self.setup_topics(teams_per_topic=0)
self.assert_serializer_output(topics, num_teams_per_topic=0, num_queries=1)
self.assert_serializer_output(topics, num_teams_per_topic=0, num_queries=2)
def test_topics_with_team_counts(self):
"""
@@ -242,7 +247,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
"""
teams_per_topic = 10
topics = self.setup_topics(teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=2)
def test_subset_of_topics(self):
"""
@@ -251,7 +256,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
"""
teams_per_topic = 10
topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=2)
def test_scoped_within_course(self):
"""Verify that team counts are scoped within a course."""
@@ -265,7 +270,7 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
}),
)
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1)
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=2)
def _merge_dicts(self, first, second):
"""Convenience method to merge two dicts in a single expression"""
@@ -277,8 +282,19 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
"""
Verify that the serializer produced the expected topics.
"""
# Set a request user
request = RequestFactory().get('/api/team/v0/topics')
request.user = self.user
with self.assertNumQueries(num_queries):
serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True)
serializer = self.serializer(
topics,
context={
'course_id': self.course.id,
'request': request
},
many=True
)
self.assertEqual(
serializer.data,
[self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics]
@@ -290,4 +306,4 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
with no topics.
"""
self.course.teams_configuration = TeamsConfig({'topics': []})
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=1)

View File

@@ -242,10 +242,58 @@ class TestDashboard(SharedModuleStoreTestCase):
expected_has_open = "hasOpenTopic: " + "true" if has_open else "false"
expected_has_public = "hasPublicManagedTopic: " + "true" if has_public else "false"
expected_has_managed = "hasManagedTopic: " + "true" if has_public or has_private else "false"
self.assertContains(response, expected_has_open)
self.assertContains(response, expected_has_public)
@ddt.unpack
@ddt.data(
(True, False, False),
(False, True, False),
(False, False, True),
(True, True, True),
(False, True, True),
)
def test_has_managed_topic(self, has_open, has_private, has_public):
topics = []
if has_open:
topics.append({
"name": "test topic 1",
"id": 1,
"description": "Desc1",
"type": "open"
})
if has_private:
topics.append({
"name": "test topic 2",
"id": 2,
"description": "Desc2",
"type": "private_managed"
})
if has_public:
topics.append({
"name": "test topic 3",
"id": 3,
"description": "Desc3",
"type": "public_managed"
})
# Given a staff user browsing the teams tab
course = CourseFactory.create(
teams_configuration=TeamsConfig({"topics": topics})
)
teams_url = reverse('teams_dashboard', args=[course.id])
staff_user = UserFactory(is_staff=True, password=self.test_password)
staff_client = APIClient()
staff_client.login(username=staff_user.username, password=self.test_password)
# When I browse to the team tab
response = staff_client.get(teams_url)
# Then "hasManagedTopic" (which is used to show the "Manage" tab)
# is shown if there are managed team-sets
expected_has_managed = "hasManagedTopic: " + "true" if has_public or has_private else "false"
self.assertContains(response, expected_has_managed)
@@ -1725,7 +1773,7 @@ class TestListTopicsAPI(TeamAPITestCase):
data = {'course_id': str(self.test_course_1.id)}
if field:
data['order_by'] = field
topics = self.get_topics_list(status, data)
topics = self.get_topics_list(status, data, user='student_enrolled')
if status == 200:
self.assertEqual(names, [topic['name'] for topic in topics['results']])
self.assertEqual(topics['sort_order'], expected_ordering)
@@ -1758,28 +1806,37 @@ class TestListTopicsAPI(TeamAPITestCase):
)
# Wind power has the most teams, followed by Solar
topics = self.get_topics_list(data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
'page': 1,
'order_by': 'team_count'
})
topics = self.get_topics_list(
data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
'page': 1,
'order_by': 'team_count'
},
user='student_enrolled'
)
self.assertEqual(["Wind Power", u'Sólar power'], [topic['name'] for topic in topics['results']])
# Coal and Nuclear are tied, so they are alphabetically sorted.
topics = self.get_topics_list(data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
'page': 2,
'order_by': 'team_count'
})
topics = self.get_topics_list(
data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
'page': 2,
'order_by': 'team_count'
},
user='student_enrolled'
)
self.assertEqual(["Coal Power", "Nuclear Power"], [topic['name'] for topic in topics['results']])
def test_pagination(self):
response = self.get_topics_list(data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
})
response = self.get_topics_list(
data={
'course_id': str(self.test_course_1.id),
'page_size': 2,
},
user='student_enrolled'
)
self.assertEqual(2, len(response['results']))
self.assertIn('next', response)
@@ -1793,7 +1850,10 @@ class TestListTopicsAPI(TeamAPITestCase):
def test_team_count(self):
"""Test that team_count is included for each topic"""
response = self.get_topics_list(data={'course_id': str(self.test_course_1.id)})
response = self.get_topics_list(
data={'course_id': str(self.test_course_1.id)},
user='student_enrolled'
)
for topic in response['results']:
self.assertIn('team_count', topic)
if topic['id'] in ('topic_0', 'topic_1', 'topic_2'):
@@ -1827,13 +1887,13 @@ class TestListTopicsAPI(TeamAPITestCase):
@ddt.unpack
@ddt.data(
('student_on_team_1_private_set_1', 2),
('student_on_team_2_private_set_1', 2),
('student_on_team_1_private_set_1', 1),
('student_on_team_2_private_set_1', 1),
('staff', 2)
)
def test_private_teamset_team_count(self, requesting_user, expected_team_count):
"""
TODO: the two students should probably not see that there's another team that they don't see
Students should only see teams they are members of in private team-sets
"""
topics = self.get_topics_list(
data={'course_id': str(self.test_course_1.id)},
@@ -1887,8 +1947,8 @@ class TestDetailTopicAPI(TeamAPITestCase):
@ddt.unpack
@ddt.data(
('student_enrolled', 404, None),
('student_on_team_1_private_set_1', 200, 2),
('student_on_team_2_private_set_1', 200, 2),
('student_on_team_1_private_set_1', 200, 1),
('student_on_team_2_private_set_1', 200, 1),
('student_masters', 404, None),
('staff', 200, 2)
)

View File

@@ -139,6 +139,7 @@ class TeamsDashboardView(GenericAPIView):
# to the serializer so that the paginated results indicate how they were sorted.
sort_order = 'name'
topics = get_alphabetical_topics(course)
topics = _filter_hidden_private_teamsets(user, topics, course)
organization_protection_status = user_organization_protection_status(request.user, course_key)
# We have some frontend logic that needs to know if we have any open, public, or managed teamsets,
@@ -443,10 +444,6 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
)
return Response(error, status=status.HTTP_400_BAD_REQUEST)
if course_module.teamsets_by_id[topic_id].is_private_managed \
and not has_access(request.user, 'staff', course_key):
result_filter.update({'membership__user__username': request.user})
result_filter.update({'topic_id': topic_id})
organization_protection_status = user_organization_protection_status(
@@ -1020,15 +1017,15 @@ class TopicListView(GenericAPIView):
# in the case of "team_count".
organization_protection_status = user_organization_protection_status(request.user, course_id)
topics = get_alphabetical_topics(course_module)
topics = self._filter_hidden_private_teamsets(topics, course_module)
topics = _filter_hidden_private_teamsets(request.user, topics, course_module)
if ordering == 'team_count':
add_team_count(topics, course_id, organization_protection_status)
add_team_count(request.user, topics, course_id, organization_protection_status)
topics.sort(key=lambda t: t['team_count'], reverse=True)
page = self.paginate_queryset(topics)
serializer = TopicSerializer(
page,
context={'course_id': course_id},
context={'course_id': course_id, 'user': request.user},
many=True,
)
else:
@@ -1037,6 +1034,7 @@ class TopicListView(GenericAPIView):
serializer = BulkTeamCountTopicSerializer(
page,
context={
'request': request,
'course_id': course_id,
'organization_protection_status': organization_protection_status
},
@@ -1048,25 +1046,26 @@ class TopicListView(GenericAPIView):
return response
def _filter_hidden_private_teamsets(self, teamsets, course_module):
"""
Return a filtered list of teamsets, removing any private teamsets that a user doesn't have access to.
Follows the same logic as `has_specific_teamset_access` but in bulk rather than for one teamset at a time
"""
if has_course_staff_privileges(self.request.user, course_module.id):
return teamsets
private_teamset_ids = [teamset.teamset_id for teamset in course_module.teamsets if teamset.is_private_managed]
teamset_ids_user_has_access_to = set(
CourseTeam.objects.filter(
course_id=course_module.id,
topic_id__in=private_teamset_ids,
membership__user=self.request.user
).values_list('topic_id', flat=True)
)
return [
teamset for teamset in teamsets
if teamset['type'] != TeamsetType.private_managed.value or teamset['id'] in teamset_ids_user_has_access_to
]
def _filter_hidden_private_teamsets(user, teamsets, course_module):
"""
Return a filtered list of teamsets, removing any private teamsets that a user doesn't have access to.
Follows the same logic as `has_specific_teamset_access` but in bulk rather than for one teamset at a time
"""
if has_course_staff_privileges(user, course_module.id):
return teamsets
private_teamset_ids = [teamset.teamset_id for teamset in course_module.teamsets if teamset.is_private_managed]
teamset_ids_user_has_access_to = set(
CourseTeam.objects.filter(
course_id=course_module.id,
topic_id__in=private_teamset_ids,
membership__user=user
).values_list('topic_id', flat=True)
)
return [
teamset for teamset in teamsets
if teamset['type'] != TeamsetType.private_managed.value or teamset['id'] in teamset_ids_user_has_access_to
]
def get_alphabetical_topics(course_module):
@@ -1159,7 +1158,8 @@ class TopicDetailView(APIView):
topic.cleaned_data,
context={
'course_id': course_id,
'organization_protection_status': organization_protection_status
'organization_protection_status': organization_protection_status,
'user': request.user
}
)
return Response(serializer.data)