replaced unittest assertions pytest assertions (#26240)

This commit is contained in:
Aarif
2021-02-12 12:31:37 +05:00
committed by GitHub
parent 00a0672f4b
commit e6a0d35009
25 changed files with 448 additions and 506 deletions

View File

@@ -54,7 +54,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
def test_anonymous_fails(self):
request = self._create_request()
response = self.SomeTpaClassView().dispatch(request)
self.assertEqual(response.status_code, 401)
assert response.status_code == 401
@ddt.data(
(True, False, 200),
@@ -68,7 +68,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
self._create_session(request, user)
response = self.SomeTpaClassView().dispatch(request)
self.assertEqual(response.status_code, expected_status_code)
assert response.status_code == expected_status_code
@ddt.data(
# unrestricted (for example, jwt cookies)
@@ -97,7 +97,7 @@ class ThirdPartyAuthPermissionTest(TestCase):
)
response = self.SomeTpaClassView().dispatch(request)
self.assertEqual(response.status_code, expected_response)
assert response.status_code == expected_response
@ddt.data(
# valid scopes
@@ -152,4 +152,4 @@ class ThirdPartyAuthPermissionTest(TestCase):
request = self._create_request(auth_header=auth_header)
response = self.SomeTpaClassView().dispatch(request, provider_id='some_tpa_provider')
self.assertEqual(response.status_code, expected_response)
assert response.status_code == expected_response

View File

@@ -131,9 +131,9 @@ class UserViewsMixin(object):
url = self.make_url({'username': target_user})
response = self.client.get(url)
self.assertEqual(response.status_code, expect_result)
assert response.status_code == expect_result
if expect_result == 200:
self.assertIn("active", response.data)
assert 'active' in response.data
six.assertCountEqual(self, response.data["active"], self.expected_active(target_user))
@ddt.data(
@@ -147,9 +147,9 @@ class UserViewsMixin(object):
def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
url = self.make_url({'username': target_user})
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
self.assertEqual(response.status_code, expect_result)
assert response.status_code == expect_result
if expect_result == 200:
self.assertIn("active", response.data)
assert 'active' in response.data
six.assertCountEqual(self, response.data["active"], self.expected_active(target_user))
@ddt.data(
@@ -164,18 +164,18 @@ class UserViewsMixin(object):
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
url = self.make_url({'username': ALICE_USERNAME})
response = self.client.get(url)
self.assertEqual(response.status_code, expect)
assert response.status_code == expect
if response.status_code == 200:
self.assertGreater(len(response.data['active']), 0)
assert len(response.data['active']) > 0
for provider_data in response.data['active']:
self.assertEqual(include_remote_id, 'remote_id' in provider_data)
assert include_remote_id == ('remote_id' in provider_data)
def test_allow_query_by_email(self):
self.client.login(username=ALICE_USERNAME, password=PASSWORD)
url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertGreater(len(response.data['active']), 0)
assert response.status_code == 200
assert len(response.data['active']) > 0
def test_throttling(self):
# Default throttle is 10/min. Make 11 requests to verify
@@ -185,9 +185,9 @@ class UserViewsMixin(object):
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
for _ in range(10):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
@override_settings(EDX_API_KEY=VALID_API_KEY)
@@ -332,7 +332,7 @@ class UserMappingViewAPITests(TpaAPITestCase):
url = reverse('third_party_auth_user_mapping_api', kwargs={'provider_id': PROVIDER_ID_TESTSHIB})
response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
self._verify_response(response, 200, get_mapping_data_by_usernames(LINKED_USERS))
@ddt.data(
@@ -346,14 +346,14 @@ class UserMappingViewAPITests(TpaAPITestCase):
with patch.object(JwtRestrictedApplication, 'has_permission', return_value=has_permission):
with patch.object(JwtHasScope, 'has_permission', return_value=has_permission):
response = self.client.get(url)
self.assertEqual(response.status_code, expect)
assert response.status_code == expect
def _verify_response(self, response, expect_code, expect_result):
""" verify the items in data_list exists in response and data_results matches results in response """
self.assertEqual(response.status_code, expect_code)
assert response.status_code == expect_code
if expect_code == 200:
for item in ['results', 'count', 'num_pages']:
self.assertIn(item, response.data)
assert item in response.data
six.assertCountEqual(self, response.data['results'], expect_result)
@@ -375,15 +375,11 @@ class TestThirdPartyAuthUserStatusView(ThirdPartyAuthTestMixin, APITestCase):
"""
self.client.login(username=self.user.username, password=PASSWORD)
response = self.client.get(self.url, content_type="application/json")
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.data,
[{
"accepts_logins": True,
"name": "Google",
"disconnect_url": "/auth/disconnect/google-oauth2/?",
"connect_url": "/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings",
"connected": False,
"id": "oa2-google-oauth2"
}]
)
assert response.status_code == 200
assert (response.data ==
[{
'accepts_logins': True, 'name': 'Google',
'disconnect_url': '/auth/disconnect/google-oauth2/?',
'connect_url': '/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings',
'connected': False, 'id': 'oa2-google-oauth2'
}])

View File

@@ -7,6 +7,7 @@ import sys
import unittest
from contextlib import contextmanager
from uuid import uuid4
import pytest
from django.conf import settings
from django.core.management import call_command
@@ -69,16 +70,16 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
call_command(self.command, self.provider_hogwarts.slug, force=True)
# user with input idp is removed, along with social auth entries
with self.assertRaises(User.DoesNotExist):
with pytest.raises(User.DoesNotExist):
User.objects.get(username='harry')
with self.assertRaises(UserSocialAuth.DoesNotExist):
with pytest.raises(UserSocialAuth.DoesNotExist):
self.find_user_social_auth_entry('harry')
# other users intact
self.user_fleur.refresh_from_db()
self.user_viktor.refresh_from_db()
self.assertIsNotNone(self.user_fleur)
self.assertIsNotNone(self.user_viktor)
assert self.user_fleur is not None
assert self.user_viktor is not None
# other social auth intact
self.find_user_social_auth_entry(self.user_viktor.username)
@@ -96,9 +97,9 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
with self._replace_stdin('confirm'):
call_command(self.command, self.provider_hogwarts.slug)
with self.assertRaises(User.DoesNotExist):
with pytest.raises(User.DoesNotExist):
User.objects.get(username='harry')
with self.assertRaises(UserSocialAuth.DoesNotExist):
with pytest.raises(UserSocialAuth.DoesNotExist):
self.find_user_social_auth_entry('harry')
@override_settings(FEATURES=FEATURES_WITH_ENABLED)
@@ -109,8 +110,8 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
call_command(self.command, self.provider_hogwarts.slug)
# no users should be removed
self.assertEqual(len(User.objects.all()), 3)
self.assertEqual(len(UserSocialAuth.objects.all()), 2)
assert len(User.objects.all()) == 3
assert len(UserSocialAuth.objects.all()) == 2
def test_feature_default_disabled(self):
""" By default this command should not be enabled """

View File

@@ -105,7 +105,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
"""
expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n"
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get", mock_get())
def test_fetch_saml_metadata(self):
@@ -118,7 +118,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n"
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get", mock_get(status_code=404))
def test_fetch_saml_metadata_failure(self):
@@ -133,7 +133,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
with self.assertRaisesRegex(CommandError, r"HTTPError: 404 Client Error"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get", mock_get(status_code=200))
def test_fetch_multiple_providers_data(self):
@@ -178,7 +178,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n'
with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
# Now add a fourth configuration, and indicate that it should not be included in the update
self.__create_saml_configurations__(
@@ -201,7 +201,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n'
with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get")
def test_saml_request_exceptions(self, mocked_get):
@@ -217,19 +217,19 @@ class TestSAMLCommand(CacheIsolationTestCase):
with self.assertRaisesRegex(CommandError, "SSLError:"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
mocked_get.side_effect = exceptions.ConnectionError
with self.assertRaisesRegex(CommandError, "ConnectionError:"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
mocked_get.side_effect = exceptions.HTTPError
with self.assertRaisesRegex(CommandError, "HTTPError:"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get", mock_get(status_code=200))
def test_saml_parse_exceptions(self):
@@ -254,7 +254,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
with self.assertRaisesRegex(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()
@mock.patch("requests.get")
def test_xml_parse_exceptions(self, mocked_get):
@@ -274,4 +274,4 @@ class TestSAMLCommand(CacheIsolationTestCase):
with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"):
call_command("saml", pull=True, stdout=self.stdout)
self.assertIn(expected, self.stdout.getvalue())
assert expected in self.stdout.getvalue()

View File

@@ -83,19 +83,19 @@ class SAMLConfigurationTests(APITestCase):
def test_get_saml_configurations_successful(self):
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
assert response.status_code == status.HTTP_200_OK
# We ultimately just need ids and slugs, so let's just check those.
results = response.data['results']
self.assertEqual(results[0]['id'], SAML_CONFIGURATIONS[0]['site'])
self.assertEqual(results[0]['slug'], SAML_CONFIGURATIONS[0]['slug'])
self.assertEqual(results[1]['id'], SAML_CONFIGURATIONS[1]['site'])
self.assertEqual(results[1]['slug'], SAML_CONFIGURATIONS[1]['slug'])
assert results[0]['id'] == SAML_CONFIGURATIONS[0]['site']
assert results[0]['slug'] == SAML_CONFIGURATIONS[0]['slug']
assert results[1]['id'] == SAML_CONFIGURATIONS[1]['site']
assert results[1]['slug'] == SAML_CONFIGURATIONS[1]['slug']
def test_get_saml_configurations_noprivate(self):
# Verify we have 3 saml configuration objects: 2 public, 1 private.
total_object_count = SAMLConfiguration.objects.count()
self.assertEqual(total_object_count, 3)
assert total_object_count == 3
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
@@ -103,10 +103,10 @@ class SAMLConfigurationTests(APITestCase):
# We should only see 2 results, since 1 out of 3 are private
# and our queryset only returns public configurations.
results = response.data['results']
self.assertEqual(len(results), 2)
assert len(results) == 2
def test_unauthenticated_user_get_saml_configurations(self):
self.client.logout()
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -91,13 +91,13 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
assert response.status_code == status.HTTP_200_OK
results = response.data['results']
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['entity_id'], SINGLE_PROVIDER_CONFIG['entity_id'])
self.assertEqual(results[0]['metadata_source'], SINGLE_PROVIDER_CONFIG['metadata_source'])
self.assertEqual(response.data['results'][0]['country'], SINGLE_PROVIDER_CONFIG['country'])
self.assertEqual(SAMLProviderConfig.objects.count(), 1)
assert len(results) == 1
assert results[0]['entity_id'] == SINGLE_PROVIDER_CONFIG['entity_id']
assert results[0]['metadata_source'] == SINGLE_PROVIDER_CONFIG['metadata_source']
assert response.data['results'][0]['country'] == SINGLE_PROVIDER_CONFIG['country']
assert SAMLProviderConfig.objects.count() == 1
def test_get_one_config_by_enterprise_uuid_invalid_uuid(self):
"""
@@ -109,7 +109,7 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_get_one_config_by_enterprise_uuid_not_found(self):
"""
@@ -128,8 +128,8 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert SAMLProviderConfig.objects.count() == orig_count
def test_create_one_config(self):
"""
@@ -142,19 +142,14 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count + 1)
assert response.status_code == status.HTTP_201_CREATED
assert SAMLProviderConfig.objects.count() == (orig_count + 1)
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug'])
self.assertEqual(provider_config.name, 'name-of-config-2')
self.assertEqual(provider_config.country, SINGLE_PROVIDER_CONFIG_2['country'])
assert provider_config.name == 'name-of-config-2'
assert provider_config.country == SINGLE_PROVIDER_CONFIG_2['country']
# check association has also been created
self.assertTrue(
EnterpriseCustomerIdentityProvider.objects.filter(
provider_id=convert_saml_slug_provider_id(provider_config.slug)
).exists(),
'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
)
assert EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(provider_config.slug)).exists(), 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
def test_create_one_config_fail_non_existent_enterprise_uuid(self):
"""
@@ -167,16 +162,11 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert SAMLProviderConfig.objects.count() == orig_count
# check association has NOT been created
self.assertFalse(
EnterpriseCustomerIdentityProvider.objects.filter(
provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])
).exists(),
'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
)
assert not EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])).exists(), 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
def test_create_one_config_with_absent_enterprise_uuid(self):
"""
@@ -188,8 +178,8 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert SAMLProviderConfig.objects.count() == orig_count
def test_create_one_config_with_no_country_urn(self):
"""
@@ -206,9 +196,9 @@ class SAMLProviderConfigTests(APITestCase):
}
response = self.client.post(url, provider_config_no_country)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
assert response.status_code == status.HTTP_201_CREATED
provider_config = SAMLProviderConfig.objects.get(slug='test-slug-none')
self.assertEqual(provider_config.country, '')
assert provider_config.country == ''
def test_create_one_config_with_empty_country_urn(self):
"""
@@ -226,9 +216,9 @@ class SAMLProviderConfigTests(APITestCase):
}
response = self.client.post(url, provider_config_blank_country)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
assert response.status_code == status.HTTP_201_CREATED
provider_config = SAMLProviderConfig.objects.get(slug='test-slug-empty')
self.assertEqual(provider_config.country, '')
assert provider_config.country == ''
def test_unauthenticated_request_is_forbidden(self):
self.client.logout()
@@ -237,12 +227,12 @@ class SAMLProviderConfigTests(APITestCase):
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
assert response.status_code == status.HTTP_403_FORBIDDEN
self.client.logout()
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_create_one_config_with_samlconfiguration(self):
"""
@@ -255,6 +245,6 @@ class SAMLProviderConfigTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
assert response.status_code == status.HTTP_201_CREATED
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug'])
self.assertEqual(provider_config.saml_configuration, self.samlconfiguration)
assert provider_config.saml_configuration == self.samlconfiguration

View File

@@ -88,10 +88,10 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
assert response.status_code == status.HTTP_200_OK
results = response.data['results']
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['sso_url'], SINGLE_PROVIDER_DATA['sso_url'])
assert len(results) == 1
assert results[0]['sso_url'] == SINGLE_PROVIDER_DATA['sso_url']
def test_create_one_provider_data_success(self):
# POST auth/saml/v0/providerdata/ -d data
@@ -102,12 +102,9 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(SAMLProviderData.objects.count(), orig_count + 1)
self.assertEqual(
SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url,
SINGLE_PROVIDER_DATA_2['sso_url']
)
assert response.status_code == status.HTTP_201_CREATED
assert SAMLProviderData.objects.count() == (orig_count + 1)
assert SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url == SINGLE_PROVIDER_DATA_2['sso_url']
def test_create_one_data_with_absent_enterprise_uuid(self):
"""
@@ -119,8 +116,8 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert SAMLProviderData.objects.count() == orig_count
def test_patch_one_provider_data(self):
# PATCH auth/saml/v0/providerdata/ -d data
@@ -133,14 +130,14 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.patch(url, data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
assert response.status_code == status.HTTP_200_OK
assert SAMLProviderData.objects.count() == orig_count
# ensure only the sso_url was updated
fetched_provider_data = SAMLProviderData.objects.get(pk=self.saml_provider_data.id)
self.assertEqual(fetched_provider_data.sso_url, 'http://new.url')
self.assertEqual(fetched_provider_data.fetched_at, SINGLE_PROVIDER_DATA['fetched_at'])
self.assertEqual(fetched_provider_data.entity_id, SINGLE_PROVIDER_DATA['entity_id'])
assert fetched_provider_data.sso_url == 'http://new.url'
assert fetched_provider_data.fetched_at == SINGLE_PROVIDER_DATA['fetched_at']
assert fetched_provider_data.entity_id == SINGLE_PROVIDER_DATA['entity_id']
def test_delete_one_provider_data(self):
# DELETE auth/saml/v0/providerdata/ -d data
@@ -151,12 +148,12 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(SAMLProviderData.objects.count(), orig_count - 1)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert SAMLProviderData.objects.count() == (orig_count - 1)
# ensure only the sso_url was updated
query_set_count = SAMLProviderData.objects.filter(pk=self.saml_provider_data.id).count()
self.assertEqual(query_set_count, 0)
assert query_set_count == 0
def test_get_one_provider_data_failure(self):
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)])
@@ -167,7 +164,7 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_unauthenticated_request_is_forbidden(self):
self.client.logout()
@@ -176,10 +173,10 @@ class SAMLProviderDataTests(APITestCase):
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
assert response.status_code == status.HTTP_403_FORBIDDEN
# manually running second case as DDT is having issues.
self.client.logout()
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@@ -6,6 +6,7 @@ Base integration test for provider implementations.
import json
import unittest
from contextlib import contextmanager
import pytest
import mock
from django import test
@@ -55,8 +56,8 @@ class HelperMixin(object):
implementations may optionally strengthen this assertion with, for
example, more details about the format of the Location header.
"""
self.assertEqual(302, response.status_code)
self.assertTrue(response.has_header('Location'))
assert 302 == response.status_code
assert response.has_header('Location')
def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields): # lint-amnesty, pylint: disable=invalid-name
"""Performs spot checks of the rendered register.html page.
@@ -95,16 +96,16 @@ class HelperMixin(object):
its connected state is correct.
"""
if duplicate:
self.assertEqual(context['duplicate_provider'], self.provider.backend_name)
assert context['duplicate_provider'] == self.provider.backend_name
else:
self.assertIsNone(context['duplicate_provider'])
assert context['duplicate_provider'] is None
if linked is not None:
expected_provider = [
provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name
][0]
self.assertIsNotNone(expected_provider)
self.assertEqual(expected_provider['connected'], linked)
assert expected_provider is not None
assert expected_provider['connected'] == linked
def assert_exception_redirect_looks_correct(self, expected_uri, auth_entry=None):
"""Tests middleware conditional redirection.
@@ -118,45 +119,45 @@ class HelperMixin(object):
request, exceptions.AuthCanceled(request.backend))
location = response.get('Location')
self.assertEqual(302, response.status_code)
self.assertIn('canceled', location)
self.assertIn(self.backend_name, location)
self.assertTrue(location.startswith(expected_uri + '?'))
assert 302 == response.status_code
assert 'canceled' in location
assert self.backend_name in location
assert location.startswith((expected_uri + '?'))
def assert_json_failure_response_is_inactive_account(self, response):
"""Asserts failure on /login for inactive account looks right."""
self.assertEqual(400, response.status_code)
assert 400 == response.status_code
payload = json.loads(response.content.decode('utf-8'))
context = {
'platformName': configuration_helpers.get_value('PLATFORM_NAME', settings.PLATFORM_NAME),
'supportLink': configuration_helpers.get_value('SUPPORT_SITE_LINK', settings.SUPPORT_SITE_LINK)
}
self.assertFalse(payload.get('success'))
self.assertIn('inactive-user', payload.get('error_code'))
self.assertEqual(context, payload.get('context'))
assert not payload.get('success')
assert 'inactive-user' in payload.get('error_code')
assert context == payload.get('context')
def assert_json_failure_response_is_missing_social_auth(self, response):
"""Asserts failure on /login for missing social auth looks right."""
self.assertEqual(403, response.status_code)
assert 403 == response.status_code
payload = json.loads(response.content.decode('utf-8'))
self.assertFalse(payload.get('success'))
self.assertEqual(payload.get('error_code'), 'third-party-auth-with-no-linked-account')
assert not payload.get('success')
assert payload.get('error_code') == 'third-party-auth-with-no-linked-account'
def assert_json_failure_response_is_username_collision(self, response):
"""Asserts the json response indicates a username collision."""
self.assertEqual(409, response.status_code)
assert 409 == response.status_code
payload = json.loads(response.content.decode('utf-8'))
self.assertFalse(payload.get('success'))
self.assertIn('belongs to an existing account', payload['username'][0]['user_message'])
assert not payload.get('success')
assert 'belongs to an existing account' in payload['username'][0]['user_message']
def assert_json_success_response_looks_correct(self, response, verify_redirect_url):
"""Asserts the json response indicates success and redirection."""
self.assertEqual(200, response.status_code)
assert 200 == response.status_code
payload = json.loads(response.content.decode('utf-8'))
self.assertTrue(payload.get('success'))
assert payload.get('success')
if verify_redirect_url:
self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url'))
assert pipeline.get_complete_url(self.provider.backend_name) == payload.get('redirect_url')
def assert_login_response_before_pipeline_looks_correct(self, response):
"""Asserts a GET of /login not in the pipeline looks correct."""
@@ -167,37 +168,36 @@ class HelperMixin(object):
def assert_login_response_in_pipeline_looks_correct(self, response):
"""Asserts a GET of /login in the pipeline looks correct."""
self.assertEqual(200, response.status_code)
assert 200 == response.status_code
def assert_password_overridden_by_pipeline(self, username, password):
"""Verifies that the given password is not correct.
The pipeline overrides POST['password'], if any, with random data.
"""
self.assertIsNone(auth.authenticate(password=password, username=username))
assert auth.authenticate(password=password, username=username) is None
def assert_pipeline_running(self, request):
"""Makes sure the given request is running an auth pipeline."""
self.assertTrue(pipeline.running(request))
assert pipeline.running(request)
def assert_redirect_after_pipeline_completes(self, response, expected_redirect_url=None):
"""Asserts a response would redirect to the expected_redirect_url or SOCIAL_AUTH_LOGIN_REDIRECT_URL."""
self.assertEqual(302, response.status_code)
assert 302 == response.status_code
# NOTE: Ideally we should use assertRedirects(), however it errors out due to the hostname, testserver,
# not being properly set. This may be an issue with the call made by PSA, but we are not certain.
self.assertTrue(response.get('Location').endswith(
expected_redirect_url or django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL,
))
assert response.get('Location').endswith((expected_redirect_url or
django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL))
def assert_redirect_to_login_looks_correct(self, response):
"""Asserts a response would redirect to /login."""
self.assertEqual(302, response.status_code)
self.assertEqual('/login', response.get('Location'))
assert 302 == response.status_code
assert '/login' == response.get('Location')
def assert_redirect_to_register_looks_correct(self, response):
"""Asserts a response would redirect to /register."""
self.assertEqual(302, response.status_code)
self.assertEqual('/register', response.get('Location'))
assert 302 == response.status_code
assert '/register' == response.get('Location')
def assert_register_response_before_pipeline_looks_correct(self, response):
"""Asserts a GET of /register not in the pipeline looks correct."""
@@ -210,24 +210,21 @@ class HelperMixin(object):
"""Asserts a user does not have an auth with the expected provider."""
social_auths = strategy.storage.user.get_social_auth_for_user(
user, provider=self.provider.backend_name)
self.assertEqual(0, len(social_auths))
assert 0 == len(social_auths)
def assert_social_auth_exists_for_user(self, user, strategy):
"""Asserts a user has a social auth with the expected provider."""
social_auths = strategy.storage.user.get_social_auth_for_user(
user, provider=self.provider.backend_name)
self.assertEqual(1, len(social_auths))
self.assertEqual(self.backend_name, social_auths[0].provider)
assert 1 == len(social_auths)
assert self.backend_name == social_auths[0].provider
def assert_logged_in_cookie_redirect(self, response):
"""Verify that the user was redirected in order to set the logged in cookie. """
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
pipeline.get_complete_url(self.provider.backend_name)
)
self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true')
self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies)
assert response.status_code == 302
assert response['Location'] == pipeline.get_complete_url(self.provider.backend_name)
assert response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value == 'true'
assert django_settings.EDXMKTG_USER_INFO_COOKIE_NAME in response.cookies
@property
def backend_name(self):
@@ -384,24 +381,24 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
# The user clicks on the Dummy button:
try_login_response = self.client.get(provider_register_url)
# The user should be redirected to the provider's login page:
self.assertEqual(try_login_response.status_code, 302)
assert try_login_response.status_code == 302
provider_response = self.do_provider_login(try_login_response['Location'])
# We should be redirected to the register screen since this account is not linked to an edX account:
self.assertEqual(provider_response.status_code, 302)
self.assertEqual(provider_response['Location'], self.register_page_url)
assert provider_response.status_code == 302
assert provider_response['Location'] == self.register_page_url
register_response = self.client.get(self.register_page_url)
tpa_context = register_response.context["data"]["third_party_auth"]
self.assertEqual(tpa_context["errorMessage"], None)
assert tpa_context['errorMessage'] is None
# Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown.
self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME)
assert tpa_context['currentProvider'] == self.PROVIDER_NAME
# Check that the data (e.g. email) from the provider is displayed in the form:
form_data = register_response.context['data']['registration_form_desc']
form_fields = {field['name']: field for field in form_data['fields']}
self.assertEqual(form_fields['email']['defaultValue'], self.USER_EMAIL)
self.assertEqual(form_fields['name']['defaultValue'], self.USER_NAME)
self.assertEqual(form_fields['username']['defaultValue'], self.USER_USERNAME)
assert form_fields['email']['defaultValue'] == self.USER_EMAIL
assert form_fields['name']['defaultValue'] == self.USER_NAME
assert form_fields['username']['defaultValue'] == self.USER_USERNAME
for field_name, value in extra_defaults.items():
self.assertEqual(form_fields[field_name]['defaultValue'], value)
assert form_fields[field_name]['defaultValue'] == value
registration_values = {
'email': 'email-edited@tpa-test.none',
'name': 'My Customized Name',
@@ -413,12 +410,12 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
reverse('user_api_registration'),
registration_values
)
self.assertEqual(ajax_register_response.status_code, 200)
assert ajax_register_response.status_code == 200
# Then the AJAX will finish the third party auth:
continue_response = self.client.get(tpa_context["finishAuthUrl"])
# And we should be redirected to the dashboard:
self.assertEqual(continue_response.status_code, 302)
self.assertEqual(continue_response['Location'], reverse('dashboard'))
assert continue_response.status_code == 302
assert continue_response['Location'] == reverse('dashboard')
# Now check that we can login again, whether or not we have yet verified the account:
self.client.logout()
@@ -437,28 +434,28 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
# The user clicks on the provider's button:
try_login_response = self.client.get(provider_login_url)
# The user should be redirected to the provider's login page:
self.assertEqual(try_login_response.status_code, 302)
assert try_login_response.status_code == 302
complete_response = self.do_provider_login(try_login_response['Location'])
# We should be redirected to the login screen since this account is not linked to an edX account:
self.assertEqual(complete_response.status_code, 302)
self.assertEqual(complete_response['Location'], self.login_page_url)
assert complete_response.status_code == 302
assert complete_response['Location'] == self.login_page_url
login_response = self.client.get(self.login_page_url)
tpa_context = login_response.context["data"]["third_party_auth"]
self.assertEqual(tpa_context["errorMessage"], None)
assert tpa_context['errorMessage'] is None
# Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown.
self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME)
assert tpa_context['currentProvider'] == self.PROVIDER_NAME
# Now the user enters their username and password.
# The AJAX on the page will log them in:
ajax_login_response = self.client.post(
reverse('user_api_login_session'),
{'email': self.user.email, 'password': 'test'}
)
self.assertEqual(ajax_login_response.status_code, 200)
assert ajax_login_response.status_code == 200
# Then the AJAX will finish the third party auth:
continue_response = self.client.get(tpa_context["finishAuthUrl"])
# And we should be redirected to the dashboard:
self.assertEqual(continue_response.status_code, 302)
self.assertEqual(continue_response['Location'], reverse('dashboard'))
assert continue_response.status_code == 302
assert continue_response['Location'] == reverse('dashboard')
# Now check that we can login again:
self.client.logout()
@@ -475,30 +472,30 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
""" Test logging in to an account that is already linked. """
# Make sure we're not logged in:
dashboard_response = self.client.get(reverse('dashboard'))
self.assertEqual(dashboard_response.status_code, 302)
assert dashboard_response.status_code == 302
# The user goes to the login page, and sees a button to login with this provider:
provider_login_url = self._check_login_page()
# The user clicks on the provider's login button:
try_login_response = self.client.get(provider_login_url)
# The user should be redirected to the provider:
self.assertEqual(try_login_response.status_code, 302)
assert try_login_response.status_code == 302
login_response = self.do_provider_login(try_login_response['Location'])
# If the previous session was manually logged out, there will be one weird redirect
# required to set the login cookie (it sticks around if the main session times out):
if not previous_session_timed_out:
self.assertEqual(login_response.status_code, 302)
self.assertEqual(login_response['Location'], self.complete_url + "?")
assert login_response.status_code == 302
assert login_response['Location'] == (self.complete_url + '?')
# And then we should be redirected to the dashboard:
login_response = self.client.get(login_response['Location'])
self.assertEqual(login_response.status_code, 302)
assert login_response.status_code == 302
if user_is_activated:
url_expected = reverse('dashboard')
else:
url_expected = reverse('third_party_inactive_redirect') + '?next=' + reverse('dashboard')
self.assertEqual(login_response['Location'], url_expected)
assert login_response['Location'] == url_expected
# Now we are logged in:
dashboard_response = self.client.get(reverse('dashboard'))
self.assertEqual(dashboard_response.status_code, 200)
assert dashboard_response.status_code == 200
def _check_login_page(self):
"""
@@ -520,7 +517,7 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin):
self.assertContains(response, self.PROVIDER_NAME)
context_data = response.context['data']['third_party_auth']
provider_urls = {provider['id']: provider[url_to_return] for provider in context_data['providers']}
self.assertIn(self.PROVIDER_ID, provider_urls)
assert self.PROVIDER_ID in provider_urls
return provider_urls[self.PROVIDER_ID]
@property
@@ -668,7 +665,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
self.assert_social_auth_exists_for_user(linked_user, strategy)
self.assert_social_auth_does_not_exist_for_user(unlinked_user, strategy)
with self.assertRaises(exceptions.AuthAlreadyAssociated):
with pytest.raises(exceptions.AuthAlreadyAssociated):
# pylint: disable=protected-access
actions.do_complete(backend, social_views._do_login, user=unlinked_user, request=strategy.request)
@@ -720,7 +717,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
partial_data = strategy.storage.partial.load(partial_pipeline_token)
self.assert_social_auth_exists_for_user(user, strategy)
self.assertTrue(user.is_active)
assert user.is_active
# Begin! Ensure that the login form contains expected controls before
# the user starts the pipeline.
@@ -862,7 +859,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
strategy.request.POST = self.get_registration_post_vars({'email': email})
# The user must not exist yet...
with self.assertRaises(auth_models.User.DoesNotExist):
with pytest.raises(auth_models.User.DoesNotExist):
self.get_user_by_email(strategy, email)
# ...but when we invoke create_account the existing edX view will make
@@ -906,7 +903,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
self.assert_redirect_to_login_looks_correct(actions.do_complete(backend, social_views._do_login,
request=request))
distinct_username = pipeline.get(request)['kwargs']['username']
self.assertNotEqual(original_username, distinct_username)
assert original_username != distinct_username
def test_new_account_registration_fails_if_email_exists(self):
request, strategy = self.get_request_and_strategy(
@@ -932,17 +929,17 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
def test_pipeline_raises_auth_entry_error_if_auth_entry_invalid(self):
auth_entry = 'invalid'
self.assertNotIn(auth_entry, pipeline._AUTH_ENTRY_CHOICES) # pylint: disable=protected-access
assert auth_entry not in pipeline._AUTH_ENTRY_CHOICES # pylint: disable=protected-access
_, strategy = self.get_request_and_strategy(auth_entry=auth_entry, redirect_uri='social:complete')
with self.assertRaises(pipeline.AuthEntryError):
with pytest.raises(pipeline.AuthEntryError):
strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy))
def test_pipeline_assumes_login_if_auth_entry_missing(self):
_, strategy = self.get_request_and_strategy(auth_entry=None, redirect_uri='social:complete')
response = self.fake_auth_complete(strategy)
self.assertEqual(response.url, reverse('signin_user'))
assert response.url == reverse('signin_user')
def assert_first_party_auth_trumps_third_party_auth(self, email=None, password=None, success=None):
"""Asserts first party auth was used in place of third party auth.
@@ -974,15 +971,15 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
if success is None:
# Request malformed -- just one of email/password given.
self.assertFalse(payload.get('success'))
self.assertIn('There was an error receiving your login information', payload.get('value'))
assert not payload.get('success')
assert 'There was an error receiving your login information' in payload.get('value')
elif success:
# Request well-formed and credentials good.
self.assertTrue(payload.get('success'))
assert payload.get('success')
else:
# Request well-formed but credentials bad.
self.assertFalse(payload.get('success'))
self.assertIn('incorrect', payload.get('value'))
assert not payload.get('success')
assert 'incorrect' in payload.get('value')
def get_response_data(self):
"""Gets a dict of response data of the form given by the provider.

View File

@@ -31,5 +31,5 @@ class GenericIntegrationTest(IntegrationTestMixin, testutil.TestCase):
Mock logging in to the Dummy provider
"""
# For the Dummy provider, the provider redirect URL is self.complete_url
self.assertEqual(provider_redirect_url, self.url_prefix + self.complete_url)
assert provider_redirect_url == (self.url_prefix + self.complete_url)
return self.client.get(provider_redirect_url)

View File

@@ -51,7 +51,7 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
def assert_redirect_to_provider_looks_correct(self, response):
super(GoogleOauth2IntegrationTest, self).assert_redirect_to_provider_looks_correct(response) # lint-amnesty, pylint: disable=super-with-arguments
self.assertIn('google.com', response['Location'])
assert 'google.com' in response['Location']
def test_custom_form(self):
"""
@@ -75,27 +75,20 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete):
response = self.client.get(complete_url)
# This should redirect to the custom login/register form:
self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/auth/custom_auth_entry')
assert response.status_code == 302
assert response['Location'] == '/auth/custom_auth_entry'
response = self.client.get(response['Location'])
self.assertEqual(response.status_code, 200)
self.assertIn('action="/misc/my-custom-registration-form" method="post"', response.content.decode('utf-8'))
assert response.status_code == 200
assert 'action="/misc/my-custom-registration-form" method="post"' in response.content.decode('utf-8')
data_decoded = base64.b64decode(response.context['data']).decode('utf-8')
data_parsed = json.loads(data_decoded)
# The user's details get passed to the custom page as a base64 encoded query parameter:
self.assertEqual(data_parsed, {
'auth_entry': 'custom1',
'backend_name': 'google-oauth2',
'provider_id': 'oa2-google-oauth2',
'user_details': {
'username': 'user',
'email': 'user@email.com',
'fullname': 'name_value',
'first_name': 'given_name_value',
'last_name': 'family_name_value',
},
})
assert data_parsed == {'auth_entry': 'custom1', 'backend_name': 'google-oauth2',
'provider_id': 'oa2-google-oauth2',
'user_details': {'username': 'user', 'email': 'user@email.com',
'fullname': 'name_value', 'first_name': 'given_name_value',
'last_name': 'family_name_value'}}
# Check the hash that is used to confirm the user's data in the GET parameter is correct
secret_key = settings.THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS['custom1']['secret_key']
hmac_expected = hmac.new(
@@ -103,18 +96,18 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
msg=data_decoded.encode('utf-8'),
digestmod=hashlib.sha256
).digest()
self.assertEqual(base64.b64decode(response.context['hmac']), hmac_expected)
assert base64.b64decode(response.context['hmac']) == hmac_expected
# Now our custom registration form creates or logs in the user:
email, password = data_parsed['user_details']['email'], 'random_password'
created_user = UserFactory(email=email, password=password)
login_response = self.client.post(reverse('login_api'), {'email': email, 'password': password})
self.assertEqual(login_response.status_code, 200)
assert login_response.status_code == 200
# Now our custom login/registration page must resume the pipeline:
response = self.client.get(complete_url)
self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/misc/final-destination')
assert response.status_code == 302
assert response['Location'] == '/misc/final-destination'
_, strategy = self.get_request_and_strategy()
self.assert_social_auth_exists_for_user(created_user, strategy)
@@ -140,5 +133,5 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty,
with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete_error):
response = self.client.get(complete_url)
# This should redirect to the custom error URL
self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/misc/my-custom-sso-error-page')
assert response.status_code == 302
assert response['Location'] == '/misc/my-custom-sso-error-page'

