diff --git a/lms/djangoapps/program_enrollments/api/v1/constants.py b/lms/djangoapps/program_enrollments/api/v1/constants.py index f336cfa924..7feed8dbe1 100644 --- a/lms/djangoapps/program_enrollments/api/v1/constants.py +++ b/lms/djangoapps/program_enrollments/api/v1/constants.py @@ -9,6 +9,8 @@ MAX_ENROLLMENT_RECORDS = 25 # The name of the key that identifies students for POST/PATCH requests REQUEST_STUDENT_KEY = 'student_key' +ENABLE_ENROLLMENT_RESET_FLAG = 'ENABLE_ENROLLMENT_RESET' + class BaseEnrollmentResponseStatuses(object): """ diff --git a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py index d1785ff82d..9fadc8b701 100644 --- a/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py +++ b/lms/djangoapps/program_enrollments/api/v1/tests/test_views.py @@ -9,11 +9,14 @@ from uuid import UUID, uuid4 import ddt import mock +from django.conf import settings from django.contrib.auth.models import User from django.core.cache import cache +from django.test import override_settings from django.urls import reverse from freezegun import freeze_time from opaque_keys.edx.keys import CourseKey +from organizations.tests.factories import OrganizationFactory from pytz import UTC from rest_framework import status from rest_framework.test import APITestCase @@ -25,14 +28,18 @@ from course_modes.models import CourseMode from lms.djangoapps.certificates.models import CertificateStatuses from lms.djangoapps.certificates.tests.factories import GeneratedCertificateFactory from lms.djangoapps.courseware.tests.factories import GlobalStaffFactory, InstructorFactory -from lms.djangoapps.program_enrollments.api.v1.constants import MAX_ENROLLMENT_RECORDS, REQUEST_STUDENT_KEY +from lms.djangoapps.program_enrollments.api.v1.constants import ( + ENABLE_ENROLLMENT_RESET_FLAG, + MAX_ENROLLMENT_RECORDS, + REQUEST_STUDENT_KEY +) from lms.djangoapps.program_enrollments.api.v1.constants import CourseEnrollmentResponseStatuses as CourseStatuses from lms.djangoapps.program_enrollments.api.v1.constants import CourseRunProgressStatuses from lms.djangoapps.program_enrollments.api.v1.constants import ProgramEnrollmentResponseStatuses as ProgramStatuses from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment from lms.djangoapps.program_enrollments.tests.factories import ProgramCourseEnrollmentFactory, ProgramEnrollmentFactory from lms.djangoapps.program_enrollments.utils import ProviderDoesNotExistException -from openedx.core.djangoapps.catalog.cache import PROGRAM_CACHE_KEY_TPL +from openedx.core.djangoapps.catalog.cache import PROGRAM_CACHE_KEY_TPL, PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL from openedx.core.djangoapps.catalog.tests.factories import CourseFactory, CourseRunFactory from openedx.core.djangoapps.catalog.tests.factories import OrganizationFactory as CatalogOrganizationFactory from openedx.core.djangoapps.catalog.tests.factories import ProgramFactory @@ -41,6 +48,7 @@ from openedx.core.djangoapps.content.course_overviews.tests.factories import Cou from openedx.core.djangolib.testing.utils import CacheIsolationMixin from student.roles import CourseStaffRole from student.tests.factories import CourseEnrollmentFactory, UserFactory +from third_party_auth.tests.factories import SAMLProviderConfigFactory from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory as ModulestoreCourseFactory from xmodule.modulestore.tests.factories import ItemFactory @@ -69,6 +77,10 @@ class ProgramCacheTestCaseMixin(CacheIsolationMixin): def set_program_in_catalog_cache(program_uuid, program): cache.set(PROGRAM_CACHE_KEY_TPL.format(uuid=program_uuid), program, None) + @staticmethod + def set_org_in_catalog_cache(organization, program_uuids): + cache.set(PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL.format(org_key=organization.short_name), program_uuids) + class ListViewTestMixin(ProgramCacheTestCaseMixin): """ @@ -1968,3 +1980,104 @@ class ProgramCourseGradeListTest(ProgramEnrollmentDataMixin, ListViewTestMixin, }, ] self.assertEqual(response.data['results'], expected_results) + + +class EnrollmentDataResetViewTests(ProgramCacheTestCaseMixin, APITestCase): + """ Tests endpoint for resetting enrollments in integration environments """ + + FEATURES_WITH_ENABLED = settings.FEATURES.copy() + FEATURES_WITH_ENABLED[ENABLE_ENROLLMENT_RESET_FLAG] = True + + reset_enrollments_cmd = 'reset_enrollment_data' + reset_users_cmd = 'remove_social_auth_users' + + def setUp(self): + super(EnrollmentDataResetViewTests, self).setUp() + self.start_cache_isolation() + + self.organization = OrganizationFactory(short_name='uox') + self.provider = SAMLProviderConfigFactory(organization=self.organization) + + self.global_staff = GlobalStaffFactory.create(username='global-staff', password='password') + self.client.login(username=self.global_staff.username, password='password') + + def request(self, organization): + return self.client.post( + reverse('programs_api:v1:reset_enrollment_data'), + {'organization': organization}, + format='json', + ) + + def tearDown(self): + self.end_cache_isolation() + super(EnrollmentDataResetViewTests, self).tearDown() + + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_feature_disabled_by_default(self, mock_call_command): + response = self.request(self.organization.short_name) + self.assertEqual(response.status_code, status.HTTP_501_NOT_IMPLEMENTED) + mock_call_command.assert_has_calls([]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_403_for_non_staff(self, mock_call_command): + student = UserFactory.create(username='student', password='password') + self.client.login(username=student.username, password='password') + response = self.request(self.organization.short_name) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + mock_call_command.assert_has_calls([]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_reset(self, mock_call_command): + programs = [str(uuid4()), str(uuid4())] + self.set_org_in_catalog_cache(self.organization, programs) + + response = self.request(self.organization.short_name) + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_call_command.assert_has_calls([ + mock.call(self.reset_users_cmd, self.provider.slug, force=True), + mock.call(self.reset_enrollments_cmd, ','.join(programs), force=True), + ]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_reset_without_idp(self, mock_call_command): + organization = OrganizationFactory() + programs = [str(uuid4()), str(uuid4())] + self.set_org_in_catalog_cache(organization, programs) + + response = self.request(organization.short_name) + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_call_command.assert_has_calls([ + mock.call(self.reset_enrollments_cmd, ','.join(programs), force=True), + ]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_organization_not_found(self, mock_call_command): + response = self.request('yyz') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + mock_call_command.assert_has_calls([]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_no_programs_doesnt_break(self, mock_call_command): + programs = [] + self.set_org_in_catalog_cache(self.organization, programs) + + response = self.request(self.organization.short_name) + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_call_command.assert_has_calls([ + mock.call(self.reset_users_cmd, self.provider.slug, force=True), + ]) + + @override_settings(FEATURES=FEATURES_WITH_ENABLED) + @mock.patch('lms.djangoapps.program_enrollments.api.v1.views.call_command', autospec=True) + def test_missing_body_content(self, mock_call_command): + response = self.client.post( + reverse('programs_api:v1:reset_enrollment_data'), + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + mock_call_command.assert_has_calls([]) diff --git a/lms/djangoapps/program_enrollments/api/v1/urls.py b/lms/djangoapps/program_enrollments/api/v1/urls.py index d178227b29..cf8ca082ba 100644 --- a/lms/djangoapps/program_enrollments/api/v1/urls.py +++ b/lms/djangoapps/program_enrollments/api/v1/urls.py @@ -5,6 +5,7 @@ from django.conf.urls import url from lms.djangoapps.program_enrollments.api.v1.constants import PROGRAM_UUID_PATTERN from lms.djangoapps.program_enrollments.api.v1.views import ( + EnrollmentDataResetView, ProgramEnrollmentsView, ProgramCourseEnrollmentsView, ProgramCourseGradesView, @@ -54,4 +55,9 @@ urlpatterns = [ ProgramCourseEnrollmentOverviewView.as_view(), name="program_course_enrollments_overview" ), + url( + r'^integration-reset', + EnrollmentDataResetView.as_view(), + name="reset_enrollment_data", + ) ] diff --git a/lms/djangoapps/program_enrollments/api/v1/views.py b/lms/djangoapps/program_enrollments/api/v1/views.py index 1e94c932b1..d01d3d4450 100644 --- a/lms/djangoapps/program_enrollments/api/v1/views.py +++ b/lms/djangoapps/program_enrollments/api/v1/views.py @@ -7,66 +7,71 @@ from __future__ import absolute_import, unicode_literals import logging from functools import wraps -from django.http import Http404 +from ccx_keys.locator import CCXLocator +from django.conf import settings from django.core.exceptions import PermissionDenied +from django.core.management import call_command +from django.db import transaction +from django.http import Http404 from django.utils.functional import cached_property from edx_rest_framework_extensions import permissions from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser from opaque_keys.edx.keys import CourseKey +from organizations.models import Organization from rest_framework import status -from rest_framework.views import APIView from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response - +from rest_framework.views import APIView from six import text_type -from ccx_keys.locator import CCXLocator from course_modes.models import CourseMode from lms.djangoapps.certificates.api import get_certificate_for_user -from lms.djangoapps.grades.api import ( - CourseGradeFactory, - clear_prefetched_course_grades, - prefetch_course_grades, -) +from lms.djangoapps.grades.api import CourseGradeFactory, clear_prefetched_course_grades, prefetch_course_grades from lms.djangoapps.grades.rest_api.v1.utils import CourseEnrollmentPagination -from lms.djangoapps.program_enrollments.api.v1.constants import ( - CourseEnrollmentResponseStatuses, - MAX_ENROLLMENT_RECORDS, - ProgramEnrollmentResponseStatuses, -) from lms.djangoapps.program_enrollments.api.api import ( - get_due_dates, + get_course_run_status, get_course_run_url, - get_emails_enabled, - get_course_run_status + get_due_dates, + get_emails_enabled +) +from lms.djangoapps.program_enrollments.api.v1.constants import ( + ENABLE_ENROLLMENT_RESET_FLAG, + MAX_ENROLLMENT_RECORDS, + CourseEnrollmentResponseStatuses, + ProgramEnrollmentResponseStatuses, ) from lms.djangoapps.program_enrollments.api.v1.serializers import ( CourseRunOverviewListSerializer, ProgramCourseEnrollmentListSerializer, ProgramCourseEnrollmentRequestSerializer, - ProgramCourseGradeResult, ProgramCourseGradeErrorResult, + ProgramCourseGradeResult, ProgramCourseGradeResultSerializer, ProgramEnrollmentCreateRequestSerializer, ProgramEnrollmentListSerializer, - ProgramEnrollmentModifyRequestSerializer, + ProgramEnrollmentModifyRequestSerializer ) from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment -from lms.djangoapps.program_enrollments.utils import get_user_by_program_id, ProviderDoesNotExistException -from student.helpers import get_resume_urls_for_enrollments -from student.models import CourseEnrollment -from student.roles import CourseInstructorRole, CourseStaffRole, UserBasedRole +from lms.djangoapps.program_enrollments.utils import ( + ProviderDoesNotExistException, + get_provider_slug, + get_user_by_program_id +) from openedx.core.djangoapps.catalog.utils import ( course_run_keys_for_program, get_programs, get_programs_by_type, - normalize_program_type, + get_programs_for_organization, + normalize_program_type ) from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.lib.api.authentication import OAuth2AuthenticationAllowInactiveUser from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, PaginatedAPIView, verify_course_exists +from student.helpers import get_resume_urls_for_enrollments +from student.models import CourseEnrollment +from student.roles import CourseInstructorRole, CourseStaffRole, UserBasedRole from util.query import use_read_replica_if_available logger = logging.getLogger(__name__) @@ -1111,7 +1116,6 @@ class ProgramCourseGradesView( ------------------------------------------------------------------------------------ **Returns** - * 200: OK - Contains a paginated set of program courserun grades. * 204: No Content - No grades to return * 207: Mixed result - Contains mixed list of program courserun grades @@ -1156,7 +1160,6 @@ class ProgramCourseGradesView( ... ], } - """ authentication_classes = ( JwtAuthentication, @@ -1273,3 +1276,62 @@ class ProgramCourseGradesView( if any(result.is_error for result in grade_results): return status.HTTP_207_MULTI_STATUS return status.HTTP_200_OK + + +class EnrollmentDataResetView(APIView): + """ + Resets enrollments and users for a given organization and set of programs. + Note, this will remove ALL users from the input organization. + + Path: ``/api/program_enrollments/v1/integration-reset/`` + + Accepts: [POST] + + ------------------------------------------------------------------------------------ + POST + ------------------------------------------------------------------------------------ + + **Returns** + * 200: OK - Enrollments and users sucessfully deleted + * 400: Bad Requeset - Program does not match the requested organization + * 401: Unauthorized - The requesting user is not authenticated. + * 404: Not Found - A requested program does not exist. + + **Response** + """ + authentication_classes = ( + JwtAuthentication, + OAuth2AuthenticationAllowInactiveUser, + SessionAuthenticationAllowInactiveUser, + ) + permission_classes = (permissions.JWT_RESTRICTED_APPLICATION_OR_USER_ACCESS,) + + @transaction.atomic + def post(self, request): + """ + Reset enrollment and user data for organization + """ + if not settings.FEATURES.get(ENABLE_ENROLLMENT_RESET_FLAG): + return Response('reset not enabled on this environment', status.HTTP_501_NOT_IMPLEMENTED) + + try: + org_key = request.data['organization'] + except KeyError: + return Response("missing required body content 'organization'", status.HTTP_400_BAD_REQUEST) + + try: + organization = Organization.objects.get(short_name=org_key) + except Organization.DoesNotExist: + return Response('organization {} not found'.format(org_key), status.HTTP_404_NOT_FOUND) + + try: + idp_slug = get_provider_slug(organization) + call_command('remove_social_auth_users', idp_slug, force=True) + except ProviderDoesNotExistException: + pass + + programs = get_programs_for_organization(organization=organization.short_name) + if programs: + call_command('reset_enrollment_data', ','.join(programs), force=True) + + return Response('success') diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_reset_enrollment_data.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_reset_enrollment_data.py index c35254e01d..9bf25ee682 100644 --- a/lms/djangoapps/program_enrollments/management/commands/tests/test_reset_enrollment_data.py +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_reset_enrollment_data.py @@ -5,12 +5,12 @@ from __future__ import absolute_import import sys from contextlib import contextmanager -from six import StringIO from uuid import uuid4 from django.core.management import call_command from django.core.management.base import CommandError from django.test import TestCase +from six import StringIO from lms.djangoapps.program_enrollments.management.commands import reset_enrollment_data from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment diff --git a/lms/djangoapps/program_enrollments/tests/test_utils.py b/lms/djangoapps/program_enrollments/tests/test_utils.py index 7e98d990fe..59ea1756ad 100644 --- a/lms/djangoapps/program_enrollments/tests/test_utils.py +++ b/lms/djangoapps/program_enrollments/tests/test_utils.py @@ -17,8 +17,8 @@ from openedx.core.djangolib.testing.utils import CacheIsolationTestCase from program_enrollments.utils import ( OrganizationDoesNotExistException, ProgramDoesNotExistException, + ProviderConfigurationException, ProviderDoesNotExistException, - UserLookupException, get_user_by_program_id ) from student.tests.factories import UserFactory @@ -140,5 +140,5 @@ class GetPlatformUserTests(CacheIsolationTestCase): # create a second active config for the same organization SAMLProviderConfigFactory.create(organization=organization, slug='foox') - with pytest.raises(UserLookupException): + with pytest.raises(ProviderConfigurationException): get_user_by_program_id(self.external_user_id, self.program_uuid) diff --git a/lms/djangoapps/program_enrollments/utils.py b/lms/djangoapps/program_enrollments/utils.py index 9035fb3e03..c8a0421dd3 100644 --- a/lms/djangoapps/program_enrollments/utils.py +++ b/lms/djangoapps/program_enrollments/utils.py @@ -14,19 +14,19 @@ from third_party_auth.models import SAMLProviderConfig log = logging.getLogger(__name__) -class UserLookupException(Exception): +class ProgramDoesNotExistException(Exception): pass -class ProgramDoesNotExistException(UserLookupException): +class OrganizationDoesNotExistException(Exception): pass -class OrganizationDoesNotExistException(UserLookupException): +class ProviderDoesNotExistException(Exception): pass -class ProviderDoesNotExistException(UserLookupException): +class ProviderConfigurationException(Exception): pass @@ -81,8 +81,24 @@ def get_user_by_organization(external_user_id, organization): Raises: ProviderDoesNotExistException if there is no SAML provider configured for the related organization. """ + provider_slug = get_provider_slug(organization) try: - provider_slug = organization.samlproviderconfig_set.current_set().get().provider_id.strip('saml-') + social_auth_uid = '{0}:{1}'.format(provider_slug, external_user_id) + return UserSocialAuth.objects.get(uid=social_auth_uid).user + except UserSocialAuth.DoesNotExist: + return None + + +def get_provider_slug(organization): + """ + Returns slug for the currently configured saml provder on an Organization + + Raises: + ProviderDoesNotExistsException + ProviderConfigurationException + """ + try: + return organization.samlproviderconfig_set.current_set().get().provider_id.strip('saml-') except SAMLProviderConfig.DoesNotExist: log.error(u'No SAML provider found for organization id [%s]', organization.id) raise ProviderDoesNotExistException @@ -91,10 +107,4 @@ def get_user_by_organization(external_user_id, organization): u'Multiple active SAML configurations found for organization=%s. Expected one.', organization.short_name, ) - raise UserLookupException - - try: - social_auth_uid = '{0}:{1}'.format(provider_slug, external_user_id) - return UserSocialAuth.objects.get(uid=social_auth_uid).user - except UserSocialAuth.DoesNotExist: - return None + raise ProviderConfigurationException diff --git a/openedx/core/djangoapps/catalog/cache.py b/openedx/core/djangoapps/catalog/cache.py index 3f53d90a24..bb62b65308 100644 --- a/openedx/core/djangoapps/catalog/cache.py +++ b/openedx/core/djangoapps/catalog/cache.py @@ -18,3 +18,6 @@ COURSE_PROGRAMS_CACHE_KEY_TPL = 'course-programs-{course_run_id}' # because program_type values are likely to be shared between different sites # that live in the same environment). PROGRAMS_BY_TYPE_CACHE_KEY_TPL = 'programs-by-type-{site_id}-{program_type}' + +# Template used to create cache keys for organization to program uuids. +PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL = 'organization-programs-{org_key}' diff --git a/openedx/core/djangoapps/catalog/management/commands/cache_programs.py b/openedx/core/djangoapps/catalog/management/commands/cache_programs.py index 3cb53eda30..cd8760d957 100644 --- a/openedx/core/djangoapps/catalog/management/commands/cache_programs.py +++ b/openedx/core/djangoapps/catalog/management/commands/cache_programs.py @@ -1,8 +1,9 @@ """"Management command to add program information to the cache.""" from __future__ import absolute_import -from collections import defaultdict + import logging import sys +from collections import defaultdict from django.contrib.auth import get_user_model from django.contrib.sites.models import Site @@ -14,15 +15,16 @@ from openedx.core.djangoapps.catalog.cache import ( COURSE_PROGRAMS_CACHE_KEY_TPL, PATHWAY_CACHE_KEY_TPL, PROGRAM_CACHE_KEY_TPL, + PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL, PROGRAMS_BY_TYPE_CACHE_KEY_TPL, SITE_PATHWAY_IDS_CACHE_KEY_TPL, - SITE_PROGRAM_UUIDS_CACHE_KEY_TPL, + SITE_PROGRAM_UUIDS_CACHE_KEY_TPL ) from openedx.core.djangoapps.catalog.models import CatalogIntegration from openedx.core.djangoapps.catalog.utils import ( - create_catalog_api_client, course_run_keys_for_program, - normalize_program_type, + create_catalog_api_client, + normalize_program_type ) logger = logging.getLogger(__name__) @@ -59,6 +61,7 @@ class Command(BaseCommand): pathways = {} courses = {} programs_by_type = {} + organizations = {} for site in Site.objects.all(): site_config = getattr(site, 'configuration', None) if site_config is None or not site_config.get_value('COURSE_CATALOG_API_URL'): @@ -86,6 +89,7 @@ class Command(BaseCommand): pathways.update(new_pathways) courses.update(self.get_courses(new_programs)) programs_by_type.update(self.get_programs_by_type(site, new_programs)) + organizations.update(self.get_programs_by_organization(new_programs)) logger.info(u'Caching UUIDs for {total} programs for site {site_name}.'.format( total=len(uuids), @@ -112,6 +116,9 @@ class Command(BaseCommand): logger.info(text_type('Caching program UUIDs by {} program types.'.format(len(programs_by_type)))) cache.set_many(programs_by_type, None) + logger.info(u'Caching programs uuids for {} organizations'.format(len(organizations))) + cache.set_many(organizations, None) + if failure: sys.exit(1) @@ -236,3 +243,14 @@ class Command(BaseCommand): cache_key = PROGRAMS_BY_TYPE_CACHE_KEY_TPL.format(site_id=site.id, program_type=program_type) programs_by_type[cache_key].append(program['uuid']) return programs_by_type + + def get_programs_by_organization(self, programs): + """ + Returns a dictionary mapping organization keys to lists of program uuids authored by that org + """ + organizations = defaultdict(list) + for program in programs.values(): + for org in program['authoring_organizations']: + org_cache_key = PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL.format(org_key=org['key']) + organizations[org_cache_key].append(program['uuid']) + return organizations diff --git a/openedx/core/djangoapps/catalog/management/commands/tests/test_cache_programs.py b/openedx/core/djangoapps/catalog/management/commands/tests/test_cache_programs.py index 1773934bb6..2994d9ca64 100644 --- a/openedx/core/djangoapps/catalog/management/commands/tests/test_cache_programs.py +++ b/openedx/core/djangoapps/catalog/management/commands/tests/test_cache_programs.py @@ -10,6 +10,7 @@ from django.core.management import call_command from openedx.core.djangoapps.catalog.cache import ( COURSE_PROGRAMS_CACHE_KEY_TPL, + PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL, PATHWAY_CACHE_KEY_TPL, PROGRAM_CACHE_KEY_TPL, PROGRAMS_BY_TYPE_CACHE_KEY_TPL, @@ -17,7 +18,7 @@ from openedx.core.djangoapps.catalog.cache import ( SITE_PROGRAM_UUIDS_CACHE_KEY_TPL ) from openedx.core.djangoapps.catalog.utils import normalize_program_type -from openedx.core.djangoapps.catalog.tests.factories import PathwayFactory, ProgramFactory +from openedx.core.djangoapps.catalog.tests.factories import OrganizationFactory, PathwayFactory, ProgramFactory from openedx.core.djangoapps.catalog.tests.mixins import CatalogIntegrationMixin from openedx.core.djangoapps.site_configuration.tests.mixins import SiteMixin from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms @@ -57,6 +58,8 @@ class TestCachePrograms(CatalogIntegrationMixin, CacheIsolationTestCase, SiteMix self.programs[0]['curricula'][0]['programs'].append(self.child_program) self.programs.append(self.child_program) + self.programs[0]['authoring_organizations'] = OrganizationFactory.create_batch(2) + for pathway in self.pathways: self.programs += pathway['programs'] @@ -193,7 +196,7 @@ class TestCachePrograms(CatalogIntegrationMixin, CacheIsolationTestCase, SiteMix self.assertIn(self.child_program['uuid'], cache.get(course_run_cache_key)) # for each program, assert that the program's UUID is in a cached list of - # program UUIDS by program type + # program UUIDS by program type and a cached list of UUIDs by authoring organization for program in self.programs: program_type = normalize_program_type(program.get('type', 'None')) program_type_cache_key = PROGRAMS_BY_TYPE_CACHE_KEY_TPL.format( @@ -201,6 +204,12 @@ class TestCachePrograms(CatalogIntegrationMixin, CacheIsolationTestCase, SiteMix ) self.assertIn(program['uuid'], cache.get(program_type_cache_key)) + for organization in program['authoring_organizations']: + organization_cache_key = PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL.format( + org_key=organization['key'] + ) + self.assertIn(program['uuid'], cache.get(organization_cache_key)) + def test_handle_pathways(self): """ Verify that the command requests and caches credit pathways diff --git a/openedx/core/djangoapps/catalog/utils.py b/openedx/core/djangoapps/catalog/utils.py index c6e5c2de4f..ce5299fcc2 100644 --- a/openedx/core/djangoapps/catalog/utils.py +++ b/openedx/core/djangoapps/catalog/utils.py @@ -18,6 +18,7 @@ from entitlements.utils import is_course_run_entitlement_fulfillable from openedx.core.constants import COURSE_PUBLISHED from openedx.core.djangoapps.catalog.cache import ( COURSE_PROGRAMS_CACHE_KEY_TPL, + PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL, PATHWAY_CACHE_KEY_TPL, PROGRAM_CACHE_KEY_TPL, PROGRAMS_BY_TYPE_CACHE_KEY_TPL, @@ -86,7 +87,7 @@ def check_catalog_integration_and_get_user(error_message_field): # pylint: disable=redefined-outer-name -def get_programs(site=None, uuid=None, uuids=None, course=None): +def get_programs(site=None, uuid=None, uuids=None, course=None, organization=None): """Read programs from the cache. The cache is populated by a management command, cache_programs. @@ -96,12 +97,13 @@ def get_programs(site=None, uuid=None, uuids=None, course=None): uuid (string): UUID identifying a specific program to read from the cache. uuids (list of string): UUIDs identifying a specific programs to read from the cache. course (string): course id identifying a specific course run to read from the cache. + organization (string): short name for specific organization to read from the cache. Returns: list of dict, representing programs. dict, if a specific program is requested. """ - if len([arg for arg in (site, uuid, uuids, course) if arg is not None]) != 1: + if len([arg for arg in (site, uuid, uuids, course, organization) if arg is not None]) != 1: raise TypeError('get_programs takes exactly one argument') if uuid: @@ -120,6 +122,10 @@ def get_programs(site=None, uuid=None, uuids=None, course=None): uuids = cache.get(SITE_PROGRAM_UUIDS_CACHE_KEY_TPL.format(domain=site.domain), []) if not uuids: logger.warning(u'Failed to get program UUIDs from the cache for site {}.'.format(site.domain)) + elif organization: + uuids = get_programs_for_organization(organization) + if not uuids: + return [] return get_programs_by_uuids(uuids) @@ -623,3 +629,10 @@ def _course_runs_from_container(container): def normalize_program_type(program_type): """ Function that normalizes a program type string for use in a cache key. """ return str(program_type).lower() + + +def get_programs_for_organization(organization): + """ + Retrieve list of program uuids authored by a given organization + """ + return cache.get(PROGRAMS_BY_ORGANIZATION_CACHE_KEY_TPL.format(org_key=organization))