diff --git a/common/djangoapps/third_party_auth/api/tests/test_permissions.py b/common/djangoapps/third_party_auth/api/tests/test_permissions.py index cc4d00c3e4..87528cf0e0 100644 --- a/common/djangoapps/third_party_auth/api/tests/test_permissions.py +++ b/common/djangoapps/third_party_auth/api/tests/test_permissions.py @@ -54,7 +54,7 @@ class ThirdPartyAuthPermissionTest(TestCase): def test_anonymous_fails(self): request = self._create_request() response = self.SomeTpaClassView().dispatch(request) - self.assertEqual(response.status_code, 401) + assert response.status_code == 401 @ddt.data( (True, False, 200), @@ -68,7 +68,7 @@ class ThirdPartyAuthPermissionTest(TestCase): self._create_session(request, user) response = self.SomeTpaClassView().dispatch(request) - self.assertEqual(response.status_code, expected_status_code) + assert response.status_code == expected_status_code @ddt.data( # unrestricted (for example, jwt cookies) @@ -97,7 +97,7 @@ class ThirdPartyAuthPermissionTest(TestCase): ) response = self.SomeTpaClassView().dispatch(request) - self.assertEqual(response.status_code, expected_response) + assert response.status_code == expected_response @ddt.data( # valid scopes @@ -152,4 +152,4 @@ class ThirdPartyAuthPermissionTest(TestCase): request = self._create_request(auth_header=auth_header) response = self.SomeTpaClassView().dispatch(request, provider_id='some_tpa_provider') - self.assertEqual(response.status_code, expected_response) + assert response.status_code == expected_response diff --git a/common/djangoapps/third_party_auth/api/tests/test_views.py b/common/djangoapps/third_party_auth/api/tests/test_views.py index ed138d4fec..6a57edef0b 100644 --- a/common/djangoapps/third_party_auth/api/tests/test_views.py +++ b/common/djangoapps/third_party_auth/api/tests/test_views.py @@ -131,9 +131,9 @@ class UserViewsMixin(object): url = self.make_url({'username': target_user}) response = self.client.get(url) - self.assertEqual(response.status_code, expect_result) + assert response.status_code == expect_result if expect_result == 200: - self.assertIn("active", response.data) + assert 'active' in response.data six.assertCountEqual(self, response.data["active"], self.expected_active(target_user)) @ddt.data( @@ -147,9 +147,9 @@ class UserViewsMixin(object): def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result): url = self.make_url({'username': target_user}) response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key) - self.assertEqual(response.status_code, expect_result) + assert response.status_code == expect_result if expect_result == 200: - self.assertIn("active", response.data) + assert 'active' in response.data six.assertCountEqual(self, response.data["active"], self.expected_active(target_user)) @ddt.data( @@ -164,18 +164,18 @@ class UserViewsMixin(object): with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged): url = self.make_url({'username': ALICE_USERNAME}) response = self.client.get(url) - self.assertEqual(response.status_code, expect) + assert response.status_code == expect if response.status_code == 200: - self.assertGreater(len(response.data['active']), 0) + assert len(response.data['active']) > 0 for provider_data in response.data['active']: - self.assertEqual(include_remote_id, 'remote_id' in provider_data) + assert include_remote_id == ('remote_id' in provider_data) def test_allow_query_by_email(self): self.client.login(username=ALICE_USERNAME, password=PASSWORD) url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)}) response = self.client.get(url) - self.assertEqual(response.status_code, 200) - self.assertGreater(len(response.data['active']), 0) + assert response.status_code == 200 + assert len(response.data['active']) > 0 def test_throttling(self): # Default throttle is 10/min. Make 11 requests to verify @@ -185,9 +185,9 @@ class UserViewsMixin(object): with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True): for _ in range(10): response = self.client.get(url) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 response = self.client.get(url) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 @override_settings(EDX_API_KEY=VALID_API_KEY) @@ -332,7 +332,7 @@ class UserMappingViewAPITests(TpaAPITestCase): url = reverse('third_party_auth_user_mapping_api', kwargs={'provider_id': PROVIDER_ID_TESTSHIB}) response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 self._verify_response(response, 200, get_mapping_data_by_usernames(LINKED_USERS)) @ddt.data( @@ -346,14 +346,14 @@ class UserMappingViewAPITests(TpaAPITestCase): with patch.object(JwtRestrictedApplication, 'has_permission', return_value=has_permission): with patch.object(JwtHasScope, 'has_permission', return_value=has_permission): response = self.client.get(url) - self.assertEqual(response.status_code, expect) + assert response.status_code == expect def _verify_response(self, response, expect_code, expect_result): """ verify the items in data_list exists in response and data_results matches results in response """ - self.assertEqual(response.status_code, expect_code) + assert response.status_code == expect_code if expect_code == 200: for item in ['results', 'count', 'num_pages']: - self.assertIn(item, response.data) + assert item in response.data six.assertCountEqual(self, response.data['results'], expect_result) @@ -375,15 +375,11 @@ class TestThirdPartyAuthUserStatusView(ThirdPartyAuthTestMixin, APITestCase): """ self.client.login(username=self.user.username, password=PASSWORD) response = self.client.get(self.url, content_type="application/json") - self.assertEqual(response.status_code, 200) - self.assertEqual( - response.data, - [{ - "accepts_logins": True, - "name": "Google", - "disconnect_url": "/auth/disconnect/google-oauth2/?", - "connect_url": "/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings", - "connected": False, - "id": "oa2-google-oauth2" - }] - ) + assert response.status_code == 200 + assert (response.data == + [{ + 'accepts_logins': True, 'name': 'Google', + 'disconnect_url': '/auth/disconnect/google-oauth2/?', + 'connect_url': '/auth/login/google-oauth2/?auth_entry=account_settings&next=%2Faccount%2Fsettings', + 'connected': False, 'id': 'oa2-google-oauth2' + }]) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_remove_social_auth_users.py b/common/djangoapps/third_party_auth/management/commands/tests/test_remove_social_auth_users.py index 8b5b4d4fd1..e809e1c1e4 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_remove_social_auth_users.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_remove_social_auth_users.py @@ -7,6 +7,7 @@ import sys import unittest from contextlib import contextmanager from uuid import uuid4 +import pytest from django.conf import settings from django.core.management import call_command @@ -69,16 +70,16 @@ class TestRemoveSocialAuthUsersCommand(TestCase): call_command(self.command, self.provider_hogwarts.slug, force=True) # user with input idp is removed, along with social auth entries - with self.assertRaises(User.DoesNotExist): + with pytest.raises(User.DoesNotExist): User.objects.get(username='harry') - with self.assertRaises(UserSocialAuth.DoesNotExist): + with pytest.raises(UserSocialAuth.DoesNotExist): self.find_user_social_auth_entry('harry') # other users intact self.user_fleur.refresh_from_db() self.user_viktor.refresh_from_db() - self.assertIsNotNone(self.user_fleur) - self.assertIsNotNone(self.user_viktor) + assert self.user_fleur is not None + assert self.user_viktor is not None # other social auth intact self.find_user_social_auth_entry(self.user_viktor.username) @@ -96,9 +97,9 @@ class TestRemoveSocialAuthUsersCommand(TestCase): with self._replace_stdin('confirm'): call_command(self.command, self.provider_hogwarts.slug) - with self.assertRaises(User.DoesNotExist): + with pytest.raises(User.DoesNotExist): User.objects.get(username='harry') - with self.assertRaises(UserSocialAuth.DoesNotExist): + with pytest.raises(UserSocialAuth.DoesNotExist): self.find_user_social_auth_entry('harry') @override_settings(FEATURES=FEATURES_WITH_ENABLED) @@ -109,8 +110,8 @@ class TestRemoveSocialAuthUsersCommand(TestCase): call_command(self.command, self.provider_hogwarts.slug) # no users should be removed - self.assertEqual(len(User.objects.all()), 3) - self.assertEqual(len(UserSocialAuth.objects.all()), 2) + assert len(User.objects.all()) == 3 + assert len(UserSocialAuth.objects.all()) == 2 def test_feature_default_disabled(self): """ By default this command should not be enabled """ diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index 1d65a215ea..1c752a7ec4 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -105,7 +105,7 @@ class TestSAMLCommand(CacheIsolationTestCase): """ expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n" call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get", mock_get()) def test_fetch_saml_metadata(self): @@ -118,7 +118,7 @@ class TestSAMLCommand(CacheIsolationTestCase): expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n" call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get", mock_get(status_code=404)) def test_fetch_saml_metadata_failure(self): @@ -133,7 +133,7 @@ class TestSAMLCommand(CacheIsolationTestCase): with self.assertRaisesRegex(CommandError, r"HTTPError: 404 Client Error"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get", mock_get(status_code=200)) def test_fetch_multiple_providers_data(self): @@ -178,7 +178,7 @@ class TestSAMLCommand(CacheIsolationTestCase): expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n' with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() # Now add a fourth configuration, and indicate that it should not be included in the update self.__create_saml_configurations__( @@ -201,7 +201,7 @@ class TestSAMLCommand(CacheIsolationTestCase): expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n' with self.assertRaisesRegex(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get") def test_saml_request_exceptions(self, mocked_get): @@ -217,19 +217,19 @@ class TestSAMLCommand(CacheIsolationTestCase): with self.assertRaisesRegex(CommandError, "SSLError:"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() mocked_get.side_effect = exceptions.ConnectionError with self.assertRaisesRegex(CommandError, "ConnectionError:"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() mocked_get.side_effect = exceptions.HTTPError with self.assertRaisesRegex(CommandError, "HTTPError:"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get", mock_get(status_code=200)) def test_saml_parse_exceptions(self): @@ -254,7 +254,7 @@ class TestSAMLCommand(CacheIsolationTestCase): with self.assertRaisesRegex(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() @mock.patch("requests.get") def test_xml_parse_exceptions(self, mocked_get): @@ -274,4 +274,4 @@ class TestSAMLCommand(CacheIsolationTestCase): with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"): call_command("saml", pull=True, stdout=self.stdout) - self.assertIn(expected, self.stdout.getvalue()) + assert expected in self.stdout.getvalue() diff --git a/common/djangoapps/third_party_auth/saml_configuration/tests/test_saml_configuration.py b/common/djangoapps/third_party_auth/saml_configuration/tests/test_saml_configuration.py index b42b5b2b97..49e57ca546 100644 --- a/common/djangoapps/third_party_auth/saml_configuration/tests/test_saml_configuration.py +++ b/common/djangoapps/third_party_auth/saml_configuration/tests/test_saml_configuration.py @@ -83,19 +83,19 @@ class SAMLConfigurationTests(APITestCase): def test_get_saml_configurations_successful(self): url = reverse('saml_configuration-list') response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_200_OK) + assert response.status_code == status.HTTP_200_OK # We ultimately just need ids and slugs, so let's just check those. results = response.data['results'] - self.assertEqual(results[0]['id'], SAML_CONFIGURATIONS[0]['site']) - self.assertEqual(results[0]['slug'], SAML_CONFIGURATIONS[0]['slug']) - self.assertEqual(results[1]['id'], SAML_CONFIGURATIONS[1]['site']) - self.assertEqual(results[1]['slug'], SAML_CONFIGURATIONS[1]['slug']) + assert results[0]['id'] == SAML_CONFIGURATIONS[0]['site'] + assert results[0]['slug'] == SAML_CONFIGURATIONS[0]['slug'] + assert results[1]['id'] == SAML_CONFIGURATIONS[1]['site'] + assert results[1]['slug'] == SAML_CONFIGURATIONS[1]['slug'] def test_get_saml_configurations_noprivate(self): # Verify we have 3 saml configuration objects: 2 public, 1 private. total_object_count = SAMLConfiguration.objects.count() - self.assertEqual(total_object_count, 3) + assert total_object_count == 3 url = reverse('saml_configuration-list') response = self.client.get(url, format='json') @@ -103,10 +103,10 @@ class SAMLConfigurationTests(APITestCase): # We should only see 2 results, since 1 out of 3 are private # and our queryset only returns public configurations. results = response.data['results'] - self.assertEqual(len(results), 2) + assert len(results) == 2 def test_unauthenticated_user_get_saml_configurations(self): self.client.logout() url = reverse('saml_configuration-list') response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py index 49f9c86923..91ef4c2b2a 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py @@ -91,13 +91,13 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_200_OK) + assert response.status_code == status.HTTP_200_OK results = response.data['results'] - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['entity_id'], SINGLE_PROVIDER_CONFIG['entity_id']) - self.assertEqual(results[0]['metadata_source'], SINGLE_PROVIDER_CONFIG['metadata_source']) - self.assertEqual(response.data['results'][0]['country'], SINGLE_PROVIDER_CONFIG['country']) - self.assertEqual(SAMLProviderConfig.objects.count(), 1) + assert len(results) == 1 + assert results[0]['entity_id'] == SINGLE_PROVIDER_CONFIG['entity_id'] + assert results[0]['metadata_source'] == SINGLE_PROVIDER_CONFIG['metadata_source'] + assert response.data['results'][0]['country'] == SINGLE_PROVIDER_CONFIG['country'] + assert SAMLProviderConfig.objects.count() == 1 def test_get_one_config_by_enterprise_uuid_invalid_uuid(self): """ @@ -109,7 +109,7 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + assert response.status_code == status.HTTP_400_BAD_REQUEST def test_get_one_config_by_enterprise_uuid_not_found(self): """ @@ -128,8 +128,8 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(SAMLProviderConfig.objects.count(), orig_count) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert SAMLProviderConfig.objects.count() == orig_count def test_create_one_config(self): """ @@ -142,19 +142,14 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(SAMLProviderConfig.objects.count(), orig_count + 1) + assert response.status_code == status.HTTP_201_CREATED + assert SAMLProviderConfig.objects.count() == (orig_count + 1) provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug']) - self.assertEqual(provider_config.name, 'name-of-config-2') - self.assertEqual(provider_config.country, SINGLE_PROVIDER_CONFIG_2['country']) + assert provider_config.name == 'name-of-config-2' + assert provider_config.country == SINGLE_PROVIDER_CONFIG_2['country'] # check association has also been created - self.assertTrue( - EnterpriseCustomerIdentityProvider.objects.filter( - provider_id=convert_saml_slug_provider_id(provider_config.slug) - ).exists(), - 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association' - ) + assert EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(provider_config.slug)).exists(), 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association' def test_create_one_config_fail_non_existent_enterprise_uuid(self): """ @@ -167,16 +162,11 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(SAMLProviderConfig.objects.count(), orig_count) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert SAMLProviderConfig.objects.count() == orig_count # check association has NOT been created - self.assertFalse( - EnterpriseCustomerIdentityProvider.objects.filter( - provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug']) - ).exists(), - 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association' - ) + assert not EnterpriseCustomerIdentityProvider.objects.filter(provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])).exists(), 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association' def test_create_one_config_with_absent_enterprise_uuid(self): """ @@ -188,8 +178,8 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(SAMLProviderConfig.objects.count(), orig_count) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert SAMLProviderConfig.objects.count() == orig_count def test_create_one_config_with_no_country_urn(self): """ @@ -206,9 +196,9 @@ class SAMLProviderConfigTests(APITestCase): } response = self.client.post(url, provider_config_no_country) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + assert response.status_code == status.HTTP_201_CREATED provider_config = SAMLProviderConfig.objects.get(slug='test-slug-none') - self.assertEqual(provider_config.country, '') + assert provider_config.country == '' def test_create_one_config_with_empty_country_urn(self): """ @@ -226,9 +216,9 @@ class SAMLProviderConfigTests(APITestCase): } response = self.client.post(url, provider_config_blank_country) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + assert response.status_code == status.HTTP_201_CREATED provider_config = SAMLProviderConfig.objects.get(slug='test-slug-empty') - self.assertEqual(provider_config.country, '') + assert provider_config.country == '' def test_unauthenticated_request_is_forbidden(self): self.client.logout() @@ -237,12 +227,12 @@ class SAMLProviderConfigTests(APITestCase): url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)]) response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + assert response.status_code == status.HTTP_403_FORBIDDEN self.client.logout() set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)]) response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + assert response.status_code == status.HTTP_403_FORBIDDEN def test_create_one_config_with_samlconfiguration(self): """ @@ -255,6 +245,6 @@ class SAMLProviderConfigTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + assert response.status_code == status.HTTP_201_CREATED provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug']) - self.assertEqual(provider_config.saml_configuration, self.samlconfiguration) + assert provider_config.saml_configuration == self.samlconfiguration diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py index 37437b0581..ad9d14770f 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py @@ -88,10 +88,10 @@ class SAMLProviderDataTests(APITestCase): response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_200_OK) + assert response.status_code == status.HTTP_200_OK results = response.data['results'] - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['sso_url'], SINGLE_PROVIDER_DATA['sso_url']) + assert len(results) == 1 + assert results[0]['sso_url'] == SINGLE_PROVIDER_DATA['sso_url'] def test_create_one_provider_data_success(self): # POST auth/saml/v0/providerdata/ -d data @@ -102,12 +102,9 @@ class SAMLProviderDataTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(SAMLProviderData.objects.count(), orig_count + 1) - self.assertEqual( - SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url, - SINGLE_PROVIDER_DATA_2['sso_url'] - ) + assert response.status_code == status.HTTP_201_CREATED + assert SAMLProviderData.objects.count() == (orig_count + 1) + assert SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url == SINGLE_PROVIDER_DATA_2['sso_url'] def test_create_one_data_with_absent_enterprise_uuid(self): """ @@ -119,8 +116,8 @@ class SAMLProviderDataTests(APITestCase): response = self.client.post(url, data) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(SAMLProviderData.objects.count(), orig_count) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert SAMLProviderData.objects.count() == orig_count def test_patch_one_provider_data(self): # PATCH auth/saml/v0/providerdata/ -d data @@ -133,14 +130,14 @@ class SAMLProviderDataTests(APITestCase): response = self.client.patch(url, data) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(SAMLProviderData.objects.count(), orig_count) + assert response.status_code == status.HTTP_200_OK + assert SAMLProviderData.objects.count() == orig_count # ensure only the sso_url was updated fetched_provider_data = SAMLProviderData.objects.get(pk=self.saml_provider_data.id) - self.assertEqual(fetched_provider_data.sso_url, 'http://new.url') - self.assertEqual(fetched_provider_data.fetched_at, SINGLE_PROVIDER_DATA['fetched_at']) - self.assertEqual(fetched_provider_data.entity_id, SINGLE_PROVIDER_DATA['entity_id']) + assert fetched_provider_data.sso_url == 'http://new.url' + assert fetched_provider_data.fetched_at == SINGLE_PROVIDER_DATA['fetched_at'] + assert fetched_provider_data.entity_id == SINGLE_PROVIDER_DATA['entity_id'] def test_delete_one_provider_data(self): # DELETE auth/saml/v0/providerdata/ -d data @@ -151,12 +148,12 @@ class SAMLProviderDataTests(APITestCase): response = self.client.delete(url) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(SAMLProviderData.objects.count(), orig_count - 1) + assert response.status_code == status.HTTP_204_NO_CONTENT + assert SAMLProviderData.objects.count() == (orig_count - 1) # ensure only the sso_url was updated query_set_count = SAMLProviderData.objects.filter(pk=self.saml_provider_data.id).count() - self.assertEqual(query_set_count, 0) + assert query_set_count == 0 def test_get_one_provider_data_failure(self): set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)]) @@ -167,7 +164,7 @@ class SAMLProviderDataTests(APITestCase): response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + assert response.status_code == status.HTTP_404_NOT_FOUND def test_unauthenticated_request_is_forbidden(self): self.client.logout() @@ -176,10 +173,10 @@ class SAMLProviderDataTests(APITestCase): url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) set_jwt_cookie(self.client, self.user, [(ENTERPRISE_LEARNER_ROLE, ENTERPRISE_ID)]) response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + assert response.status_code == status.HTTP_403_FORBIDDEN # manually running second case as DDT is having issues. self.client.logout() set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, BAD_ENTERPRISE_ID)]) response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/common/djangoapps/third_party_auth/tests/specs/base.py b/common/djangoapps/third_party_auth/tests/specs/base.py index faaea58cf7..66ad493202 100644 --- a/common/djangoapps/third_party_auth/tests/specs/base.py +++ b/common/djangoapps/third_party_auth/tests/specs/base.py @@ -6,6 +6,7 @@ Base integration test for provider implementations. import json import unittest from contextlib import contextmanager +import pytest import mock from django import test @@ -55,8 +56,8 @@ class HelperMixin(object): implementations may optionally strengthen this assertion with, for example, more details about the format of the Location header. """ - self.assertEqual(302, response.status_code) - self.assertTrue(response.has_header('Location')) + assert 302 == response.status_code + assert response.has_header('Location') def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields): # lint-amnesty, pylint: disable=invalid-name """Performs spot checks of the rendered register.html page. @@ -95,16 +96,16 @@ class HelperMixin(object): its connected state is correct. """ if duplicate: - self.assertEqual(context['duplicate_provider'], self.provider.backend_name) + assert context['duplicate_provider'] == self.provider.backend_name else: - self.assertIsNone(context['duplicate_provider']) + assert context['duplicate_provider'] is None if linked is not None: expected_provider = [ provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name ][0] - self.assertIsNotNone(expected_provider) - self.assertEqual(expected_provider['connected'], linked) + assert expected_provider is not None + assert expected_provider['connected'] == linked def assert_exception_redirect_looks_correct(self, expected_uri, auth_entry=None): """Tests middleware conditional redirection. @@ -118,45 +119,45 @@ class HelperMixin(object): request, exceptions.AuthCanceled(request.backend)) location = response.get('Location') - self.assertEqual(302, response.status_code) - self.assertIn('canceled', location) - self.assertIn(self.backend_name, location) - self.assertTrue(location.startswith(expected_uri + '?')) + assert 302 == response.status_code + assert 'canceled' in location + assert self.backend_name in location + assert location.startswith((expected_uri + '?')) def assert_json_failure_response_is_inactive_account(self, response): """Asserts failure on /login for inactive account looks right.""" - self.assertEqual(400, response.status_code) + assert 400 == response.status_code payload = json.loads(response.content.decode('utf-8')) context = { 'platformName': configuration_helpers.get_value('PLATFORM_NAME', settings.PLATFORM_NAME), 'supportLink': configuration_helpers.get_value('SUPPORT_SITE_LINK', settings.SUPPORT_SITE_LINK) } - self.assertFalse(payload.get('success')) - self.assertIn('inactive-user', payload.get('error_code')) - self.assertEqual(context, payload.get('context')) + assert not payload.get('success') + assert 'inactive-user' in payload.get('error_code') + assert context == payload.get('context') def assert_json_failure_response_is_missing_social_auth(self, response): """Asserts failure on /login for missing social auth looks right.""" - self.assertEqual(403, response.status_code) + assert 403 == response.status_code payload = json.loads(response.content.decode('utf-8')) - self.assertFalse(payload.get('success')) - self.assertEqual(payload.get('error_code'), 'third-party-auth-with-no-linked-account') + assert not payload.get('success') + assert payload.get('error_code') == 'third-party-auth-with-no-linked-account' def assert_json_failure_response_is_username_collision(self, response): """Asserts the json response indicates a username collision.""" - self.assertEqual(409, response.status_code) + assert 409 == response.status_code payload = json.loads(response.content.decode('utf-8')) - self.assertFalse(payload.get('success')) - self.assertIn('belongs to an existing account', payload['username'][0]['user_message']) + assert not payload.get('success') + assert 'belongs to an existing account' in payload['username'][0]['user_message'] def assert_json_success_response_looks_correct(self, response, verify_redirect_url): """Asserts the json response indicates success and redirection.""" - self.assertEqual(200, response.status_code) + assert 200 == response.status_code payload = json.loads(response.content.decode('utf-8')) - self.assertTrue(payload.get('success')) + assert payload.get('success') if verify_redirect_url: - self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url')) + assert pipeline.get_complete_url(self.provider.backend_name) == payload.get('redirect_url') def assert_login_response_before_pipeline_looks_correct(self, response): """Asserts a GET of /login not in the pipeline looks correct.""" @@ -167,37 +168,36 @@ class HelperMixin(object): def assert_login_response_in_pipeline_looks_correct(self, response): """Asserts a GET of /login in the pipeline looks correct.""" - self.assertEqual(200, response.status_code) + assert 200 == response.status_code def assert_password_overridden_by_pipeline(self, username, password): """Verifies that the given password is not correct. The pipeline overrides POST['password'], if any, with random data. """ - self.assertIsNone(auth.authenticate(password=password, username=username)) + assert auth.authenticate(password=password, username=username) is None def assert_pipeline_running(self, request): """Makes sure the given request is running an auth pipeline.""" - self.assertTrue(pipeline.running(request)) + assert pipeline.running(request) def assert_redirect_after_pipeline_completes(self, response, expected_redirect_url=None): """Asserts a response would redirect to the expected_redirect_url or SOCIAL_AUTH_LOGIN_REDIRECT_URL.""" - self.assertEqual(302, response.status_code) + assert 302 == response.status_code # NOTE: Ideally we should use assertRedirects(), however it errors out due to the hostname, testserver, # not being properly set. This may be an issue with the call made by PSA, but we are not certain. - self.assertTrue(response.get('Location').endswith( - expected_redirect_url or django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL, - )) + assert response.get('Location').endswith((expected_redirect_url or + django_settings.SOCIAL_AUTH_LOGIN_REDIRECT_URL)) def assert_redirect_to_login_looks_correct(self, response): """Asserts a response would redirect to /login.""" - self.assertEqual(302, response.status_code) - self.assertEqual('/login', response.get('Location')) + assert 302 == response.status_code + assert '/login' == response.get('Location') def assert_redirect_to_register_looks_correct(self, response): """Asserts a response would redirect to /register.""" - self.assertEqual(302, response.status_code) - self.assertEqual('/register', response.get('Location')) + assert 302 == response.status_code + assert '/register' == response.get('Location') def assert_register_response_before_pipeline_looks_correct(self, response): """Asserts a GET of /register not in the pipeline looks correct.""" @@ -210,24 +210,21 @@ class HelperMixin(object): """Asserts a user does not have an auth with the expected provider.""" social_auths = strategy.storage.user.get_social_auth_for_user( user, provider=self.provider.backend_name) - self.assertEqual(0, len(social_auths)) + assert 0 == len(social_auths) def assert_social_auth_exists_for_user(self, user, strategy): """Asserts a user has a social auth with the expected provider.""" social_auths = strategy.storage.user.get_social_auth_for_user( user, provider=self.provider.backend_name) - self.assertEqual(1, len(social_auths)) - self.assertEqual(self.backend_name, social_auths[0].provider) + assert 1 == len(social_auths) + assert self.backend_name == social_auths[0].provider def assert_logged_in_cookie_redirect(self, response): """Verify that the user was redirected in order to set the logged in cookie. """ - self.assertEqual(response.status_code, 302) - self.assertEqual( - response["Location"], - pipeline.get_complete_url(self.provider.backend_name) - ) - self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true') - self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies) + assert response.status_code == 302 + assert response['Location'] == pipeline.get_complete_url(self.provider.backend_name) + assert response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value == 'true' + assert django_settings.EDXMKTG_USER_INFO_COOKIE_NAME in response.cookies @property def backend_name(self): @@ -384,24 +381,24 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): # The user clicks on the Dummy button: try_login_response = self.client.get(provider_register_url) # The user should be redirected to the provider's login page: - self.assertEqual(try_login_response.status_code, 302) + assert try_login_response.status_code == 302 provider_response = self.do_provider_login(try_login_response['Location']) # We should be redirected to the register screen since this account is not linked to an edX account: - self.assertEqual(provider_response.status_code, 302) - self.assertEqual(provider_response['Location'], self.register_page_url) + assert provider_response.status_code == 302 + assert provider_response['Location'] == self.register_page_url register_response = self.client.get(self.register_page_url) tpa_context = register_response.context["data"]["third_party_auth"] - self.assertEqual(tpa_context["errorMessage"], None) + assert tpa_context['errorMessage'] is None # Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown. - self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME) + assert tpa_context['currentProvider'] == self.PROVIDER_NAME # Check that the data (e.g. email) from the provider is displayed in the form: form_data = register_response.context['data']['registration_form_desc'] form_fields = {field['name']: field for field in form_data['fields']} - self.assertEqual(form_fields['email']['defaultValue'], self.USER_EMAIL) - self.assertEqual(form_fields['name']['defaultValue'], self.USER_NAME) - self.assertEqual(form_fields['username']['defaultValue'], self.USER_USERNAME) + assert form_fields['email']['defaultValue'] == self.USER_EMAIL + assert form_fields['name']['defaultValue'] == self.USER_NAME + assert form_fields['username']['defaultValue'] == self.USER_USERNAME for field_name, value in extra_defaults.items(): - self.assertEqual(form_fields[field_name]['defaultValue'], value) + assert form_fields[field_name]['defaultValue'] == value registration_values = { 'email': 'email-edited@tpa-test.none', 'name': 'My Customized Name', @@ -413,12 +410,12 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): reverse('user_api_registration'), registration_values ) - self.assertEqual(ajax_register_response.status_code, 200) + assert ajax_register_response.status_code == 200 # Then the AJAX will finish the third party auth: continue_response = self.client.get(tpa_context["finishAuthUrl"]) # And we should be redirected to the dashboard: - self.assertEqual(continue_response.status_code, 302) - self.assertEqual(continue_response['Location'], reverse('dashboard')) + assert continue_response.status_code == 302 + assert continue_response['Location'] == reverse('dashboard') # Now check that we can login again, whether or not we have yet verified the account: self.client.logout() @@ -437,28 +434,28 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): # The user clicks on the provider's button: try_login_response = self.client.get(provider_login_url) # The user should be redirected to the provider's login page: - self.assertEqual(try_login_response.status_code, 302) + assert try_login_response.status_code == 302 complete_response = self.do_provider_login(try_login_response['Location']) # We should be redirected to the login screen since this account is not linked to an edX account: - self.assertEqual(complete_response.status_code, 302) - self.assertEqual(complete_response['Location'], self.login_page_url) + assert complete_response.status_code == 302 + assert complete_response['Location'] == self.login_page_url login_response = self.client.get(self.login_page_url) tpa_context = login_response.context["data"]["third_party_auth"] - self.assertEqual(tpa_context["errorMessage"], None) + assert tpa_context['errorMessage'] is None # Check that the "You've successfully signed into [PROVIDER_NAME]" message is shown. - self.assertEqual(tpa_context["currentProvider"], self.PROVIDER_NAME) + assert tpa_context['currentProvider'] == self.PROVIDER_NAME # Now the user enters their username and password. # The AJAX on the page will log them in: ajax_login_response = self.client.post( reverse('user_api_login_session'), {'email': self.user.email, 'password': 'test'} ) - self.assertEqual(ajax_login_response.status_code, 200) + assert ajax_login_response.status_code == 200 # Then the AJAX will finish the third party auth: continue_response = self.client.get(tpa_context["finishAuthUrl"]) # And we should be redirected to the dashboard: - self.assertEqual(continue_response.status_code, 302) - self.assertEqual(continue_response['Location'], reverse('dashboard')) + assert continue_response.status_code == 302 + assert continue_response['Location'] == reverse('dashboard') # Now check that we can login again: self.client.logout() @@ -475,30 +472,30 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): """ Test logging in to an account that is already linked. """ # Make sure we're not logged in: dashboard_response = self.client.get(reverse('dashboard')) - self.assertEqual(dashboard_response.status_code, 302) + assert dashboard_response.status_code == 302 # The user goes to the login page, and sees a button to login with this provider: provider_login_url = self._check_login_page() # The user clicks on the provider's login button: try_login_response = self.client.get(provider_login_url) # The user should be redirected to the provider: - self.assertEqual(try_login_response.status_code, 302) + assert try_login_response.status_code == 302 login_response = self.do_provider_login(try_login_response['Location']) # If the previous session was manually logged out, there will be one weird redirect # required to set the login cookie (it sticks around if the main session times out): if not previous_session_timed_out: - self.assertEqual(login_response.status_code, 302) - self.assertEqual(login_response['Location'], self.complete_url + "?") + assert login_response.status_code == 302 + assert login_response['Location'] == (self.complete_url + '?') # And then we should be redirected to the dashboard: login_response = self.client.get(login_response['Location']) - self.assertEqual(login_response.status_code, 302) + assert login_response.status_code == 302 if user_is_activated: url_expected = reverse('dashboard') else: url_expected = reverse('third_party_inactive_redirect') + '?next=' + reverse('dashboard') - self.assertEqual(login_response['Location'], url_expected) + assert login_response['Location'] == url_expected # Now we are logged in: dashboard_response = self.client.get(reverse('dashboard')) - self.assertEqual(dashboard_response.status_code, 200) + assert dashboard_response.status_code == 200 def _check_login_page(self): """ @@ -520,7 +517,7 @@ class IntegrationTestMixin(testutil.TestCase, test.TestCase, HelperMixin): self.assertContains(response, self.PROVIDER_NAME) context_data = response.context['data']['third_party_auth'] provider_urls = {provider['id']: provider[url_to_return] for provider in context_data['providers']} - self.assertIn(self.PROVIDER_ID, provider_urls) + assert self.PROVIDER_ID in provider_urls return provider_urls[self.PROVIDER_ID] @property @@ -668,7 +665,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): self.assert_social_auth_exists_for_user(linked_user, strategy) self.assert_social_auth_does_not_exist_for_user(unlinked_user, strategy) - with self.assertRaises(exceptions.AuthAlreadyAssociated): + with pytest.raises(exceptions.AuthAlreadyAssociated): # pylint: disable=protected-access actions.do_complete(backend, social_views._do_login, user=unlinked_user, request=strategy.request) @@ -720,7 +717,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): partial_data = strategy.storage.partial.load(partial_pipeline_token) self.assert_social_auth_exists_for_user(user, strategy) - self.assertTrue(user.is_active) + assert user.is_active # Begin! Ensure that the login form contains expected controls before # the user starts the pipeline. @@ -862,7 +859,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): strategy.request.POST = self.get_registration_post_vars({'email': email}) # The user must not exist yet... - with self.assertRaises(auth_models.User.DoesNotExist): + with pytest.raises(auth_models.User.DoesNotExist): self.get_user_by_email(strategy, email) # ...but when we invoke create_account the existing edX view will make @@ -906,7 +903,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): self.assert_redirect_to_login_looks_correct(actions.do_complete(backend, social_views._do_login, request=request)) distinct_username = pipeline.get(request)['kwargs']['username'] - self.assertNotEqual(original_username, distinct_username) + assert original_username != distinct_username def test_new_account_registration_fails_if_email_exists(self): request, strategy = self.get_request_and_strategy( @@ -932,17 +929,17 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): def test_pipeline_raises_auth_entry_error_if_auth_entry_invalid(self): auth_entry = 'invalid' - self.assertNotIn(auth_entry, pipeline._AUTH_ENTRY_CHOICES) # pylint: disable=protected-access + assert auth_entry not in pipeline._AUTH_ENTRY_CHOICES # pylint: disable=protected-access _, strategy = self.get_request_and_strategy(auth_entry=auth_entry, redirect_uri='social:complete') - with self.assertRaises(pipeline.AuthEntryError): + with pytest.raises(pipeline.AuthEntryError): strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy)) def test_pipeline_assumes_login_if_auth_entry_missing(self): _, strategy = self.get_request_and_strategy(auth_entry=None, redirect_uri='social:complete') response = self.fake_auth_complete(strategy) - self.assertEqual(response.url, reverse('signin_user')) + assert response.url == reverse('signin_user') def assert_first_party_auth_trumps_third_party_auth(self, email=None, password=None, success=None): """Asserts first party auth was used in place of third party auth. @@ -974,15 +971,15 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin): if success is None: # Request malformed -- just one of email/password given. - self.assertFalse(payload.get('success')) - self.assertIn('There was an error receiving your login information', payload.get('value')) + assert not payload.get('success') + assert 'There was an error receiving your login information' in payload.get('value') elif success: # Request well-formed and credentials good. - self.assertTrue(payload.get('success')) + assert payload.get('success') else: # Request well-formed but credentials bad. - self.assertFalse(payload.get('success')) - self.assertIn('incorrect', payload.get('value')) + assert not payload.get('success') + assert 'incorrect' in payload.get('value') def get_response_data(self): """Gets a dict of response data of the form given by the provider. diff --git a/common/djangoapps/third_party_auth/tests/specs/test_generic.py b/common/djangoapps/third_party_auth/tests/specs/test_generic.py index e896700b6e..bb05947e63 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_generic.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_generic.py @@ -31,5 +31,5 @@ class GenericIntegrationTest(IntegrationTestMixin, testutil.TestCase): Mock logging in to the Dummy provider """ # For the Dummy provider, the provider redirect URL is self.complete_url - self.assertEqual(provider_redirect_url, self.url_prefix + self.complete_url) + assert provider_redirect_url == (self.url_prefix + self.complete_url) return self.client.get(provider_redirect_url) diff --git a/common/djangoapps/third_party_auth/tests/specs/test_google.py b/common/djangoapps/third_party_auth/tests/specs/test_google.py index 1941481cde..c3cf331c92 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_google.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_google.py @@ -51,7 +51,7 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty, def assert_redirect_to_provider_looks_correct(self, response): super(GoogleOauth2IntegrationTest, self).assert_redirect_to_provider_looks_correct(response) # lint-amnesty, pylint: disable=super-with-arguments - self.assertIn('google.com', response['Location']) + assert 'google.com' in response['Location'] def test_custom_form(self): """ @@ -75,27 +75,20 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty, with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete): response = self.client.get(complete_url) # This should redirect to the custom login/register form: - self.assertEqual(response.status_code, 302) - self.assertEqual(response['Location'], '/auth/custom_auth_entry') + assert response.status_code == 302 + assert response['Location'] == '/auth/custom_auth_entry' response = self.client.get(response['Location']) - self.assertEqual(response.status_code, 200) - self.assertIn('action="/misc/my-custom-registration-form" method="post"', response.content.decode('utf-8')) + assert response.status_code == 200 + assert 'action="/misc/my-custom-registration-form" method="post"' in response.content.decode('utf-8') data_decoded = base64.b64decode(response.context['data']).decode('utf-8') data_parsed = json.loads(data_decoded) # The user's details get passed to the custom page as a base64 encoded query parameter: - self.assertEqual(data_parsed, { - 'auth_entry': 'custom1', - 'backend_name': 'google-oauth2', - 'provider_id': 'oa2-google-oauth2', - 'user_details': { - 'username': 'user', - 'email': 'user@email.com', - 'fullname': 'name_value', - 'first_name': 'given_name_value', - 'last_name': 'family_name_value', - }, - }) + assert data_parsed == {'auth_entry': 'custom1', 'backend_name': 'google-oauth2', + 'provider_id': 'oa2-google-oauth2', + 'user_details': {'username': 'user', 'email': 'user@email.com', + 'fullname': 'name_value', 'first_name': 'given_name_value', + 'last_name': 'family_name_value'}} # Check the hash that is used to confirm the user's data in the GET parameter is correct secret_key = settings.THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS['custom1']['secret_key'] hmac_expected = hmac.new( @@ -103,18 +96,18 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty, msg=data_decoded.encode('utf-8'), digestmod=hashlib.sha256 ).digest() - self.assertEqual(base64.b64decode(response.context['hmac']), hmac_expected) + assert base64.b64decode(response.context['hmac']) == hmac_expected # Now our custom registration form creates or logs in the user: email, password = data_parsed['user_details']['email'], 'random_password' created_user = UserFactory(email=email, password=password) login_response = self.client.post(reverse('login_api'), {'email': email, 'password': password}) - self.assertEqual(login_response.status_code, 200) + assert login_response.status_code == 200 # Now our custom login/registration page must resume the pipeline: response = self.client.get(complete_url) - self.assertEqual(response.status_code, 302) - self.assertEqual(response['Location'], '/misc/final-destination') + assert response.status_code == 302 + assert response['Location'] == '/misc/final-destination' _, strategy = self.get_request_and_strategy() self.assert_social_auth_exists_for_user(created_user, strategy) @@ -140,5 +133,5 @@ class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest): # lint-amnesty, with patch.object(self.provider.backend_class, 'auth_complete', fake_auth_complete_error): response = self.client.get(complete_url) # This should redirect to the custom error URL - self.assertEqual(response.status_code, 302) - self.assertEqual(response['Location'], '/misc/my-custom-sso-error-page') + assert response.status_code == 302 + assert response['Location'] == '/misc/my-custom-sso-error-page' diff --git a/common/djangoapps/third_party_auth/tests/specs/test_lti.py b/common/djangoapps/third_party_auth/tests/specs/test_lti.py index 4547604cb1..adfc070f7f 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_lti.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_lti.py @@ -70,8 +70,8 @@ class IntegrationTestLTI(testutil.TestCase): ) login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body) # The user should be redirected to the registration form - self.assertEqual(login_response.status_code, 302) - self.assertTrue(login_response['Location'].endswith(reverse('signin_user'))) + assert login_response.status_code == 302 + assert login_response['Location'].endswith(reverse('signin_user')) register_response = self.client.get(login_response['Location']) self.assertContains(register_response, '"currentProvider": "LTI Test Tool Consumer"') self.assertContains(register_response, '"errorMessage": null') @@ -86,15 +86,12 @@ class IntegrationTestLTI(testutil.TestCase): 'honor_code': True, } ) - self.assertEqual(ajax_register_response.status_code, 200) + assert ajax_register_response.status_code == 200 continue_response = self.client.get(self.url_prefix + LTI_TPA_COMPLETE_URL) # The user should be redirected to the finish_auth view which will enroll them. # FinishAuthView.js reads the URL parameters directly from $.url - self.assertEqual(continue_response.status_code, 302) - self.assertEqual( - continue_response['Location'], - '/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll' - ) + assert continue_response.status_code == 302 + assert continue_response['Location'] == '/account/finish_auth/?course_id=my_course_id&enrollment_action=enroll' # Now check that we can login again self.client.logout() @@ -106,19 +103,20 @@ class IntegrationTestLTI(testutil.TestCase): ) login_2_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body) # The user should be redirected to the dashboard - self.assertEqual(login_2_response.status_code, 302) - self.assertEqual(login_2_response['Location'], LTI_TPA_COMPLETE_URL + "?") + assert login_2_response.status_code == 302 + assert login_2_response['Location'] == (LTI_TPA_COMPLETE_URL + '?') continue_2_response = self.client.get(login_2_response['Location']) - self.assertEqual(continue_2_response.status_code, 302) - self.assertTrue(continue_2_response['Location'].endswith(reverse('dashboard'))) + assert continue_2_response.status_code == 302 + assert continue_2_response['Location'].endswith(reverse('dashboard')) # Check that the user was created correctly user = User.objects.get(email=EMAIL) - self.assertEqual(user.username, EDX_USER_ID) + assert user.username == EDX_USER_ID def test_reject_initiating_login(self): response = self.client.get(self.url_prefix + LTI_TPA_LOGIN_URL) - self.assertEqual(response.status_code, 405) # Not Allowed + assert response.status_code == 405 + # Not Allowed def test_reject_bad_login(self): login_response = self.client.post( @@ -127,8 +125,8 @@ class IntegrationTestLTI(testutil.TestCase): ) # The user should be redirected to the login page with an error message # (auth_entry defaults to login for this provider) - self.assertEqual(login_response.status_code, 302) - self.assertTrue(login_response['Location'].endswith(reverse('signin_user'))) + assert login_response.status_code == 302 + assert login_response['Location'].endswith(reverse('signin_user')) error_response = self.client.get(login_response['Location']) self.assertContains( error_response, @@ -152,8 +150,8 @@ class IntegrationTestLTI(testutil.TestCase): with self.settings(SOCIAL_AUTH_LTI_CONSUMER_SECRETS={OTHER_LTI_CONSUMER_KEY: OTHER_LTI_CONSUMER_SECRET}): login_response = self.client.post(path=uri, content_type=FORM_ENCODED, data=body) # The user should be redirected to the registration form - self.assertEqual(login_response.status_code, 302) - self.assertTrue(login_response['Location'].endswith(reverse('signin_user'))) + assert login_response.status_code == 302 + assert login_response['Location'].endswith(reverse('signin_user')) register_response = self.client.get(login_response['Location']) self.assertContains( register_response, diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 8d63f7206f..12290c4ada 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -114,21 +114,21 @@ class SamlIntegrationTestUtilities(object): saml_provider = self.configure_saml_provider(**kwargs) # pylint: disable=no-member if fetch_metadata: - self.assertTrue(httpretty.is_enabled()) # lint-amnesty, pylint: disable=no-member + assert httpretty.is_enabled() # lint-amnesty, pylint: disable=no-member num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata() if assert_metadata_updates: - self.assertEqual(num_total, 1) # lint-amnesty, pylint: disable=no-member - self.assertEqual(num_skipped, 0) # lint-amnesty, pylint: disable=no-member - self.assertEqual(num_attempted, 1) # lint-amnesty, pylint: disable=no-member - self.assertEqual(num_updated, 1) # lint-amnesty, pylint: disable=no-member - self.assertEqual(num_failed, 0) # lint-amnesty, pylint: disable=no-member - self.assertEqual(len(failure_messages), 0) # lint-amnesty, pylint: disable=no-member + assert num_total == 1 # lint-amnesty, pylint: disable=no-member + assert num_skipped == 0 # lint-amnesty, pylint: disable=no-member + assert num_attempted == 1 # lint-amnesty, pylint: disable=no-member + assert num_updated == 1 # lint-amnesty, pylint: disable=no-member + assert num_failed == 0 # lint-amnesty, pylint: disable=no-member + assert len(failure_messages) == 0 # lint-amnesty, pylint: disable=no-member return saml_provider def do_provider_login(self, provider_redirect_url): """ Mocked: the user logs in to TestShib and then gets redirected back """ # The SAML provider (TestShib) will authenticate the user, then get the browser to POST a response: - self.assertTrue(provider_redirect_url.startswith(TESTSHIB_SSO_URL)) # lint-amnesty, pylint: disable=no-member + assert provider_redirect_url.startswith(TESTSHIB_SSO_URL) # lint-amnesty, pylint: disable=no-member saml_response_xml = utils.read_and_pre_process_xml( os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'testshib_saml_response.xml') @@ -188,9 +188,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin enterprise_customer = EnterpriseCustomerFactory() assert EnterpriseCustomerUser.objects.count() == 0, "Precondition check: no link records should exist" EnterpriseCustomerUser.objects.link_user(enterprise_customer, user.email) - self.assertTrue( # lint-amnesty, pylint: disable=wrong-assert-type - EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1 - ) + assert (EnterpriseCustomerUser.objects + .filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 1) EnterpriseCustomerIdentityProvider.objects.get_or_create(enterprise_customer=enterprise_customer, provider_id=self.provider.provider_id) @@ -229,10 +228,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin redirect_field_name=auth.REDIRECT_FIELD_NAME, request=request ) - self.assertNotEqual( - EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(), - 0 - ) + assert EnterpriseCustomerUser.objects\ + .filter(enterprise_customer=enterprise_customer, user_id=user.id).count() != 0 # Fire off the disconnect pipeline to unlink. self.assert_redirect_after_pipeline_completes( @@ -247,10 +244,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin # Now we expect to be in the unlinked state, with no backend entry. self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False) self.assert_social_auth_does_not_exist_for_user(user, strategy) - self.assertEqual( - EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_customer, user_id=user.id).count(), - 0 - ) + assert EnterpriseCustomerUser.objects\ + .filter(enterprise_customer=enterprise_customer, user_id=user.id).count() == 0 def get_response_data(self): """Gets dict (string -> object) of merged data about the user.""" @@ -269,8 +264,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin # The user clicks on the TestShib button: try_login_response = self.client.get(testshib_login_url) # The user should be redirected to back to the login page: - self.assertEqual(try_login_response.status_code, 302) - self.assertEqual(try_login_response['Location'], self.login_page_url) + assert try_login_response.status_code == 302 + assert try_login_response['Location'] == self.login_page_url # When loading the login page, the user will see an error message: response = self.client.get(self.login_page_url) self.assertContains(response, 'Authentication with TestShib is currently unavailable.') @@ -294,12 +289,11 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin user=self.user, provider=self.PROVIDER_BACKEND, uid__startswith=self.PROVIDER_IDP_SLUG ) attributes = record.extra_data - self.assertEqual( - attributes.get("urn:oid:1.3.6.1.4.1.5923.1.1.1.9"), ["Member@testshib.org", "Staff@testshib.org"] - ) - self.assertEqual(attributes.get("urn:oid:2.5.4.3"), ["Me Myself And I"]) - self.assertEqual(attributes.get("urn:oid:0.9.2342.19200300.100.1.1"), ["myself"]) - self.assertEqual(attributes.get("urn:oid:2.5.4.20"), ["555-5555"]) # Phone number + assert attributes.get('urn:oid:1.3.6.1.4.1.5923.1.1.1.9') == ['Member@testshib.org', 'Staff@testshib.org'] + assert attributes.get('urn:oid:2.5.4.3') == ['Me Myself And I'] + assert attributes.get('urn:oid:0.9.2342.19200300.100.1.1') == ['myself'] + assert attributes.get('urn:oid:2.5.4.20') == ['555-5555'] + # Phone number @ddt.data(True, False) def test_debug_mode_login(self, debug_mode_enabled): @@ -310,30 +304,30 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin if debug_mode_enabled: # We expect that test_login() does two full logins, and each attempt generates two # logs - one for the request and one for the response - self.assertEqual(mock_log.call_count, 4) + assert mock_log.call_count == 4 expected_next_url = "/dashboard" (msg, action_type, idp_name, request_data, next_url, xml), _kwargs = mock_log.call_args_list[0] - self.assertTrue(msg.startswith(u"SAML login %s")) - self.assertEqual(action_type, "request") - self.assertEqual(idp_name, self.PROVIDER_IDP_SLUG) + assert msg.startswith(u'SAML login %s') + assert action_type == 'request' + assert idp_name == self.PROVIDER_IDP_SLUG self.assertDictContainsSubset( {"idp": idp_name, "auth_entry": "login", "next": expected_next_url}, request_data ) - self.assertEqual(next_url, expected_next_url) - self.assertIn('