View File

@@ -70,8 +70,8 @@ class IntegrationTestLTI(testutil.TestCase):
)
login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
# The user should be redirected to the registration form
self.assertEqual(login_response.status_code, 302)
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
assert login_response.status_code == 302
assert login_response['Location'].endswith(reverse('signin_user'))
register_response = self.client.get(login_response['Location'])
self.assertContains(register_response, '"currentProvider": "LTI Test Tool Consumer"')
self.assertContains(register_response, '"errorMessage": null')
@@ -86,15 +86,12 @@ class IntegrationTestLTI(testutil.TestCase):
'honor_code': True,
}
)
self.assertEqual(ajax_register_response.status_code, 200)
assert ajax_register_response.status_code == 200
continue_response = self.client.get(self.url_prefix + LTI_TPA_COMPLETE_URL)
# The user should be redirected to the finish_auth view which will enroll them.
# FinishAuthView.js reads the URL parameters directly from $.url
self.assertEqual(continue_response.status_code, 302)
self.assertEqual(
continue_response['Location'],
'/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll'
)
assert continue_response.status_code == 302
assert continue_response['Location'] == '/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll'
# Now check that we can login again
self.client.logout()
@@ -106,19 +103,20 @@ class IntegrationTestLTI(testutil.TestCase):
)
login_2_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
# The user should be redirected to the dashboard
self.assertEqual(login_2_response.status_code, 302)
self.assertEqual(login_2_response['Location'], LTI_TPA_COMPLETE_URL + "?")
assert login_2_response.status_code == 302
assert login_2_response['Location'] == (LTI_TPA_COMPLETE_URL + '?')
continue_2_response = self.client.get(login_2_response['Location'])
self.assertEqual(continue_2_response.status_code, 302)
self.assertTrue(continue_2_response['Location'].endswith(reverse('dashboard')))
assert continue_2_response.status_code == 302
assert continue_2_response['Location'].endswith(reverse('dashboard'))
# Check that the user was created correctly
user = User.objects.get(email=EMAIL)
self.assertEqual(user.username, EDX_USER_ID)
assert user.username == EDX_USER_ID
def test_reject_initiating_login(self):
response = self.client.get(self.url_prefix + LTI_TPA_LOGIN_URL)
self.assertEqual(response.status_code, 405) # Not Allowed
assert response.status_code == 405
# Not Allowed
def test_reject_bad_login(self):
login_response = self.client.post(
@@ -127,8 +125,8 @@ class IntegrationTestLTI(testutil.TestCase):
)
# The user should be redirected to the login page with an error message
# (auth_entry defaults to login for this provider)
self.assertEqual(login_response.status_code, 302)
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
assert login_response.status_code == 302
assert login_response['Location'].endswith(reverse('signin_user'))
error_response = self.client.get(login_response['Location'])
self.assertContains(
error_response,
@@ -152,8 +150,8 @@ class IntegrationTestLTI(testutil.TestCase):
with self.settings(SOCIAL_AUTH_LTI_CONSUMER_SECRETS={OTHER_LTI_CONSUMER_KEY: OTHER_LTI_CONSUMER_SECRET}):
login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body)
# The user should be redirected to the registration form
self.assertEqual(login_response.status_code, 302)
self.assertTrue(login_response['Location'].endswith(reverse('signin_user')))
assert login_response.status_code == 302
assert login_response['Location'].endswith(reverse('signin_user'))
register_response = self.client.get(login_response['Location'])
self.assertContains(
register_response,

View File

@@ -114,21 +114,21 @@ class SamlIntegrationTestUtilities(object):
saml_provider = self.configure_saml_provider(**kwargs) # pylint: disable=no-member
if fetch_metadata:
self.assertTrue(httpretty.is_enabled()) # lint-amnesty, pylint: disable=no-member
assert httpretty.is_enabled() # lint-amnesty, pylint: disable=no-member
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
if assert_metadata_updates:
self.assertEqual(num_total, 1) # lint-amnesty, pylint: disable=no-member
self.assertEqual(num_skipped, 0) # lint-amnesty, pylint: disable=no-member
self.assertEqual(num_attempted, 1) # lint-amnesty, pylint: disable=no-member
self.assertEqual(num_updated, 1) # lint-amnesty, pylint: disable=no-member
self.assertEqual(num_failed, 0) # lint-amnesty, pylint: disable=no-member
self.assertEqual(len(failure_messages), 0) # lint-amnesty, pylint: disable=no-member
assert num_total == 1 # lint-amnesty, pylint: disable=no-member
assert num_skipped == 0 # lint-amnesty, pylint: disable=no-member
assert num_attempted == 1 # lint-amnesty, pylint: disable=no-member
assert num_updated == 1 # lint-amnesty, pylint: disable=no-member
assert num_failed == 0 # lint-amnesty, pylint: disable=no-member
assert len(failure_messages) == 0 # lint-amnesty, pylint: disable=no-member
return saml_provider
def do_provider_login(self, provider_redirect_url):
""" Mocked: the user logs in to TestShib and then gets redirected back """
# The SAML provider (TestShib) will authenticate the user, then get the browser to POST a response:
self.assertTrue(provider_redirect_url.startswith(TESTSHIB_SSO_URL)) # lint-amnesty, pylint: disable=no-member
assert provider_redirect_url.startswith(TESTSHIB_SSO_URL) # lint-amnesty, pylint: disable=no-member
saml_response_xml = utils.read_and_pre_process_xml(
os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'testshib_saml_response.xml')
@@ -188,9 +188,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
enterprise_customer = EnterpriseCustomerFactory()
assert EnterpriseCustomerUser.objects.count() == 0, "Precondition check: no link records should exist"
EnterpriseCustomerUser.objects.link_user(enterprise_customer, user.email)
self.assertTrue( # lint-amnesty, pylint: disable=wrong-assert-type
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1
)
assert (EnterpriseCustomerUser.objects
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1)
EnterpriseCustomerIdentityProvider.objects.get_or_create(enterprise_customer=enterprise_customer,
provider_id=self.provider.provider_id)
@@ -229,10 +228,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
redirect_field_name=auth.REDIRECT_FIELD_NAME,
request=request
)
self.assertNotEqual(
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(),
0
)
assert EnterpriseCustomerUser.objects\
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() != 0
# Fire off the disconnect pipeline to unlink.
self.assert_redirect_after_pipeline_completes(
@@ -247,10 +244,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
# Now we expect to be in the unlinked state, with no backend entry.
self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False)
self.assert_social_auth_does_not_exist_for_user(user, strategy)
self.assertEqual(
EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(),
0
)
assert EnterpriseCustomerUser.objects\
.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0
def get_response_data(self):
"""Gets dict (string -> object) of merged data about the user."""
@@ -269,8 +264,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
# The user clicks on the TestShib button:
try_login_response = self.client.get(testshib_login_url)
# The user should be redirected to back to the login page:
self.assertEqual(try_login_response.status_code, 302)
self.assertEqual(try_login_response['Location'], self.login_page_url)
assert try_login_response.status_code == 302
assert try_login_response['Location'] == self.login_page_url
# When loading the login page, the user will see an error message:
response = self.client.get(self.login_page_url)
self.assertContains(response, 'Authentication with TestShib is currently unavailable.')
@@ -294,12 +289,11 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
user=self.user, provider=self.PROVIDER_BACKEND, uid__startswith=self.PROVIDER_IDP_SLUG
)
attributes = record.extra_data
self.assertEqual(
attributes.get("urn:oid:1.3.6.1.4.1.5923.1.1.1.9"), ["Member@testshib.org", "Staff@testshib.org"]
)
self.assertEqual(attributes.get("urn:oid:2.5.4.3"), ["Me Myself And I"])
self.assertEqual(attributes.get("urn:oid:0.9.2342.19200300.100.1.1"), ["myself"])
self.assertEqual(attributes.get("urn:oid:2.5.4.20"), ["555-5555"]) # Phone number
assert attributes.get('urn:oid:1.3.6.1.4.1.5923.1.1.1.9') == ['Member@testshib.org', 'Staff@testshib.org']
assert attributes.get('urn:oid:2.5.4.3') == ['Me Myself And I']
assert attributes.get('urn:oid:0.9.2342.19200300.100.1.1') == ['myself']
assert attributes.get('urn:oid:2.5.4.20') == ['555-5555']
# Phone number
@ddt.data(True, False)
def test_debug_mode_login(self, debug_mode_enabled):
@@ -310,30 +304,30 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
if debug_mode_enabled:
# We expect that test_login() does two full logins, and each attempt generates two
# logs - one for the request and one for the response
self.assertEqual(mock_log.call_count, 4)
assert mock_log.call_count == 4
expected_next_url = "/dashboard"
(msg, action_type, idp_name, request_data, next_url, xml), _kwargs = mock_log.call_args_list[0]
self.assertTrue(msg.startswith(u"SAML login %s"))
self.assertEqual(action_type, "request")
self.assertEqual(idp_name, self.PROVIDER_IDP_SLUG)
assert msg.startswith(u'SAML login %s')
assert action_type == 'request'
assert idp_name == self.PROVIDER_IDP_SLUG
self.assertDictContainsSubset(
{"idp": idp_name, "auth_entry": "login", "next": expected_next_url},
request_data
)
self.assertEqual(next_url, expected_next_url)
self.assertIn('<samlp:AuthnRequest', xml)
assert next_url == expected_next_url
assert '<samlp:AuthnRequest' in xml
(msg, action_type, idp_name, response_data, next_url, xml), _kwargs = mock_log.call_args_list[1]
self.assertTrue(msg.startswith(u"SAML login %s"))
self.assertEqual(action_type, "response")
self.assertEqual(idp_name, self.PROVIDER_IDP_SLUG)
assert msg.startswith(u'SAML login %s')
assert action_type == 'response'
assert idp_name == self.PROVIDER_IDP_SLUG
self.assertDictContainsSubset({"RelayState": idp_name}, response_data)
self.assertIn('SAMLResponse', response_data)
self.assertEqual(next_url, expected_next_url)
self.assertIn('<saml2p:Response', xml)
assert 'SAMLResponse' in response_data
assert next_url == expected_next_url
assert '<saml2p:Response' in xml
else:
self.assertFalse(mock_log.called)
assert not mock_log.called
def test_configure_testshib_provider_with_cache_duration(self):
""" Enable and configure the TestShib SAML IdP as a third_party_auth provider """
@@ -347,14 +341,14 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
kwargs.setdefault('icon_class', 'fa-university')
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
self.configure_saml_provider(**kwargs)
self.assertTrue(httpretty.is_enabled())
assert httpretty.is_enabled()
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
self.assertEqual(num_total, 1)
self.assertEqual(num_skipped, 0)
self.assertEqual(num_attempted, 1)
self.assertEqual(num_updated, 1)
self.assertEqual(num_failed, 0)
self.assertEqual(len(failure_messages), 0)
assert num_total == 1
assert num_skipped == 0
assert num_attempted == 1
assert num_updated == 1
assert num_failed == 0
assert len(failure_messages) == 0
def test_login_with_testshib_provider_short_session_length(self):
"""
@@ -403,10 +397,10 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
"""
Return a fake assertion after checking that the input is what we expect.
"""
self.assertIn(b'private_key=fake_private_key_here', _request.body)
self.assertIn(b'user_id=myself', _request.body)
self.assertIn(b'token_url=http%3A%2F%2Fsuccessfactors.com%2Foauth%2Ftoken', _request.body)
self.assertIn(b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB', _request.body)
assert b'private_key=fake_private_key_here' in _request.body
assert b'user_id=myself' in _request.body
assert b'token_url=http%3A%2F%2Fsuccessfactors.com%2Foauth%2Ftoken' in _request.body
assert b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB' in _request.body
return (200, headers, 'fake_saml_assertion')
httpretty.register_uri(httpretty.POST, SAPSF_ASSERTION_URL, content_type='text/plain', body=assertion_callback)
@@ -428,10 +422,10 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
"""
Return a fake assertion after checking that the input is what we expect.
"""
self.assertIn(b'assertion=fake_saml_assertion', _request.body)
self.assertIn(b'company_id=NCC1701D', _request.body)
self.assertIn(b'grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Asaml2-bearer', _request.body)
self.assertIn(b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB', _request.body)
assert b'assertion=fake_saml_assertion' in _request.body
assert b'company_id=NCC1701D' in _request.body
assert b'grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Asaml2-bearer' in _request.body
assert b'client_id=TatVotSEiCMteSNWtSOnLanCtBGwNhGB' in _request.body
return (200, headers, '{"access_token": "faketoken"}')
httpretty.register_uri(httpretty.POST, SAPSF_TOKEN_URL, content_type='application/json', body=token_callback)
@@ -444,7 +438,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
def user_callback(request, _uri, headers):
auth_header = request.headers.get('Authorization')
self.assertEqual(auth_header, 'Bearer faketoken')
assert auth_header == 'Bearer faketoken'
return (
200,
headers,
@@ -543,7 +537,7 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
def user_callback(request, _uri, headers):
auth_header = request.headers.get('Authorization')
self.assertEqual(auth_header, 'Bearer faketoken')
assert auth_header == 'Bearer faketoken'
return (
200,
headers,
@@ -714,13 +708,13 @@ class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTes
with LogCapture(level=logging.WARNING) as log_capture:
self._test_register()
logging_messages = str([log_msg.getMessage() for log_msg in log_capture.records]).replace('\\', '')
self.assertIn(odata_company_id, logging_messages)
self.assertIn(mocked_odata_api_url, logging_messages)
self.assertIn(self.USER_USERNAME, logging_messages)
self.assertIn("SAPSuccessFactors", logging_messages)
self.assertIn("Error message", logging_messages)
self.assertIn("System message", logging_messages)
self.assertIn("Headers", logging_messages)
assert odata_company_id in logging_messages
assert mocked_odata_api_url in logging_messages
assert self.USER_USERNAME in logging_messages
assert 'SAPSuccessFactors' in logging_messages
assert 'Error message' in logging_messages
assert 'System message' in logging_messages
assert 'Headers' in logging_messages
@skip('Test not necessary for this subclass')
def test_get_saml_idp_class_with_fake_identifier(self):

View File

@@ -56,8 +56,8 @@ class Oauth2ProviderConfigAdminTest(testutil.TestCase):
# Get the provider instance with active flag
providers = OAuth2ProviderConfig.objects.all()
self.assertEqual(len(providers), 1)
self.assertEqual(providers[pcount].id, provider1.id)
assert len(providers) == 1
assert providers[pcount].id == provider1.id
# Edit the provider via the admin edit link
admin = OAuth2ProviderConfigAdmin(provider1, AdminSite())
@@ -77,14 +77,14 @@ class Oauth2ProviderConfigAdminTest(testutil.TestCase):
# Post the edit form: expecting redirect
response = self.client.post(update_url, post_data)
self.assertEqual(response.status_code, 302)
assert response.status_code == 302
# Editing the existing provider creates a new provider instance
providers = OAuth2ProviderConfig.objects.all()
self.assertEqual(len(providers), pcount + 2)
self.assertEqual(providers[pcount].id, provider1.id)
assert len(providers) == (pcount + 2)
assert providers[pcount].id == provider1.id
provider2 = providers[pcount + 1]
# Ensure the icon_image was preserved on the new provider instance
self.assertEqual(provider2.icon_image, provider1.icon_image)
self.assertEqual(provider2.name, post_data['name'])
assert provider2.icon_image == provider1.icon_image
assert provider2.name == post_data['name']

View File

@@ -48,11 +48,11 @@ class TestXFrameWhitelistDecorator(TestCase):
response = mock_view(request)
self.assertEqual(response['X-Frame-Options'], expected_result)
assert response['X-Frame-Options'] == expected_result
@ddt.data('http://localhost/login', 'http://not-a-real-domain.com', None)
def test_feature_flag_off(self, url):
with self.settings(FEATURES={'ENABLE_THIRD_PARTY_AUTH': False}):
request = self.construct_request(url)
response = mock_view(request)
self.assertEqual(response['X-Frame-Options'], 'DENY')
assert response['X-Frame-Options'] == 'DENY'

View File

@@ -3,7 +3,8 @@ Unit tests for the IdentityServer3 OAuth2 Backend
"""
import json
import ddt
import unittest # lint-amnesty, pylint: disable=unused-import, wrong-import-order
import pytest # pylint: disable=unused-import
from common.djangoapps.third_party_auth.identityserver3 import IdentityServer3
from common.djangoapps.third_party_auth.tests import testutil
from common.djangoapps.third_party_auth.tests.utils import skip_unless_thirdpartyauth
@@ -41,21 +42,21 @@ class IdentityServer3Test(testutil.TestCase):
make sure the "sub" claim works properly to grab user Id
"""
response = {"sub": 1, "email": "example@example.com"}
self.assertEqual(self.id3_instance.get_user_id({}, response), 1)
assert self.id3_instance.get_user_id({}, response) == 1
def test_key_error_thrown_with_no_sub(self):
"""
test that a KeyError is thrown if the "sub" claim does not exist
"""
response = {"id": 1}
self.assertRaises(TypeError, self.id3_instance.get_user_id({}, response))
assert self.id3_instance.get_user_id({}, response) is None
def test_proper_config_access(self):
"""
test that the IdentityServer3 model properly grabs OAuth2Configs
"""
provider_config = self.configure_identityServer3_provider(backend_name="identityServer3")
self.assertEqual(self.id3_instance.get_config(), provider_config)
assert self.id3_instance.get_config() == provider_config
def test_config_after_updating(self):
"""
@@ -66,8 +67,8 @@ class IdentityServer3Test(testutil.TestCase):
slug="updated",
backend_name="identityServer3"
)
self.assertEqual(self.id3_instance.get_config(), updated_provider_config)
self.assertNotEqual(self.id3_instance.get_config(), original_provider_config)
assert self.id3_instance.get_config() == updated_provider_config
assert self.id3_instance.get_config() != original_provider_config
@ddt.data(
('first_name_claim_key', 'given_name', 'first_name', 'Edx'),
@@ -91,7 +92,7 @@ class IdentityServer3Test(testutil.TestCase):
setting_field_key: setting_field_value,
})
)
self.assertEqual(provider_config.backend_class().get_user_details(self.response)[output_name], output_value)
assert provider_config.backend_class().get_user_details(self.response)[output_name] == output_value
def test_user_details_without_settings(self):
"""

View File

@@ -21,9 +21,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
details = lti.get_user_details({LTI_PARAMS_KEY: {
'lis_person_name_full': 'Full name'
}})
self.assertEqual(details, {
'fullname': 'Full name'
})
assert details == {'fullname': 'Full name'}
def test_get_user_details_extra_keys(self):
lti = LTIAuthBackend()
@@ -34,12 +32,8 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
'email': 'user@example.com',
'other': 'something else'
}})
self.assertEqual(details, {
'fullname': 'Full name',
'first_name': 'Given',
'last_name': 'Family',
'email': 'user@example.com'
})
assert details == {'fullname': 'Full name', 'first_name': 'Given',
'last_name': 'Family', 'email': 'user@example.com'}
def test_get_user_id(self):
lti = LTIAuthBackend()
@@ -47,7 +41,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
'oauth_consumer_key': 'consumer',
'user_id': 'user'
}})
self.assertEqual(user_id, 'consumer:user')
assert user_id == 'consumer:user'
def test_validate_lti_valid_request(self):
request = Request(
@@ -60,7 +54,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertTrue(parameters)
assert parameters
self.assertDictContainsSubset({
'custom_extra': 'parameter',
'user_id': '292832126'
@@ -77,7 +71,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertTrue(parameters)
assert parameters
self.assertDictContainsSubset({
'custom_extra': 'parameter',
'user_id': '292832126'
@@ -94,7 +88,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertFalse(parameters)
assert not parameters
def test_validate_lti_invalid_signature(self):
request = Request(
@@ -107,7 +101,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertFalse(parameters)
assert not parameters
def test_validate_lti_cannot_add_get_params(self):
request = Request(
@@ -120,7 +114,7 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertFalse(parameters)
assert not parameters
def test_validate_lti_garbage(self):
request = Request(
@@ -133,4 +127,4 @@ class UnitTestLTI(unittest.TestCase, ThirdPartyAuthTestMixin):
lti_consumer_valid=True, lti_consumer_secret='secret',
lti_max_timestamp_age=10
)
self.assertFalse(parameters)
assert not parameters

View File

@@ -40,5 +40,5 @@ class ThirdPartyAuthMiddlewareTestCase(TestCase):
)
target_url = response.url
self.assertEqual(response.status_code, 302)
self.assertTrue(target_url.endswith(login_url))
assert response.status_code == 302
assert target_url.endswith(login_url)

View File

@@ -23,7 +23,7 @@ class ProviderUserStateTestCase(testutil.TestCase):
def test_get_unlink_form_name(self):
google_provider = self.configure_google_provider(enabled=True)
state = pipeline.ProviderUserState(google_provider, object(), None)
self.assertEqual(google_provider.provider_id + '_unlink_form', state.get_unlink_form_name())
assert (google_provider.provider_id + '_unlink_form') == state.get_unlink_form_name()
@ddt.data(
('saml', 'tpa-saml'),
@@ -52,7 +52,7 @@ class ProviderUserStateTestCase(testutil.TestCase):
}
with simulate_running_pipeline("common.djangoapps.third_party_auth.pipeline", backend_name, **kwargs):
logout_url = pipeline.get_idp_logout_url_from_running_pipeline(request)
self.assertEqual(idp_config['logout_url'], logout_url)
assert idp_config['logout_url'] == logout_url
@skip_unless_thirdpartyauth()
@@ -98,4 +98,4 @@ class PipelineOverridesTest(SamlIntegrationTestUtilities, IntegrationTestMixin,
type(uuid4).hex = mock.PropertyMock(return_value='9fe2c4e93f654fdbb24c02b15259716c')
mock_uuid.return_value = uuid4
final_username = pipeline.get_username(strategy, details, self.provider.backend_class())
self.assertEqual(expected_username, final_username['username'])
assert expected_username == final_username['username']

View File

@@ -2,7 +2,7 @@
import datetime
import unittest # lint-amnesty, pylint: disable=unused-import
import pytest
import ddt
import mock
@@ -43,32 +43,32 @@ class GetAuthenticatedUserTestCase(TestCase):
return social_models.DjangoStorage.user.user_model().objects.get(username=username)
def test_raises_does_not_exist_if_user_missing(self):
with self.assertRaises(models.User.DoesNotExist):
with pytest.raises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.enabled_provider, 'new_' + self.user.username, 'user@example.com')
def test_raises_does_not_exist_if_user_found_but_no_association(self):
backend_name = 'backend'
self.assertIsNotNone(self.get_by_username(self.user.username))
self.assertFalse(any(provider.Registry.get_enabled_by_backend_name(backend_name)))
assert self.get_by_username(self.user.username) is not None
assert not any(provider.Registry.get_enabled_by_backend_name(backend_name))
with self.assertRaises(models.User.DoesNotExist):
with pytest.raises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'user@example.com')
def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self):
self.assertIsNotNone(self.get_by_username(self.user.username))
assert self.get_by_username(self.user.username) is not None
social_models.DjangoStorage.user.create_social_auth(
self.user, 'uid', 'other_' + self.enabled_provider.backend_name)
with self.assertRaises(models.User.DoesNotExist):
with pytest.raises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
def test_returns_user_with_is_authenticated_and_backend_set_if_match(self):
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.backend_name)
user = pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
self.assertEqual(self.user, user)
self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend)
assert self.user == user
assert self.enabled_provider.get_authentication_backend() == user.backend
class GetProviderUserStatesTestCase(TestCase):
@@ -80,8 +80,8 @@ class GetProviderUserStatesTestCase(TestCase):
self.user = social_models.DjangoStorage.user.create_user(username='username', password='password')
def test_returns_empty_list_if_no_enabled_providers(self):
self.assertFalse(provider.Registry.enabled())
self.assertEqual([], pipeline.get_provider_user_states(self.user))
assert not provider.Registry.enabled()
assert [] == pipeline.get_provider_user_states(self.user)
def test_state_not_returned_for_disabled_provider(self):
disabled_provider = self.configure_google_provider(enabled=False)
@@ -89,9 +89,9 @@ class GetProviderUserStatesTestCase(TestCase):
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', disabled_provider.backend_name)
states = pipeline.get_provider_user_states(self.user)
self.assertEqual(1, len(states))
self.assertNotIn(disabled_provider.provider_id, (state.provider.provider_id for state in states))
self.assertIn(enabled_provider.provider_id, (state.provider.provider_id for state in states))
assert 1 == len(states)
assert disabled_provider.provider_id not in (state.provider.provider_id for state in states)
assert enabled_provider.provider_id in (state.provider.provider_id for state in states)
def test_states_for_enabled_providers_user_has_accounts_associated_with(self):
# Enable two providers - Google and LinkedIn:
@@ -103,48 +103,48 @@ class GetProviderUserStatesTestCase(TestCase):
self.user, 'uid', linkedin_provider.backend_name)
states = pipeline.get_provider_user_states(self.user)
self.assertEqual(2, len(states))
assert 2 == len(states)
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
self.assertTrue(google_state.has_account)
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
assert google_state.has_account
assert google_provider.provider_id == google_state.provider.provider_id
# Also check the row ID. Note this 'id' changes whenever the configuration does:
self.assertEqual(google_provider.id, google_state.provider.id)
self.assertEqual(self.user, google_state.user)
self.assertEqual(user_social_auth_google.id, google_state.association_id)
assert google_provider.id == google_state.provider.id
assert self.user == google_state.user
assert user_social_auth_google.id == google_state.association_id
self.assertTrue(linkedin_state.has_account)
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id)
self.assertEqual(self.user, linkedin_state.user)
self.assertEqual(user_social_auth_linkedin.id, linkedin_state.association_id)
assert linkedin_state.has_account
assert linkedin_provider.provider_id == linkedin_state.provider.provider_id
assert linkedin_provider.id == linkedin_state.provider.id
assert self.user == linkedin_state.user
assert user_social_auth_linkedin.id == linkedin_state.association_id
def test_states_for_enabled_providers_user_has_no_account_associated_with(self):
# Enable two providers - Google and LinkedIn:
google_provider = self.configure_google_provider(enabled=True)
linkedin_provider = self.configure_linkedin_provider(enabled=True)
self.assertEqual(len(provider.Registry.enabled()), 2)
assert len(provider.Registry.enabled()) == 2
states = pipeline.get_provider_user_states(self.user)
self.assertEqual([], [x for x in social_models.DjangoStorage.user.objects.all()]) # lint-amnesty, pylint: disable=unnecessary-comprehension
self.assertEqual(2, len(states))
assert [] == list(social_models.DjangoStorage.user.objects.all())
assert 2 == len(states)
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
self.assertFalse(google_state.has_account)
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
assert not google_state.has_account
assert google_provider.provider_id == google_state.provider.provider_id
# Also check the row ID. Note this 'id' changes whenever the configuration does:
self.assertEqual(google_provider.id, google_state.provider.id)
self.assertEqual(self.user, google_state.user)
assert google_provider.id == google_state.provider.id
assert self.user == google_state.user
self.assertFalse(linkedin_state.has_account)
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id)
self.assertEqual(self.user, linkedin_state.user)
assert not linkedin_state.has_account
assert linkedin_provider.provider_id == linkedin_state.provider.provider_id
assert linkedin_provider.id == linkedin_state.provider.id
assert self.user == linkedin_state.user
class UrlFormationTestCase(TestCase):
@@ -153,62 +153,59 @@ class UrlFormationTestCase(TestCase):
def test_complete_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_name))
assert provider.Registry.get(provider_name) is None
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_complete_url(provider_name)
def test_complete_url_returns_expected_format(self):
complete_url = pipeline.get_complete_url(self.enabled_provider.backend_name)
self.assertTrue(complete_url.startswith('/auth/complete'))
self.assertIn(self.enabled_provider.backend_name, complete_url)
assert complete_url.startswith('/auth/complete')
assert self.enabled_provider.backend_name in complete_url
def test_disconnect_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_name))
assert provider.Registry.get(provider_name) is None
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_disconnect_url(provider_name, 1000)
def test_disconnect_url_returns_expected_format(self):
disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.provider_id, 1000)
disconnect_url = disconnect_url.rstrip('?')
self.assertEqual(
disconnect_url,
'/auth/disconnect/{backend}/{association_id}/'.format(
backend=self.enabled_provider.backend_name, association_id=1000)
)
assert disconnect_url == '/auth/disconnect/{backend}/{association_id}/'\
.format(backend=self.enabled_provider.backend_name, association_id=1000)
def test_login_url_raises_value_error_if_provider_not_enabled(self):
provider_id = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_id))
assert provider.Registry.get(provider_id) is None
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
def test_login_url_returns_expected_format(self):
login_url = pipeline.get_login_url(self.enabled_provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)
self.assertTrue(login_url.startswith('/auth/login'))
self.assertIn(self.enabled_provider.backend_name, login_url)
self.assertTrue(login_url.endswith(pipeline.AUTH_ENTRY_LOGIN))
assert login_url.startswith('/auth/login')
assert self.enabled_provider.backend_name in login_url
assert login_url.endswith(pipeline.AUTH_ENTRY_LOGIN)
def test_for_value_error_if_provider_id_invalid(self):
provider_id = 'invalid' # Format is normally "{prefix}-{identifier}"
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
provider.Registry.get(provider_id)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_disconnect_url(provider_id, 1000)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pipeline.get_complete_url(provider_id)
@@ -242,7 +239,7 @@ class TestPipelineUtilityFunctions(TestCase):
with mock.patch('common.djangoapps.third_party_auth.pipeline.get') as get_pipeline:
get_pipeline.return_value = pipeline_partial
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth)
assert real_social == self.social_auth
def test_get_real_social_auth(self):
"""
@@ -259,7 +256,7 @@ class TestPipelineUtilityFunctions(TestCase):
with mock.patch('common.djangoapps.third_party_auth.pipeline.get') as get_pipeline:
get_pipeline.return_value = pipeline_partial
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth)
assert real_social == self.social_auth
def test_get_real_social_auth_no_pipeline(self):
"""
@@ -268,7 +265,7 @@ class TestPipelineUtilityFunctions(TestCase):
"""
request = mock.MagicMock(session={})
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, None)
assert real_social is None
def test_get_real_social_auth_no_social(self):
"""
@@ -283,7 +280,7 @@ class TestPipelineUtilityFunctions(TestCase):
}
)
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, None)
assert real_social is None
def test_quarantine(self):
"""
@@ -294,12 +291,10 @@ class TestPipelineUtilityFunctions(TestCase):
session={}
)
pipeline.quarantine_session(request, locations=('my_totally_real_module', 'other_real_module',))
self.assertEqual(
request.session['third_party_auth_quarantined_modules'],
('my_totally_real_module', 'other_real_module',),
)
assert request.session['third_party_auth_quarantined_modules'] ==\
('my_totally_real_module', 'other_real_module')
pipeline.lift_quarantine(request)
self.assertNotIn('third_party_auth_quarantined_modules', request.session)
assert 'third_party_auth_quarantined_modules' not in request.session
@ddt.ddt
@@ -597,4 +592,4 @@ class SetIDVerificationStatusTestCase(TestCase):
)
# Ensure a verification signal was sent
self.assertEqual(mock_signal.call_count, 1)
assert mock_signal.call_count == 1

