reaplced unittest assertions pytest assertions (#26241)

This commit is contained in:
Aarif
2021-02-04 09:55:55 +05:00
committed by GitHub
parent 22eb29dabe
commit 277d52982c
3 changed files with 29 additions and 46 deletions

View File

@@ -60,7 +60,7 @@ class DOTAdapterMixin(object):
request.session = {}
view = DOTAccessTokenExchangeView.as_view()
response = view(request, backend='facebook')
self.assertEqual(response.status_code, 400)
assert response.status_code == 400
@expectedFailure
def test_single_access_token(self):

View File

@@ -39,17 +39,14 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
def _assert_error(self, data, expected_error, expected_error_description): # lint-amnesty, pylint: disable=arguments-differ
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
self.assertEqual(
form.errors,
{"error": expected_error, "error_description": expected_error_description}
)
assert form.errors == {'error': expected_error, 'error_description': expected_error_description}
def _assert_success(self, data, expected_scopes):
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
self.assertTrue(form.is_valid())
self.assertEqual(form.cleaned_data["user"], self.user)
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
self.assertEqual(set(form.cleaned_data["scope"]), set(expected_scopes))
assert form.is_valid()
assert form.cleaned_data['user'] == self.user
assert form.cleaned_data['client'] == self.oauth_client
assert set(form.cleaned_data['scope']) == set(expected_scopes)
# This is necessary because cms does not implement third party auth

View File

@@ -44,37 +44,31 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
def _assert_error(self, data, expected_error, expected_error_description, error_code=None):
response = self.csrf_client.post(self.url, data)
self.assertEqual(response.status_code, error_code if error_code else 400)
self.assertEqual(response["Content-Type"], "application/json")
assert response.status_code == (error_code if error_code else 400)
assert response['Content-Type'] == 'application/json'
expected_data = {u"error": expected_error, u"error_description": expected_error_description}
if error_code:
expected_data['error_code'] = error_code
self.assertEqual(
json.loads(response.content.decode('utf-8')),
expected_data
)
assert json.loads(response.content.decode('utf-8')) == expected_data
def _assert_success(self, data, expected_scopes):
response = self.csrf_client.post(self.url, data)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/json")
assert response.status_code == 200
assert response['Content-Type'] == 'application/json'
content = json.loads(response.content.decode('utf-8'))
self.assertEqual(set(content.keys()), self.get_token_response_keys())
self.assertEqual(content["token_type"], "Bearer")
self.assertLessEqual(
timedelta(seconds=int(content["expires_in"])),
timedelta(days=30)
)
assert set(content.keys()) == self.get_token_response_keys()
assert content['token_type'] == 'Bearer'
assert timedelta(seconds=int(content['expires_in'])) <= timedelta(days=30)
actual_scopes = content["scope"]
if actual_scopes:
actual_scopes = actual_scopes.split(' ')
else:
actual_scopes = []
self.assertEqual(set(actual_scopes), set(expected_scopes))
assert set(actual_scopes) == set(expected_scopes)
token = self.oauth2_adapter.get_access_token(token_string=content["access_token"])
self.assertEqual(token.user, self.user)
self.assertEqual(self.oauth2_adapter.get_client_for_token(token), self.oauth_client)
self.assertEqual(set(self.oauth2_adapter.get_token_scope_names(token)), set(expected_scopes))
assert token.user == self.user
assert self.oauth2_adapter.get_client_for_token(token) == self.oauth_client
assert set(self.oauth2_adapter.get_token_scope_names(token)) == set(expected_scopes)
def test_single_access_token(self):
def extract_token(response):
@@ -91,28 +85,20 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
):
first_response = self.client.post(self.url, self.data)
second_response = self.client.post(self.url, self.data)
self.assertEqual(first_response.status_code, 200)
self.assertEqual(second_response.status_code, 200)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
assert first_response.status_code == 200
assert second_response.status_code == 200
assert (extract_token(first_response) == extract_token(second_response)) == single_access_token
def test_get_method(self):
response = self.client.get(self.url, self.data)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.content.decode('utf-8')),
{
"error": "invalid_request",
"error_description": "Only POST requests allowed.",
}
)
assert response.status_code == 400
assert json.loads(response.content.decode('utf-8')) ==\
{'error': 'invalid_request', 'error_description': 'Only POST requests allowed.'}
def test_invalid_provider(self):
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
response = self.client.post(url, self.data)
self.assertEqual(response.status_code, 404)
assert response.status_code == 404
@pytest.mark.skip(reason="this is very entangled with dop use in third_party_auth")
def test_invalid_client(self):
@@ -186,9 +172,9 @@ class TestLoginWithAccessTokenView(TestCase):
"""
url = reverse("login_with_access_token")
response = self.client.post(url, HTTP_AUTHORIZATION=u"Bearer {0}".format(access_token).encode('utf-8'))
self.assertEqual(response.status_code, expected_status_code)
assert response.status_code == expected_status_code
if expected_cookie_name:
self.assertIn(expected_cookie_name, response.cookies)
assert expected_cookie_name in response.cookies
def _create_dot_access_token(self, grant_type='Client credentials'):
"""
@@ -199,13 +185,13 @@ class TestLoginWithAccessTokenView(TestCase):
def test_invalid_token(self):
self._verify_response("invalid_token", expected_status_code=401)
self.assertNotIn("session_key", self.client.session)
assert 'session_key' not in self.client.session
def test_dot_password_grant_supported(self):
access_token = self._create_dot_access_token(grant_type='password')
self._verify_response(access_token, expected_status_code=204, expected_cookie_name='sessionid')
self.assertEqual(int(self.client.session['_auth_user_id']), self.user.id)
assert int(self.client.session['_auth_user_id']) == self.user.id
def test_dot_client_credentials_unsupported(self):
access_token = self._create_dot_access_token()