Merge pull request #15657 from open-craft/haikuginger/enterprise-from-pipeline
[ENT-491] Check for Enterprise-linked TPA pipeline when asked for Enterprise on request
This commit is contained in:
@@ -22,6 +22,8 @@ from openedx.core.djangoapps.catalog.models import CatalogIntegration
|
||||
from openedx.core.djangoapps.catalog.utils import create_catalog_api_client
|
||||
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
|
||||
from openedx.core.lib.token_utils import JwtBuilder
|
||||
from third_party_auth.pipeline import get as get_partial_pipeline
|
||||
from third_party_auth.provider import Registry
|
||||
|
||||
try:
|
||||
from enterprise import utils as enterprise_utils
|
||||
@@ -241,13 +243,22 @@ def enterprise_customer_for_request(request, tpa_hint=None):
|
||||
|
||||
ec = None
|
||||
|
||||
running_pipeline = get_partial_pipeline(request)
|
||||
if running_pipeline:
|
||||
# Determine if the user is in the middle of a third-party auth pipeline,
|
||||
# and set the tpa_hint parameter to match if so.
|
||||
tpa_hint = Registry.get_from_pipeline(running_pipeline).provider_id
|
||||
|
||||
if tpa_hint:
|
||||
# If we have a third-party auth provider, get the linked enterprise customer.
|
||||
try:
|
||||
ec = EnterpriseCustomer.objects.get(enterprise_customer_identity_provider__provider_id=tpa_hint)
|
||||
except EnterpriseCustomer.DoesNotExist:
|
||||
pass
|
||||
|
||||
ec_uuid = request.GET.get('enterprise_customer') or request.COOKIES.get(settings.ENTERPRISE_CUSTOMER_COOKIE_NAME)
|
||||
# If we haven't obtained an EnterpriseCustomer through the other methods, check the
|
||||
# session cookies and URL parameters for an explicitly-passed EnterpriseCustomer.
|
||||
if not ec and ec_uuid:
|
||||
try:
|
||||
ec = EnterpriseCustomer.objects.get(uuid=ec_uuid)
|
||||
|
||||
@@ -25,7 +25,8 @@ class TestEnterpriseApi(unittest.TestCase):
|
||||
|
||||
@override_settings(ENABLE_ENTERPRISE_INTEGRATION=True)
|
||||
@mock.patch('openedx.features.enterprise_support.api.EnterpriseCustomer')
|
||||
def test_enterprise_customer_for_request(self, ec_class_mock):
|
||||
@mock.patch('openedx.features.enterprise_support.api.get_partial_pipeline')
|
||||
def test_enterprise_customer_for_request(self, pipeline_mock, ec_class_mock):
|
||||
"""
|
||||
Test that the correct EnterpriseCustomer, if any, is returned.
|
||||
"""
|
||||
@@ -43,6 +44,8 @@ class TestEnterpriseApi(unittest.TestCase):
|
||||
ec_class_mock.DoesNotExist = Exception
|
||||
ec_class_mock.objects.get.side_effect = get_ec_mock
|
||||
|
||||
pipeline_mock.return_value = None
|
||||
|
||||
request = mock.MagicMock()
|
||||
request.GET.get.return_value = 'real-uuid'
|
||||
self.assertEqual(enterprise_customer_for_request(request), 'this-is-actually-an-enterprise-customer')
|
||||
@@ -58,6 +61,36 @@ class TestEnterpriseApi(unittest.TestCase):
|
||||
self.assertEqual(enterprise_customer_for_request(request, tpa_hint='fake-provider-id'), None)
|
||||
self.assertEqual(enterprise_customer_for_request(request, tpa_hint=None), None)
|
||||
|
||||
@override_settings(ENABLE_ENTERPRISE_INTEGRATION=True)
|
||||
@mock.patch('openedx.features.enterprise_support.api.EnterpriseCustomer')
|
||||
@mock.patch('openedx.features.enterprise_support.api.Registry')
|
||||
@mock.patch('openedx.features.enterprise_support.api.get_partial_pipeline')
|
||||
def test_get_enterprise_customer_for_request_from_pipeline(self, pipeline_mock, registry_mock, ec_class_mock):
|
||||
"""
|
||||
Test that the correct EnterpriseCustomer, if any, is returned when
|
||||
the user is in the middle of a third-party auth pipeline.
|
||||
"""
|
||||
def get_ec_mock(**kwargs):
|
||||
by_provider_id_kw = 'enterprise_customer_identity_provider__provider_id'
|
||||
provider_id = kwargs.get(by_provider_id_kw, '')
|
||||
uuid = kwargs.get('uuid', '')
|
||||
if uuid == 'real-uuid' or provider_id == 'real-provider-id':
|
||||
# Only return the good value if we get the parameter we expect.
|
||||
return 'this-is-actually-an-enterprise-customer'
|
||||
|
||||
ec_class_mock.DoesNotExist = Exception
|
||||
ec_class_mock.objects.get.side_effect = get_ec_mock
|
||||
|
||||
# Truthy value from the pipeline getter to imitate a running pipeline
|
||||
pipeline_mock.return_value = {"fake_pipeline": "sofake"}
|
||||
|
||||
provider_mock = registry_mock.get_from_pipeline.return_value
|
||||
provider_mock.provider_id = 'real-provider-id'
|
||||
|
||||
request = mock.MagicMock()
|
||||
|
||||
self.assertEqual(enterprise_customer_for_request(request), 'this-is-actually-an-enterprise-customer')
|
||||
|
||||
def check_data_sharing_consent(self, consent_required=False, consent_url=None):
|
||||
"""
|
||||
Used to test the data_sharing_consent_required view decorator.
|
||||
|
||||
Reference in New Issue
Block a user