From 0dbdac3262d0d6b730e05b57ec959c0043749d33 Mon Sep 17 00:00:00 2001 From: asadiqbal Date: Wed, 8 Aug 2018 12:52:17 +0500 Subject: [PATCH] Unlink learner from Enterprise Customer when learner unlinks from IDP --- common/djangoapps/third_party_auth/saml.py | 33 + .../third_party_auth/tests/specs/base.py | 712 +++++++++--------- .../tests/specs/test_testshib.py | 146 +++- 3 files changed, 516 insertions(+), 375 deletions(-) diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index 35a786e6c5..6d44958abd 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -13,6 +13,11 @@ from onelogin.saml2.settings import OneLogin_Saml2_Settings from six import text_type from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider from social_core.exceptions import AuthForbidden +from enterprise.models import ( + EnterpriseCustomerUser, + EnterpriseCustomerIdentityProvider, + PendingEnterpriseCustomerUser +) from openedx.core.djangoapps.theming.helpers import get_current_request @@ -110,6 +115,34 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method return super(SAMLAuthBackend, self).auth_url() + def disconnect(self, *args, **kwargs): + """ + Override of SAMLAuth.disconnect to unlink the learner from enterprise customer if associated. + """ + from . import pipeline, provider + running_pipeline = pipeline.get(self.strategy.request) + provider_id = provider.Registry.get_from_pipeline(running_pipeline).provider_id + try: + user_email = kwargs.get('user').email + except AttributeError: + user_email = None + + try: + enterprise_customer_idp = EnterpriseCustomerIdentityProvider.objects.get(provider_id=provider_id) + except EnterpriseCustomerIdentityProvider.DoesNotExist: + enterprise_customer_idp = None + + if enterprise_customer_idp and user_email: + try: + # Unlink user email from Enterprise Customer. + EnterpriseCustomerUser.objects.unlink_user( + enterprise_customer=enterprise_customer_idp.enterprise_customer, user_email=user_email + ) + except (EnterpriseCustomerUser.DoesNotExist, PendingEnterpriseCustomerUser.DoesNotExist): + pass + + return super(SAMLAuthBackend, self).disconnect(*args, **kwargs) + def _check_entitlements(self, idp, attributes): """ Check if we require the presence of any specific eduPersonEntitlement. diff --git a/common/djangoapps/third_party_auth/tests/specs/base.py b/common/djangoapps/third_party_auth/tests/specs/base.py index 2933fa5e54..f9dbc69cb4 100644 --- a/common/djangoapps/third_party_auth/tests/specs/base.py +++ b/common/djangoapps/third_party_auth/tests/specs/base.py @@ -30,7 +30,291 @@ from third_party_auth import middleware, pipeline from third_party_auth.tests import testutil -class IntegrationTestMixin(object): +class HelperMixin(object): + """ + Contains helper methods for IntegrationTestMixin and IntegrationTest classes below. + """ + + provider = None + + def assert_redirect_to_provider_looks_correct(self, response): + """Asserts the redirect to the provider's site looks correct. + + When we hit /auth/login/, we should be redirected to the + provider's site. Here we check that we're redirected, but we don't know + enough about the provider to check what we're redirected to. Child test + implementations may optionally strengthen this assertion with, for + example, more details about the format of the Location header. + """ + self.assertEqual(302, response.status_code) + self.assertTrue(response.has_header('Location')) + + def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields): + """Performs spot checks of the rendered register.html page. + + When we display the new account registration form after the user signs + in with a third party, we prepopulate the form with values sent back + from the provider. The exact set of values varies on a provider-by- + provider basis and is generated by + provider.BaseProvider.get_register_form_data. We provide some stock + assertions based on the provider's implementation; if you want more + assertions in your test, override this method. + """ + self.assertEqual(200, response.status_code) + # Check that the correct provider was selected. + self.assertIn('successfully signed in with %s' % self.provider.name, response.content) + # Expect that each truthy value we've prepopulated the register form + # with is actually present. + form_field_data = self.provider.get_register_form_data(pipeline_kwargs) + for prepopulated_form_data in form_field_data: + if prepopulated_form_data in required_fields: + self.assertIn(form_field_data[prepopulated_form_data], response.content.decode('utf-8')) + + # pylint: disable=invalid-name + def assert_account_settings_context_looks_correct(self, context, duplicate=False, linked=None): + """Asserts the user's account settings page context is in the expected state. + + If duplicate is True, we expect context['duplicate_provider'] to contain + the duplicate provider backend name. If linked is passed, we conditionally + check that the provider is included in context['auth']['providers'] and + its connected state is correct. + """ + if duplicate: + self.assertEqual(context['duplicate_provider'], self.provider.backend_name) + else: + self.assertIsNone(context['duplicate_provider']) + + if linked is not None: + expected_provider = [ + provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name + ][0] + self.assertIsNotNone(expected_provider) + self.assertEqual(expected_provider['connected'], linked) + + def assert_exception_redirect_looks_correct(self, expected_uri, auth_entry=None): + """Tests middleware conditional redirection. + + middleware.ExceptionMiddleware makes sure the user ends up in the right + place when they cancel authentication via the provider's UX. + """ + exception_middleware = middleware.ExceptionMiddleware() + request, _ = self.get_request_and_strategy(auth_entry=auth_entry) + response = exception_middleware.process_exception( + request, exceptions.AuthCanceled(request.backend)) + location = response.get('Location') + + self.assertEqual(302, response.status_code) + self.assertIn('canceled', location) + self.assertIn(self.backend_name, location) + self.assertTrue(location.startswith(expected_uri + '?')) + + def assert_json_failure_response_is_inactive_account(self, response): + """Asserts failure on /login for inactive account looks right.""" + self.assertEqual(200, response.status_code) # Yes, it's a 200 even though it's a failure. + payload = json.loads(response.content) + self.assertFalse(payload.get('success')) + self.assertIn('In order to sign in, you need to activate your account.', payload.get('value')) + + def assert_json_failure_response_is_missing_social_auth(self, response): + """Asserts failure on /login for missing social auth looks right.""" + self.assertEqual(403, response.status_code) + self.assertIn( + "successfully logged into your %s account, but this account isn't linked" % self.provider.name, + response.content + ) + + def assert_json_failure_response_is_username_collision(self, response): + """Asserts the json response indicates a username collision.""" + self.assertEqual(400, response.status_code) + payload = json.loads(response.content) + self.assertFalse(payload.get('success')) + self.assertIn('belongs to an existing account', payload.get('value')) + + def assert_json_success_response_looks_correct(self, response): + """Asserts the json response indicates success and redirection.""" + self.assertEqual(200, response.status_code) + payload = json.loads(response.content) + self.assertTrue(payload.get('success')) + self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url')) + + def assert_login_response_before_pipeline_looks_correct(self, response): + """Asserts a GET of /login not in the pipeline looks correct.""" + self.assertEqual(200, response.status_code) + # The combined login/registration page dynamically generates the login button, + # but we can still check that the provider name is passed in the data attribute + # for the container element. + self.assertIn(self.provider.name, response.content) + + def assert_login_response_in_pipeline_looks_correct(self, response): + """Asserts a GET of /login in the pipeline looks correct.""" + self.assertEqual(200, response.status_code) + + def assert_password_overridden_by_pipeline(self, username, password): + """Verifies that the given password is not correct. + + The pipeline overrides POST['password'], if any, with random data. + """ + self.assertIsNone(auth.authenticate(password=password, username=username)) + + def assert_pipeline_running(self, request): + """Makes sure the given request is running an auth pipeline.""" + self.assertTrue(pipeline.running(request)) + + def assert_redirect_to_dashboard_looks_correct(self, response): + """Asserts a response would redirect to /dashboard.""" + self.assertEqual(302, response.status_code) + # NOTE: Ideally we should use assertRedirects(), however it errors out due to the hostname, testserver, + # not being properly set. This may be an issue with the call made by PSA, but we are not certain. + self.assertTrue(response.get('Location').endswith(django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL)) + + def assert_redirect_to_login_looks_correct(self, response): + """Asserts a response would redirect to /login.""" + self.assertEqual(302, response.status_code) + self.assertEqual('/login', response.get('Location')) + + def assert_redirect_to_register_looks_correct(self, response): + """Asserts a response would redirect to /register.""" + self.assertEqual(302, response.status_code) + self.assertEqual('/register', response.get('Location')) + + def assert_register_response_before_pipeline_looks_correct(self, response): + """Asserts a GET of /register not in the pipeline looks correct.""" + self.assertEqual(200, response.status_code) + # The combined login/registration page dynamically generates the register button, + # but we can still check that the provider name is passed in the data attribute + # for the container element. + self.assertIn(self.provider.name, response.content) + + def assert_social_auth_does_not_exist_for_user(self, user, strategy): + """Asserts a user does not have an auth with the expected provider.""" + social_auths = strategy.storage.user.get_social_auth_for_user( + user, provider=self.provider.backend_name) + self.assertEqual(0, len(social_auths)) + + def assert_social_auth_exists_for_user(self, user, strategy): + """Asserts a user has a social auth with the expected provider.""" + social_auths = strategy.storage.user.get_social_auth_for_user( + user, provider=self.provider.backend_name) + self.assertEqual(1, len(social_auths)) + self.assertEqual(self.backend_name, social_auths[0].provider) + + def assert_logged_in_cookie_redirect(self, response): + """Verify that the user was redirected in order to set the logged in cookie. """ + self.assertEqual(response.status_code, 302) + self.assertEqual( + response["Location"], + pipeline.get_complete_url(self.provider.backend_name) + ) + self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true') + self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies) + + @property + def backend_name(self): + """ Shortcut for the backend name """ + return self.provider.backend_name + + def get_registration_post_vars(self, overrides=None): + """POST vars generated by the registration form.""" + defaults = { + 'username': 'username', + 'name': 'First Last', + 'gender': '', + 'year_of_birth': '', + 'level_of_education': '', + 'goals': '', + 'honor_code': 'true', + 'terms_of_service': 'true', + 'password': 'password', + 'mailing_address': '', + 'email': 'user@email.com', + } + + if overrides: + defaults.update(overrides) + + return defaults + + def get_request_and_strategy(self, auth_entry=None, redirect_uri=None): + """Gets a fully-configured request and strategy. + + These two objects contain circular references, so we create them + together. The references themselves are a mixture of normal __init__ + stuff and monkey-patching done by python-social-auth. See, for example, + social_django.utils.strategy(). + """ + request = self.request_factory.get( + pipeline.get_complete_url(self.backend_name) + + '?redirect_state=redirect_state_value&code=code_value&state=state_value') + request.site = SiteFactory.create() + request.user = auth_models.AnonymousUser() + request.session = cache.SessionStore() + request.session[self.backend_name + '_state'] = 'state_value' + + if auth_entry: + request.session[pipeline.AUTH_ENTRY_KEY] = auth_entry + + strategy = social_utils.load_strategy(request=request) + request.social_strategy = strategy + request.backend = social_utils.load_backend(strategy, self.backend_name, redirect_uri) + + return request, strategy + + @contextmanager + def _patch_edxmako_current_request(self, request): + """Make ``request`` be the current request for edxmako template rendering.""" + + with mock.patch('edxmako.request_context.get_current_request', return_value=request): + yield + + def get_user_by_email(self, strategy, email): + """Gets a user by email, using the given strategy.""" + return strategy.storage.user.user_model().objects.get(email=email) + + def set_logged_in_cookies(self, request): + """Simulate setting the marketing site cookie on the request. """ + request.COOKIES[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME] = 'true' + request.COOKIES[django_settings.EDXMKTG_USER_INFO_COOKIE_NAME] = json.dumps({ + 'version': django_settings.EDXMKTG_USER_INFO_COOKIE_VERSION, + }) + + def create_user_models_for_existing_account(self, strategy, email, password, username, skip_social_auth=False): + """Creates user, profile, registration, and (usually) social auth. + + This synthesizes what happens during /register. + See student.views.register and student.helpers.do_create_account. + """ + response_data = self.get_response_data() + uid = strategy.request.backend.get_user_id(response_data, response_data) + user = social_utils.Storage.user.create_user(email=email, password=password, username=username) + profile = student_models.UserProfile(user=user) + profile.save() + registration = student_models.Registration() + registration.register(user) + registration.save() + + if not skip_social_auth: + social_utils.Storage.user.create_social_auth(user, uid, self.provider.backend_name) + + return user + + def fake_auth_complete(self, strategy): + """Fake implementation of social_core.backends.BaseAuth.auth_complete. + + Unlike what the docs say, it does not need to return a user instance. + Sometimes (like when directing users to the /register form) it instead + returns a response that 302s to /register. + """ + args = () + kwargs = { + 'request': strategy.request, + 'backend': strategy.request.backend, + 'user': None, + 'response': self.get_response_data(), + } + return strategy.authenticate(*args, **kwargs) + + +class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): """ Mixin base class for third_party_auth integration tests. This class is newer and simpler than the 'IntegrationTest' alternative below, but it is @@ -48,6 +332,8 @@ class IntegrationTestMixin(object): def setUp(self): super(IntegrationTestMixin, self).setUp() + + self.request_factory = test.RequestFactory() self.login_page_url = reverse('signin_user') self.register_page_url = reverse('register_user') patcher = testutil.patch_mako_templates() @@ -55,8 +341,10 @@ class IntegrationTestMixin(object): self.addCleanup(patcher.stop) # Override this method in a subclass and enable at least one provider. - def test_register(self, **extra_defaults): - # The user goes to the register page, and sees a button to register with the provider: + def _test_register(self, **extra_defaults): + """ + The user goes to the register page, and sees a button to register with the provider. + """ provider_register_url = self._check_register_page() # The user clicks on the Dummy button: try_login_response = self.client.get(provider_register_url) @@ -105,9 +393,11 @@ class IntegrationTestMixin(object): self.verify_user_email('email-edited@tpa-test.none') self._test_return_login(user_is_activated=True) - def test_login(self): + def _test_login(self): + """ + The user goes to the login page, and sees a button to login with the provider. + """ self.user = UserFactory.create() - # The user goes to the login page, and sees a button to login with this provider: provider_login_url = self._check_login_page() # The user clicks on the provider's button: try_login_response = self.client.get(provider_login_url) @@ -208,361 +498,13 @@ class IntegrationTestMixin(object): @unittest.skipUnless( testutil.AUTH_FEATURES_KEY in django_settings.FEATURES, testutil.AUTH_FEATURES_KEY + ' not in settings.FEATURES') @django_utils.override_settings() # For settings reversion on a method-by-method basis. -class IntegrationTest(testutil.TestCase, test.TestCase): +class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): """Abstract base class for provider integration tests.""" - # Override setUp and set this: - provider = None - - # Methods you must override in your children. - - def get_response_data(self): - """Gets a dict of response data of the form given by the provider. - - To determine what the provider returns, drop into a debugger in your - provider's do_auth implementation. Providers may merge different kinds - of data (for example, data about the user and data about the user's - credentials). - """ - raise NotImplementedError - - def get_username(self): - """Gets username based on response data from a provider. - - Each provider has different logic for username generation. Sadly, - this is not extracted into its own method in python-social-auth, so we - must provide a getter ourselves. - - Note that this is the *initial* value the framework will attempt to use. - If it collides, the pipeline will generate a new username. We extract - it here so we can force collisions in a polymorphic way. - """ - raise NotImplementedError - - # Asserts you can optionally override and make more specific. - - def assert_redirect_to_provider_looks_correct(self, response): - """Asserts the redirect to the provider's site looks correct. - - When we hit /auth/login/, we should be redirected to the - provider's site. Here we check that we're redirected, but we don't know - enough about the provider to check what we're redirected to. Child test - implementations may optionally strengthen this assertion with, for - example, more details about the format of the Location header. - """ - self.assertEqual(302, response.status_code) - self.assertTrue(response.has_header('Location')) - - def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields): - """Performs spot checks of the rendered register.html page. - - When we display the new account registration form after the user signs - in with a third party, we prepopulate the form with values sent back - from the provider. The exact set of values varies on a provider-by- - provider basis and is generated by - provider.BaseProvider.get_register_form_data. We provide some stock - assertions based on the provider's implementation; if you want more - assertions in your test, override this method. - """ - self.assertEqual(200, response.status_code) - # Check that the correct provider was selected. - self.assertIn('successfully signed in with %s' % self.provider.name, response.content) - # Expect that each truthy value we've prepopulated the register form - # with is actually present. - form_field_data = self.provider.get_register_form_data(pipeline_kwargs) - for prepopulated_form_data in form_field_data: - if prepopulated_form_data in required_fields: - self.assertIn(form_field_data[prepopulated_form_data], response.content.decode('utf-8')) - - # Implementation details and actual tests past this point -- no more - # configuration needed. - def setUp(self): super(IntegrationTest, self).setUp() self.request_factory = test.RequestFactory() - @property - def backend_name(self): - """ Shortcut for the backend name """ - return self.provider.backend_name - - # pylint: disable=invalid-name - def assert_account_settings_context_looks_correct(self, context, duplicate=False, linked=None): - """Asserts the user's account settings page context is in the expected state. - - If duplicate is True, we expect context['duplicate_provider'] to contain - the duplicate provider backend name. If linked is passed, we conditionally - check that the provider is included in context['auth']['providers'] and - its connected state is correct. - """ - if duplicate: - self.assertEqual(context['duplicate_provider'], self.provider.backend_name) - else: - self.assertIsNone(context['duplicate_provider']) - - if linked is not None: - expected_provider = [ - provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name - ][0] - self.assertIsNotNone(expected_provider) - self.assertEqual(expected_provider['connected'], linked) - - def assert_exception_redirect_looks_correct(self, expected_uri, auth_entry=None): - """Tests middleware conditional redirection. - - middleware.ExceptionMiddleware makes sure the user ends up in the right - place when they cancel authentication via the provider's UX. - """ - exception_middleware = middleware.ExceptionMiddleware() - request, _ = self.get_request_and_strategy(auth_entry=auth_entry) - response = exception_middleware.process_exception( - request, exceptions.AuthCanceled(request.backend)) - location = response.get('Location') - - self.assertEqual(302, response.status_code) - self.assertIn('canceled', location) - self.assertIn(self.backend_name, location) - self.assertTrue(location.startswith(expected_uri + '?')) - - def assert_first_party_auth_trumps_third_party_auth(self, email=None, password=None, success=None): - """Asserts first party auth was used in place of third party auth. - - Args: - email: string. The user's email. If not None, will be set on POST. - password: string. The user's password. If not None, will be set on - POST. - success: None or bool. Whether we expect auth to be successful. Set - to None to indicate we expect the request to be invalid (meaning - one of username or password will be missing). - """ - _, strategy = self.get_request_and_strategy( - auth_entry=pipeline.AUTH_ENTRY_LOGIN, redirect_uri='social:complete') - strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy)) - self.create_user_models_for_existing_account( - strategy, email, password, self.get_username(), skip_social_auth=True) - - strategy.request.POST = dict(strategy.request.POST) - - if email: - strategy.request.POST['email'] = email - if password: - strategy.request.POST['password'] = 'bad_' + password if success is False else password - - self.assert_pipeline_running(strategy.request) - payload = json.loads(student_views.login_user(strategy.request).content) - - if success is None: - # Request malformed -- just one of email/password given. - self.assertFalse(payload.get('success')) - self.assertIn('There was an error receiving your login information', payload.get('value')) - elif success: - # Request well-formed and credentials good. - self.assertTrue(payload.get('success')) - else: - # Request well-formed but credentials bad. - self.assertFalse(payload.get('success')) - self.assertIn('incorrect', payload.get('value')) - - def assert_json_failure_response_is_inactive_account(self, response): - """Asserts failure on /login for inactive account looks right.""" - self.assertEqual(200, response.status_code) # Yes, it's a 200 even though it's a failure. - payload = json.loads(response.content) - self.assertFalse(payload.get('success')) - self.assertIn('In order to sign in, you need to activate your account.', payload.get('value')) - - def assert_json_failure_response_is_missing_social_auth(self, response): - """Asserts failure on /login for missing social auth looks right.""" - self.assertEqual(403, response.status_code) - self.assertIn( - "successfully logged into your %s account, but this account isn't linked" % self.provider.name, - response.content - ) - - def assert_json_failure_response_is_username_collision(self, response): - """Asserts the json response indicates a username collision.""" - self.assertEqual(400, response.status_code) - payload = json.loads(response.content) - self.assertFalse(payload.get('success')) - self.assertIn('belongs to an existing account', payload.get('value')) - - def assert_json_success_response_looks_correct(self, response): - """Asserts the json response indicates success and redirection.""" - self.assertEqual(200, response.status_code) - payload = json.loads(response.content) - self.assertTrue(payload.get('success')) - self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url')) - - def assert_login_response_before_pipeline_looks_correct(self, response): - """Asserts a GET of /login not in the pipeline looks correct.""" - self.assertEqual(200, response.status_code) - # The combined login/registration page dynamically generates the login button, - # but we can still check that the provider name is passed in the data attribute - # for the container element. - self.assertIn(self.provider.name, response.content) - - def assert_login_response_in_pipeline_looks_correct(self, response): - """Asserts a GET of /login in the pipeline looks correct.""" - self.assertEqual(200, response.status_code) - - def assert_password_overridden_by_pipeline(self, username, password): - """Verifies that the given password is not correct. - - The pipeline overrides POST['password'], if any, with random data. - """ - self.assertIsNone(auth.authenticate(password=password, username=username)) - - def assert_pipeline_running(self, request): - """Makes sure the given request is running an auth pipeline.""" - self.assertTrue(pipeline.running(request)) - - def assert_redirect_to_dashboard_looks_correct(self, response): - """Asserts a response would redirect to /dashboard.""" - self.assertEqual(302, response.status_code) - # NOTE: Ideally we should use assertRedirects(), however it errors out due to the hostname, testserver, - # not being properly set. This may be an issue with the call made by PSA, but we are not certain. - self.assertTrue(response.get('Location').endswith(django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL)) - - def assert_redirect_to_login_looks_correct(self, response): - """Asserts a response would redirect to /login.""" - self.assertEqual(302, response.status_code) - self.assertEqual('/login', response.get('Location')) - - def assert_redirect_to_register_looks_correct(self, response): - """Asserts a response would redirect to /register.""" - self.assertEqual(302, response.status_code) - self.assertEqual('/register', response.get('Location')) - - def assert_register_response_before_pipeline_looks_correct(self, response): - """Asserts a GET of /register not in the pipeline looks correct.""" - self.assertEqual(200, response.status_code) - # The combined login/registration page dynamically generates the register button, - # but we can still check that the provider name is passed in the data attribute - # for the container element. - self.assertIn(self.provider.name, response.content) - - def assert_social_auth_does_not_exist_for_user(self, user, strategy): - """Asserts a user does not have an auth with the expected provider.""" - social_auths = strategy.storage.user.get_social_auth_for_user( - user, provider=self.provider.backend_name) - self.assertEqual(0, len(social_auths)) - - def assert_social_auth_exists_for_user(self, user, strategy): - """Asserts a user has a social auth with the expected provider.""" - social_auths = strategy.storage.user.get_social_auth_for_user( - user, provider=self.provider.backend_name) - self.assertEqual(1, len(social_auths)) - self.assertEqual(self.backend_name, social_auths[0].provider) - - def create_user_models_for_existing_account(self, strategy, email, password, username, skip_social_auth=False): - """Creates user, profile, registration, and (usually) social auth. - - This synthesizes what happens during /register. - See student.views.register and student.helpers.do_create_account. - """ - response_data = self.get_response_data() - uid = strategy.request.backend.get_user_id(response_data, response_data) - user = social_utils.Storage.user.create_user(email=email, password=password, username=username) - profile = student_models.UserProfile(user=user) - profile.save() - registration = student_models.Registration() - registration.register(user) - registration.save() - - if not skip_social_auth: - social_utils.Storage.user.create_social_auth(user, uid, self.provider.backend_name) - - return user - - def fake_auth_complete(self, strategy): - """Fake implementation of social_core.backends.BaseAuth.auth_complete. - - Unlike what the docs say, it does not need to return a user instance. - Sometimes (like when directing users to the /register form) it instead - returns a response that 302s to /register. - """ - args = () - kwargs = { - 'request': strategy.request, - 'backend': strategy.request.backend, - 'user': None, - 'response': self.get_response_data(), - } - return strategy.authenticate(*args, **kwargs) - - def get_registration_post_vars(self, overrides=None): - """POST vars generated by the registration form.""" - defaults = { - 'username': 'username', - 'name': 'First Last', - 'gender': '', - 'year_of_birth': '', - 'level_of_education': '', - 'goals': '', - 'honor_code': 'true', - 'terms_of_service': 'true', - 'password': 'password', - 'mailing_address': '', - 'email': 'user@email.com', - } - - if overrides: - defaults.update(overrides) - - return defaults - - def get_request_and_strategy(self, auth_entry=None, redirect_uri=None): - """Gets a fully-configured request and strategy. - - These two objects contain circular references, so we create them - together. The references themselves are a mixture of normal __init__ - stuff and monkey-patching done by python-social-auth. See, for example, - social_django.utils.strategy(). - """ - request = self.request_factory.get( - pipeline.get_complete_url(self.backend_name) + - '?redirect_state=redirect_state_value&code=code_value&state=state_value') - request.site = SiteFactory.create() - request.user = auth_models.AnonymousUser() - request.session = cache.SessionStore() - request.session[self.backend_name + '_state'] = 'state_value' - - if auth_entry: - request.session[pipeline.AUTH_ENTRY_KEY] = auth_entry - - strategy = social_utils.load_strategy(request=request) - request.social_strategy = strategy - request.backend = social_utils.load_backend(strategy, self.backend_name, redirect_uri) - - return request, strategy - - @contextmanager - def _patch_edxmako_current_request(self, request): - """Make ``request`` be the current request for edxmako template rendering.""" - - with mock.patch('edxmako.request_context.get_current_request', return_value=request): - yield - - def get_user_by_email(self, strategy, email): - """Gets a user by email, using the given strategy.""" - return strategy.storage.user.user_model().objects.get(email=email) - - def assert_logged_in_cookie_redirect(self, response): - """Verify that the user was redirected in order to set the logged in cookie. """ - self.assertEqual(response.status_code, 302) - self.assertEqual( - response["Location"], - pipeline.get_complete_url(self.provider.backend_name) - ) - self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true') - self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies) - - def set_logged_in_cookies(self, request): - """Simulate setting the marketing site cookie on the request. """ - request.COOKIES[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME] = 'true' - request.COOKIES[django_settings.EDXMKTG_USER_INFO_COOKIE_NAME] = json.dumps({ - 'version': django_settings.EDXMKTG_USER_INFO_COOKIE_VERSION, - }) - # Actual tests, executed once per child. def test_canceling_authentication_redirects_to_login_when_auth_entry_login(self): @@ -947,6 +889,68 @@ class IntegrationTest(testutil.TestCase, test.TestCase): response = self.fake_auth_complete(strategy) self.assertEqual(response.url, reverse('signin_user')) + def assert_first_party_auth_trumps_third_party_auth(self, email=None, password=None, success=None): + """Asserts first party auth was used in place of third party auth. + + Args: + email: string. The user's email. If not None, will be set on POST. + password: string. The user's password. If not None, will be set on + POST. + success: None or bool. Whether we expect auth to be successful. Set + to None to indicate we expect the request to be invalid (meaning + one of username or password will be missing). + """ + _, strategy = self.get_request_and_strategy( + auth_entry=pipeline.AUTH_ENTRY_LOGIN, redirect_uri='social:complete') + strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy)) + self.create_user_models_for_existing_account( + strategy, email, password, self.get_username(), skip_social_auth=True) + + strategy.request.POST = dict(strategy.request.POST) + + if email: + strategy.request.POST['email'] = email + if password: + strategy.request.POST['password'] = 'bad_' + password if success is False else password + + self.assert_pipeline_running(strategy.request) + payload = json.loads(student_views.login_user(strategy.request).content) + + if success is None: + # Request malformed -- just one of email/password given. + self.assertFalse(payload.get('success')) + self.assertIn('There was an error receiving your login information', payload.get('value')) + elif success: + # Request well-formed and credentials good. + self.assertTrue(payload.get('success')) + else: + # Request well-formed but credentials bad. + self.assertFalse(payload.get('success')) + self.assertIn('incorrect', payload.get('value')) + + def get_response_data(self): + """Gets a dict of response data of the form given by the provider. + + To determine what the provider returns, drop into a debugger in your + provider's do_auth implementation. Providers may merge different kinds + of data (for example, data about the user and data about the user's + credentials). + """ + raise NotImplementedError + + def get_username(self): + """Gets username based on response data from a provider. + + Each provider has different logic for username generation. Sadly, + this is not extracted into its own method in python-social-auth, so we + must provide a getter ourselves. + + Note that this is the *initial* value the framework will attempt to use. + If it collides, the pipeline will generate a new username. We extract + it here so we can force collisions in a polymorphic way. + """ + raise NotImplementedError + # pylint: disable=abstract-method @django_utils.override_settings(ECOMMERCE_API_URL=TEST_API_URL) diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 8691a98618..d21c845bf5 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -2,24 +2,32 @@ Third_party_auth integration tests using a mock version of the TestShib provider """ import datetime -import ddt -import unittest -import httpretty import json import logging -from mock import patch -from freezegun import freeze_time -from social_django.models import UserSocialAuth -from testfixtures import LogCapture +import unittest from unittest import skip -from third_party_auth.saml import log as saml_log, SapSuccessFactorsIdentityProvider +import ddt +import httpretty +from django.contrib import auth +from freezegun import freeze_time +from mock import MagicMock, patch +from social_core import actions +from social_django import views as social_views +from social_django.models import UserSocialAuth +from testfixtures import LogCapture + +from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomerUser +from openedx.features.enterprise_support.tests.factories import EnterpriseCustomerFactory +from student import views as student_views +from student_account.views import account_settings_context +from third_party_auth import pipeline +from third_party_auth.saml import SapSuccessFactorsIdentityProvider, log as saml_log from third_party_auth.tasks import fetch_saml_metadata from third_party_auth.tests import testutil from .base import IntegrationTestMixin - TESTSHIB_ENTITY_ID = 'https://idp.testshib.org/idp/shibboleth' TESTSHIB_METADATA_URL = 'https://mock.testshib.org/metadata/testshib-providers.xml' TESTSHIB_METADATA_URL_WITH_CACHE_DURATION = 'https://mock.testshib.org/metadata/testshib-providers-cache.xml' @@ -90,13 +98,14 @@ class SamlIntegrationTestUtilities(object): kwargs.setdefault('name', self.PROVIDER_NAME) kwargs.setdefault('enabled', True) kwargs.setdefault('visible', True) + kwargs.setdefault("backend_name", "tpa-saml") kwargs.setdefault('slug', self.PROVIDER_IDP_SLUG) kwargs.setdefault('entity_id', TESTSHIB_ENTITY_ID) kwargs.setdefault('metadata_source', TESTSHIB_METADATA_URL) kwargs.setdefault('icon_class', 'fa-university') kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName kwargs.setdefault('max_session_length', None) - self.configure_saml_provider(**kwargs) + saml_provider = self.configure_saml_provider(**kwargs) # pylint: disable=no-member if fetch_metadata: self.assertTrue(httpretty.is_enabled()) @@ -108,6 +117,7 @@ class SamlIntegrationTestUtilities(object): self.assertEqual(num_updated, 1) self.assertEqual(num_failed, 0) self.assertEqual(len(failure_messages), 0) + return saml_provider def do_provider_login(self, provider_redirect_url): """ Mocked: the user logs in to TestShib and then gets redirected back """ @@ -127,6 +137,100 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin TestShib provider Integration Test, to test SAML functionality """ + TOKEN_RESPONSE_DATA = { + 'access_token': 'access_token_value', + 'expires_in': 'expires_in_value', + } + USER_RESPONSE_DATA = { + 'lastName': 'lastName_value', + 'id': 'id_value', + 'firstName': 'firstName_value', + 'idp_name': 'testshib', + 'attributes': {u'urn:oid:0.9.2342.19200300.100.1.1': [u'myself']} + } + + def test_full_pipeline_succeeds_for_unlinking_testshib_account(self): + + # First, create, the request and strategy that store pipeline state, + # configure the backend, and mock out wire traffic. + self.provider = self._configure_testshib_provider() + request, strategy = self.get_request_and_strategy( + auth_entry=pipeline.AUTH_ENTRY_LOGIN, redirect_uri='social:complete') + request.backend.auth_complete = MagicMock(return_value=self.fake_auth_complete(strategy)) + user = self.create_user_models_for_existing_account( + strategy, 'user@example.com', 'password', self.get_username()) + self.assert_social_auth_exists_for_user(user, strategy) + + request.user = user + + # We're already logged in, so simulate that the cookie is set correctly + self.set_logged_in_cookies(request) + + # linking a learner with enterprise customer. + enterprise_customer = EnterpriseCustomerFactory() + assert EnterpriseCustomerUser.objects.count() == 0, "Precondition check: no link records should exist" + EnterpriseCustomerUser.objects.link_user(enterprise_customer, user.email) + self.assertTrue( + EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1 + ) + EnterpriseCustomerIdentityProvider.objects.get_or_create(enterprise_customer=enterprise_customer, + provider_id=self.provider.provider_id) + + # Instrument the pipeline to get to the dashboard with the full expected state. + self.client.get( + pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)) + actions.do_complete(request.backend, social_views._do_login, # pylint: disable=protected-access + request=request) + + with self._patch_edxmako_current_request(strategy.request): + student_views.signin_user(strategy.request) + student_views.login_user(strategy.request) + actions.do_complete(request.backend, social_views._do_login, user=user, # pylint: disable=protected-access + request=request) + + # First we expect that we're in the linked state, with a backend entry. + self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=True) + self.assert_social_auth_exists_for_user(request.user, strategy) + + # Fire off the disconnect pipeline without the user information. + actions.do_disconnect( + request.backend, + None, + None, + redirect_field_name=auth.REDIRECT_FIELD_NAME, + request=request + ) + self.assertFalse( + EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0 + ) + + # Fire off the disconnect pipeline to unlink. + self.assert_redirect_to_dashboard_looks_correct( + actions.do_disconnect( + request.backend, + user, + None, + redirect_field_name=auth.REDIRECT_FIELD_NAME, + request=request + ) + ) + # Now we expect to be in the unlinked state, with no backend entry. + self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False) + self.assert_social_auth_does_not_exist_for_user(user, strategy) + self.assertTrue( + EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0 + ) + + def get_response_data(self): + """Gets dict (string -> object) of merged data about the user.""" + response_data = dict(self.TOKEN_RESPONSE_DATA) + response_data.update(self.USER_RESPONSE_DATA) + return response_data + + def get_username(self): + response_data = self.get_response_data() + return response_data.get('idp_name') + def test_login_before_metadata_fetched(self): self._configure_testshib_provider(fetch_metadata=False) # The user goes to the login page, and sees a button to login with TestShib: @@ -144,12 +248,12 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin def test_login(self): """ Configure TestShib before running the login test """ self._configure_testshib_provider() - super(TestShibIntegrationTest, self).test_login() + self._test_login() def test_register(self): """ Configure TestShib before running the register test """ self._configure_testshib_provider() - super(TestShibIntegrationTest, self).test_register() + self._test_register() def test_login_records_attributes(self): """ @@ -172,7 +276,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin """ Test SAML login logs with debug mode enabled or not """ self._configure_testshib_provider(debug_mode=debug_mode_enabled) with patch.object(saml_log, 'info') as mock_log: - super(TestShibIntegrationTest, self).test_login() + self._test_login() if debug_mode_enabled: # We expect that test_login() does two full logins, and each attempt generates two # logs - one for the request and one for the response @@ -225,7 +329,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin now = datetime.datetime.utcnow() with freeze_time(now): # Test the login flow, adding the user in the process - super(TestShibIntegrationTest, self).test_login() + self._test_login() # Wait 30 seconds; longer than the manually-set 10-second timeout later = now + datetime.timedelta(seconds=30) @@ -355,7 +459,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes self.USER_EMAIL = "myself@testshib.org" self.USER_NAME = "Me Myself And I" self.USER_USERNAME = "myself" - super(SuccessFactorsIntegrationTest, self).test_register() + self._test_register() @patch.dict('django.conf.settings.REGISTRATION_EXTRA_FIELDS', country='optional') def test_register_sapsf_metadata_present(self): @@ -381,7 +485,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes metadata_source=TESTSHIB_METADATA_URL, other_settings=json.dumps(provider_settings) ) - super(SuccessFactorsIntegrationTest, self).test_register(country=expected_country) + self._test_register(country=expected_country) @patch.dict('django.conf.settings.REGISTRATION_EXTRA_FIELDS', country='optional') def test_register_sapsf_metadata_present_override_relevant_value(self): @@ -410,7 +514,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes metadata_source=TESTSHIB_METADATA_URL, other_settings=json.dumps(provider_settings) ) - super(SuccessFactorsIntegrationTest, self).test_register(country=expected_country) + self._test_register(country=expected_country) @patch.dict('django.conf.settings.REGISTRATION_EXTRA_FIELDS', country='optional') def test_register_sapsf_metadata_present_override_other_value(self): @@ -439,7 +543,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes metadata_source=TESTSHIB_METADATA_URL, other_settings=json.dumps(provider_settings) ) - super(SuccessFactorsIntegrationTest, self).test_register(country=expected_country) + self._test_register(country=expected_country) @patch.dict('django.conf.settings.REGISTRATION_EXTRA_FIELDS', country='optional') def test_register_sapsf_metadata_present_empty_value_override(self): @@ -469,7 +573,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes metadata_source=TESTSHIB_METADATA_URL, other_settings=json.dumps(provider_settings) ) - super(SuccessFactorsIntegrationTest, self).test_register(country=expected_country) + self._test_register(country=expected_country) def test_register_http_failure(self): """ @@ -491,7 +595,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes self.USER_EMAIL = "myself@testshib.org" self.USER_NAME = "Me Myself And I" self.USER_USERNAME = "myself" - super(SuccessFactorsIntegrationTest, self).test_register() + self._test_register() def test_register_http_failure_in_odata(self): """ @@ -518,7 +622,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes }) ) with LogCapture(level=logging.WARNING) as log_capture: - super(SuccessFactorsIntegrationTest, self).test_register() + self._test_register() logging_messages = str([log_msg.getMessage() for log_msg in log_capture.records]).replace('\\', '') self.assertIn(odata_company_id, logging_messages) self.assertIn(mocked_odata_api_url, logging_messages)