diff --git a/common/djangoapps/enrollment/views.py b/common/djangoapps/enrollment/views.py index 1c41506a3e..8bd64cf177 100644 --- a/common/djangoapps/enrollment/views.py +++ b/common/djangoapps/enrollment/views.py @@ -30,8 +30,8 @@ from openedx.core.lib.api.permissions import ApiKeyHeaderPermission, ApiKeyHeade from openedx.core.lib.exceptions import CourseNotFoundError from openedx.core.lib.log_utils import audit_log from openedx.features.enterprise_support.api import ( - ConsentApiClient, - EnterpriseApiClient, + ConsentApiServiceClient, + EnterpriseApiServiceClient, EnterpriseApiException, enterprise_enabled ) @@ -598,8 +598,8 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): enterprise_course_consent = request.data.get('enterprise_course_consent') explicit_linked_enterprise = request.data.get('linked_enterprise_customer') if (enterprise_course_consent or explicit_linked_enterprise) and has_api_key_permissions and enterprise_enabled(): - enterprise_api_client = EnterpriseApiClient() - consent_client = ConsentApiClient() + enterprise_api_client = EnterpriseApiServiceClient() + consent_client = ConsentApiServiceClient() # We received an explicitly-linked EnterpriseCustomer for the enrollment if explicit_linked_enterprise is not None: try: diff --git a/openedx/features/enterprise_support/api.py b/openedx/features/enterprise_support/api.py index d2adab1430..ce03291252 100644 --- a/openedx/features/enterprise_support/api.py +++ b/openedx/features/enterprise_support/api.py @@ -1,11 +1,9 @@ """ APIs providing support for enterprise functionality. """ -import hashlib import logging from functools import wraps -import six from django.conf import settings from django.contrib.auth.models import User from django.core.cache import cache @@ -15,13 +13,11 @@ from django.template.loader import render_to_string from django.utils.http import urlencode from django.utils.translation import ugettext as _ from edx_rest_api_client.client import EdxRestApiClient -from requests.exceptions import ConnectionError, Timeout from slumber.exceptions import HttpClientError, HttpNotFoundError, HttpServerError, SlumberBaseException -from openedx.core.djangoapps.catalog.models import CatalogIntegration -from openedx.core.djangoapps.catalog.utils import create_catalog_api_client from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers from openedx.core.lib.token_utils import JwtBuilder +from openedx.features.enterprise_support.utils import get_cache_key from third_party_auth.pipeline import get as get_partial_pipeline from third_party_auth.provider import Registry @@ -30,6 +26,7 @@ try: except ImportError: pass + CONSENT_FAILED_PARAMETER = 'consent_failed' LOGGER = logging.getLogger("edx.enterprise_helpers") @@ -46,12 +43,12 @@ class ConsentApiClient(object): Class for producing an Enterprise Consent service API client """ - def __init__(self): + def __init__(self, user): """ - Initialize a consent service API client, authenticated using the Enterprise worker username. + Initialize an authenticated Consent service API client by using the + provided user. """ - self.user = User.objects.get(username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME) - jwt = JwtBuilder(self.user).build_token([]) + jwt = JwtBuilder(user).build_token([]) url = configuration_helpers.get_value('ENTERPRISE_CONSENT_API_URL', settings.ENTERPRISE_CONSENT_API_URL) self.client = EdxRestApiClient( url, @@ -97,17 +94,38 @@ class ConsentApiClient(object): return response['consent_required'] +class EnterpriseServiceClientMixin(object): + """ + Class for initializing an Enterprise API clients with service user. + """ + + def __init__(self): + """ + Initialize an authenticated Enterprise API client by using the + Enterprise worker user by default. + """ + user = User.objects.get(username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME) + super(EnterpriseServiceClientMixin, self).__init__(user) + + +class ConsentApiServiceClient(EnterpriseServiceClientMixin, ConsentApiClient): + """ + Class for producing an Enterprise Consent API client with service user. + """ + pass + + class EnterpriseApiClient(object): """ Class for producing an Enterprise service API client. """ - def __init__(self): + def __init__(self, user): """ - Initialize an Enterprise service API client, authenticated using the Enterprise worker username. + Initialize an authenticated Enterprise service API client by using the + provided user. """ - self.user = User.objects.get(username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME) - jwt = JwtBuilder(self.user).build_token([]) + jwt = JwtBuilder(user).build_token([]) self.client = EdxRestApiClient( configuration_helpers.get_value('ENTERPRISE_API_URL', settings.ENTERPRISE_API_URL), jwt=jwt @@ -238,6 +256,30 @@ class EnterpriseApiClient(object): return response +class EnterpriseApiServiceClient(EnterpriseServiceClientMixin, EnterpriseApiClient): + """ + Class for producing an Enterprise service API client with service user. + """ + + def get_enterprise_customer(self, uuid): + """ + Fetch enterprise customer with enterprise service user and cache the + API response`. + """ + cache_key = get_cache_key( + resource='enterprise-customer', + username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME, + ) + enterprise_customer = cache.get(cache_key) + if not enterprise_customer: + endpoint = getattr(self.client, 'enterprise-customer') + enterprise_customer = endpoint(uuid).get() + if enterprise_customer: + cache.set(cache_key, enterprise_customer, settings.ENTERPRISE_API_CACHE_TIMEOUT) + + return enterprise_customer + + def data_sharing_consent_required(view_func): """ Decorator which makes a view method redirect to the Data Sharing Consent form if: @@ -291,7 +333,7 @@ def enterprise_customer_for_request(request): if not enterprise_enabled(): return None - ec = None + enterprise_customer = None sso_provider_id = request.GET.get('tpa_hint') running_pipeline = get_partial_pipeline(request) @@ -308,34 +350,34 @@ def enterprise_customer_for_request(request): # Check if there's an Enterprise Customer such that the linked SSO provider # has an ID equal to the ID we got from the running pipeline or from the # request tpa_hint URL parameter. - ec_uuid = EnterpriseCustomer.objects.get( + enterprise_customer_uuid = EnterpriseCustomer.objects.get( enterprise_customer_identity_provider__provider_id=sso_provider_id ).uuid except EnterpriseCustomer.DoesNotExist: # If there is not an EnterpriseCustomer linked to this SSO provider, set # the UUID variable to be null. - ec_uuid = None + enterprise_customer_uuid = None else: # Check if we got an Enterprise UUID passed directly as either a query # parameter, or as a value in the Enterprise cookie. - ec_uuid = request.GET.get('enterprise_customer') or request.COOKIES.get(settings.ENTERPRISE_CUSTOMER_COOKIE_NAME) + enterprise_customer_uuid = request.GET.get('enterprise_customer') or request.COOKIES.get( + settings.ENTERPRISE_CUSTOMER_COOKIE_NAME + ) - if not ec_uuid and request.user.is_authenticated(): - # If there's no way to get an Enterprise UUID for the request, check to see - # if there's already an Enterprise attached to the requesting user on the backend. - learner_data = get_enterprise_learner_data(request.site, request.user) - if learner_data: - ec_uuid = learner_data[0]['enterprise_customer']['uuid'] - if ec_uuid: + if enterprise_customer_uuid: # If we were able to obtain an EnterpriseCustomer UUID, go ahead # and use it to attempt to retrieve EnterpriseCustomer details # from the EnterpriseCustomer API. - try: - ec = EnterpriseApiClient().get_enterprise_customer(ec_uuid) - except HttpNotFoundError: - ec = None + enterprise_api_client = EnterpriseApiServiceClient() + if request.user.is_authenticated(): + enterprise_api_client = EnterpriseApiClient(user=request.user) - return ec + try: + enterprise_customer = enterprise_api_client.get_enterprise_customer(enterprise_customer_uuid) + except HttpNotFoundError: + enterprise_customer = None + + return enterprise_customer def consent_needed_for_course(request, user, course_id, enrollment_exists=False): @@ -355,7 +397,7 @@ def consent_needed_for_course(request, user, course_id, enrollment_exists=False) if not enterprise_learner_details: consent_needed = False else: - client = ConsentApiClient() + client = ConsentApiClient(user=request.user) consent_needed = any( client.consent_required( username=user.username, @@ -423,7 +465,7 @@ def get_enterprise_learner_data(site, user): if not enterprise_enabled(): return None - enterprise_learner_data = EnterpriseApiClient().fetch_enterprise_learner_data(site=site, user=user) + enterprise_learner_data = EnterpriseApiClient(user=user).fetch_enterprise_learner_data(site=site, user=user) if enterprise_learner_data: return enterprise_learner_data['results'] @@ -458,7 +500,7 @@ def get_dashboard_consent_notification(request, user, course_enrollments): enrollment = course_enrollment break - client = ConsentApiClient() + client = ConsentApiClient(user=request.user) consent_needed = client.consent_required( enterprise_customer_uuid=enterprise_customer['uuid'], username=user.username, diff --git a/openedx/features/enterprise_support/tests/test_api.py b/openedx/features/enterprise_support/tests/test_api.py index 7d46e0c4f8..a1e6700efb 100644 --- a/openedx/features/enterprise_support/tests/test_api.py +++ b/openedx/features/enterprise_support/tests/test_api.py @@ -8,21 +8,26 @@ import ddt import httpretty import mock from django.conf import settings +from django.contrib.auth.models import User +from django.core.cache import cache from django.core.urlresolvers import reverse from django.http import HttpResponseRedirect -from django.test import TestCase from django.test.utils import override_settings +from openedx.core.djangolib.testing.utils import CacheIsolationTestCase from openedx.features.enterprise_support.api import ( + ConsentApiClient, + ConsentApiServiceClient, consent_needed_for_course, data_sharing_consent_required, + EnterpriseApiClient, + EnterpriseApiServiceClient, enterprise_customer_for_request, - enterprise_enabled, get_dashboard_consent_notification, get_enterprise_consent_url, ) - from openedx.features.enterprise_support.tests.mixins.enterprise import EnterpriseServiceMockMixin +from openedx.features.enterprise_support.utils import get_cache_key from student.tests.factories import UserFactory @@ -38,27 +43,123 @@ class MockEnrollment(mock.MagicMock): @ddt.ddt @override_settings(ENABLE_ENTERPRISE_INTEGRATION=True) @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') -class TestEnterpriseApi(EnterpriseServiceMockMixin, TestCase): +class TestEnterpriseApi(EnterpriseServiceMockMixin, CacheIsolationTestCase): """ Test enterprise support APIs. """ + ENABLED_CACHES = ['default'] + @classmethod def setUpTestData(cls): - UserFactory.create( - username='enterprise_worker', + cls.user = UserFactory.create( + username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME, email='ent_worker@example.com', password='password123', ) super(TestEnterpriseApi, cls).setUpTestData() + def _assert_api_service_client(self, api_client, mocked_jwt_builder): + """ + Verify that the provided api client uses the enterprise service user to generate + JWT token for auth. + """ + mocked_jwt_builder.return_value.build_token.return_value = 'test-token' + enterprise_service_user = User.objects.get(username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME) + enterprise_api_service_client = api_client() + + mocked_jwt_builder.assert_called_once_with(enterprise_service_user) + # pylint: disable=protected-access + self.assertEqual(enterprise_api_service_client.client._store['session'].auth.token, 'test-token') + + def _assert_api_client_with_user(self, api_client, mocked_jwt_builder): + """ + Verify that the provided api client uses the expected user to generate + JWT token for auth. + """ + mocked_jwt_builder.return_value.build_token.return_value = 'test-token' + dummy_enterprise_user = UserFactory.create( + username='dummy-enterprise-user', + email='dummy-enterprise-user@example.com', + password='password123', + ) + enterprise_api_service_client = api_client(dummy_enterprise_user) + + mocked_jwt_builder.assert_called_once_with(dummy_enterprise_user) + # pylint: disable=protected-access + self.assertEqual(enterprise_api_service_client.client._store['session'].auth.token, 'test-token') + + def _assert_get_enterprise_customer(self, api_client): + """ + DRY method to verify caching for get enterprise customer method. + """ + dummy_enterprise_api_data = {'name': 'dummy-enterprise-customer', 'uuid': 'enterprise-uuid'} + cache_key = get_cache_key( + resource='enterprise-customer', + username=settings.ENTERPRISE_SERVICE_WORKER_USERNAME, + ) + self.mock_get_enterprise_customer('enterprise-uuid', dummy_enterprise_api_data, 200) + self._assert_get_enterprise_customer_with_cache(api_client, dummy_enterprise_api_data, cache_key) + + def _assert_get_enterprise_customer_with_cache(self, api_client, enterprise_customer_data, cache_key): + """ + DRY method to verify that get enterprise customer response is cached. + """ + cached_enterprise_customer = cache.get(cache_key) + self.assertIsNone(cached_enterprise_customer) + + enterprise_customer = api_client.get_enterprise_customer(enterprise_customer_data['uuid']) + self.assertEqual(enterprise_customer_data, enterprise_customer) + cached_enterprise_customer = cache.get(cache_key) + self.assertEqual(cached_enterprise_customer, enterprise_customer) + + @httpretty.activate + @mock.patch('openedx.features.enterprise_support.api.JwtBuilder') + def test_enterprise_api_client_with_service_user(self, mock_jwt_builder): + """ + Verify that enterprise API service client uses enterprise service user + by default to authenticate and access enterprise API. + """ + self._assert_api_service_client(EnterpriseApiServiceClient, mock_jwt_builder) + + # Now verify that enterprise customer data is cached properly for + # the enterprise api client. + enterprise_api_client = EnterpriseApiServiceClient() + self._assert_get_enterprise_customer(enterprise_api_client) + + @httpretty.activate + @mock.patch('openedx.features.enterprise_support.api.JwtBuilder') + def test_enterprise_api_client_with_user(self, mock_jwt_builder): + """ + Verify that enterprise API client uses the provided user to + authenticate and access enterprise API. + """ + self._assert_api_client_with_user(EnterpriseApiClient, mock_jwt_builder) + + @httpretty.activate + @mock.patch('openedx.features.enterprise_support.api.JwtBuilder') + def test_enterprise_consent_api_client_with_service_user(self, mock_jwt_builder): + """ + Verify that enterprise API consent service client uses enterprise + service user by default to authenticate and access enterprise API. + """ + self._assert_api_service_client(ConsentApiServiceClient, mock_jwt_builder) + + @httpretty.activate + @mock.patch('openedx.features.enterprise_support.api.JwtBuilder') + def test_enterprise_consent_api_client_with_user(self, mock_jwt_builder): + """ + Verify that enterprise API consent service client uses the provided + user to authenticate and access enterprise API. + """ + self._assert_api_client_with_user(ConsentApiClient, mock_jwt_builder) + @httpretty.activate - @override_settings(ENTERPRISE_SERVICE_WORKER_USERNAME='enterprise_worker') def test_consent_needed_for_course(self): user = mock.MagicMock( username='janedoe', is_authenticated=lambda: True, ) - request = mock.MagicMock(session={}) + request = mock.MagicMock(session={}, user=user) self.mock_enterprise_learner_api() self.mock_consent_missing(user.username, 'fake-course', 'cf246b88-d5f6-4908-a522-fc307e0b0c59') self.assertTrue(consent_needed_for_course(request, user, 'fake-course')) @@ -70,67 +171,61 @@ class TestEnterpriseApi(EnterpriseServiceMockMixin, TestCase): self.assertFalse(consent_needed_for_course(request, user, 'fake-course')) @httpretty.activate - @mock.patch('openedx.features.enterprise_support.api.get_enterprise_learner_data') @mock.patch('openedx.features.enterprise_support.api.EnterpriseCustomer') @mock.patch('openedx.features.enterprise_support.api.get_partial_pipeline') @mock.patch('openedx.features.enterprise_support.api.Registry') - @override_settings(ENTERPRISE_SERVICE_WORKER_USERNAME='enterprise_worker') def test_enterprise_customer_for_request( self, mock_registry, mock_partial, - mock_ec_model, - mock_get_el_data + mock_enterprise_customer_model, ): - def mock_get_ec(**kwargs): + def mock_get_enterprise_customer(**kwargs): uuid = kwargs.get('enterprise_customer_identity_provider__provider_id') if uuid: - return mock.MagicMock(uuid=uuid) + return mock.MagicMock(uuid=uuid, user=self.user) raise Exception - mock_ec_model.objects.get.side_effect = mock_get_ec - mock_ec_model.DoesNotExist = Exception - + dummy_request = mock.MagicMock(session={}, user=self.user) + mock_enterprise_customer_model.objects.get.side_effect = mock_get_enterprise_customer + mock_enterprise_customer_model.DoesNotExist = Exception mock_partial.return_value = True mock_registry.get_from_pipeline.return_value.provider_id = 'real-ent-uuid' - self.mock_get_enterprise_customer('real-ent-uuid', {"real": "enterprisecustomer"}, 200) - - ec = enterprise_customer_for_request(mock.MagicMock()) - - self.assertEqual(ec, {"real": "enterprisecustomer"}) + # Verify that the method `enterprise_customer_for_request` returns + # expected enterprise customer against the requesting user. + self.mock_get_enterprise_customer('real-ent-uuid', {'real': 'enterprisecustomer'}, 200) + enterprise_customer = enterprise_customer_for_request(dummy_request) + self.assertEqual(enterprise_customer, {'real': 'enterprisecustomer'}) httpretty.reset() - self.mock_get_enterprise_customer('real-ent-uuid', {"detail": "Not found."}, 404) + # Verify that the method `enterprise_customer_for_request` returns no + # enterprise customer if the enterprise customer API throws 404. + self.mock_get_enterprise_customer('real-ent-uuid', {'detail': 'Not found.'}, 404) + enterprise_customer = enterprise_customer_for_request(dummy_request) + self.assertIsNone(enterprise_customer) - ec = enterprise_customer_for_request(mock.MagicMock()) - - self.assertIsNone(ec) + httpretty.reset() + # Verify that the method `enterprise_customer_for_request` returns + # expected enterprise customer against the requesting user even if + # the third-party auth pipeline has no `provider_id`. mock_registry.get_from_pipeline.return_value.provider_id = None - - httpretty.reset() - - self.mock_get_enterprise_customer('real-ent-uuid', {"real": "enterprisecustomer"}, 200) - - ec = enterprise_customer_for_request(mock.MagicMock(GET={"enterprise_customer": 'real-ent-uuid'})) - - self.assertEqual(ec, {"real": "enterprisecustomer"}) - - ec = enterprise_customer_for_request( - mock.MagicMock(GET={}, COOKIES={settings.ENTERPRISE_CUSTOMER_COOKIE_NAME: 'real-ent-uuid'}) + self.mock_get_enterprise_customer('real-ent-uuid', {'real': 'enterprisecustomer'}, 200) + enterprise_customer = enterprise_customer_for_request( + mock.MagicMock(GET={'enterprise_customer': 'real-ent-uuid'}, user=self.user) ) + self.assertEqual(enterprise_customer, {'real': 'enterprisecustomer'}) - self.assertEqual(ec, {"real": "enterprisecustomer"}) - - mock_get_el_data.return_value = [{'enterprise_customer': {'uuid': 'real-ent-uuid'}}] - - ec = enterprise_customer_for_request( - mock.MagicMock(GET={}, COOKIES={}, user=mock.MagicMock(is_authenticated=lambda: True), site=1) + # Verify that the method `enterprise_customer_for_request` returns + # expected enterprise customer against the requesting user even if + # the third-party auth pipeline has no `provider_id` but there is + # enterprise customer UUID in the cookie. + enterprise_customer = enterprise_customer_for_request( + mock.MagicMock(GET={}, COOKIES={settings.ENTERPRISE_CUSTOMER_COOKIE_NAME: 'real-ent-uuid'}, user=self.user) ) - - self.assertEqual(ec, {"real": "enterprisecustomer"}) + self.assertEqual(enterprise_customer, {'real': 'enterprisecustomer'}) def check_data_sharing_consent(self, consent_required=False, consent_url=None): """ diff --git a/openedx/features/enterprise_support/utils.py b/openedx/features/enterprise_support/utils.py new file mode 100644 index 0000000000..614f027d9a --- /dev/null +++ b/openedx/features/enterprise_support/utils.py @@ -0,0 +1,29 @@ +from __future__ import unicode_literals + +import hashlib + +import six + + +def get_cache_key(**kwargs): + """ + Get MD5 encoded cache key for given arguments. + + Here is the format of key before MD5 encryption. + key1:value1__key2:value2 ... + + Example: + >>> get_cache_key(site_domain="example.com", resource="enterprise-learner") + # Here is key format for above call + # "site_domain:example.com__resource:enterprise-learner" + a54349175618ff1659dee0978e3149ca + + Arguments: + **kwargs: Key word arguments that need to be present in cache key. + + Returns: + An MD5 encoded key uniquely identified by the key word arguments. + """ + key = '__'.join(['{}:{}'.format(item, value) for item, value in six.iteritems(kwargs)]) + + return hashlib.md5(key).hexdigest()