View File

@@ -2,7 +2,6 @@
import re
import unittest # lint-amnesty, pylint: disable=unused-import
from django.contrib.sites.models import Site
from django.db import connections, DEFAULT_DB_ALIAS
@@ -23,37 +22,37 @@ class RegistryTest(testutil.TestCase):
def test_configure_once_adds_gettable_providers(self):
facebook_provider = self.configure_facebook_provider(enabled=True)
self.assertEqual(facebook_provider.id, provider.Registry.get(facebook_provider.provider_id).id)
assert facebook_provider.id == provider.Registry.get(facebook_provider.provider_id).id
def test_no_providers_by_default(self):
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 0, "By default, no providers are enabled.")
assert len(enabled_providers) == 0, 'By default, no providers are enabled.'
def test_runtime_configuration(self):
self.configure_google_provider(enabled=True)
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 1)
self.assertEqual(enabled_providers[0].name, "Google")
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "opensesame")
assert len(enabled_providers) == 1
assert enabled_providers[0].name == 'Google'
assert enabled_providers[0].get_setting('SECRET') == 'opensesame'
self.configure_google_provider(enabled=False)
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 0)
assert len(enabled_providers) == 0
self.configure_google_provider(enabled=True, secret="alohomora")
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 1)
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "alohomora")
assert len(enabled_providers) == 1
assert enabled_providers[0].get_setting('SECRET') == 'alohomora'
def test_secure_configuration(self):
""" Test that some sensitive values can be configured via Django settings """
self.configure_google_provider(enabled=True, secret="")
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 1)
self.assertEqual(enabled_providers[0].name, "Google")
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "")
assert len(enabled_providers) == 1
assert enabled_providers[0].name == 'Google'
assert enabled_providers[0].get_setting('SECRET') == ''
with self.settings(SOCIAL_AUTH_OAUTH_SECRETS={'google-oauth2': 'secret42'}):
self.assertEqual(enabled_providers[0].get_setting("SECRET"), "secret42")
assert enabled_providers[0].get_setting('SECRET') == 'secret42'
def test_cannot_load_arbitrary_backends(self):
""" Test that only backend_names listed in settings.AUTHENTICATION_BACKENDS can be used """
@@ -65,7 +64,7 @@ class RegistryTest(testutil.TestCase):
slug="test",
backend_name="disallowed"
)
self.assertEqual(len(provider.Registry.enabled()), 0)
assert len(provider.Registry.enabled()) == 0
def test_enabled_returns_list_of_enabled_providers_sorted_by_name(self):
provider_names = ["Stack Overflow", "Google", "LinkedIn", "GitHub"]
@@ -76,7 +75,7 @@ class RegistryTest(testutil.TestCase):
self.configure_oauth_provider(enabled=True, name=name, backend_name=backend_name)
with patch('common.djangoapps.third_party_auth.provider._PSA_OAUTH2_BACKENDS', backend_names):
self.assertEqual(sorted(provider_names), [prov.name for prov in provider.Registry.enabled()])
assert sorted(provider_names) == [prov.name for prov in provider.Registry.enabled()]
def test_enabled_doesnt_query_site(self):
"""Regression test for 1+N queries for django_site (ARCHBOM-1139)"""
@@ -90,11 +89,12 @@ class RegistryTest(testutil.TestCase):
with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as cq:
enabled_slugs = {p.slug for p in provider.Registry.enabled()}
self.assertEqual(len(enabled_slugs), provider_count)
assert len(enabled_slugs) == provider_count
# Should not involve any queries for Site, or at least should not *scale* with number of providers
all_queries = [q['sql'] for q in cq.captured_queries]
django_site_queries = list(filter(re_django_site_query.search, all_queries))
self.assertEqual(len(django_site_queries), 0) # previously was == provider_count (1 for each provider)
assert len(django_site_queries) == 0
# previously was == provider_count (1 for each provider)
def test_providers_displayed_for_login(self):
"""
@@ -107,11 +107,11 @@ class RegistryTest(testutil.TestCase):
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
no_log_in_provider = self.configure_lti_provider()
provider_ids = [idp.provider_id for idp in provider.Registry.displayed_for_login()]
self.assertNotIn(hidden_provider.provider_id, provider_ids)
self.assertNotIn(implicitly_hidden_provider.provider_id, provider_ids)
self.assertNotIn(disabled_provider.provider_id, provider_ids)
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
self.assertIn(normal_provider.provider_id, provider_ids)
assert hidden_provider.provider_id not in provider_ids
assert implicitly_hidden_provider.provider_id not in provider_ids
assert disabled_provider.provider_id not in provider_ids
assert no_log_in_provider.provider_id not in provider_ids
assert normal_provider.provider_id in provider_ids
def test_tpa_hint_provider_displayed_for_login(self):
"""
@@ -125,7 +125,7 @@ class RegistryTest(testutil.TestCase):
idp.provider_id
for idp in provider.Registry.displayed_for_login(tpa_hint=hidden_provider.provider_id)
]
self.assertIn(hidden_provider.provider_id, provider_ids)
assert hidden_provider.provider_id in provider_ids
# New providers are hidden (ie, not flagged as 'visible') by default
# The tpa_hint parameter should work for these providers as well
@@ -134,7 +134,7 @@ class RegistryTest(testutil.TestCase):
idp.provider_id
for idp in provider.Registry.displayed_for_login(tpa_hint=implicitly_hidden_provider.provider_id)
]
self.assertIn(implicitly_hidden_provider.provider_id, provider_ids)
assert implicitly_hidden_provider.provider_id in provider_ids
# Disabled providers should not be matched in tpa_hint scenarios
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
@@ -142,7 +142,7 @@ class RegistryTest(testutil.TestCase):
idp.provider_id
for idp in provider.Registry.displayed_for_login(tpa_hint=disabled_provider.provider_id)
]
self.assertNotIn(disabled_provider.provider_id, provider_ids)
assert disabled_provider.provider_id not in provider_ids
# Providers not utilized for learner authentication should not match tpa_hint
no_log_in_provider = self.configure_lti_provider()
@@ -150,14 +150,14 @@ class RegistryTest(testutil.TestCase):
idp.provider_id
for idp in provider.Registry.displayed_for_login(tpa_hint=no_log_in_provider.provider_id)
]
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
assert no_log_in_provider.provider_id not in provider_ids
def test_provider_enabled_for_current_site(self):
"""
Verify that enabled_for_current_site returns True when the provider matches the current site.
"""
prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current())
self.assertEqual(prov.enabled_for_current_site, True)
assert prov.enabled_for_current_site is True
@with_site_configuration(SITE_DOMAIN_A)
def test_provider_disabled_for_mismatching_site(self):
@@ -166,11 +166,11 @@ class RegistryTest(testutil.TestCase):
"""
site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0]
prov = self.configure_google_provider(visible=True, enabled=True, site=site_b)
self.assertEqual(prov.enabled_for_current_site, False)
assert prov.enabled_for_current_site is False
def test_get_returns_enabled_provider(self):
google_provider = self.configure_google_provider(enabled=True)
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
assert google_provider.id == provider.Registry.get(google_provider.provider_id).id
def test_oauth2_provider_keyed_by_slug(self):
"""
@@ -178,15 +178,15 @@ class RegistryTest(testutil.TestCase):
which doesn't match any of the possible backend_names.
"""
google_provider = self.configure_google_provider(enabled=True, slug='custom_slug')
self.assertIn(google_provider, provider.Registry._enabled_providers()) # lint-amnesty, pylint: disable=protected-access
self.assertIn(google_provider, provider.Registry.get_enabled_by_backend_name('google-oauth2'))
assert google_provider in provider.Registry._enabled_providers() # pylint: disable=protected-access
assert google_provider in provider.Registry.get_enabled_by_backend_name('google-oauth2')
def test_oath2_different_slug_from_backend_name(self):
"""
Test that an OAuth2 provider can have a slug that differs from the backend name.
"""
dummy_provider = self.configure_oauth_provider(enabled=True, name="dummy", slug="default", backend_name="dummy")
self.assertIn(dummy_provider, provider.Registry.get_enabled_by_backend_name('dummy'))
assert dummy_provider in provider.Registry.get_enabled_by_backend_name('dummy')
def test_oauth2_enabled_only_for_supplied_backend(self):
"""
@@ -195,32 +195,32 @@ class RegistryTest(testutil.TestCase):
"""
facebook_provider = self.configure_facebook_provider(enabled=True)
self.configure_google_provider(enabled=True)
self.assertNotIn(facebook_provider, provider.Registry.get_enabled_by_backend_name('google-oauth2'))
assert facebook_provider not in provider.Registry.get_enabled_by_backend_name('google-oauth2')
def test_get_returns_none_if_provider_id_is_none(self):
self.assertIsNone(provider.Registry.get(None))
assert provider.Registry.get(None) is None
def test_get_returns_none_if_provider_not_enabled(self):
linkedin_provider_id = "oa2-linkedin-oauth2"
# At this point there should be no configuration entries at all so no providers should be enabled
self.assertEqual(provider.Registry.enabled(), [])
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
assert provider.Registry.enabled() == []
assert provider.Registry.get(linkedin_provider_id) is None
# Now explicitly disabled this provider:
self.configure_linkedin_provider(enabled=False)
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
assert provider.Registry.get(linkedin_provider_id) is None
self.configure_linkedin_provider(enabled=True)
self.assertEqual(provider.Registry.get(linkedin_provider_id).provider_id, linkedin_provider_id)
assert provider.Registry.get(linkedin_provider_id).provider_id == linkedin_provider_id
def test_get_from_pipeline_returns_none_if_provider_not_enabled(self):
self.assertEqual(provider.Registry.enabled(), [], "By default, no providers are enabled.")
self.assertIsNone(provider.Registry.get_from_pipeline(Mock()))
assert provider.Registry.enabled() == [], 'By default, no providers are enabled.'
assert provider.Registry.get_from_pipeline(Mock()) is None
def test_get_enabled_by_backend_name_returns_enabled_provider(self):
google_provider = self.configure_google_provider(enabled=True)
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
self.assertEqual(found, [google_provider])
assert found == [google_provider]
def test_get_enabled_by_backend_name_returns_none_if_provider_not_enabled(self):
google_provider = self.configure_google_provider(enabled=False)
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
self.assertEqual(found, [])
assert found == []

View File

@@ -26,9 +26,9 @@ class TestEdXSAMLIdentityProvider(SAMLTestCase):
u'[THIRD_PARTY_AUTH] Invalid EdXSAMLIdentityProvider subclass--'
u'using EdXSAMLIdentityProvider base class. Provider: {provider}'.format(provider='fake_idp_class_option')
)
self.assertIs(idp_class, EdXSAMLIdentityProvider)
assert idp_class is EdXSAMLIdentityProvider
def test_get_user_details(self):
""" test get_attr and get_user_details of EdXSAMLIdentityProvider"""
edx_saml_identity_provider = EdXSAMLIdentityProvider('demo', **mock_conf)
self.assertEqual(edx_saml_identity_provider.get_user_details(mock_attributes), expected_user_details)
assert edx_saml_identity_provider.get_user_details(mock_attributes) == expected_user_details

View File

@@ -37,33 +37,34 @@ class SettingsUnitTest(testutil.TestCase):
def test_apply_settings_adds_exception_middleware(self):
settings.apply_settings(self.settings)
self.assertIn('common.djangoapps.third_party_auth.middleware.ExceptionMiddleware', self.settings.MIDDLEWARE)
assert 'common.djangoapps.third_party_auth.middleware.ExceptionMiddleware' in self.settings.MIDDLEWARE
def test_apply_settings_adds_fields_stored_in_session(self):
settings.apply_settings(self.settings)
self.assertEqual(['auth_entry', 'next'], self.settings.FIELDS_STORED_IN_SESSION)
assert ['auth_entry', 'next'] == self.settings.FIELDS_STORED_IN_SESSION
@skip_unless_thirdpartyauth()
def test_apply_settings_enables_no_providers_by_default(self):
# Providers are only enabled via ConfigurationModels in the database
settings.apply_settings(self.settings)
self.assertEqual([], provider.Registry.enabled())
assert [] == provider.Registry.enabled()
def test_apply_settings_turns_off_raising_social_exceptions(self):
# Guard against submitting a conf change that's convenient in dev but
# bad in prod.
settings.apply_settings(self.settings)
self.assertFalse(self.settings.SOCIAL_AUTH_RAISE_EXCEPTIONS)
assert not self.settings.SOCIAL_AUTH_RAISE_EXCEPTIONS
def test_apply_settings_turns_off_redirect_sanitization(self):
settings.apply_settings(self.settings)
self.assertFalse(self.settings.SOCIAL_AUTH_SANITIZE_REDIRECTS)
assert not self.settings.SOCIAL_AUTH_SANITIZE_REDIRECTS
def test_apply_settings_avoids_default_username_check(self):
# Avoid the default username check where non-ascii characters are not
# allowed when unicode username is enabled
settings.apply_settings(self.settings)
self.assertTrue(self.settings.SOCIAL_AUTH_CLEAN_USERNAMES) # verify default behavior
assert self.settings.SOCIAL_AUTH_CLEAN_USERNAMES
# verify default behavior
with patch.dict('django.conf.settings.FEATURES', {'ENABLE_UNICODE_USERNAME': True}):
settings.apply_settings(self.settings)
self.assertFalse(self.settings.SOCIAL_AUTH_CLEAN_USERNAMES)
assert not self.settings.SOCIAL_AUTH_CLEAN_USERNAMES

View File

@@ -32,21 +32,11 @@ class TestUtils(TestCase):
"""
# Create users from factory
UserFactory(username='test_user', email='test_user@example.com')
self.assertTrue(
user_exists({'username': 'test_user', 'email': 'test_user@example.com'}),
)
self.assertTrue(
user_exists({'username': 'test_user'}),
)
self.assertTrue(
user_exists({'email': 'test_user@example.com'}),
)
self.assertFalse(
user_exists({'username': 'invalid_user'}),
)
self.assertTrue(
user_exists({'username': 'TesT_User'})
)
assert user_exists({'username': 'test_user', 'email': 'test_user@example.com'})
assert user_exists({'username': 'test_user'})
assert user_exists({'email': 'test_user@example.com'})
assert not user_exists({'username': 'invalid_user'})
assert user_exists({'username': 'TesT_User'})
def test_convert_saml_slug_provider_id(self):
"""
@@ -55,13 +45,9 @@ class TestUtils(TestCase):
provider_names = {'saml-samltest': 'samltest', 'saml-example': 'example'}
for provider_id in provider_names:
# provider_id -> slug
self.assertEqual(
convert_saml_slug_provider_id(provider_id), provider_names[provider_id]
)
assert convert_saml_slug_provider_id(provider_id) == provider_names[provider_id]
# slug -> provider_id
self.assertEqual(
convert_saml_slug_provider_id(provider_names[provider_id]), provider_id
)
assert convert_saml_slug_provider_id(provider_names[provider_id]) == provider_id
def test_get_user(self):
"""

