reaplced unittest assertions pytest assertions (#26241)
This commit is contained in:
@@ -60,7 +60,7 @@ class DOTAdapterMixin(object):
|
||||
request.session = {}
|
||||
view = DOTAccessTokenExchangeView.as_view()
|
||||
response = view(request, backend='facebook')
|
||||
self.assertEqual(response.status_code, 400)
|
||||
assert response.status_code == 400
|
||||
|
||||
@expectedFailure
|
||||
def test_single_access_token(self):
|
||||
|
||||
@@ -39,17 +39,14 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
|
||||
|
||||
def _assert_error(self, data, expected_error, expected_error_description): # lint-amnesty, pylint: disable=arguments-differ
|
||||
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
|
||||
self.assertEqual(
|
||||
form.errors,
|
||||
{"error": expected_error, "error_description": expected_error_description}
|
||||
)
|
||||
assert form.errors == {'error': expected_error, 'error_description': expected_error_description}
|
||||
|
||||
def _assert_success(self, data, expected_scopes):
|
||||
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
|
||||
self.assertTrue(form.is_valid())
|
||||
self.assertEqual(form.cleaned_data["user"], self.user)
|
||||
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
|
||||
self.assertEqual(set(form.cleaned_data["scope"]), set(expected_scopes))
|
||||
assert form.is_valid()
|
||||
assert form.cleaned_data['user'] == self.user
|
||||
assert form.cleaned_data['client'] == self.oauth_client
|
||||
assert set(form.cleaned_data['scope']) == set(expected_scopes)
|
||||
|
||||
|
||||
# This is necessary because cms does not implement third party auth
|
||||
|
||||
@@ -44,37 +44,31 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
|
||||
|
||||
def _assert_error(self, data, expected_error, expected_error_description, error_code=None):
|
||||
response = self.csrf_client.post(self.url, data)
|
||||
self.assertEqual(response.status_code, error_code if error_code else 400)
|
||||
self.assertEqual(response["Content-Type"], "application/json")
|
||||
assert response.status_code == (error_code if error_code else 400)
|
||||
assert response['Content-Type'] == 'application/json'
|
||||
expected_data = {u"error": expected_error, u"error_description": expected_error_description}
|
||||
if error_code:
|
||||
expected_data['error_code'] = error_code
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
expected_data
|
||||
)
|
||||
assert json.loads(response.content.decode('utf-8')) == expected_data
|
||||
|
||||
def _assert_success(self, data, expected_scopes):
|
||||
response = self.csrf_client.post(self.url, data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "application/json")
|
||||
assert response.status_code == 200
|
||||
assert response['Content-Type'] == 'application/json'
|
||||
content = json.loads(response.content.decode('utf-8'))
|
||||
self.assertEqual(set(content.keys()), self.get_token_response_keys())
|
||||
self.assertEqual(content["token_type"], "Bearer")
|
||||
self.assertLessEqual(
|
||||
timedelta(seconds=int(content["expires_in"])),
|
||||
timedelta(days=30)
|
||||
)
|
||||
assert set(content.keys()) == self.get_token_response_keys()
|
||||
assert content['token_type'] == 'Bearer'
|
||||
assert timedelta(seconds=int(content['expires_in'])) <= timedelta(days=30)
|
||||
actual_scopes = content["scope"]
|
||||
if actual_scopes:
|
||||
actual_scopes = actual_scopes.split(' ')
|
||||
else:
|
||||
actual_scopes = []
|
||||
self.assertEqual(set(actual_scopes), set(expected_scopes))
|
||||
assert set(actual_scopes) == set(expected_scopes)
|
||||
token = self.oauth2_adapter.get_access_token(token_string=content["access_token"])
|
||||
self.assertEqual(token.user, self.user)
|
||||
self.assertEqual(self.oauth2_adapter.get_client_for_token(token), self.oauth_client)
|
||||
self.assertEqual(set(self.oauth2_adapter.get_token_scope_names(token)), set(expected_scopes))
|
||||
assert token.user == self.user
|
||||
assert self.oauth2_adapter.get_client_for_token(token) == self.oauth_client
|
||||
assert set(self.oauth2_adapter.get_token_scope_names(token)) == set(expected_scopes)
|
||||
|
||||
def test_single_access_token(self):
|
||||
def extract_token(response):
|
||||
@@ -91,28 +85,20 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
|
||||
):
|
||||
first_response = self.client.post(self.url, self.data)
|
||||
second_response = self.client.post(self.url, self.data)
|
||||
self.assertEqual(first_response.status_code, 200)
|
||||
self.assertEqual(second_response.status_code, 200)
|
||||
self.assertEqual(
|
||||
extract_token(first_response) == extract_token(second_response),
|
||||
single_access_token
|
||||
)
|
||||
assert first_response.status_code == 200
|
||||
assert second_response.status_code == 200
|
||||
assert (extract_token(first_response) == extract_token(second_response)) == single_access_token
|
||||
|
||||
def test_get_method(self):
|
||||
response = self.client.get(self.url, self.data)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST requests allowed.",
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert json.loads(response.content.decode('utf-8')) ==\
|
||||
{'error': 'invalid_request', 'error_description': 'Only POST requests allowed.'}
|
||||
|
||||
def test_invalid_provider(self):
|
||||
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
|
||||
response = self.client.post(url, self.data)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.skip(reason="this is very entangled with dop use in third_party_auth")
|
||||
def test_invalid_client(self):
|
||||
@@ -186,9 +172,9 @@ class TestLoginWithAccessTokenView(TestCase):
|
||||
"""
|
||||
url = reverse("login_with_access_token")
|
||||
response = self.client.post(url, HTTP_AUTHORIZATION=u"Bearer {0}".format(access_token).encode('utf-8'))
|
||||
self.assertEqual(response.status_code, expected_status_code)
|
||||
assert response.status_code == expected_status_code
|
||||
if expected_cookie_name:
|
||||
self.assertIn(expected_cookie_name, response.cookies)
|
||||
assert expected_cookie_name in response.cookies
|
||||
|
||||
def _create_dot_access_token(self, grant_type='Client credentials'):
|
||||
"""
|
||||
@@ -199,13 +185,13 @@ class TestLoginWithAccessTokenView(TestCase):
|
||||
|
||||
def test_invalid_token(self):
|
||||
self._verify_response("invalid_token", expected_status_code=401)
|
||||
self.assertNotIn("session_key", self.client.session)
|
||||
assert 'session_key' not in self.client.session
|
||||
|
||||
def test_dot_password_grant_supported(self):
|
||||
access_token = self._create_dot_access_token(grant_type='password')
|
||||
|
||||
self._verify_response(access_token, expected_status_code=204, expected_cookie_name='sessionid')
|
||||
self.assertEqual(int(self.client.session['_auth_user_id']), self.user.id)
|
||||
assert int(self.client.session['_auth_user_id']) == self.user.id
|
||||
|
||||
def test_dot_client_credentials_unsupported(self):
|
||||
access_token = self._create_dot_access_token()
|
||||
|
||||
Reference in New Issue
Block a user