diff --git a/openedx/features/enterprise_support/api.py b/openedx/features/enterprise_support/api.py index 293ab6294d..31b6d74cb6 100644 --- a/openedx/features/enterprise_support/api.py +++ b/openedx/features/enterprise_support/api.py @@ -22,6 +22,8 @@ from openedx.core.djangoapps.catalog.models import CatalogIntegration from openedx.core.djangoapps.catalog.utils import create_catalog_api_client from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers from openedx.core.lib.token_utils import JwtBuilder +from third_party_auth.pipeline import get as get_partial_pipeline +from third_party_auth.provider import Registry try: from enterprise import utils as enterprise_utils @@ -241,13 +243,22 @@ def enterprise_customer_for_request(request, tpa_hint=None): ec = None + running_pipeline = get_partial_pipeline(request) + if running_pipeline: + # Determine if the user is in the middle of a third-party auth pipeline, + # and set the tpa_hint parameter to match if so. + tpa_hint = Registry.get_from_pipeline(running_pipeline).provider_id + if tpa_hint: + # If we have a third-party auth provider, get the linked enterprise customer. try: ec = EnterpriseCustomer.objects.get(enterprise_customer_identity_provider__provider_id=tpa_hint) except EnterpriseCustomer.DoesNotExist: pass ec_uuid = request.GET.get('enterprise_customer') or request.COOKIES.get(settings.ENTERPRISE_CUSTOMER_COOKIE_NAME) + # If we haven't obtained an EnterpriseCustomer through the other methods, check the + # session cookies and URL parameters for an explicitly-passed EnterpriseCustomer. if not ec and ec_uuid: try: ec = EnterpriseCustomer.objects.get(uuid=ec_uuid) diff --git a/openedx/features/enterprise_support/tests/test_api.py b/openedx/features/enterprise_support/tests/test_api.py index fbe49f2231..8cffcb9e28 100644 --- a/openedx/features/enterprise_support/tests/test_api.py +++ b/openedx/features/enterprise_support/tests/test_api.py @@ -25,7 +25,8 @@ class TestEnterpriseApi(unittest.TestCase): @override_settings(ENABLE_ENTERPRISE_INTEGRATION=True) @mock.patch('openedx.features.enterprise_support.api.EnterpriseCustomer') - def test_enterprise_customer_for_request(self, ec_class_mock): + @mock.patch('openedx.features.enterprise_support.api.get_partial_pipeline') + def test_enterprise_customer_for_request(self, pipeline_mock, ec_class_mock): """ Test that the correct EnterpriseCustomer, if any, is returned. """ @@ -43,6 +44,8 @@ class TestEnterpriseApi(unittest.TestCase): ec_class_mock.DoesNotExist = Exception ec_class_mock.objects.get.side_effect = get_ec_mock + pipeline_mock.return_value = None + request = mock.MagicMock() request.GET.get.return_value = 'real-uuid' self.assertEqual(enterprise_customer_for_request(request), 'this-is-actually-an-enterprise-customer') @@ -58,6 +61,36 @@ class TestEnterpriseApi(unittest.TestCase): self.assertEqual(enterprise_customer_for_request(request, tpa_hint='fake-provider-id'), None) self.assertEqual(enterprise_customer_for_request(request, tpa_hint=None), None) + @override_settings(ENABLE_ENTERPRISE_INTEGRATION=True) + @mock.patch('openedx.features.enterprise_support.api.EnterpriseCustomer') + @mock.patch('openedx.features.enterprise_support.api.Registry') + @mock.patch('openedx.features.enterprise_support.api.get_partial_pipeline') + def test_get_enterprise_customer_for_request_from_pipeline(self, pipeline_mock, registry_mock, ec_class_mock): + """ + Test that the correct EnterpriseCustomer, if any, is returned when + the user is in the middle of a third-party auth pipeline. + """ + def get_ec_mock(**kwargs): + by_provider_id_kw = 'enterprise_customer_identity_provider__provider_id' + provider_id = kwargs.get(by_provider_id_kw, '') + uuid = kwargs.get('uuid', '') + if uuid == 'real-uuid' or provider_id == 'real-provider-id': + # Only return the good value if we get the parameter we expect. + return 'this-is-actually-an-enterprise-customer' + + ec_class_mock.DoesNotExist = Exception + ec_class_mock.objects.get.side_effect = get_ec_mock + + # Truthy value from the pipeline getter to imitate a running pipeline + pipeline_mock.return_value = {"fake_pipeline": "sofake"} + + provider_mock = registry_mock.get_from_pipeline.return_value + provider_mock.provider_id = 'real-provider-id' + + request = mock.MagicMock() + + self.assertEqual(enterprise_customer_for_request(request), 'this-is-actually-an-enterprise-customer') + def check_data_sharing_consent(self, consent_required=False, consent_url=None): """ Used to test the data_sharing_consent_required view decorator.