View File

@@ -4,6 +4,7 @@ Test the views served by third_party_auth.
import unittest
import pytest
import ddt
from django.conf import settings
@@ -32,15 +33,15 @@ class SAMLMetadataTest(SAMLTestCase):
""" When SAML is not enabled, the metadata view should return 404 """
self.enable_saml(enabled=False)
response = self.client.get(self.METADATA_URL)
self.assertEqual(response.status_code, 404)
assert response.status_code == 404
def test_metadata(self):
self.enable_saml()
doc = self._fetch_metadata()
# Check the ACS URL:
acs_node = doc.find(".//{}".format(etree.QName(SAML_XML_NS, 'AssertionConsumerService')))
self.assertIsNotNone(acs_node)
self.assertEqual(acs_node.attrib['Location'], 'http://example.none/auth/complete/tpa-saml/')
assert acs_node is not None
assert acs_node.attrib['Location'] == 'http://example.none/auth/complete/tpa-saml/'
def test_default_contact_info(self):
self.enable_saml()
@@ -90,7 +91,7 @@ class SAMLMetadataTest(SAMLTestCase):
private_key='',
other_config_str='{"SECURITY_CONFIG": {"signMetadata": true} }',
)
with self.assertRaises(OneLogin_Saml2_Error):
with pytest.raises(OneLogin_Saml2_Error):
self._fetch_metadata() # OneLogin_Saml2_Error: Cannot sign metadata: missing SP private key.
with self.settings(
SOCIAL_AUTH_SAML_SP_PRIVATE_KEY=self._get_private_key('saml_key'),
@@ -102,40 +103,40 @@ class SAMLMetadataTest(SAMLTestCase):
""" Fetch the SAML metadata and do some validation """
doc = self._fetch_metadata()
sig_node = doc.find(".//{}".format(etree.QName(XMLDSIG_XML_NS, 'SignatureValue')))
self.assertIsNotNone(sig_node)
assert sig_node is not None
# Check that the right public key was used:
pub_key_node = doc.find(".//{}".format(etree.QName(XMLDSIG_XML_NS, 'X509Certificate')))
self.assertIsNotNone(pub_key_node)
self.assertIn(pub_key_starts_with, pub_key_node.text)
assert pub_key_node is not None
assert pub_key_starts_with in pub_key_node.text
def _fetch_metadata(self):
""" Fetch and parse the metadata XML at self.METADATA_URL """
response = self.client.get(self.METADATA_URL)
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'], 'text/xml')
assert response.status_code == 200
assert response['Content-Type'] == 'text/xml'
# The result should be valid XML:
try:
metadata_doc = etree.fromstring(response.content)
except etree.LxmlError:
self.fail('SAML metadata must be valid XML')
self.assertEqual(metadata_doc.tag, etree.QName(SAML_XML_NS, 'EntityDescriptor'))
assert metadata_doc.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor')
return metadata_doc
def check_metadata_contacts(self, xml, tech_name, tech_email, support_name, support_email):
""" Validate that the contact info in the metadata has the expected values """
technical_node = xml.find(".//{}[@contactType='technical']".format(etree.QName(SAML_XML_NS, 'ContactPerson')))
self.assertIsNotNone(technical_node)
assert technical_node is not None
tech_name_node = technical_node.find(etree.QName(SAML_XML_NS, 'GivenName'))
self.assertEqual(tech_name_node.text, tech_name)
assert tech_name_node.text == tech_name
tech_email_node = technical_node.find(etree.QName(SAML_XML_NS, 'EmailAddress'))
self.assertEqual(tech_email_node.text, tech_email)
assert tech_email_node.text == tech_email
support_node = xml.find(".//{}[@contactType='support']".format(etree.QName(SAML_XML_NS, 'ContactPerson')))
self.assertIsNotNone(support_node)
assert support_node is not None
support_name_node = support_node.find(etree.QName(SAML_XML_NS, 'GivenName'))
self.assertEqual(support_name_node.text, support_name)
assert support_name_node.text == support_name
support_email_node = support_node.find(etree.QName(SAML_XML_NS, 'EmailAddress'))
self.assertEqual(support_email_node.text, support_email)
assert support_email_node.text == support_email
@unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
@@ -149,13 +150,13 @@ class SAMLAuthTest(SAMLTestCase):
""" Accessing the login endpoint without an idp query param should return 302 """
self.enable_saml()
response = self.client.get(self.LOGIN_URL)
self.assertEqual(response.status_code, 302)
assert response.status_code == 302
def test_login_disabled(self):
""" When SAML is not enabled, the login view should return 404 """
self.enable_saml(enabled=False)
response = self.client.get(self.LOGIN_URL)
self.assertEqual(response.status_code, 404)
assert response.status_code == 404
@unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
@@ -179,15 +180,15 @@ class IdPRedirectViewTest(SAMLTestCase):
response = self.client.get(endpoint_url)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url)
assert response.status_code == 302
assert response.url == expected_url
def test_with_invalid_provider_slug(self):
endpoint_url = self.get_idp_redirect_url('saml-test-invalid')
response = self.client.get(endpoint_url)
self.assertEqual(response.status_code, 404)
assert response.status_code == 404
@staticmethod
def get_idp_redirect_url(provider_slug, next_destination=None):

View File

@@ -85,10 +85,8 @@ class ThirdPartyAuthTestMixin(object):
def configure_saml_provider(self, **kwargs):
""" Update the settings for a SAML-based third party auth provider """
self.assertTrue(
SAMLConfiguration.is_enabled(Site.objects.get_current(), 'default'),
"SAML Provider Configuration only works if SAML is enabled."
)
assert SAMLConfiguration.is_enabled(Site.objects.get_current(), 'default'), \
'SAML Provider Configuration only works if SAML is enabled.'
obj = SAMLProviderConfig(**kwargs)
obj.save()
return obj