replaced unittest assertions pytest assertions (#26240)
This commit is contained in:
@@ -54,7 +54,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
|
||||
def test_anonymous_fails(self):
|
||||
request = self._create_request()
|
||||
response = self.SomeTpaClassView().dispatch(request)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
assert response.status_code == 401
|
||||
|
||||
@ddt.data(
|
||||
(True, False, 200),
|
||||
@@ -68,7 +68,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
|
||||
self._create_session(request, user)
|
||||
|
||||
response = self.SomeTpaClassView().dispatch(request)
|
||||
self.assertEqual(response.status_code, expected_status_code)
|
||||
assert response.status_code == expected_status_code
|
||||
|
||||
@ddt.data(
|
||||
# unrestricted (for example, jwt cookies)
|
||||
@@ -97,7 +97,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
|
||||
)
|
||||
|
||||
response = self.SomeTpaClassView().dispatch(request)
|
||||
self.assertEqual(response.status_code, expected_response)
|
||||
assert response.status_code == expected_response
|
||||
|
||||
@ddt.data(
|
||||
# valid scopes
|
||||
@@ -152,4 +152,4 @@ class ThirdPartyAuthPermissionTest(TestCase):
|
||||
request = self._create_request(auth_header=auth_header)
|
||||
|
||||
response = self.SomeTpaClassView().dispatch(request, provider_id='some_tpa_provider')
|
||||
self.assertEqual(response.status_code, expected_response)
|
||||
assert response.status_code == expected_response
|
||||
|
||||
@@ -131,9 +131,9 @@ class UserViewsMixin(object):
|
||||
url = self.make_url({'username': target_user})
|
||||
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, expect_result)
|
||||
assert response.status_code == expect_result
|
||||
if expect_result == 200:
|
||||
self.assertIn("active", response.data)
|
||||
assert 'active' in response.data
|
||||
six.assertCountEqual(self, response.data["active"], self.expected_active(target_user))
|
||||
|
||||
@ddt.data(
|
||||
@@ -147,9 +147,9 @@ class UserViewsMixin(object):
|
||||
def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
|
||||
url = self.make_url({'username': target_user})
|
||||
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
|
||||
self.assertEqual(response.status_code, expect_result)
|
||||
assert response.status_code == expect_result
|
||||
if expect_result == 200:
|
||||
self.assertIn("active", response.data)
|
||||
assert 'active' in response.data
|
||||
six.assertCountEqual(self, response.data["active"], self.expected_active(target_user))
|
||||
|
||||
@ddt.data(
|
||||
@@ -164,18 +164,18 @@ class UserViewsMixin(object):
|
||||
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
|
||||
url = self.make_url({'username': ALICE_USERNAME})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, expect)
|
||||
assert response.status_code == expect
|
||||
if response.status_code == 200:
|
||||
self.assertGreater(len(response.data['active']), 0)
|
||||
assert len(response.data['active']) > 0
|
||||
for provider_data in response.data['active']:
|
||||
self.assertEqual(include_remote_id, 'remote_id' in provider_data)
|
||||
assert include_remote_id == ('remote_id' in provider_data)
|
||||
|
||||
def test_allow_query_by_email(self):
|
||||
self.client.login(username=ALICE_USERNAME, password=PASSWORD)
|
||||
url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertGreater(len(response.data['active']), 0)
|
||||
assert response.status_code == 200
|
||||
assert len(response.data['active']) > 0
|
||||
|
||||
def test_throttling(self):
|
||||
# Default throttle is 10/min. Make 11 requests to verify
|
||||
@@ -185,9 +185,9 @@ class UserViewsMixin(object):
|
||||
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
|
||||
for _ in range(10):
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@override_settings(EDX_API_KEY=VALID_API_KEY)
|
||||
@@ -332,7 +332,7 @@ class UserMappingViewAPITests(TpaAPITestCase):
|
||||
|
||||
url = reverse('third_party_auth_user_mapping_api', kwargs={'provider_id': PROVIDER_ID_TESTSHIB})
|
||||
response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
self._verify_response(response, 200, get_mapping_data_by_usernames(LINKED_USERS))
|
||||
|
||||
@ddt.data(
|
||||
@@ -346,14 +346,14 @@ class UserMappingViewAPITests(TpaAPITestCase):
|
||||
with patch.object(JwtRestrictedApplication, 'has_permission', return_value=has_permission):
|
||||
with patch.object(JwtHasScope, 'has_permission', return_value=has_permission):
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, expect)
|
||||
assert response.status_code == expect
|
||||
|
||||
def _verify_response(self, response, expect_code, expect_result):
|
||||
""" verify the items in data_list exists in response and data_results matches results in response """
|
||||
self.assertEqual(response.status_code, expect_code)
|
||||
assert response.status_code == expect_code
|
||||
if expect_code == 200:
|
||||
for item in ['results', 'count', 'num_pages']:
|
||||
self.assertIn(item, response.data)
|
||||
assert item in response.data
|
||||
six.assertCountEqual(self, response.data['results'], expect_result)
|
||||
|
||||
|
||||
@@ -375,15 +375,11 @@ class TestThirdPartyAuthUserStatusView(ThirdPartyAuthTestMixin, APITestCase):
|
||||
"""
|
||||
self.client.login(username=self.user.username, password=PASSWORD)
|
||||
response = self.client.get(self.url, content_type="application/json")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(
|
||||
response.data,
|
||||
[{
|
||||
"accepts_logins": True,
|
||||
"name": "Google",
|
||||
"disconnect_url": "/auth/disconnect/google-oauth2/?",
|
||||
"connect_url": "/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings",
|
||||
"connected": False,
|
||||
"id": "oa2-google-oauth2"
|
||||
}]
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert (response.data ==
|
||||
[{
|
||||
'accepts_logins': True, 'name': 'Google',
|
||||
'disconnect_url': '/auth/disconnect/google-oauth2/?',
|
||||
'connect_url': '/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings',
|
||||
'connected': False, 'id': 'oa2-google-oauth2'
|
||||
}])
|
||||
|
||||
@@ -7,6 +7,7 @@ import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management import call_command
|
||||
@@ -69,16 +70,16 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
|
||||
call_command(self.command, self.provider_hogwarts.slug, force=True)
|
||||
|
||||
# user with input idp is removed, along with social auth entries
|
||||
with self.assertRaises(User.DoesNotExist):
|
||||
with pytest.raises(User.DoesNotExist):
|
||||
User.objects.get(username='harry')
|
||||
with self.assertRaises(UserSocialAuth.DoesNotExist):
|
||||
with pytest.raises(UserSocialAuth.DoesNotExist):
|
||||
self.find_user_social_auth_entry('harry')
|
||||
|
||||
# other users intact
|
||||
self.user_fleur.refresh_from_db()
|
||||
self.user_viktor.refresh_from_db()
|
||||
self.assertIsNotNone(self.user_fleur)
|
||||
self.assertIsNotNone(self.user_viktor)
|
||||
assert self.user_fleur is not None
|
||||
assert self.user_viktor is not None
|
||||
|
||||
# other social auth intact
|
||||
self.find_user_social_auth_entry(self.user_viktor.username)
|
||||
@@ -96,9 +97,9 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
|
||||
with self._replace_stdin('confirm'):
|
||||
call_command(self.command, self.provider_hogwarts.slug)
|
||||
|
||||
with self.assertRaises(User.DoesNotExist):
|
||||
with pytest.raises(User.DoesNotExist):
|
||||
User.objects.get(username='harry')
|
||||
with self.assertRaises(UserSocialAuth.DoesNotExist):
|
||||
with pytest.raises(UserSocialAuth.DoesNotExist):
|
||||
self.find_user_social_auth_entry('harry')
|
||||
|
||||
@override_settings(FEATURES=FEATURES_WITH_ENABLED)
|
||||
@@ -109,8 +110,8 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
|
||||
call_command(self.command, self.provider_hogwarts.slug)
|
||||
|
||||
# no users should be removed
|
||||
self.assertEqual(len(User.objects.all()), 3)
|
||||
self.assertEqual(len(UserSocialAuth.objects.all()), 2)
|
||||
assert len(User.objects.all()) == 3
|
||||
assert len(UserSocialAuth.objects.all()) == 2
|
||||
|
||||
def test_feature_default_disabled(self):
|
||||
""" By default this command should not be enabled """
|
||||
|
||||
@@ -105,7 +105,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
"""
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n"
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get", mock_get())
|
||||
def test_fetch_saml_metadata(self):
|
||||
@@ -118,7 +118,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n"
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=404))
|
||||
def test_fetch_saml_metadata_failure(self):
|
||||
@@ -133,7 +133,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
with self.assertRaisesRegex(CommandError, r"HTTPError: 404 Client Error"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_fetch_multiple_providers_data(self):
|
||||
@@ -178,7 +178,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n'
|
||||
with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
# Now add a fourth configuration, and indicate that it should not be included in the update
|
||||
self.__create_saml_configurations__(
|
||||
@@ -201,7 +201,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n'
|
||||
with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_saml_request_exceptions(self, mocked_get):
|
||||
@@ -217,19 +217,19 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
with self.assertRaisesRegex(CommandError, "SSLError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
mocked_get.side_effect = exceptions.ConnectionError
|
||||
|
||||
with self.assertRaisesRegex(CommandError, "ConnectionError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
mocked_get.side_effect = exceptions.HTTPError
|
||||
|
||||
with self.assertRaisesRegex(CommandError, "HTTPError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_saml_parse_exceptions(self):
|
||||
@@ -254,7 +254,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
with self.assertRaisesRegex(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_xml_parse_exceptions(self, mocked_get):
|
||||
@@ -274,4 +274,4 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@@ -83,19 +83,19 @@ class SAMLConfigurationTests(APITestCase):
|
||||
def test_get_saml_configurations_successful(self):
|
||||
url = reverse('saml_configuration-list')
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# We ultimately just need ids and slugs, so let's just check those.
|
||||
results = response.data['results']
|
||||
self.assertEqual(results[0]['id'], SAML_CONFIGURATIONS[0]['site'])
|
||||
self.assertEqual(results[0]['slug'], SAML_CONFIGURATIONS[0]['slug'])
|
||||
self.assertEqual(results[1]['id'], SAML_CONFIGURATIONS[1]['site'])
|
||||
self.assertEqual(results[1]['slug'], SAML_CONFIGURATIONS[1]['slug'])
|
||||
assert results[0]['id'] == SAML_CONFIGURATIONS[0]['site']
|
||||
assert results[0]['slug'] == SAML_CONFIGURATIONS[0]['slug']
|
||||
assert results[1]['id'] == SAML_CONFIGURATIONS[1]['site']
|
||||
assert results[1]['slug'] == SAML_CONFIGURATIONS[1]['slug']
|
||||
|
||||
def test_get_saml_configurations_noprivate(self):
|
||||
# Verify we have 3 saml configuration objects: 2 public, 1 private.
|
||||
total_object_count = SAMLConfiguration.objects.count()
|
||||
self.assertEqual(total_object_count, 3)
|
||||
assert total_object_count == 3
|
||||
|
||||
url = reverse('saml_configuration-list')
|
||||
response = self.client.get(url, format='json')
|
||||
@@ -103,10 +103,10 @@ class SAMLConfigurationTests(APITestCase):
|
||||
# We should only see 2 results, since 1 out of 3 are private
|
||||
# and our queryset only returns public configurations.
|
||||
results = response.data['results']
|
||||
self.assertEqual(len(results), 2)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_unauthenticated_user_get_saml_configurations(self):
|
||||
self.client.logout()
|
||||
url = reverse('saml_configuration-list')
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -91,13 +91,13 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
results = response.data['results']
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['entity_id'], SINGLE_PROVIDER_CONFIG['entity_id'])
|
||||
self.assertEqual(results[0]['metadata_source'], SINGLE_PROVIDER_CONFIG['metadata_source'])
|
||||
self.assertEqual(response.data['results'][0]['country'], SINGLE_PROVIDER_CONFIG['country'])
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), 1)
|
||||
assert len(results) == 1
|
||||
assert results[0]['entity_id'] == SINGLE_PROVIDER_CONFIG['entity_id']
|
||||
assert results[0]['metadata_source'] == SINGLE_PROVIDER_CONFIG['metadata_source']
|
||||
assert response.data['results'][0]['country'] == SINGLE_PROVIDER_CONFIG['country']
|
||||
assert SAMLProviderConfig.objects.count() == 1
|
||||
|
||||
def test_get_one_config_by_enterprise_uuid_invalid_uuid(self):
|
||||
"""
|
||||
@@ -109,7 +109,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_get_one_config_by_enterprise_uuid_not_found(self):
|
||||
"""
|
||||
@@ -128,8 +128,8 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert SAMLProviderConfig.objects.count() == orig_count
|
||||
|
||||
def test_create_one_config(self):
|
||||
"""
|
||||
@@ -142,19 +142,14 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count + 1)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert SAMLProviderConfig.objects.count() == (orig_count + 1)
|
||||
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug'])
|
||||
self.assertEqual(provider_config.name, 'name-of-config-2')
|
||||
self.assertEqual(provider_config.country, SINGLE_PROVIDER_CONFIG_2['country'])
|
||||
assert provider_config.name == 'name-of-config-2'
|
||||
assert provider_config.country == SINGLE_PROVIDER_CONFIG_2['country']
|
||||
|
||||
# check association has also been created
|
||||
self.assertTrue(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=convert_saml_slug_provider_id(provider_config.slug)
|
||||
).exists(),
|
||||
'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
assert EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(provider_config.slug)).exists(), 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
|
||||
def test_create_one_config_fail_non_existent_enterprise_uuid(self):
|
||||
"""
|
||||
@@ -167,16 +162,11 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert SAMLProviderConfig.objects.count() == orig_count
|
||||
|
||||
# check association has NOT been created
|
||||
self.assertFalse(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])
|
||||
).exists(),
|
||||
'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
assert not EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])).exists(), 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
|
||||
def test_create_one_config_with_absent_enterprise_uuid(self):
|
||||
"""
|
||||
@@ -188,8 +178,8 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert SAMLProviderConfig.objects.count() == orig_count
|
||||
|
||||
def test_create_one_config_with_no_country_urn(self):
|
||||
"""
|
||||
@@ -206,9 +196,9 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
}
|
||||
|
||||
response = self.client.post(url, provider_config_no_country)
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
provider_config = SAMLProviderConfig.objects.get(slug='test-slug-none')
|
||||
self.assertEqual(provider_config.country, '')
|
||||
assert provider_config.country == ''
|
||||
|
||||
def test_create_one_config_with_empty_country_urn(self):
|
||||
"""
|
||||
@@ -226,9 +216,9 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
}
|
||||
|
||||
response = self.client.post(url, provider_config_blank_country)
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
provider_config = SAMLProviderConfig.objects.get(slug='test-slug-empty')
|
||||
self.assertEqual(provider_config.country, '')
|
||||
assert provider_config.country == ''
|
||||
|
||||
def test_unauthenticated_request_is_forbidden(self):
|
||||
self.client.logout()
|
||||
@@ -237,12 +227,12 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)])
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
self.client.logout()
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)])
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_create_one_config_with_samlconfiguration(self):
|
||||
"""
|
||||
@@ -255,6 +245,6 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug'])
|
||||
self.assertEqual(provider_config.saml_configuration, self.samlconfiguration)
|
||||
assert provider_config.saml_configuration == self.samlconfiguration
|
||||
|
||||
@@ -88,10 +88,10 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
results = response.data['results']
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['sso_url'], SINGLE_PROVIDER_DATA['sso_url'])
|
||||
assert len(results) == 1
|
||||
assert results[0]['sso_url'] == SINGLE_PROVIDER_DATA['sso_url']
|
||||
|
||||
def test_create_one_provider_data_success(self):
|
||||
# POST auth/saml/v0/providerdata/ -d data
|
||||
@@ -102,12 +102,9 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(SAMLProviderData.objects.count(), orig_count + 1)
|
||||
self.assertEqual(
|
||||
SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url,
|
||||
SINGLE_PROVIDER_DATA_2['sso_url']
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert SAMLProviderData.objects.count() == (orig_count + 1)
|
||||
assert SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url == SINGLE_PROVIDER_DATA_2['sso_url']
|
||||
|
||||
def test_create_one_data_with_absent_enterprise_uuid(self):
|
||||
"""
|
||||
@@ -119,8 +116,8 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert SAMLProviderData.objects.count() == orig_count
|
||||
|
||||
def test_patch_one_provider_data(self):
|
||||
# PATCH auth/saml/v0/providerdata/ -d data
|
||||
@@ -133,14 +130,14 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.patch(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert SAMLProviderData.objects.count() == orig_count
|
||||
|
||||
# ensure only the sso_url was updated
|
||||
fetched_provider_data = SAMLProviderData.objects.get(pk=self.saml_provider_data.id)
|
||||
self.assertEqual(fetched_provider_data.sso_url, 'http://new.url')
|
||||
self.assertEqual(fetched_provider_data.fetched_at, SINGLE_PROVIDER_DATA['fetched_at'])
|
||||
self.assertEqual(fetched_provider_data.entity_id, SINGLE_PROVIDER_DATA['entity_id'])
|
||||
assert fetched_provider_data.sso_url == 'http://new.url'
|
||||
assert fetched_provider_data.fetched_at == SINGLE_PROVIDER_DATA['fetched_at']
|
||||
assert fetched_provider_data.entity_id == SINGLE_PROVIDER_DATA['entity_id']
|
||||
|
||||
def test_delete_one_provider_data(self):
|
||||
# DELETE auth/saml/v0/providerdata/ -d data
|
||||
@@ -151,12 +148,12 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.delete(url)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
|
||||
self.assertEqual(SAMLProviderData.objects.count(), orig_count - 1)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
assert SAMLProviderData.objects.count() == (orig_count - 1)
|
||||
|
||||
# ensure only the sso_url was updated
|
||||
query_set_count = SAMLProviderData.objects.filter(pk=self.saml_provider_data.id).count()
|
||||
self.assertEqual(query_set_count, 0)
|
||||
assert query_set_count == 0
|
||||
|
||||
def test_get_one_provider_data_failure(self):
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)])
|
||||
@@ -167,7 +164,7 @@ class SAMLProviderDataTests(APITestCase):
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_unauthenticated_request_is_forbidden(self):
|
||||
self.client.logout()
|
||||
@@ -176,10 +173,10 @@ class SAMLProviderDataTests(APITestCase):
|
||||
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)])
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# manually running second case as DDT is having issues.
|
||||
self.client.logout()
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)])
|
||||
response = self.client.get(url, format='json')
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@@ -6,6 +6,7 @@ Base integration test for provider implementations.
|
||||
import json
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
import pytest
|
||||
|
||||
import mock
|
||||
from django import test
|
||||
@@ -55,8 +56,8 @@ class HelperMixin(object):
|
||||
implementations may optionally strengthen this assertion with, for
|
||||
example, more details about the format of the Location header.
|
||||
"""
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertTrue(response.has_header('Location'))
|
||||
assert 302 == response.status_code
|
||||
assert response.has_header('Location')
|
||||
|
||||
def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields): # lint-amnesty, pylint: disable=invalid-name
|
||||
"""Performs spot checks of the rendered register.html page.
|
||||
@@ -95,16 +96,16 @@ class HelperMixin(object):
|
||||
its connected state is correct.
|
||||
"""
|
||||
if duplicate:
|
||||
self.assertEqual(context['duplicate_provider'], self.provider.backend_name)
|
||||
assert context['duplicate_provider'] == self.provider.backend_name
|
||||
else:
|
||||
self.assertIsNone(context['duplicate_provider'])
|
||||
assert context['duplicate_provider'] is None
|
||||
|
||||
if linked is not None:
|
||||
expected_provider = [
|
||||
provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name
|
||||
][0]
|
||||
self.assertIsNotNone(expected_provider)
|
||||
self.assertEqual(expected_provider['connected'], linked)
|
||||
assert expected_provider is not None
|
||||
assert expected_provider['connected'] == linked
|
||||
|
||||
def assert_exception_redirect_looks_correct(self, expected_uri, auth_entry=None):
|
||||
"""Tests middleware conditional redirection.
|
||||
@@ -118,45 +119,45 @@ class HelperMixin(object):
|
||||
request, exceptions.AuthCanceled(request.backend))
|
||||
location = response.get('Location')
|
||||
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertIn('canceled', location)
|
||||
self.assertIn(self.backend_name, location)
|
||||
self.assertTrue(location.startswith(expected_uri + '?'))
|
||||
assert 302 == response.status_code
|
||||
assert 'canceled' in location
|
||||
assert self.backend_name in location
|
||||
assert location.startswith((expected_uri + '?'))
|
||||
|
||||
def assert_json_failure_response_is_inactive_account(self, response):
|
||||
"""Asserts failure on /login for inactive account looks right."""
|
||||
self.assertEqual(400, response.status_code)
|
||||
assert 400 == response.status_code
|
||||
payload = json.loads(response.content.decode('utf-8'))
|
||||
context = {
|
||||
'platformName': configuration_helpers.get_value('PLATFORM_NAME', settings.PLATFORM_NAME),
|
||||
'supportLink': configuration_helpers.get_value('SUPPORT_SITE_LINK', settings.SUPPORT_SITE_LINK)
|
||||
}
|
||||
|
||||
self.assertFalse(payload.get('success'))
|
||||
self.assertIn('inactive-user', payload.get('error_code'))
|
||||
self.assertEqual(context, payload.get('context'))
|
||||
assert not payload.get('success')
|
||||
assert 'inactive-user' in payload.get('error_code')
|
||||
assert context == payload.get('context')
|
||||
|
||||
def assert_json_failure_response_is_missing_social_auth(self, response):
|
||||
"""Asserts failure on /login for missing social auth looks right."""
|
||||
self.assertEqual(403, response.status_code)
|
||||
assert 403 == response.status_code
|
||||
payload = json.loads(response.content.decode('utf-8'))
|
||||
self.assertFalse(payload.get('success'))
|
||||
self.assertEqual(payload.get('error_code'), 'third-party-auth-with-no-linked-account')
|
||||
assert not payload.get('success')
|
||||
assert payload.get('error_code') == 'third-party-auth-with-no-linked-account'
|
||||
|
||||
def assert_json_failure_response_is_username_collision(self, response):
|
||||
"""Asserts the json response indicates a username collision."""
|
||||
self.assertEqual(409, response.status_code)
|
||||
assert 409 == response.status_code
|
||||
payload = json.loads(response.content.decode('utf-8'))
|
||||
self.assertFalse(payload.get('success'))
|
||||
self.assertIn('belongs to an existing account', payload['username'][0]['user_message'])
|
||||
assert not payload.get('success')
|
||||
assert 'belongs to an existing account' in payload['username'][0]['user_message']
|
||||
|
||||
def assert_json_success_response_looks_correct(self, response, verify_redirect_url):
|
||||
"""Asserts the json response indicates success and redirection."""
|
||||
self.assertEqual(200, response.status_code)
|
||||
assert 200 == response.status_code
|
||||
payload = json.loads(response.content.decode('utf-8'))
|
||||
self.assertTrue(payload.get('success'))
|
||||
assert payload.get('success')
|
||||
if verify_redirect_url:
|
||||
self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url'))
|
||||
assert pipeline.get_complete_url(self.provider.backend_name) == payload.get('redirect_url')
|
||||
|
||||
def assert_login_response_before_pipeline_looks_correct(self, response):
|
||||
"""Asserts a GET of /login not in the pipeline looks correct."""
|
||||
@@ -167,37 +168,36 @@ class HelperMixin(object):
|
||||
|
||||
def assert_login_response_in_pipeline_looks_correct(self, response):
|
||||
"""Asserts a GET of /login in the pipeline looks correct."""
|
||||
self.assertEqual(200, response.status_code)
|
||||
assert 200 == response.status_code
|
||||
|
||||
def assert_password_overridden_by_pipeline(self, username, password):
|
||||
"""Verifies that the given password is not correct.
|
||||
|
||||
The pipeline overrides POST['password'], if any, with random data.
|
||||
"""
|
||||
self.assertIsNone(auth.authenticate(password=password, username=username))
|
||||
assert auth.authenticate(password=password, username=username) is None
|
||||
|
||||
def assert_pipeline_running(self, request):
|
||||
"""Makes sure the given request is running an auth pipeline."""
|
||||
self.assertTrue(pipeline.running(request))
|
||||
assert pipeline.running(request)
|
||||
|
||||
def assert_redirect_after_pipeline_completes(self, response, expected_redirect_url=None):
|
||||
"""Asserts a response would redirect to the expected_redirect_url or SOCIAL_AUTH_LOGIN_REDIRECT_URL."""
|
||||
self.assertEqual(302, response.status_code)
|
||||
assert 302 == response.status_code
|
||||
# NOTE: Ideally we should use assertRedirects(), however it errors out due to the hostname, testserver,
|
||||
# not being properly set. This may be an issue with the call made by PSA, but we are not certain.
|
||||
self.assertTrue(response.get('Location').endswith(
|
||||
expected_redirect_url or django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL,
|
||||
))
|
||||
assert response.get('Location').endswith((expected_redirect_url or
|
||||
django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL))
|
||||
|
||||
def assert_redirect_to_login_looks_correct(self, response):
|
||||
"""Asserts a response would redirect to /login."""
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual('/login', response.get('Location'))
|
||||
assert 302 == response.status_code
|
||||
assert '/login' == response.get('Location')
|
||||
|
||||
def assert_redirect_to_register_looks_correct(self, response):
|
||||
"""Asserts a response would redirect to /register."""
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual('/register', response.get('Location'))
|
||||
assert 302 == response.status_code
|
||||
assert '/register' == response.get('Location')
|
||||
|
||||
def assert_register_response_before_pipeline_looks_correct(self, response):
|
||||
"""Asserts a GET of /register not in the pipeline looks correct."""
|
||||
@@ -210,24 +210,21 @@ class HelperMixin(object):
|
||||
"""Asserts a user does not have an auth with the expected provider."""
|
||||
social_auths = strategy.storage.user.get_social_auth_for_user(
|
||||
user, provider=self.provider.backend_name)
|
||||
self.assertEqual(0, len(social_auths))
|
||||
assert 0 == len(social_auths)
|
||||
|
||||
def assert_social_auth_exists_for_user(self, user, strategy):
|
||||
"""Asserts a user has a social auth with the expected provider."""
|
||||
social_auths = strategy.storage.user.get_social_auth_for_user(
|
||||
user, provider=self.provider.backend_name)
|
||||
self.assertEqual(1, len(social_auths))
|
||||
self.assertEqual(self.backend_name, social_auths[0].provider)
|
||||
assert 1 == len(social_auths)
|
||||
assert self.backend_name == social_auths[0].provider
|
||||
|
||||
def assert_logged_in_cookie_redirect(self, response):
|
||||
"""Verify that the user was redirected in order to set the logged in cookie. """
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(
|
||||
response["Location"],
|
||||
pipeline.get_complete_url(self.provider.backend_name)
|
||||
)
|
||||
self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true')
|
||||
self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies)
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == pipeline.get_complete_url(self.provider.backend_name)
|
||||
assert response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value == 'true'
|
||||
assert django_settings.EDXMKTG_USER_INFO_COOKIE_NAME in response.cookies
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
@@ -384,24 +381,24 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
# The user clicks on the Dummy button:
|
||||
try_login_response = self.client.get(provider_register_url)
|
||||
# The user should be redirected to the provider's login page:
|
||||
self.assertEqual(try_login_response.status_code, 302)
|
||||
assert try_login_response.status_code == 302
|
||||
provider_response = self.do_provider_login(try_login_response['Location'])
|
||||
# We should be redirected to the register screen since this account is not linked to an edX account:
|
||||
self.assertEqual(provider_response.status_code, 302)
|
||||
self.assertEqual(provider_response['Location'], self.register_page_url)
|
||||
assert provider_response.status_code == 302
|
||||
assert provider_response['Location'] == self.register_page_url
|
||||
register_response = self.client.get(self.register_page_url)
|
||||
tpa_context = register_response.context["data"]["third_party_auth"]
|
||||
self.assertEqual(tpa_context["errorMessage"], None)
|
||||
assert tpa_context['errorMessage'] is None
|
||||
# Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown.
|
||||
self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME)
|
||||
assert tpa_context['currentProvider'] == self.PROVIDER_NAME
|
||||
# Check that the data (e.g. email) from the provider is displayed in the form:
|
||||
form_data = register_response.context['data']['registration_form_desc']
|
||||
form_fields = {field['name']: field for field in form_data['fields']}
|
||||
self.assertEqual(form_fields['email']['defaultValue'], self.USER_EMAIL)
|
||||
self.assertEqual(form_fields['name']['defaultValue'], self.USER_NAME)
|
||||
self.assertEqual(form_fields['username']['defaultValue'], self.USER_USERNAME)
|
||||
assert form_fields['email']['defaultValue'] == self.USER_EMAIL
|
||||
assert form_fields['name']['defaultValue'] == self.USER_NAME
|
||||
assert form_fields['username']['defaultValue'] == self.USER_USERNAME
|
||||
for field_name, value in extra_defaults.items():
|
||||
self.assertEqual(form_fields[field_name]['defaultValue'], value)
|
||||
assert form_fields[field_name]['defaultValue'] == value
|
||||
registration_values = {
|
||||
'email': 'email-edited@tpa-test.none',
|
||||
'name': 'My Customized Name',
|
||||
@@ -413,12 +410,12 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
reverse('user_api_registration'),
|
||||
registration_values
|
||||
)
|
||||
self.assertEqual(ajax_register_response.status_code, 200)
|
||||
assert ajax_register_response.status_code == 200
|
||||
# Then the AJAX will finish the third party auth:
|
||||
continue_response = self.client.get(tpa_context["finishAuthUrl"])
|
||||
# And we should be redirected to the dashboard:
|
||||
self.assertEqual(continue_response.status_code, 302)
|
||||
self.assertEqual(continue_response['Location'], reverse('dashboard'))
|
||||
assert continue_response.status_code == 302
|
||||
assert continue_response['Location'] == reverse('dashboard')
|
||||
|
||||
# Now check that we can login again, whether or not we have yet verified the account:
|
||||
self.client.logout()
|
||||
@@ -437,28 +434,28 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
# The user clicks on the provider's button:
|
||||
try_login_response = self.client.get(provider_login_url)
|
||||
# The user should be redirected to the provider's login page:
|
||||
self.assertEqual(try_login_response.status_code, 302)
|
||||
assert try_login_response.status_code == 302
|
||||
complete_response = self.do_provider_login(try_login_response['Location'])
|
||||
# We should be redirected to the login screen since this account is not linked to an edX account:
|
||||
self.assertEqual(complete_response.status_code, 302)
|
||||
self.assertEqual(complete_response['Location'], self.login_page_url)
|
||||
assert complete_response.status_code == 302
|
||||
assert complete_response['Location'] == self.login_page_url
|
||||
login_response = self.client.get(self.login_page_url)
|
||||
tpa_context = login_response.context["data"]["third_party_auth"]
|
||||
self.assertEqual(tpa_context["errorMessage"], None)
|
||||
assert tpa_context['errorMessage'] is None
|
||||
# Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown.
|
||||
self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME)
|
||||
assert tpa_context['currentProvider'] == self.PROVIDER_NAME
|
||||
# Now the user enters their username and password.
|
||||
# The AJAX on the page will log them in:
|
||||
ajax_login_response = self.client.post(
|
||||
reverse('user_api_login_session'),
|
||||
{'email': self.user.email, 'password': 'test'}
|
||||
)
|
||||
self.assertEqual(ajax_login_response.status_code, 200)
|
||||
assert ajax_login_response.status_code == 200
|
||||
# Then the AJAX will finish the third party auth:
|
||||
continue_response = self.client.get(tpa_context["finishAuthUrl"])
|
||||
# And we should be redirected to the dashboard:
|
||||
self.assertEqual(continue_response.status_code, 302)
|
||||
self.assertEqual(continue_response['Location'], reverse('dashboard'))
|
||||
assert continue_response.status_code == 302
|
||||
assert continue_response['Location'] == reverse('dashboard')
|
||||
|
||||
# Now check that we can login again:
|
||||
self.client.logout()
|
||||
@@ -475,30 +472,30 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
""" Test logging in to an account that is already linked. """
|
||||
# Make sure we're not logged in:
|
||||
dashboard_response = self.client.get(reverse('dashboard'))
|
||||
self.assertEqual(dashboard_response.status_code, 302)
|
||||
assert dashboard_response.status_code == 302
|
||||
# The user goes to the login page, and sees a button to login with this provider:
|
||||
provider_login_url = self._check_login_page()
|
||||
# The user clicks on the provider's login button:
|
||||
try_login_response = self.client.get(provider_login_url)
|
||||
# The user should be redirected to the provider:
|
||||
self.assertEqual(try_login_response.status_code, 302)
|
||||
assert try_login_response.status_code == 302
|
||||
login_response = self.do_provider_login(try_login_response['Location'])
|
||||
# If the previous session was manually logged out, there will be one weird redirect
|
||||
# required to set the login cookie (it sticks around if the main session times out):
|
||||
if not previous_session_timed_out:
|
||||
self.assertEqual(login_response.status_code, 302)
|
||||
self.assertEqual(login_response['Location'], self.complete_url + "?")
|
||||
assert login_response.status_code == 302
|
||||
assert login_response['Location'] == (self.complete_url + '?')
|
||||
# And then we should be redirected to the dashboard:
|
||||
login_response = self.client.get(login_response['Location'])
|
||||
self.assertEqual(login_response.status_code, 302)
|
||||
assert login_response.status_code == 302
|
||||
if user_is_activated:
|
||||
url_expected = reverse('dashboard')
|
||||
else:
|
||||
url_expected = reverse('third_party_inactive_redirect') + '?next=' + reverse('dashboard')
|
||||
self.assertEqual(login_response['Location'], url_expected)
|
||||
assert login_response['Location'] == url_expected
|
||||
# Now we are logged in:
|
||||
dashboard_response = self.client.get(reverse('dashboard'))
|
||||
self.assertEqual(dashboard_response.status_code, 200)
|
||||
assert dashboard_response.status_code == 200
|
||||
|
||||
def _check_login_page(self):
|
||||
"""
|
||||
@@ -520,7 +517,7 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
self.assertContains(response, self.PROVIDER_NAME)
|
||||
context_data = response.context['data']['third_party_auth']
|
||||
provider_urls = {provider['id']: provider[url_to_return] for provider in context_data['providers']}
|
||||
self.assertIn(self.PROVIDER_ID, provider_urls)
|
||||
assert self.PROVIDER_ID in provider_urls
|
||||
return provider_urls[self.PROVIDER_ID]
|
||||
|
||||
@property
|
||||
@@ -668,7 +665,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
self.assert_social_auth_exists_for_user(linked_user, strategy)
|
||||
self.assert_social_auth_does_not_exist_for_user(unlinked_user, strategy)
|
||||
|
||||
with self.assertRaises(exceptions.AuthAlreadyAssociated):
|
||||
with pytest.raises(exceptions.AuthAlreadyAssociated):
|
||||
# pylint: disable=protected-access
|
||||
actions.do_complete(backend, social_views._do_login, user=unlinked_user, request=strategy.request)
|
||||
|
||||
@@ -720,7 +717,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
partial_data = strategy.storage.partial.load(partial_pipeline_token)
|
||||
|
||||
self.assert_social_auth_exists_for_user(user, strategy)
|
||||
self.assertTrue(user.is_active)
|
||||
assert user.is_active
|
||||
|
||||
# Begin! Ensure that the login form contains expected controls before
|
||||
# the user starts the pipeline.
|
||||
@@ -862,7 +859,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
strategy.request.POST = self.get_registration_post_vars({'email': email})
|
||||
|
||||
# The user must not exist yet...
|
||||
with self.assertRaises(auth_models.User.DoesNotExist):
|
||||
with pytest.raises(auth_models.User.DoesNotExist):
|
||||
self.get_user_by_email(strategy, email)
|
||||
|
||||
# ...but when we invoke create_account the existing edX view will make
|
||||
@@ -906,7 +903,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
self.assert_redirect_to_login_looks_correct(actions.do_complete(backend, social_views._do_login,
|
||||
request=request))
|
||||
distinct_username = pipeline.get(request)['kwargs']['username']
|
||||
self.assertNotEqual(original_username, distinct_username)
|
||||
assert original_username != distinct_username
|
||||
|
||||
def test_new_account_registration_fails_if_email_exists(self):
|
||||
request, strategy = self.get_request_and_strategy(
|
||||
@@ -932,17 +929,17 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
|
||||
def test_pipeline_raises_auth_entry_error_if_auth_entry_invalid(self):
|
||||
auth_entry = 'invalid'
|
||||
self.assertNotIn(auth_entry, pipeline._AUTH_ENTRY_CHOICES) # pylint: disable=protected-access
|
||||
assert auth_entry not in pipeline._AUTH_ENTRY_CHOICES # pylint: disable=protected-access
|
||||
|
||||
_, strategy = self.get_request_and_strategy(auth_entry=auth_entry, redirect_uri='social:complete')
|
||||
|
||||
with self.assertRaises(pipeline.AuthEntryError):
|
||||
with pytest.raises(pipeline.AuthEntryError):
|
||||
strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy))
|
||||
|
||||
def test_pipeline_assumes_login_if_auth_entry_missing(self):
|
||||
_, strategy = self.get_request_and_strategy(auth_entry=None, redirect_uri='social:complete')
|
||||
response = self.fake_auth_complete(strategy)
|
||||
self.assertEqual(response.url, reverse('signin_user'))
|
||||
assert response.url == reverse('signin_user')
|
||||
|
||||
def assert_first_party_auth_trumps_third_party_auth(self, email=None, password=None, success=None):
|
||||
"""Asserts first party auth was used in place of third party auth.
|
||||
@@ -974,15 +971,15 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
|
||||
|
||||
if success is None:
|
||||
# Request malformed -- just one of email/password given.
|
||||
self.assertFalse(payload.get('success'))
|
||||
self.assertIn('There was an error receiving your login information', payload.get('value'))
|
||||
assert not payload.get('success')
|
||||
assert 'There was an error receiving your login information' in payload.get('value')
|
||||
elif success:
|
||||
# Request well-formed and credentials good.
|
||||
self.assertTrue(payload.get('success'))
|
||||
assert payload.get('success')
|
||||
else:
|
||||
# Request well-formed but credentials bad.
|
||||
self.assertFalse(payload.get('success'))
|
||||
self.assertIn('incorrect', payload.get('value'))
|
||||
assert not payload.get('success')
|
||||
assert 'incorrect' in payload.get('value')
|
||||
|
||||
def get_response_data(self):
|
||||
"""Gets a dict of response data of the form given by the provider.
|
||||
|
||||
@@ -31,5 +31,5 @@ class GenericIntegrationTest(IntegrationTestMixin, testutil.TestCase):
|
||||
Mock logging in to the Dummy provider
|
||||
"""
|
||||
# For the Dummy provider, the provider redirect URL is self.complete_url
|
||||
self.assertEqual(provider_redirect_url, self.url_prefix + self.complete_url)
|
||||
assert provider_redirect_url == (self.url_prefix + self.complete_url)
|
||||
return self.client.get(provider_redirect_url)
|
||||
|
||||
@@ -51,7 +51,7 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
|
||||
|
||||
def assert_redirect_to_provider_looks_correct(self, response):
|
||||
super(GoogleOauth2IntegrationTest, self).assert_redirect_to_provider_looks_correct(response) # lint-amnesty, pylint: disable=super-with-arguments
|
||||
self.assertIn('google.com', response['Location'])
|
||||
assert 'google.com' in response['Location']
|
||||
|
||||
def test_custom_form(self):
|
||||
"""
|
||||
@@ -75,27 +75,20 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
|
||||
with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete):
|
||||
response = self.client.get(complete_url)
|
||||
# This should redirect to the custom login/register form:
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response['Location'], '/auth/custom_auth_entry')
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == '/auth/custom_auth_entry'
|
||||
|
||||
response = self.client.get(response['Location'])
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('action="/misc/my-custom-registration-form" method="post"', response.content.decode('utf-8'))
|
||||
assert response.status_code == 200
|
||||
assert 'action="/misc/my-custom-registration-form" method="post"' in response.content.decode('utf-8')
|
||||
data_decoded = base64.b64decode(response.context['data']).decode('utf-8')
|
||||
data_parsed = json.loads(data_decoded)
|
||||
# The user's details get passed to the custom page as a base64 encoded query parameter:
|
||||
self.assertEqual(data_parsed, {
|
||||
'auth_entry': 'custom1',
|
||||
'backend_name': 'google-oauth2',
|
||||
'provider_id': 'oa2-google-oauth2',
|
||||
'user_details': {
|
||||
'username': 'user',
|
||||
'email': 'user@email.com',
|
||||
'fullname': 'name_value',
|
||||
'first_name': 'given_name_value',
|
||||
'last_name': 'family_name_value',
|
||||
},
|
||||
})
|
||||
assert data_parsed == {'auth_entry': 'custom1', 'backend_name': 'google-oauth2',
|
||||
'provider_id': 'oa2-google-oauth2',
|
||||
'user_details': {'username': 'user', 'email': 'user@email.com',
|
||||
'fullname': 'name_value', 'first_name': 'given_name_value',
|
||||
'last_name': 'family_name_value'}}
|
||||
# Check the hash that is used to confirm the user's data in the GET parameter is correct
|
||||
secret_key = settings.THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS['custom1']['secret_key']
|
||||
hmac_expected = hmac.new(
|
||||
@@ -103,18 +96,18 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
|
||||
msg=data_decoded.encode('utf-8'),
|
||||
digestmod=hashlib.sha256
|
||||
).digest()
|
||||
self.assertEqual(base64.b64decode(response.context['hmac']), hmac_expected)
|
||||
assert base64.b64decode(response.context['hmac']) == hmac_expected
|
||||
|
||||
# Now our custom registration form creates or logs in the user:
|
||||
email, password = data_parsed['user_details']['email'], 'random_password'
|
||||
created_user = UserFactory(email=email, password=password)
|
||||
login_response = self.client.post(reverse('login_api'), {'email': email, 'password': password})
|
||||
self.assertEqual(login_response.status_code, 200)
|
||||
assert login_response.status_code == 200
|
||||
|
||||
# Now our custom login/registration page must resume the pipeline:
|
||||
response = self.client.get(complete_url)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response['Location'], '/misc/final-destination')
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == '/misc/final-destination'
|
||||
|
||||
_, strategy = self.get_request_and_strategy()
|
||||
self.assert_social_auth_exists_for_user(created_user, strategy)
|
||||
@@ -140,5 +133,5 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
|
||||
with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete_error):
|
||||
response = self.client.get(complete_url)
|
||||
# This should redirect to the custom error URL
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response['Location'], '/misc/my-custom-sso-error-page')
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == '/misc/my-custom-sso-error-page'
|
||||
|
||||
@@ -70,8 +70,8 @@ class IntegrationTestLTI(testutil.TestCase):
|
||||
)
|
||||
login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
|
||||
# The user should be redirected to the registration form
|
||||
self.assertEqual(login_response.status_code, 302)
|
||||
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
|
||||
assert login_response.status_code == 302
|
||||
assert login_response['Location'].endswith(reverse('signin_user'))
|
||||
register_response = self.client.get(login_response['Location'])
|
||||
self.assertContains(register_response, '"currentProvider": "LTI Test Tool Consumer"')
|
||||
self.assertContains(register_response, '"errorMessage": null')
|
||||
@@ -86,15 +86,12 @@ class IntegrationTestLTI(testutil.TestCase):
|
||||
'honor_code': True,
|
||||
}
|
||||
)
|
||||
self.assertEqual(ajax_register_response.status_code, 200)
|
||||
assert ajax_register_response.status_code == 200
|
||||
continue_response = self.client.get(self.url_prefix + LTI_TPA_COMPLETE_URL)
|
||||
# The user should be redirected to the finish_auth view which will enroll them.
|
||||
# FinishAuthView.js reads the URL parameters directly from $.url
|
||||
self.assertEqual(continue_response.status_code, 302)
|
||||
self.assertEqual(
|
||||
continue_response['Location'],
|
||||
'/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll'
|
||||
)
|
||||
assert continue_response.status_code == 302
|
||||
assert continue_response['Location'] == '/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll'
|
||||
|
||||
# Now check that we can login again
|
||||
self.client.logout()
|
||||
@@ -106,19 +103,20 @@ class IntegrationTestLTI(testutil.TestCase):
|
||||
)
|
||||
login_2_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
|
||||
# The user should be redirected to the dashboard
|
||||
self.assertEqual(login_2_response.status_code, 302)
|
||||
self.assertEqual(login_2_response['Location'], LTI_TPA_COMPLETE_URL + "?")
|
||||
assert login_2_response.status_code == 302
|
||||
assert login_2_response['Location'] == (LTI_TPA_COMPLETE_URL + '?')
|
||||
continue_2_response = self.client.get(login_2_response['Location'])
|
||||
self.assertEqual(continue_2_response.status_code, 302)
|
||||
self.assertTrue(continue_2_response['Location'].endswith(reverse('dashboard')))
|
||||
assert continue_2_response.status_code == 302
|
||||
assert continue_2_response['Location'].endswith(reverse('dashboard'))
|
||||
|
||||
# Check that the user was created correctly
|
||||
user = User.objects.get(email=EMAIL)
|
||||
self.assertEqual(user.username, EDX_USER_ID)
|
||||
assert user.username == EDX_USER_ID
|
||||
|
||||
def test_reject_initiating_login(self):
|
||||
response = self.client.get(self.url_prefix + LTI_TPA_LOGIN_URL)
|
||||
self.assertEqual(response.status_code, 405) # Not Allowed
|
||||
assert response.status_code == 405
|
||||
# Not Allowed
|
||||
|
||||
def test_reject_bad_login(self):
|
||||
login_response = self.client.post(
|
||||
@@ -127,8 +125,8 @@ class IntegrationTestLTI(testutil.TestCase):
|
||||
)
|
||||
# The user should be redirected to the login page with an error message
|
||||
# (auth_entry defaults to login for this provider)
|
||||
self.assertEqual(login_response.status_code, 302)
|
||||
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
|
||||
assert login_response.status_code == 302
|
||||
assert login_response['Location'].endswith(reverse('signin_user'))
|
||||
error_response = self.client.get(login_response['Location'])
|
||||
self.assertContains(
|
||||
error_response,
|
||||
@@ -152,8 +150,8 @@ class IntegrationTestLTI(testutil.TestCase):
|
||||
with self.settings(SOCIAL_AUTH_LTI_CONSUMER_SECRETS={OTHER_LTI_CONSUMER_KEY: OTHER_LTI_CONSUMER_SECRET}):
|
||||
login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
|
||||
# The user should be redirected to the registration form
|
||||
self.assertEqual(login_response.status_code, 302)
|
||||
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
|
||||
assert login_response.status_code == 302
|
||||
assert login_response['Location'].endswith(reverse('signin_user'))
|
||||
register_response = self.client.get(login_response['Location'])
|
||||
self.assertContains(
|
||||
register_response,
|
||||
|
||||
@@ -114,21 +114,21 @@ class SamlIntegrationTestUtilities(object):
|
||||
saml_provider = self.configure_saml_provider(**kwargs) # pylint: disable=no-member
|
||||
|
||||
if fetch_metadata:
|
||||
self.assertTrue(httpretty.is_enabled()) # lint-amnesty, pylint: disable=no-member
|
||||
assert httpretty.is_enabled() # lint-amnesty, pylint: disable=no-member
|
||||
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
|
||||
if assert_metadata_updates:
|
||||
self.assertEqual(num_total, 1) # lint-amnesty, pylint: disable=no-member
|
||||
self.assertEqual(num_skipped, 0) # lint-amnesty, pylint: disable=no-member
|
||||
self.assertEqual(num_attempted, 1) # lint-amnesty, pylint: disable=no-member
|
||||
self.assertEqual(num_updated, 1) # lint-amnesty, pylint: disable=no-member
|
||||
self.assertEqual(num_failed, 0) # lint-amnesty, pylint: disable=no-member
|
||||
self.assertEqual(len(failure_messages), 0) # lint-amnesty, pylint: disable=no-member
|
||||
assert num_total == 1 # lint-amnesty, pylint: disable=no-member
|
||||
assert num_skipped == 0 # lint-amnesty, pylint: disable=no-member
|
||||
assert num_attempted == 1 # lint-amnesty, pylint: disable=no-member
|
||||
assert num_updated == 1 # lint-amnesty, pylint: disable=no-member
|
||||
assert num_failed == 0 # lint-amnesty, pylint: disable=no-member
|
||||
assert len(failure_messages) == 0 # lint-amnesty, pylint: disable=no-member
|
||||
return saml_provider
|
||||
|
||||
def do_provider_login(self, provider_redirect_url):
|
||||
""" Mocked: the user logs in to TestShib and then gets redirected back """
|
||||
# The SAML provider (TestShib) will authenticate the user, then get the browser to POST a response:
|
||||
self.assertTrue(provider_redirect_url.startswith(TESTSHIB_SSO_URL)) # lint-amnesty, pylint: disable=no-member
|
||||
assert provider_redirect_url.startswith(TESTSHIB_SSO_URL) # lint-amnesty, pylint: disable=no-member
|
||||
|
||||
saml_response_xml = utils.read_and_pre_process_xml(
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'testshib_saml_response.xml')
|
||||
@@ -188,9 +188,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
enterprise_customer = EnterpriseCustomerFactory()
|
||||
assert EnterpriseCustomerUser.objects.count() == 0, "Precondition check: no link records should exist"
|
||||
EnterpriseCustomerUser.objects.link_user(enterprise_customer, user.email)
|
||||
self.assertTrue( # lint-amnesty, pylint: disable=wrong-assert-type
|
||||
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1
|
||||
)
|
||||
assert (EnterpriseCustomerUser.objects
|
||||
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1)
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(enterprise_customer=enterprise_customer,
|
||||
provider_id=self.provider.provider_id)
|
||||
|
||||
@@ -229,10 +228,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
redirect_field_name=auth.REDIRECT_FIELD_NAME,
|
||||
request=request
|
||||
)
|
||||
self.assertNotEqual(
|
||||
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(),
|
||||
0
|
||||
)
|
||||
assert EnterpriseCustomerUser.objects\
|
||||
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() != 0
|
||||
|
||||
# Fire off the disconnect pipeline to unlink.
|
||||
self.assert_redirect_after_pipeline_completes(
|
||||
@@ -247,10 +244,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
# Now we expect to be in the unlinked state, with no backend entry.
|
||||
self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False)
|
||||
self.assert_social_auth_does_not_exist_for_user(user, strategy)
|
||||
self.assertEqual(
|
||||
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(),
|
||||
0
|
||||
)
|
||||
assert EnterpriseCustomerUser.objects\
|
||||
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0
|
||||
|
||||
def get_response_data(self):
|
||||
"""Gets dict (string -> object) of merged data about the user."""
|
||||
@@ -269,8 +264,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
# The user clicks on the TestShib button:
|
||||
try_login_response = self.client.get(testshib_login_url)
|
||||
# The user should be redirected to back to the login page:
|
||||
self.assertEqual(try_login_response.status_code, 302)
|
||||
self.assertEqual(try_login_response['Location'], self.login_page_url)
|
||||
assert try_login_response.status_code == 302
|
||||
assert try_login_response['Location'] == self.login_page_url
|
||||
# When loading the login page, the user will see an error message:
|
||||
response = self.client.get(self.login_page_url)
|
||||
self.assertContains(response, 'Authentication with TestShib is currently unavailable.')
|
||||
@@ -294,12 +289,11 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
user=self.user, provider=self.PROVIDER_BACKEND, uid__startswith=self.PROVIDER_IDP_SLUG
|
||||
)
|
||||
attributes = record.extra_data
|
||||
self.assertEqual(
|
||||
attributes.get("urn:oid:1.3.6.1.4.1.5923.1.1.1.9"), ["Member@testshib.org", "Staff@testshib.org"]
|
||||
)
|
||||
self.assertEqual(attributes.get("urn:oid:2.5.4.3"), ["Me Myself And I"])
|
||||
self.assertEqual(attributes.get("urn:oid:0.9.2342.19200300.100.1.1"), ["myself"])
|
||||
self.assertEqual(attributes.get("urn:oid:2.5.4.20"), ["555-5555"]) # Phone number
|
||||
assert attributes.get('urn:oid:1.3.6.1.4.1.5923.1.1.1.9') == ['Member@testshib.org', 'Staff@testshib.org']
|
||||
assert attributes.get('urn:oid:2.5.4.3') == ['Me Myself And I']
|
||||
assert attributes.get('urn:oid:0.9.2342.19200300.100.1.1') == ['myself']
|
||||
assert attributes.get('urn:oid:2.5.4.20') == ['555-5555']
|
||||
# Phone number
|
||||
|
||||
@ddt.data(True, False)
|
||||
def test_debug_mode_login(self, debug_mode_enabled):
|
||||
@@ -310,30 +304,30 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
if debug_mode_enabled:
|
||||
# We expect that test_login() does two full logins, and each attempt generates two
|
||||
# logs - one for the request and one for the response
|
||||
self.assertEqual(mock_log.call_count, 4)
|
||||
assert mock_log.call_count == 4
|
||||
|
||||
expected_next_url = "/dashboard"
|
||||
(msg, action_type, idp_name, request_data, next_url, xml), _kwargs = mock_log.call_args_list[0]
|
||||
self.assertTrue(msg.startswith(u"SAML login %s"))
|
||||
self.assertEqual(action_type, "request")
|
||||
self.assertEqual(idp_name, self.PROVIDER_IDP_SLUG)
|
||||
assert msg.startswith(u'SAML login %s')
|
||||
assert action_type == 'request'
|
||||
assert idp_name == self.PROVIDER_IDP_SLUG
|
||||
self.assertDictContainsSubset(
|
||||
{"idp": idp_name, "auth_entry": "login", "next": expected_next_url},
|
||||
request_data
|
||||
)
|
||||
self.assertEqual(next_url, expected_next_url)
|
||||
self.assertIn('<samlp:AuthnRequest', xml)
|
||||
assert next_url == expected_next_url
|
||||
assert '<samlp:AuthnRequest' in xml
|
||||
|
||||
(msg, action_type, idp_name, response_data, next_url, xml), _kwargs = mock_log.call_args_list[1]
|
||||
self.assertTrue(msg.startswith(u"SAML login %s"))
|
||||
self.assertEqual(action_type, "response")
|
||||
self.assertEqual(idp_name, self.PROVIDER_IDP_SLUG)
|
||||
assert msg.startswith(u'SAML login %s')
|
||||
assert action_type == 'response'
|
||||
assert idp_name == self.PROVIDER_IDP_SLUG
|
||||
self.assertDictContainsSubset({"RelayState": idp_name}, response_data)
|
||||
self.assertIn('SAMLResponse', response_data)
|
||||
self.assertEqual(next_url, expected_next_url)
|
||||
self.assertIn('<saml2p:Response', xml)
|
||||
assert 'SAMLResponse' in response_data
|
||||
assert next_url == expected_next_url
|
||||
assert '<saml2p:Response' in xml
|
||||
else:
|
||||
self.assertFalse(mock_log.called)
|
||||
assert not mock_log.called
|
||||
|
||||
def test_configure_testshib_provider_with_cache_duration(self):
|
||||
""" Enable and configure the TestShib SAML IdP as a third_party_auth provider """
|
||||
@@ -347,14 +341,14 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
|
||||
kwargs.setdefault('icon_class', 'fa-university')
|
||||
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
|
||||
self.configure_saml_provider(**kwargs)
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
assert httpretty.is_enabled()
|
||||
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
|
||||
self.assertEqual(num_total, 1)
|
||||
self.assertEqual(num_skipped, 0)
|
||||
self.assertEqual(num_attempted, 1)
|
||||
self.assertEqual(num_updated, 1)
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(len(failure_messages), 0)
|
||||
assert num_total == 1
|
||||
assert num_skipped == 0
|
||||
assert num_attempted == 1
|
||||
assert num_updated == 1
|
||||
assert num_failed == 0
|
||||
assert len(failure_messages) == 0
|
||||
|
||||
def test_login_with_testshib_provider_short_session_length(self):
|
||||
"""
|
||||
@@ -403,10 +397,10 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
|
||||
"""
|
||||
Return a fake assertion after checking that the input is what we expect.
|
||||
"""
|
||||
self.assertIn(b'private_key=fake_private_key_here', _request.body)
|
||||
self.assertIn(b'user_id=myself', _request.body)
|
||||
self.assertIn(b'token_url=http%3A%2F%2Fsuccessfactors.com%2Foauth%2Ftoken', _request.body)
|
||||
self.assertIn(b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB', _request.body)
|
||||
assert b'private_key=fake_private_key_here' in _request.body
|
||||
assert b'user_id=myself' in _request.body
|
||||
assert b'token_url=http%3A%2F%2Fsuccessfactors.com%2Foauth%2Ftoken' in _request.body
|
||||
assert b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB' in _request.body
|
||||
return (200, headers, 'fake_saml_assertion')
|
||||
|
||||
httpretty.register_uri(httpretty.POST, SAPSF_ASSERTION_URL, content_type='text/plain', body=assertion_callback)
|
||||
@@ -428,10 +422,10 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
|
||||
"""
|
||||
Return a fake assertion after checking that the input is what we expect.
|
||||
"""
|
||||
self.assertIn(b'assertion=fake_saml_assertion', _request.body)
|
||||
self.assertIn(b'company_id=NCC1701D', _request.body)
|
||||
self.assertIn(b'grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Asaml2-bearer', _request.body)
|
||||
self.assertIn(b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB', _request.body)
|
||||
assert b'assertion=fake_saml_assertion' in _request.body
|
||||
assert b'company_id=NCC1701D' in _request.body
|
||||
assert b'grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Asaml2-bearer' in _request.body
|
||||
assert b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB' in _request.body
|
||||
return (200, headers, '{"access_token": "faketoken"}')
|
||||
|
||||
httpretty.register_uri(httpretty.POST, SAPSF_TOKEN_URL, content_type='application/json', body=token_callback)
|
||||
@@ -444,7 +438,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
|
||||
|
||||
def user_callback(request, _uri, headers):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
self.assertEqual(auth_header, 'Bearer faketoken')
|
||||
assert auth_header == 'Bearer faketoken'
|
||||
return (
|
||||
200,
|
||||
headers,
|
||||
@@ -543,7 +537,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
|
||||
|
||||
def user_callback(request, _uri, headers):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
self.assertEqual(auth_header, 'Bearer faketoken')
|
||||
assert auth_header == 'Bearer faketoken'
|
||||
return (
|
||||
200,
|
||||
headers,
|
||||
@@ -714,13 +708,13 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
|
||||
with LogCapture(level=logging.WARNING) as log_capture:
|
||||
self._test_register()
|
||||
logging_messages = str([log_msg.getMessage() for log_msg in log_capture.records]).replace('\\', '')
|
||||
self.assertIn(odata_company_id, logging_messages)
|
||||
self.assertIn(mocked_odata_api_url, logging_messages)
|
||||
self.assertIn(self.USER_USERNAME, logging_messages)
|
||||
self.assertIn("SAPSuccessFactors", logging_messages)
|
||||
self.assertIn("Error message", logging_messages)
|
||||
self.assertIn("System message", logging_messages)
|
||||
self.assertIn("Headers", logging_messages)
|
||||
assert odata_company_id in logging_messages
|
||||
assert mocked_odata_api_url in logging_messages
|
||||
assert self.USER_USERNAME in logging_messages
|
||||
assert 'SAPSuccessFactors' in logging_messages
|
||||
assert 'Error message' in logging_messages
|
||||
assert 'System message' in logging_messages
|
||||
assert 'Headers' in logging_messages
|
||||
|
||||
@skip('Test not necessary for this subclass')
|
||||
def test_get_saml_idp_class_with_fake_identifier(self):
|
||||
|
||||
@@ -56,8 +56,8 @@ class Oauth2ProviderConfigAdminTest(testutil.TestCase):
|
||||
|
||||
# Get the provider instance with active flag
|
||||
providers = OAuth2ProviderConfig.objects.all()
|
||||
self.assertEqual(len(providers), 1)
|
||||
self.assertEqual(providers[pcount].id, provider1.id)
|
||||
assert len(providers) == 1
|
||||
assert providers[pcount].id == provider1.id
|
||||
|
||||
# Edit the provider via the admin edit link
|
||||
admin = OAuth2ProviderConfigAdmin(provider1, AdminSite())
|
||||
@@ -77,14 +77,14 @@ class Oauth2ProviderConfigAdminTest(testutil.TestCase):
|
||||
|
||||
# Post the edit form: expecting redirect
|
||||
response = self.client.post(update_url, post_data)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
assert response.status_code == 302
|
||||
|
||||
# Editing the existing provider creates a new provider instance
|
||||
providers = OAuth2ProviderConfig.objects.all()
|
||||
self.assertEqual(len(providers), pcount + 2)
|
||||
self.assertEqual(providers[pcount].id, provider1.id)
|
||||
assert len(providers) == (pcount + 2)
|
||||
assert providers[pcount].id == provider1.id
|
||||
provider2 = providers[pcount + 1]
|
||||
|
||||
# Ensure the icon_image was preserved on the new provider instance
|
||||
self.assertEqual(provider2.icon_image, provider1.icon_image)
|
||||
self.assertEqual(provider2.name, post_data['name'])
|
||||
assert provider2.icon_image == provider1.icon_image
|
||||
assert provider2.name == post_data['name']
|
||||
|
||||
@@ -48,11 +48,11 @@ class TestXFrameWhitelistDecorator(TestCase):
|
||||
|
||||
response = mock_view(request)
|
||||
|
||||
self.assertEqual(response['X-Frame-Options'], expected_result)
|
||||
assert response['X-Frame-Options'] == expected_result
|
||||
|
||||
@ddt.data('http://localhost/login', 'http://not-a-real-domain.com', None)
|
||||
def test_feature_flag_off(self, url):
|
||||
with self.settings(FEATURES={'ENABLE_THIRD_PARTY_AUTH': False}):
|
||||
request = self.construct_request(url)
|
||||
response = mock_view(request)
|
||||
self.assertEqual(response['X-Frame-Options'], 'DENY')
|
||||
assert response['X-Frame-Options'] == 'DENY'
|
||||
|
||||
@@ -3,7 +3,8 @@ Unit tests for the IdentityServer3 OAuth2 Backend
|
||||
"""
|
||||
import json
|
||||
import ddt
|
||||
import unittest # lint-amnesty, pylint: disable=unused-import, wrong-import-order
|
||||
import pytest # pylint: disable=unused-import
|
||||
|
||||
from common.djangoapps.third_party_auth.identityserver3 import IdentityServer3
|
||||
from common.djangoapps.third_party_auth.tests import testutil
|
||||
from common.djangoapps.third_party_auth.tests.utils import skip_unless_thirdpartyauth
|
||||
@@ -41,21 +42,21 @@ class IdentityServer3Test(testutil.TestCase):
|
||||
make sure the "sub" claim works properly to grab user Id
|
||||
"""
|
||||
response = {"sub": 1, "email": "example@example.com"}
|
||||
self.assertEqual(self.id3_instance.get_user_id({}, response), 1)
|
||||
assert self.id3_instance.get_user_id({}, response) == 1
|
||||
|
||||
def test_key_error_thrown_with_no_sub(self):
|
||||
"""
|
||||
test that a KeyError is thrown if the "sub" claim does not exist
|
||||
"""
|
||||
response = {"id": 1}
|
||||
self.assertRaises(TypeError, self.id3_instance.get_user_id({}, response))
|
||||
assert self.id3_instance.get_user_id({}, response) is None
|
||||
|
||||
def test_proper_config_access(self):
|
||||
"""
|
||||
test that the IdentityServer3 model properly grabs OAuth2Configs
|
||||
"""
|
||||
provider_config = self.configure_identityServer3_provider(backend_name="identityServer3")
|
||||
self.assertEqual(self.id3_instance.get_config(), provider_config)
|
||||
assert self.id3_instance.get_config() == provider_config
|
||||
|
||||
def test_config_after_updating(self):
|
||||
"""
|
||||
@@ -66,8 +67,8 @@ class IdentityServer3Test(testutil.TestCase):
|
||||
slug="updated",
|
||||
backend_name="identityServer3"
|
||||
)
|
||||
self.assertEqual(self.id3_instance.get_config(), updated_provider_config)
|
||||
self.assertNotEqual(self.id3_instance.get_config(), original_provider_config)
|
||||
assert self.id3_instance.get_config() == updated_provider_config
|
||||
assert self.id3_instance.get_config() != original_provider_config
|
||||
|
||||
@ddt.data(
|
||||
('first_name_claim_key', 'given_name', 'first_name', 'Edx'),
|
||||
@@ -91,7 +92,7 @@ class IdentityServer3Test(testutil.TestCase):
|
||||
setting_field_key: setting_field_value,
|
||||
})
|
||||
)
|
||||
self.assertEqual(provider_config.backend_class().get_user_details(self.response)[output_name], output_value)
|
||||
assert provider_config.backend_class().get_user_details(self.response)[output_name] == output_value
|
||||
|
||||
def test_user_details_without_settings(self):
|
||||
"""
|
||||
|
||||
@@ -21,9 +21,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
details = lti.get_user_details({LTI_PARAMS_KEY: {
|
||||
'lis_person_name_full': 'Full name'
|
||||
}})
|
||||
self.assertEqual(details, {
|
||||
'fullname': 'Full name'
|
||||
})
|
||||
assert details == {'fullname': 'Full name'}
|
||||
|
||||
def test_get_user_details_extra_keys(self):
|
||||
lti = LTIAuthBackend()
|
||||
@@ -34,12 +32,8 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
'email': 'user@example.com',
|
||||
'other': 'something else'
|
||||
}})
|
||||
self.assertEqual(details, {
|
||||
'fullname': 'Full name',
|
||||
'first_name': 'Given',
|
||||
'last_name': 'Family',
|
||||
'email': 'user@example.com'
|
||||
})
|
||||
assert details == {'fullname': 'Full name', 'first_name': 'Given',
|
||||
'last_name': 'Family', 'email': 'user@example.com'}
|
||||
|
||||
def test_get_user_id(self):
|
||||
lti = LTIAuthBackend()
|
||||
@@ -47,7 +41,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
'oauth_consumer_key': 'consumer',
|
||||
'user_id': 'user'
|
||||
}})
|
||||
self.assertEqual(user_id, 'consumer:user')
|
||||
assert user_id == 'consumer:user'
|
||||
|
||||
def test_validate_lti_valid_request(self):
|
||||
request = Request(
|
||||
@@ -60,7 +54,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertTrue(parameters)
|
||||
assert parameters
|
||||
self.assertDictContainsSubset({
|
||||
'custom_extra': 'parameter',
|
||||
'user_id': '292832126'
|
||||
@@ -77,7 +71,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertTrue(parameters)
|
||||
assert parameters
|
||||
self.assertDictContainsSubset({
|
||||
'custom_extra': 'parameter',
|
||||
'user_id': '292832126'
|
||||
@@ -94,7 +88,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertFalse(parameters)
|
||||
assert not parameters
|
||||
|
||||
def test_validate_lti_invalid_signature(self):
|
||||
request = Request(
|
||||
@@ -107,7 +101,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertFalse(parameters)
|
||||
assert not parameters
|
||||
|
||||
def test_validate_lti_cannot_add_get_params(self):
|
||||
request = Request(
|
||||
@@ -120,7 +114,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertFalse(parameters)
|
||||
assert not parameters
|
||||
|
||||
def test_validate_lti_garbage(self):
|
||||
request = Request(
|
||||
@@ -133,4 +127,4 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
|
||||
lti_consumer_valid=True, lti_consumer_secret='secret',
|
||||
lti_max_timestamp_age=10
|
||||
)
|
||||
self.assertFalse(parameters)
|
||||
assert not parameters
|
||||
|
||||
@@ -40,5 +40,5 @@ class ThirdPartyAuthMiddlewareTestCase(TestCase):
|
||||
)
|
||||
target_url = response.url
|
||||
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertTrue(target_url.endswith(login_url))
|
||||
assert response.status_code == 302
|
||||
assert target_url.endswith(login_url)
|
||||
|
||||
@@ -23,7 +23,7 @@ class ProviderUserStateTestCase(testutil.TestCase):
|
||||
def test_get_unlink_form_name(self):
|
||||
google_provider = self.configure_google_provider(enabled=True)
|
||||
state = pipeline.ProviderUserState(google_provider, object(), None)
|
||||
self.assertEqual(google_provider.provider_id + '_unlink_form', state.get_unlink_form_name())
|
||||
assert (google_provider.provider_id + '_unlink_form') == state.get_unlink_form_name()
|
||||
|
||||
@ddt.data(
|
||||
('saml', 'tpa-saml'),
|
||||
@@ -52,7 +52,7 @@ class ProviderUserStateTestCase(testutil.TestCase):
|
||||
}
|
||||
with simulate_running_pipeline("common.djangoapps.third_party_auth.pipeline", backend_name, **kwargs):
|
||||
logout_url = pipeline.get_idp_logout_url_from_running_pipeline(request)
|
||||
self.assertEqual(idp_config['logout_url'], logout_url)
|
||||
assert idp_config['logout_url'] == logout_url
|
||||
|
||||
|
||||
@skip_unless_thirdpartyauth()
|
||||
@@ -98,4 +98,4 @@ class PipelineOverridesTest(SamlIntegrationTestUtilities, IntegrationTestMixin,
|
||||
type(uuid4).hex = mock.PropertyMock(return_value='9fe2c4e93f654fdbb24c02b15259716c')
|
||||
mock_uuid.return_value = uuid4
|
||||
final_username = pipeline.get_username(strategy, details, self.provider.backend_class())
|
||||
self.assertEqual(expected_username, final_username['username'])
|
||||
assert expected_username == final_username['username']
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import datetime
|
||||
import unittest # lint-amnesty, pylint: disable=unused-import
|
||||
import pytest
|
||||
|
||||
import ddt
|
||||
import mock
|
||||
@@ -43,32 +43,32 @@ class GetAuthenticatedUserTestCase(TestCase):
|
||||
return social_models.DjangoStorage.user.user_model().objects.get(username=username)
|
||||
|
||||
def test_raises_does_not_exist_if_user_missing(self):
|
||||
with self.assertRaises(models.User.DoesNotExist):
|
||||
with pytest.raises(models.User.DoesNotExist):
|
||||
pipeline.get_authenticated_user(self.enabled_provider, 'new_' + self.user.username, 'user@example.com')
|
||||
|
||||
def test_raises_does_not_exist_if_user_found_but_no_association(self):
|
||||
backend_name = 'backend'
|
||||
|
||||
self.assertIsNotNone(self.get_by_username(self.user.username))
|
||||
self.assertFalse(any(provider.Registry.get_enabled_by_backend_name(backend_name)))
|
||||
assert self.get_by_username(self.user.username) is not None
|
||||
assert not any(provider.Registry.get_enabled_by_backend_name(backend_name))
|
||||
|
||||
with self.assertRaises(models.User.DoesNotExist):
|
||||
with pytest.raises(models.User.DoesNotExist):
|
||||
pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'user@example.com')
|
||||
|
||||
def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self):
|
||||
self.assertIsNotNone(self.get_by_username(self.user.username))
|
||||
assert self.get_by_username(self.user.username) is not None
|
||||
social_models.DjangoStorage.user.create_social_auth(
|
||||
self.user, 'uid', 'other_' + self.enabled_provider.backend_name)
|
||||
|
||||
with self.assertRaises(models.User.DoesNotExist):
|
||||
with pytest.raises(models.User.DoesNotExist):
|
||||
pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
|
||||
|
||||
def test_returns_user_with_is_authenticated_and_backend_set_if_match(self):
|
||||
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.backend_name)
|
||||
user = pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
|
||||
|
||||
self.assertEqual(self.user, user)
|
||||
self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend)
|
||||
assert self.user == user
|
||||
assert self.enabled_provider.get_authentication_backend() == user.backend
|
||||
|
||||
|
||||
class GetProviderUserStatesTestCase(TestCase):
|
||||
@@ -80,8 +80,8 @@ class GetProviderUserStatesTestCase(TestCase):
|
||||
self.user = social_models.DjangoStorage.user.create_user(username='username', password='password')
|
||||
|
||||
def test_returns_empty_list_if_no_enabled_providers(self):
|
||||
self.assertFalse(provider.Registry.enabled())
|
||||
self.assertEqual([], pipeline.get_provider_user_states(self.user))
|
||||
assert not provider.Registry.enabled()
|
||||
assert [] == pipeline.get_provider_user_states(self.user)
|
||||
|
||||
def test_state_not_returned_for_disabled_provider(self):
|
||||
disabled_provider = self.configure_google_provider(enabled=False)
|
||||
@@ -89,9 +89,9 @@ class GetProviderUserStatesTestCase(TestCase):
|
||||
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', disabled_provider.backend_name)
|
||||
states = pipeline.get_provider_user_states(self.user)
|
||||
|
||||
self.assertEqual(1, len(states))
|
||||
self.assertNotIn(disabled_provider.provider_id, (state.provider.provider_id for state in states))
|
||||
self.assertIn(enabled_provider.provider_id, (state.provider.provider_id for state in states))
|
||||
assert 1 == len(states)
|
||||
assert disabled_provider.provider_id not in (state.provider.provider_id for state in states)
|
||||
assert enabled_provider.provider_id in (state.provider.provider_id for state in states)
|
||||
|
||||
def test_states_for_enabled_providers_user_has_accounts_associated_with(self):
|
||||
# Enable two providers - Google and LinkedIn:
|
||||
@@ -103,48 +103,48 @@ class GetProviderUserStatesTestCase(TestCase):
|
||||
self.user, 'uid', linkedin_provider.backend_name)
|
||||
states = pipeline.get_provider_user_states(self.user)
|
||||
|
||||
self.assertEqual(2, len(states))
|
||||
assert 2 == len(states)
|
||||
|
||||
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
|
||||
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
|
||||
|
||||
self.assertTrue(google_state.has_account)
|
||||
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
|
||||
assert google_state.has_account
|
||||
assert google_provider.provider_id == google_state.provider.provider_id
|
||||
# Also check the row ID. Note this 'id' changes whenever the configuration does:
|
||||
self.assertEqual(google_provider.id, google_state.provider.id)
|
||||
self.assertEqual(self.user, google_state.user)
|
||||
self.assertEqual(user_social_auth_google.id, google_state.association_id)
|
||||
assert google_provider.id == google_state.provider.id
|
||||
assert self.user == google_state.user
|
||||
assert user_social_auth_google.id == google_state.association_id
|
||||
|
||||
self.assertTrue(linkedin_state.has_account)
|
||||
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
|
||||
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id)
|
||||
self.assertEqual(self.user, linkedin_state.user)
|
||||
self.assertEqual(user_social_auth_linkedin.id, linkedin_state.association_id)
|
||||
assert linkedin_state.has_account
|
||||
assert linkedin_provider.provider_id == linkedin_state.provider.provider_id
|
||||
assert linkedin_provider.id == linkedin_state.provider.id
|
||||
assert self.user == linkedin_state.user
|
||||
assert user_social_auth_linkedin.id == linkedin_state.association_id
|
||||
|
||||
def test_states_for_enabled_providers_user_has_no_account_associated_with(self):
|
||||
# Enable two providers - Google and LinkedIn:
|
||||
google_provider = self.configure_google_provider(enabled=True)
|
||||
linkedin_provider = self.configure_linkedin_provider(enabled=True)
|
||||
self.assertEqual(len(provider.Registry.enabled()), 2)
|
||||
assert len(provider.Registry.enabled()) == 2
|
||||
|
||||
states = pipeline.get_provider_user_states(self.user)
|
||||
|
||||
self.assertEqual([], [x for x in social_models.DjangoStorage.user.objects.all()]) # lint-amnesty, pylint: disable=unnecessary-comprehension
|
||||
self.assertEqual(2, len(states))
|
||||
assert [] == list(social_models.DjangoStorage.user.objects.all())
|
||||
assert 2 == len(states)
|
||||
|
||||
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
|
||||
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
|
||||
|
||||
self.assertFalse(google_state.has_account)
|
||||
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
|
||||
assert not google_state.has_account
|
||||
assert google_provider.provider_id == google_state.provider.provider_id
|
||||
# Also check the row ID. Note this 'id' changes whenever the configuration does:
|
||||
self.assertEqual(google_provider.id, google_state.provider.id)
|
||||
self.assertEqual(self.user, google_state.user)
|
||||
assert google_provider.id == google_state.provider.id
|
||||
assert self.user == google_state.user
|
||||
|
||||
self.assertFalse(linkedin_state.has_account)
|
||||
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
|
||||
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id)
|
||||
self.assertEqual(self.user, linkedin_state.user)
|
||||
assert not linkedin_state.has_account
|
||||
assert linkedin_provider.provider_id == linkedin_state.provider.provider_id
|
||||
assert linkedin_provider.id == linkedin_state.provider.id
|
||||
assert self.user == linkedin_state.user
|
||||
|
||||
|
||||
class UrlFormationTestCase(TestCase):
|
||||
@@ -153,62 +153,59 @@ class UrlFormationTestCase(TestCase):
|
||||
def test_complete_url_raises_value_error_if_provider_not_enabled(self):
|
||||
provider_name = 'oa2-not-enabled'
|
||||
|
||||
self.assertIsNone(provider.Registry.get(provider_name))
|
||||
assert provider.Registry.get(provider_name) is None
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_complete_url(provider_name)
|
||||
|
||||
def test_complete_url_returns_expected_format(self):
|
||||
complete_url = pipeline.get_complete_url(self.enabled_provider.backend_name)
|
||||
|
||||
self.assertTrue(complete_url.startswith('/auth/complete'))
|
||||
self.assertIn(self.enabled_provider.backend_name, complete_url)
|
||||
assert complete_url.startswith('/auth/complete')
|
||||
assert self.enabled_provider.backend_name in complete_url
|
||||
|
||||
def test_disconnect_url_raises_value_error_if_provider_not_enabled(self):
|
||||
provider_name = 'oa2-not-enabled'
|
||||
|
||||
self.assertIsNone(provider.Registry.get(provider_name))
|
||||
assert provider.Registry.get(provider_name) is None
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_disconnect_url(provider_name, 1000)
|
||||
|
||||
def test_disconnect_url_returns_expected_format(self):
|
||||
disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.provider_id, 1000)
|
||||
disconnect_url = disconnect_url.rstrip('?')
|
||||
self.assertEqual(
|
||||
disconnect_url,
|
||||
'/auth/disconnect/{backend}/{association_id}/'.format(
|
||||
backend=self.enabled_provider.backend_name, association_id=1000)
|
||||
)
|
||||
assert disconnect_url == '/auth/disconnect/{backend}/{association_id}/'\
|
||||
.format(backend=self.enabled_provider.backend_name, association_id=1000)
|
||||
|
||||
def test_login_url_raises_value_error_if_provider_not_enabled(self):
|
||||
provider_id = 'oa2-not-enabled'
|
||||
|
||||
self.assertIsNone(provider.Registry.get(provider_id))
|
||||
assert provider.Registry.get(provider_id) is None
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
|
||||
|
||||
def test_login_url_returns_expected_format(self):
|
||||
login_url = pipeline.get_login_url(self.enabled_provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)
|
||||
|
||||
self.assertTrue(login_url.startswith('/auth/login'))
|
||||
self.assertIn(self.enabled_provider.backend_name, login_url)
|
||||
self.assertTrue(login_url.endswith(pipeline.AUTH_ENTRY_LOGIN))
|
||||
assert login_url.startswith('/auth/login')
|
||||
assert self.enabled_provider.backend_name in login_url
|
||||
assert login_url.endswith(pipeline.AUTH_ENTRY_LOGIN)
|
||||
|
||||
def test_for_value_error_if_provider_id_invalid(self):
|
||||
provider_id = 'invalid' # Format is normally "{prefix}-{identifier}"
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
provider.Registry.get(provider_id)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_disconnect_url(provider_id, 1000)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.get_complete_url(provider_id)
|
||||
|
||||
|
||||
@@ -242,7 +239,7 @@ class TestPipelineUtilityFunctions(TestCase):
|
||||
with mock.patch('common.djangoapps.third_party_auth.pipeline.get') as get_pipeline:
|
||||
get_pipeline.return_value = pipeline_partial
|
||||
real_social = pipeline.get_real_social_auth_object(request)
|
||||
self.assertEqual(real_social, self.social_auth)
|
||||
assert real_social == self.social_auth
|
||||
|
||||
def test_get_real_social_auth(self):
|
||||
"""
|
||||
@@ -259,7 +256,7 @@ class TestPipelineUtilityFunctions(TestCase):
|
||||
with mock.patch('common.djangoapps.third_party_auth.pipeline.get') as get_pipeline:
|
||||
get_pipeline.return_value = pipeline_partial
|
||||
real_social = pipeline.get_real_social_auth_object(request)
|
||||
self.assertEqual(real_social, self.social_auth)
|
||||
assert real_social == self.social_auth
|
||||
|
||||
def test_get_real_social_auth_no_pipeline(self):
|
||||
"""
|
||||
@@ -268,7 +265,7 @@ class TestPipelineUtilityFunctions(TestCase):
|
||||
"""
|
||||
request = mock.MagicMock(session={})
|
||||
real_social = pipeline.get_real_social_auth_object(request)
|
||||
self.assertEqual(real_social, None)
|
||||
assert real_social is None
|
||||
|
||||
def test_get_real_social_auth_no_social(self):
|
||||
"""
|
||||
@@ -283,7 +280,7 @@ class TestPipelineUtilityFunctions(TestCase):
|
||||
}
|
||||
)
|
||||
real_social = pipeline.get_real_social_auth_object(request)
|
||||
self.assertEqual(real_social, None)
|
||||
assert real_social is None
|
||||
|
||||
def test_quarantine(self):
|
||||
"""
|
||||
@@ -294,12 +291,10 @@ class TestPipelineUtilityFunctions(TestCase):
|
||||
session={}
|
||||
)
|
||||
pipeline.quarantine_session(request, locations=('my_totally_real_module', 'other_real_module',))
|
||||
self.assertEqual(
|
||||
request.session['third_party_auth_quarantined_modules'],
|
||||
('my_totally_real_module', 'other_real_module',),
|
||||
)
|
||||
assert request.session['third_party_auth_quarantined_modules'] ==\
|
||||
('my_totally_real_module', 'other_real_module')
|
||||
pipeline.lift_quarantine(request)
|
||||
self.assertNotIn('third_party_auth_quarantined_modules', request.session)
|
||||
assert 'third_party_auth_quarantined_modules' not in request.session
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
@@ -597,4 +592,4 @@ class SetIDVerificationStatusTestCase(TestCase):
|
||||
)
|
||||
|
||||
# Ensure a verification signal was sent
|
||||
self.assertEqual(mock_signal.call_count, 1)
|
||||
assert mock_signal.call_count == 1
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
|
||||
import re
|
||||
import unittest # lint-amnesty, pylint: disable=unused-import
|
||||
|
||||
from django.contrib.sites.models import Site
|
||||
from django.db import connections, DEFAULT_DB_ALIAS
|
||||
@@ -23,37 +22,37 @@ class RegistryTest(testutil.TestCase):
|
||||
|
||||
def test_configure_once_adds_gettable_providers(self):
|
||||
facebook_provider = self.configure_facebook_provider(enabled=True)
|
||||
self.assertEqual(facebook_provider.id, provider.Registry.get(facebook_provider.provider_id).id)
|
||||
assert facebook_provider.id == provider.Registry.get(facebook_provider.provider_id).id
|
||||
|
||||
def test_no_providers_by_default(self):
|
||||
enabled_providers = provider.Registry.enabled()
|
||||
self.assertEqual(len(enabled_providers), 0, "By default, no providers are enabled.")
|
||||
assert len(enabled_providers) == 0, 'By default, no providers are enabled.'
|
||||
|
||||
def test_runtime_configuration(self):
|
||||
self.configure_google_provider(enabled=True)
|
||||
enabled_providers = provider.Registry.enabled()
|
||||
self.assertEqual(len(enabled_providers), 1)
|
||||
self.assertEqual(enabled_providers[0].name, "Google")
|
||||
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "opensesame")
|
||||
assert len(enabled_providers) == 1
|
||||
assert enabled_providers[0].name == 'Google'
|
||||
assert enabled_providers[0].get_setting('SECRET') == 'opensesame'
|
||||
|
||||
self.configure_google_provider(enabled=False)
|
||||
enabled_providers = provider.Registry.enabled()
|
||||
self.assertEqual(len(enabled_providers), 0)
|
||||
assert len(enabled_providers) == 0
|
||||
|
||||
self.configure_google_provider(enabled=True, secret="alohomora")
|
||||
enabled_providers = provider.Registry.enabled()
|
||||
self.assertEqual(len(enabled_providers), 1)
|
||||
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "alohomora")
|
||||
assert len(enabled_providers) == 1
|
||||
assert enabled_providers[0].get_setting('SECRET') == 'alohomora'
|
||||
|
||||
def test_secure_configuration(self):
|
||||
""" Test that some sensitive values can be configured via Django settings """
|
||||
self.configure_google_provider(enabled=True, secret="")
|
||||
enabled_providers = provider.Registry.enabled()
|
||||
self.assertEqual(len(enabled_providers), 1)
|
||||
self.assertEqual(enabled_providers[0].name, "Google")
|
||||
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "")
|
||||
assert len(enabled_providers) == 1
|
||||
assert enabled_providers[0].name == 'Google'
|
||||
assert enabled_providers[0].get_setting('SECRET') == ''
|
||||
with self.settings(SOCIAL_AUTH_OAUTH_SECRETS={'google-oauth2': 'secret42'}):
|
||||
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "secret42")
|
||||
assert enabled_providers[0].get_setting('SECRET') == 'secret42'
|
||||
|
||||
def test_cannot_load_arbitrary_backends(self):
|
||||
""" Test that only backend_names listed in settings.AUTHENTICATION_BACKENDS can be used """
|
||||
@@ -65,7 +64,7 @@ class RegistryTest(testutil.TestCase):
|
||||
slug="test",
|
||||
backend_name="disallowed"
|
||||
)
|
||||
self.assertEqual(len(provider.Registry.enabled()), 0)
|
||||
assert len(provider.Registry.enabled()) == 0
|
||||
|
||||
def test_enabled_returns_list_of_enabled_providers_sorted_by_name(self):
|
||||
provider_names = ["Stack Overflow", "Google", "LinkedIn", "GitHub"]
|
||||
@@ -76,7 +75,7 @@ class RegistryTest(testutil.TestCase):
|
||||
self.configure_oauth_provider(enabled=True, name=name, backend_name=backend_name)
|
||||
|
||||
with patch('common.djangoapps.third_party_auth.provider._PSA_OAUTH2_BACKENDS', backend_names):
|
||||
self.assertEqual(sorted(provider_names), [prov.name for prov in provider.Registry.enabled()])
|
||||
assert sorted(provider_names) == [prov.name for prov in provider.Registry.enabled()]
|
||||
|
||||
def test_enabled_doesnt_query_site(self):
|
||||
"""Regression test for 1+N queries for django_site (ARCHBOM-1139)"""
|
||||
@@ -90,11 +89,12 @@ class RegistryTest(testutil.TestCase):
|
||||
with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as cq:
|
||||
enabled_slugs = {p.slug for p in provider.Registry.enabled()}
|
||||
|
||||
self.assertEqual(len(enabled_slugs), provider_count)
|
||||
assert len(enabled_slugs) == provider_count
|
||||
# Should not involve any queries for Site, or at least should not *scale* with number of providers
|
||||
all_queries = [q['sql'] for q in cq.captured_queries]
|
||||
django_site_queries = list(filter(re_django_site_query.search, all_queries))
|
||||
self.assertEqual(len(django_site_queries), 0) # previously was == provider_count (1 for each provider)
|
||||
assert len(django_site_queries) == 0
|
||||
# previously was == provider_count (1 for each provider)
|
||||
|
||||
def test_providers_displayed_for_login(self):
|
||||
"""
|
||||
@@ -107,11 +107,11 @@ class RegistryTest(testutil.TestCase):
|
||||
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
|
||||
no_log_in_provider = self.configure_lti_provider()
|
||||
provider_ids = [idp.provider_id for idp in provider.Registry.displayed_for_login()]
|
||||
self.assertNotIn(hidden_provider.provider_id, provider_ids)
|
||||
self.assertNotIn(implicitly_hidden_provider.provider_id, provider_ids)
|
||||
self.assertNotIn(disabled_provider.provider_id, provider_ids)
|
||||
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
|
||||
self.assertIn(normal_provider.provider_id, provider_ids)
|
||||
assert hidden_provider.provider_id not in provider_ids
|
||||
assert implicitly_hidden_provider.provider_id not in provider_ids
|
||||
assert disabled_provider.provider_id not in provider_ids
|
||||
assert no_log_in_provider.provider_id not in provider_ids
|
||||
assert normal_provider.provider_id in provider_ids
|
||||
|
||||
def test_tpa_hint_provider_displayed_for_login(self):
|
||||
"""
|
||||
@@ -125,7 +125,7 @@ class RegistryTest(testutil.TestCase):
|
||||
idp.provider_id
|
||||
for idp in provider.Registry.displayed_for_login(tpa_hint=hidden_provider.provider_id)
|
||||
]
|
||||
self.assertIn(hidden_provider.provider_id, provider_ids)
|
||||
assert hidden_provider.provider_id in provider_ids
|
||||
|
||||
# New providers are hidden (ie, not flagged as 'visible') by default
|
||||
# The tpa_hint parameter should work for these providers as well
|
||||
@@ -134,7 +134,7 @@ class RegistryTest(testutil.TestCase):
|
||||
idp.provider_id
|
||||
for idp in provider.Registry.displayed_for_login(tpa_hint=implicitly_hidden_provider.provider_id)
|
||||
]
|
||||
self.assertIn(implicitly_hidden_provider.provider_id, provider_ids)
|
||||
assert implicitly_hidden_provider.provider_id in provider_ids
|
||||
|
||||
# Disabled providers should not be matched in tpa_hint scenarios
|
||||
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
|
||||
@@ -142,7 +142,7 @@ class RegistryTest(testutil.TestCase):
|
||||
idp.provider_id
|
||||
for idp in provider.Registry.displayed_for_login(tpa_hint=disabled_provider.provider_id)
|
||||
]
|
||||
self.assertNotIn(disabled_provider.provider_id, provider_ids)
|
||||
assert disabled_provider.provider_id not in provider_ids
|
||||
|
||||
# Providers not utilized for learner authentication should not match tpa_hint
|
||||
no_log_in_provider = self.configure_lti_provider()
|
||||
@@ -150,14 +150,14 @@ class RegistryTest(testutil.TestCase):
|
||||
idp.provider_id
|
||||
for idp in provider.Registry.displayed_for_login(tpa_hint=no_log_in_provider.provider_id)
|
||||
]
|
||||
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
|
||||
assert no_log_in_provider.provider_id not in provider_ids
|
||||
|
||||
def test_provider_enabled_for_current_site(self):
|
||||
"""
|
||||
Verify that enabled_for_current_site returns True when the provider matches the current site.
|
||||
"""
|
||||
prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current())
|
||||
self.assertEqual(prov.enabled_for_current_site, True)
|
||||
assert prov.enabled_for_current_site is True
|
||||
|
||||
@with_site_configuration(SITE_DOMAIN_A)
|
||||
def test_provider_disabled_for_mismatching_site(self):
|
||||
@@ -166,11 +166,11 @@ class RegistryTest(testutil.TestCase):
|
||||
"""
|
||||
site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0]
|
||||
prov = self.configure_google_provider(visible=True, enabled=True, site=site_b)
|
||||
self.assertEqual(prov.enabled_for_current_site, False)
|
||||
assert prov.enabled_for_current_site is False
|
||||
|
||||
def test_get_returns_enabled_provider(self):
|
||||
google_provider = self.configure_google_provider(enabled=True)
|
||||
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
|
||||
assert google_provider.id == provider.Registry.get(google_provider.provider_id).id
|
||||
|
||||
def test_oauth2_provider_keyed_by_slug(self):
|
||||
"""
|
||||
@@ -178,15 +178,15 @@ class RegistryTest(testutil.TestCase):
|
||||
which doesn't match any of the possible backend_names.
|
||||
"""
|
||||
google_provider = self.configure_google_provider(enabled=True, slug='custom_slug')
|
||||
self.assertIn(google_provider, provider.Registry._enabled_providers()) # lint-amnesty, pylint: disable=protected-access
|
||||
self.assertIn(google_provider, provider.Registry.get_enabled_by_backend_name('google-oauth2'))
|
||||
assert google_provider in provider.Registry._enabled_providers() # pylint: disable=protected-access
|
||||
assert google_provider in provider.Registry.get_enabled_by_backend_name('google-oauth2')
|
||||
|
||||
def test_oath2_different_slug_from_backend_name(self):
|
||||
"""
|
||||
Test that an OAuth2 provider can have a slug that differs from the backend name.
|
||||
"""
|
||||
dummy_provider = self.configure_oauth_provider(enabled=True, name="dummy", slug="default", backend_name="dummy")
|
||||
self.assertIn(dummy_provider, provider.Registry.get_enabled_by_backend_name('dummy'))
|
||||
assert dummy_provider in provider.Registry.get_enabled_by_backend_name('dummy')
|
||||
|
||||
def test_oauth2_enabled_only_for_supplied_backend(self):
|
||||
"""
|
||||
@@ -195,32 +195,32 @@ class RegistryTest(testutil.TestCase):
|
||||
"""
|
||||
facebook_provider = self.configure_facebook_provider(enabled=True)
|
||||
self.configure_google_provider(enabled=True)
|
||||
self.assertNotIn(facebook_provider, provider.Registry.get_enabled_by_backend_name('google-oauth2'))
|
||||
assert facebook_provider not in provider.Registry.get_enabled_by_backend_name('google-oauth2')
|
||||
|
||||
def test_get_returns_none_if_provider_id_is_none(self):
|
||||
self.assertIsNone(provider.Registry.get(None))
|
||||
assert provider.Registry.get(None) is None
|
||||
|
||||
def test_get_returns_none_if_provider_not_enabled(self):
|
||||
linkedin_provider_id = "oa2-linkedin-oauth2"
|
||||
# At this point there should be no configuration entries at all so no providers should be enabled
|
||||
self.assertEqual(provider.Registry.enabled(), [])
|
||||
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
|
||||
assert provider.Registry.enabled() == []
|
||||
assert provider.Registry.get(linkedin_provider_id) is None
|
||||
# Now explicitly disabled this provider:
|
||||
self.configure_linkedin_provider(enabled=False)
|
||||
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
|
||||
assert provider.Registry.get(linkedin_provider_id) is None
|
||||
self.configure_linkedin_provider(enabled=True)
|
||||
self.assertEqual(provider.Registry.get(linkedin_provider_id).provider_id, linkedin_provider_id)
|
||||
assert provider.Registry.get(linkedin_provider_id).provider_id == linkedin_provider_id
|
||||
|
||||
def test_get_from_pipeline_returns_none_if_provider_not_enabled(self):
|
||||
self.assertEqual(provider.Registry.enabled(), [], "By default, no providers are enabled.")
|
||||
self.assertIsNone(provider.Registry.get_from_pipeline(Mock()))
|
||||
assert provider.Registry.enabled() == [], 'By default, no providers are enabled.'
|
||||
assert provider.Registry.get_from_pipeline(Mock()) is None
|
||||
|
||||
def test_get_enabled_by_backend_name_returns_enabled_provider(self):
|
||||
google_provider = self.configure_google_provider(enabled=True)
|
||||
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
|
||||
self.assertEqual(found, [google_provider])
|
||||
assert found == [google_provider]
|
||||
|
||||
def test_get_enabled_by_backend_name_returns_none_if_provider_not_enabled(self):
|
||||
google_provider = self.configure_google_provider(enabled=False)
|
||||
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
|
||||
self.assertEqual(found, [])
|
||||
assert found == []
|
||||
|
||||
@@ -26,9 +26,9 @@ class TestEdXSAMLIdentityProvider(SAMLTestCase):
|
||||
u'[THIRD_PARTY_AUTH] Invalid EdXSAMLIdentityProvider subclass--'
|
||||
u'using EdXSAMLIdentityProvider base class. Provider: {provider}'.format(provider='fake_idp_class_option')
|
||||
)
|
||||
self.assertIs(idp_class, EdXSAMLIdentityProvider)
|
||||
assert idp_class is EdXSAMLIdentityProvider
|
||||
|
||||
def test_get_user_details(self):
|
||||
""" test get_attr and get_user_details of EdXSAMLIdentityProvider"""
|
||||
edx_saml_identity_provider = EdXSAMLIdentityProvider('demo', **mock_conf)
|
||||
self.assertEqual(edx_saml_identity_provider.get_user_details(mock_attributes), expected_user_details)
|
||||
assert edx_saml_identity_provider.get_user_details(mock_attributes) == expected_user_details
|
||||
|
||||
@@ -37,33 +37,34 @@ class SettingsUnitTest(testutil.TestCase):
|
||||
|
||||
def test_apply_settings_adds_exception_middleware(self):
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertIn('common.djangoapps.third_party_auth.middleware.ExceptionMiddleware', self.settings.MIDDLEWARE)
|
||||
assert 'common.djangoapps.third_party_auth.middleware.ExceptionMiddleware' in self.settings.MIDDLEWARE
|
||||
|
||||
def test_apply_settings_adds_fields_stored_in_session(self):
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertEqual(['auth_entry', 'next'], self.settings.FIELDS_STORED_IN_SESSION)
|
||||
assert ['auth_entry', 'next'] == self.settings.FIELDS_STORED_IN_SESSION
|
||||
|
||||
@skip_unless_thirdpartyauth()
|
||||
def test_apply_settings_enables_no_providers_by_default(self):
|
||||
# Providers are only enabled via ConfigurationModels in the database
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertEqual([], provider.Registry.enabled())
|
||||
assert [] == provider.Registry.enabled()
|
||||
|
||||
def test_apply_settings_turns_off_raising_social_exceptions(self):
|
||||
# Guard against submitting a conf change that's convenient in dev but
|
||||
# bad in prod.
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertFalse(self.settings.SOCIAL_AUTH_RAISE_EXCEPTIONS)
|
||||
assert not self.settings.SOCIAL_AUTH_RAISE_EXCEPTIONS
|
||||
|
||||
def test_apply_settings_turns_off_redirect_sanitization(self):
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertFalse(self.settings.SOCIAL_AUTH_SANITIZE_REDIRECTS)
|
||||
assert not self.settings.SOCIAL_AUTH_SANITIZE_REDIRECTS
|
||||
|
||||
def test_apply_settings_avoids_default_username_check(self):
|
||||
# Avoid the default username check where non-ascii characters are not
|
||||
# allowed when unicode username is enabled
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertTrue(self.settings.SOCIAL_AUTH_CLEAN_USERNAMES) # verify default behavior
|
||||
assert self.settings.SOCIAL_AUTH_CLEAN_USERNAMES
|
||||
# verify default behavior
|
||||
with patch.dict('django.conf.settings.FEATURES', {'ENABLE_UNICODE_USERNAME': True}):
|
||||
settings.apply_settings(self.settings)
|
||||
self.assertFalse(self.settings.SOCIAL_AUTH_CLEAN_USERNAMES)
|
||||
assert not self.settings.SOCIAL_AUTH_CLEAN_USERNAMES
|
||||
|
||||
@@ -32,21 +32,11 @@ class TestUtils(TestCase):
|
||||
"""
|
||||
# Create users from factory
|
||||
UserFactory(username='test_user', email='test_user@example.com')
|
||||
self.assertTrue(
|
||||
user_exists({'username': 'test_user', 'email': 'test_user@example.com'}),
|
||||
)
|
||||
self.assertTrue(
|
||||
user_exists({'username': 'test_user'}),
|
||||
)
|
||||
self.assertTrue(
|
||||
user_exists({'email': 'test_user@example.com'}),
|
||||
)
|
||||
self.assertFalse(
|
||||
user_exists({'username': 'invalid_user'}),
|
||||
)
|
||||
self.assertTrue(
|
||||
user_exists({'username': 'TesT_User'})
|
||||
)
|
||||
assert user_exists({'username': 'test_user', 'email': 'test_user@example.com'})
|
||||
assert user_exists({'username': 'test_user'})
|
||||
assert user_exists({'email': 'test_user@example.com'})
|
||||
assert not user_exists({'username': 'invalid_user'})
|
||||
assert user_exists({'username': 'TesT_User'})
|
||||
|
||||
def test_convert_saml_slug_provider_id(self):
|
||||
"""
|
||||
@@ -55,13 +45,9 @@ class TestUtils(TestCase):
|
||||
provider_names = {'saml-samltest': 'samltest', 'saml-example': 'example'}
|
||||
for provider_id in provider_names:
|
||||
# provider_id -> slug
|
||||
self.assertEqual(
|
||||
convert_saml_slug_provider_id(provider_id), provider_names[provider_id]
|
||||
)
|
||||
assert convert_saml_slug_provider_id(provider_id) == provider_names[provider_id]
|
||||
# slug -> provider_id
|
||||
self.assertEqual(
|
||||
convert_saml_slug_provider_id(provider_names[provider_id]), provider_id
|
||||
)
|
||||
assert convert_saml_slug_provider_id(provider_names[provider_id]) == provider_id
|
||||
|
||||
def test_get_user(self):
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ Test the views served by third_party_auth.
|
||||
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import ddt
|
||||
from django.conf import settings
|
||||
@@ -32,15 +33,15 @@ class SAMLMetadataTest(SAMLTestCase):
|
||||
""" When SAML is not enabled, the metadata view should return 404 """
|
||||
self.enable_saml(enabled=False)
|
||||
response = self.client.get(self.METADATA_URL)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_metadata(self):
|
||||
self.enable_saml()
|
||||
doc = self._fetch_metadata()
|
||||
# Check the ACS URL:
|
||||
acs_node = doc.find(".//{}".format(etree.QName(SAML_XML_NS, 'AssertionConsumerService')))
|
||||
self.assertIsNotNone(acs_node)
|
||||
self.assertEqual(acs_node.attrib['Location'], 'http://example.none/auth/complete/tpa-saml/')
|
||||
assert acs_node is not None
|
||||
assert acs_node.attrib['Location'] == 'http://example.none/auth/complete/tpa-saml/'
|
||||
|
||||
def test_default_contact_info(self):
|
||||
self.enable_saml()
|
||||
@@ -90,7 +91,7 @@ class SAMLMetadataTest(SAMLTestCase):
|
||||
private_key='',
|
||||
other_config_str='{"SECURITY_CONFIG": {"signMetadata": true} }',
|
||||
)
|
||||
with self.assertRaises(OneLogin_Saml2_Error):
|
||||
with pytest.raises(OneLogin_Saml2_Error):
|
||||
self._fetch_metadata() # OneLogin_Saml2_Error: Cannot sign metadata: missing SP private key.
|
||||
with self.settings(
|
||||
SOCIAL_AUTH_SAML_SP_PRIVATE_KEY=self._get_private_key('saml_key'),
|
||||
@@ -102,40 +103,40 @@ class SAMLMetadataTest(SAMLTestCase):
|
||||
""" Fetch the SAML metadata and do some validation """
|
||||
doc = self._fetch_metadata()
|
||||
sig_node = doc.find(".//{}".format(etree.QName(XMLDSIG_XML_NS, 'SignatureValue')))
|
||||
self.assertIsNotNone(sig_node)
|
||||
assert sig_node is not None
|
||||
# Check that the right public key was used:
|
||||
pub_key_node = doc.find(".//{}".format(etree.QName(XMLDSIG_XML_NS, 'X509Certificate')))
|
||||
self.assertIsNotNone(pub_key_node)
|
||||
self.assertIn(pub_key_starts_with, pub_key_node.text)
|
||||
assert pub_key_node is not None
|
||||
assert pub_key_starts_with in pub_key_node.text
|
||||
|
||||
def _fetch_metadata(self):
|
||||
""" Fetch and parse the metadata XML at self.METADATA_URL """
|
||||
response = self.client.get(self.METADATA_URL)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response['Content-Type'], 'text/xml')
|
||||
assert response.status_code == 200
|
||||
assert response['Content-Type'] == 'text/xml'
|
||||
# The result should be valid XML:
|
||||
try:
|
||||
metadata_doc = etree.fromstring(response.content)
|
||||
except etree.LxmlError:
|
||||
self.fail('SAML metadata must be valid XML')
|
||||
self.assertEqual(metadata_doc.tag, etree.QName(SAML_XML_NS, 'EntityDescriptor'))
|
||||
assert metadata_doc.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor')
|
||||
return metadata_doc
|
||||
|
||||
def check_metadata_contacts(self, xml, tech_name, tech_email, support_name, support_email):
|
||||
""" Validate that the contact info in the metadata has the expected values """
|
||||
technical_node = xml.find(".//{}[@contactType='technical']".format(etree.QName(SAML_XML_NS, 'ContactPerson')))
|
||||
self.assertIsNotNone(technical_node)
|
||||
assert technical_node is not None
|
||||
tech_name_node = technical_node.find(etree.QName(SAML_XML_NS, 'GivenName'))
|
||||
self.assertEqual(tech_name_node.text, tech_name)
|
||||
assert tech_name_node.text == tech_name
|
||||
tech_email_node = technical_node.find(etree.QName(SAML_XML_NS, 'EmailAddress'))
|
||||
self.assertEqual(tech_email_node.text, tech_email)
|
||||
assert tech_email_node.text == tech_email
|
||||
|
||||
support_node = xml.find(".//{}[@contactType='support']".format(etree.QName(SAML_XML_NS, 'ContactPerson')))
|
||||
self.assertIsNotNone(support_node)
|
||||
assert support_node is not None
|
||||
support_name_node = support_node.find(etree.QName(SAML_XML_NS, 'GivenName'))
|
||||
self.assertEqual(support_name_node.text, support_name)
|
||||
assert support_name_node.text == support_name
|
||||
support_email_node = support_node.find(etree.QName(SAML_XML_NS, 'EmailAddress'))
|
||||
self.assertEqual(support_email_node.text, support_email)
|
||||
assert support_email_node.text == support_email
|
||||
|
||||
|
||||
@unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
|
||||
@@ -149,13 +150,13 @@ class SAMLAuthTest(SAMLTestCase):
|
||||
""" Accessing the login endpoint without an idp query param should return 302 """
|
||||
self.enable_saml()
|
||||
response = self.client.get(self.LOGIN_URL)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_login_disabled(self):
|
||||
""" When SAML is not enabled, the login view should return 404 """
|
||||
self.enable_saml(enabled=False)
|
||||
response = self.client.get(self.LOGIN_URL)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
|
||||
@@ -179,15 +180,15 @@ class IdPRedirectViewTest(SAMLTestCase):
|
||||
|
||||
response = self.client.get(endpoint_url)
|
||||
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, expected_url)
|
||||
assert response.status_code == 302
|
||||
assert response.url == expected_url
|
||||
|
||||
def test_with_invalid_provider_slug(self):
|
||||
endpoint_url = self.get_idp_redirect_url('saml-test-invalid')
|
||||
|
||||
response = self.client.get(endpoint_url)
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
assert response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def get_idp_redirect_url(provider_slug, next_destination=None):
|
||||
|
||||
@@ -85,10 +85,8 @@ class ThirdPartyAuthTestMixin(object):
|
||||
|
||||
def configure_saml_provider(self, **kwargs):
|
||||
""" Update the settings for a SAML-based third party auth provider """
|
||||
self.assertTrue(
|
||||
SAMLConfiguration.is_enabled(Site.objects.get_current(), 'default'),
|
||||
"SAML Provider Configuration only works if SAML is enabled."
|
||||
)
|
||||
assert SAMLConfiguration.is_enabled(Site.objects.get_current(), 'default'), \
|
||||
'SAML Provider Configuration only works if SAML is enabled.'
|
||||
obj = SAMLProviderConfig(**kwargs)
|
||||
obj.save()
|
||||
return obj
|
||||
|
||||
Reference in New Issue
Block a user