diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index 6cb31ab42e..0061a6dea3 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -11,7 +11,6 @@ from django.contrib.sites.models import Site from django.http import Http404 from django.utils.functional import cached_property from django_countries import countries -from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomerUser, PendingEnterpriseCustomerUser from onelogin.saml2.settings import OneLogin_Saml2_Settings from six import text_type from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider @@ -134,28 +133,9 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method """ Override of SAMLAuth.disconnect to unlink the learner from enterprise customer if associated. """ - from . import pipeline, provider - running_pipeline = pipeline.get(self.strategy.request) - provider_id = provider.Registry.get_from_pipeline(running_pipeline).provider_id - try: - user_email = kwargs.get('user').email - except AttributeError: - user_email = None - - try: - enterprise_customer_idp = EnterpriseCustomerIdentityProvider.objects.get(provider_id=provider_id) - except EnterpriseCustomerIdentityProvider.DoesNotExist: - enterprise_customer_idp = None - - if enterprise_customer_idp and user_email: - try: - # Unlink user email from Enterprise Customer. - EnterpriseCustomerUser.objects.unlink_user( - enterprise_customer=enterprise_customer_idp.enterprise_customer, user_email=user_email - ) - except (EnterpriseCustomerUser.DoesNotExist, PendingEnterpriseCustomerUser.DoesNotExist): - pass - + from openedx.features.enterprise_support.api import unlink_enterprise_user_from_idp + user = kwargs.get('user', None) + unlink_enterprise_user_from_idp(self.strategy.request, user, self.name) return super(SAMLAuthBackend, self).disconnect(*args, **kwargs) def _check_entitlements(self, idp, attributes): diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 6da7319361..df4a195017 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -12,6 +12,7 @@ from unittest import skip import ddt import httpretty +from django.conf import settings from django.contrib import auth from freezegun import freeze_time from mock import MagicMock, patch @@ -158,7 +159,15 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin 'attributes': {u'urn:oid:0.9.2342.19200300.100.1.1': [u'myself']} } - def test_full_pipeline_succeeds_for_unlinking_testshib_account(self): + @patch('openedx.features.enterprise_support.api.enterprise_customer_for_request') + @patch('openedx.features.enterprise_support.api.get_enterprise_customer_for_learner') + @patch('openedx.core.djangoapps.user_api.accounts.settings_views.get_enterprise_customer_for_learner') + def test_full_pipeline_succeeds_for_unlinking_testshib_account( + self, + mock_get_enterprise_customer_for_learner_settings_view, + mock_get_enterprise_customer_for_learner, + mock_enterprise_customer_for_request, + ): # First, create, the request and strategy that store pipeline state, # configure the backend, and mock out wire traffic. @@ -185,6 +194,15 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin EnterpriseCustomerIdentityProvider.objects.get_or_create(enterprise_customer=enterprise_customer, provider_id=self.provider.provider_id) + enterprise_customer_data = { + 'uuid': enterprise_customer.uuid, + 'name': enterprise_customer.name, + 'identity_provider': 'saml-default', + } + mock_enterprise_customer_for_request.return_value = enterprise_customer_data + mock_get_enterprise_customer_for_learner.return_value = enterprise_customer_data + mock_get_enterprise_customer_for_learner_settings_view.return_value = enterprise_customer_data + # Instrument the pipeline to get to the dashboard with the full expected state. self.client.get( pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)) @@ -201,34 +219,39 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=True) self.assert_social_auth_exists_for_user(request.user, strategy) - # Fire off the disconnect pipeline without the user information. - actions.do_disconnect( - request.backend, - None, - None, - redirect_field_name=auth.REDIRECT_FIELD_NAME, - request=request - ) - self.assertFalse( - EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0 - ) - - # Fire off the disconnect pipeline to unlink. - self.assert_redirect_after_pipeline_completes( + FEATURES_WITH_ENTERPRISE_ENABLED = settings.FEATURES.copy() + FEATURES_WITH_ENTERPRISE_ENABLED['ENABLE_ENTERPRISE_INTEGRATION'] = True + with patch.dict("django.conf.settings.FEATURES", FEATURES_WITH_ENTERPRISE_ENABLED): + # Fire off the disconnect pipeline without the user information. actions.do_disconnect( request.backend, - user, + None, None, redirect_field_name=auth.REDIRECT_FIELD_NAME, request=request ) - ) - # Now we expect to be in the unlinked state, with no backend entry. - self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False) - self.assert_social_auth_does_not_exist_for_user(user, strategy) - self.assertTrue( - EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0 - ) + self.assertNotEqual( + EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(), + 0 + ) + + # Fire off the disconnect pipeline to unlink. + self.assert_redirect_after_pipeline_completes( + actions.do_disconnect( + request.backend, + user, + None, + redirect_field_name=auth.REDIRECT_FIELD_NAME, + request=request + ) + ) + # Now we expect to be in the unlinked state, with no backend entry. + self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False) + self.assert_social_auth_does_not_exist_for_user(user, strategy) + self.assertEqual( + EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(), + 0 + ) def get_response_data(self): """Gets dict (string -> object) of merged data about the user.""" diff --git a/openedx/features/enterprise_support/api.py b/openedx/features/enterprise_support/api.py index 52d827a34b..8df92d8459 100644 --- a/openedx/features/enterprise_support/api.py +++ b/openedx/features/enterprise_support/api.py @@ -25,7 +25,12 @@ from third_party_auth.pipeline import get as get_partial_pipeline from third_party_auth.provider import Registry try: - from enterprise.models import EnterpriseCustomer, EnterpriseCustomerUser + from enterprise.models import ( + EnterpriseCustomer, + EnterpriseCustomerIdentityProvider, + EnterpriseCustomerUser, + PendingEnterpriseCustomerUser + ) from consent.models import DataSharingConsent, DataSharingConsentTextOverrides except ImportError: pass @@ -664,3 +669,35 @@ def insert_enterprise_pipeline_elements(pipeline): insert_point = pipeline.index('social_core.pipeline.social_auth.load_extra_data') for index, element in enumerate(additional_elements): pipeline.insert(insert_point + index, element) + + +@enterprise_is_enabled() +def unlink_enterprise_user_from_idp(request, user, idp_backend_name): + """ + Un-links learner from their enterprise identity provider + Args: + request (wsgi request): request object + user (User): user who initiated disconnect request + idp_backend_name (str): Name of identity provider's backend + + Returns: None + + """ + enterprise_customer = enterprise_customer_for_request(request) + if user and enterprise_customer: + enabled_providers = Registry.get_enabled_by_backend_name(idp_backend_name) + provider_ids = [enabled_provider.provider_id for enabled_provider in enabled_providers] + enterprise_customer_idps = EnterpriseCustomerIdentityProvider.objects.filter( + enterprise_customer__uuid=enterprise_customer['uuid'], + provider_id__in=provider_ids + ) + + if enterprise_customer_idps: + try: + # Unlink user email from each Enterprise Customer. + for enterprise_customer_idp in enterprise_customer_idps: + EnterpriseCustomerUser.objects.unlink_user( + enterprise_customer=enterprise_customer_idp.enterprise_customer, user_email=user.email + ) + except (EnterpriseCustomerUser.DoesNotExist, PendingEnterpriseCustomerUser.DoesNotExist): + pass