diff --git a/common/djangoapps/student/api.py b/common/djangoapps/student/api.py index 2bf42f4828..c5bd7b8618 100644 --- a/common/djangoapps/student/api.py +++ b/common/djangoapps/student/api.py @@ -4,6 +4,7 @@ Python APIs exposed by the student app to other in-process apps. """ +from typing import TYPE_CHECKING import logging from django.contrib.auth import get_user_model @@ -32,6 +33,10 @@ from common.djangoapps.student.roles import ( ) from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers +if TYPE_CHECKING: + from django.contrib.auth.models import AnonymousUser, User # pylint: disable=imported-auth-user + from django.db.models.query import QuerySet + # This is done so that if these strings change within the app, we can keep exported constants the same ENROLLED_TO_ENROLLED = _ENROLLED_TO_ENROLLED @@ -92,13 +97,7 @@ def create_manual_enrollment_audit( else: enrollment = None - _create_manual_enrollment_audit( - enrolled_by, - user_email, - transition_state, - reason, - enrollment - ) + _create_manual_enrollment_audit(enrolled_by, user_email, transition_state, reason, enrollment) def get_access_role_by_role_name(role_name): @@ -132,7 +131,31 @@ def is_user_staff_or_instructor_in_course(user, course_key): course_key = CourseKey.from_string(course_key) return ( - GlobalStaff().has_user(user) or - CourseStaffRole(course_key).has_user(user) or - CourseInstructorRole(course_key).has_user(user) + GlobalStaff().has_user(user) + or CourseStaffRole(course_key).has_user(user) + or CourseInstructorRole(course_key).has_user(user) ) + + +def get_course_enrollments( + user: "AnonymousUser | User", + is_filtered: bool = False, + course_ids: list[str | None] | None = None, +) -> "QuerySet[CourseEnrollment]": + """ + Return enrollments for a user, potentially filtered by course_id. + + Because an empty `course_ids` value is a meaningful filter, the easiest way to verify + that the list should be filtered intentionally is to specify `is_filtered`. + + Arguments: + + * is_filtered (bool): whether or not the list is filtered + * course_ids (list): a list of course IDs to filter by. + """ + course_enrollments = CourseEnrollment.enrollments_for_user(user).select_related("course") + + if is_filtered: + course_enrollments = course_enrollments.filter(course_id__in=course_ids) + + return course_enrollments diff --git a/common/djangoapps/student/tests/test_api.py b/common/djangoapps/student/tests/test_api.py index ad462830a1..7cb20380cf 100644 --- a/common/djangoapps/student/tests/test_api.py +++ b/common/djangoapps/student/tests/test_api.py @@ -1,10 +1,16 @@ """ Test Student api.py """ + from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory -from common.djangoapps.student.api import is_user_enrolled_in_course, is_user_staff_or_instructor_in_course +from common.djangoapps.student.api import ( + is_user_enrolled_in_course, + is_user_staff_or_instructor_in_course, + get_course_enrollments, +) +from common.djangoapps.student.models import CourseEnrollment from common.djangoapps.student.tests.factories import ( CourseEnrollmentFactory, GlobalStaffFactory, @@ -33,10 +39,7 @@ class TestStudentApi(SharedModuleStoreTestCase): """ Verify the correct value is returned when a learner is actively enrolled in a course-run. """ - CourseEnrollmentFactory.create( - user_id=self.user.id, - course_id=self.course.id - ) + CourseEnrollmentFactory.create(user_id=self.user.id, course_id=self.course.id) result = is_user_enrolled_in_course(self.user, self.course_run_key) assert result @@ -45,11 +48,7 @@ class TestStudentApi(SharedModuleStoreTestCase): """ Verify the correct value is returned when a learner is not actively enrolled in a course-run. """ - CourseEnrollmentFactory.create( - user_id=self.user.id, - course_id=self.course.id, - is_active=False - ) + CourseEnrollmentFactory.create(user_id=self.user.id, course_id=self.course.id, is_active=False) result = is_user_enrolled_in_course(self.user, self.course_run_key) assert not result @@ -79,3 +78,25 @@ class TestStudentApi(SharedModuleStoreTestCase): assert is_user_staff_or_instructor_in_course(instructor, self.course_run_key) assert not is_user_staff_or_instructor_in_course(self.user, self.course_run_key) assert not is_user_staff_or_instructor_in_course(instructor_different_course, self.course_run_key) + + def test_get_course_enrollments(self): + """Verify all enrollments can be retrieved""" + course_2 = CourseFactory.create() + CourseEnrollmentFactory.create(user_id=self.user.id, course_id=self.course.id) + CourseEnrollmentFactory.create(user_id=self.user.id, course_id=course_2.id) + expected = CourseEnrollment.objects.all() + + result = get_course_enrollments(self.user) + + self.assertQuerySetEqual(expected, result) + + def test_get_filtered_course_enrollments(self): + """Verify a filtered subset of enrollments can be retrieved""" + course_2 = CourseFactory.create() + CourseEnrollmentFactory.create(user_id=self.user.id, course_id=self.course.id) + ce_2 = CourseEnrollmentFactory.create(user_id=self.user.id, course_id=course_2.id) + expected = CourseEnrollment.objects.filter(id=ce_2.id) + + result = get_course_enrollments(self.user, True, course_ids=[course_2.id]) + + self.assertQuerySetEqual(expected, result) diff --git a/openedx/core/djangoapps/programs/rest_api/v1/tests/test_views.py b/openedx/core/djangoapps/programs/rest_api/v1/tests/test_views.py index e80d4c615a..2864f41a92 100644 --- a/openedx/core/djangoapps/programs/rest_api/v1/tests/test_views.py +++ b/openedx/core/djangoapps/programs/rest_api/v1/tests/test_views.py @@ -1,11 +1,12 @@ """ -Unit tests for Learner Dashboard REST APIs and Views +Unit tests for Programs REST APIs and Views """ from unittest import mock from uuid import uuid4 from django.core.cache import cache +from django.test.utils import override_settings from django.urls import reverse_lazy from enterprise.models import EnterpriseCourseEnrollment @@ -31,6 +32,7 @@ from openedx.core.djangoapps.site_configuration.tests.test_util import ( with_site_configuration, ) from openedx.core.djangolib.testing.utils import skip_unless_lms +from openedx.features.enterprise_support.api import enterprise_is_enabled from openedx.features.enterprise_support.tests.factories import ( EnterpriseCourseEnrollmentFactory, EnterpriseCustomerFactory, @@ -192,6 +194,8 @@ class TestProgramsView(SharedModuleStoreTestCase, ProgramCacheMixin): ) @with_site_configuration(configuration={"COURSE_CATALOG_API_URL": "foo"}) + @override_settings(FEATURES=dict(ENABLE_ENTERPRISE_INTEGRATION=True)) + @enterprise_is_enabled() def test_program_list(self): """ Verify API returns proper response. @@ -221,6 +225,8 @@ class TestProgramsView(SharedModuleStoreTestCase, ProgramCacheMixin): } @with_site_configuration(configuration={"COURSE_CATALOG_API_URL": "foo"}) + @override_settings(FEATURES=dict(ENABLE_ENTERPRISE_INTEGRATION=True)) + @enterprise_is_enabled() def test_program_empty_list_if_no_enterprise_enrollments(self): """ Verify API returns empty response if no enterprise enrollments exists for a learner. diff --git a/openedx/core/djangoapps/programs/rest_api/v1/views.py b/openedx/core/djangoapps/programs/rest_api/v1/views.py index 69f47d268a..a5bf939e1e 100644 --- a/openedx/core/djangoapps/programs/rest_api/v1/views.py +++ b/openedx/core/djangoapps/programs/rest_api/v1/views.py @@ -3,12 +3,12 @@ from typing import Any, TYPE_CHECKING import logging -from enterprise.models import EnterpriseCourseEnrollment +from django.db.models.query import EmptyQuerySet from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView -from common.djangoapps.student.models import CourseEnrollment +from common.djangoapps.student.api import get_course_enrollments from openedx.core.djangoapps.programs.utils import ( ProgramProgressMeter, get_certificates, @@ -16,11 +16,14 @@ from openedx.core.djangoapps.programs.utils import ( get_program_and_course_data, get_program_urls, ) +from openedx.features.enterprise_support.api import get_enterprise_course_enrollments, enterprise_is_enabled if TYPE_CHECKING: from django.http import HttpRequest, HttpResponse from django.contrib.auth.models import AnonymousUser, User # pylint: disable=imported-auth-user from django.contrib.sites.models import Site + from django.db.models.query import QuerySet + from common.djangoapps.student.models import CourseEnrollment logger = logging.getLogger(__name__) @@ -86,7 +89,7 @@ class Programs(APIView): """ user: "AnonymousUser | User" = request.user - enrollments = self._get_enterprise_course_enrollments(enterprise_uuid, user) + enrollments = list(self._get_enterprise_course_enrollments(enterprise_uuid, user)) # return empty reponse if no enterprise enrollments exists for a user if not enrollments: return Response([]) @@ -170,26 +173,22 @@ class Programs(APIView): return programs + @enterprise_is_enabled(otherwise=EmptyQuerySet) def _get_enterprise_course_enrollments( self, enterprise_uuid: str, user: "AnonymousUser | User" - ) -> list[CourseEnrollment]: + ) -> "QuerySet[CourseEnrollment]": """ Return only enterprise enrollments for a user. """ - enterprise_enrollment_course_ids = list( - EnterpriseCourseEnrollment.objects.filter( - enterprise_customer_user__user_id=user.id, - enterprise_customer_user__enterprise_customer__uuid=enterprise_uuid, - ).values_list("course_id", flat=True) + enterprise_enrollment_course_ids = ( + get_enterprise_course_enrollments(user) + .filter(enterprise_customer_user__enterprise_customer__uuid=enterprise_uuid) + .values_list("course_id", flat=True) ) - course_enrollments = ( - CourseEnrollment.enrollments_for_user(user) - .filter(course_id__in=enterprise_enrollment_course_ids) - .select_related("course") - ) + course_enrollments = get_course_enrollments(user, True, list(enterprise_enrollment_course_ids)) - return list(course_enrollments) + return course_enrollments class ProgramProgressDetailView(APIView):