diff --git a/lms/djangoapps/mobile_api/users/tests.py b/lms/djangoapps/mobile_api/users/tests.py index 2cfefaa058..7c4b3e437d 100644 --- a/lms/djangoapps/mobile_api/users/tests.py +++ b/lms/djangoapps/mobile_api/users/tests.py @@ -285,6 +285,35 @@ class TestUserEnrollmentApi(UrlResetMixin, MobileAPITestCase, MobileAuthUserTest for entry in courses: assert entry['course']['org'] == 'edX' + @ddt.data(API_V05, API_V1, API_V2) + @patch('lms.djangoapps.mobile_api.users.views.get_current_site_orgs', return_value=['edX']) + def test_filter_by_current_site_orgs(self, api_version, get_current_site_orgs_mock): + self.login() + + # Create list of courses with various organizations + courses = [ + CourseFactory.create(org='edX', mobile_available=True), + CourseFactory.create(org='edX', mobile_available=True), + CourseFactory.create(org='edX', mobile_available=True, visible_to_staff_only=True), + CourseFactory.create(org='Proversity.org', mobile_available=True), + CourseFactory.create(org='MITx', mobile_available=True), + CourseFactory.create(org='HarvardX', mobile_available=True), + ] + + # Enroll in all the courses + for course in courses: + self.enroll(course.id) + + response = self.api_response(api_version=api_version) + courses = response.data['enrollments'] if api_version == API_V2 else response.data + + # Test for 3 expected courses + self.assertEqual(len(courses), 3) + + # Verify only edX courses are returned + for entry in courses: + self.assertEqual(entry['course']['org'], 'edX') + def create_enrollment(self, expired): """ Create an enrollment diff --git a/lms/djangoapps/mobile_api/users/views.py b/lms/djangoapps/mobile_api/users/views.py index c86f3add9d..324db83a37 100644 --- a/lms/djangoapps/mobile_api/users/views.py +++ b/lms/djangoapps/mobile_api/users/views.py @@ -40,6 +40,7 @@ from lms.djangoapps.courseware.models import StudentModule from lms.djangoapps.courseware.views.index import save_positions_recursively_up from lms.djangoapps.mobile_api.models import MobileConfig from lms.djangoapps.mobile_api.utils import API_V1, API_V05, API_V2, API_V3, API_V4 +from openedx.core.djangoapps.site_configuration.helpers import get_current_site_orgs from openedx.features.course_duration_limits.access import check_course_expired from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order from xmodule.modulestore.exceptions import ItemNotFoundError # lint-amnesty, pylint: disable=wrong-import-order @@ -350,6 +351,11 @@ class UserCourseEnrollmentsList(generics.ListAPIView): """ Check course org matches request org param or no param provided """ + current_orgs = get_current_site_orgs() + + if current_orgs and course_org not in current_orgs: + return False + return check_org is None or (check_org.lower() == course_org.lower()) def get_serializer_context(self):