replaced unittest assertions pytest assertions (#26242)
This commit is contained in:
@@ -5,6 +5,7 @@ Tests the ``edx_clear_expired_tokens`` management command.
|
||||
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
import pytest
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
@@ -43,7 +44,7 @@ class EdxClearExpiredTokensTests(TestCase):
|
||||
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz')
|
||||
def test_invalid_expiration_time(self):
|
||||
with LogCapture(LOGGER_NAME) as log:
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
with pytest.raises(ImproperlyConfigured):
|
||||
call_command('edx_clear_expired_tokens')
|
||||
log.check(
|
||||
(
|
||||
@@ -79,7 +80,7 @@ class EdxClearExpiredTokensTests(TestCase):
|
||||
'Cleaning 0 rows from Grant table',
|
||||
)
|
||||
)
|
||||
self.assertTrue(RefreshToken.objects.filter(application=application).exists())
|
||||
assert RefreshToken.objects.filter(application=application).exists()
|
||||
|
||||
@override_settings()
|
||||
def test_clear_expired_tokens(self):
|
||||
@@ -91,15 +92,12 @@ class EdxClearExpiredTokensTests(TestCase):
|
||||
for user in users:
|
||||
application = factories.ApplicationFactory(user=user)
|
||||
factories.AccessTokenFactory(user=user, application=application, expires=expires)
|
||||
self.assertEqual(
|
||||
AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(),
|
||||
initial_count
|
||||
)
|
||||
assert AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count() == initial_count
|
||||
original_delete = QuerySet.delete
|
||||
QuerySet.delete = counter(QuerySet.delete)
|
||||
try:
|
||||
call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0)
|
||||
self.assertEqual(QuerySet.delete.invocations, initial_count)
|
||||
self.assertEqual(AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), 0)
|
||||
assert not QuerySet.delete.invocations != initial_count # pylint: disable=no-member
|
||||
assert AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count() == 0
|
||||
finally:
|
||||
QuerySet.delete = original_delete
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""
|
||||
Tests the ``create_dot_application`` management command.
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
import ddt
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
@@ -49,41 +48,41 @@ class TestCreateDotApplication(TestCase):
|
||||
call_args = base_call_args + [URI_OLD]
|
||||
call_command(Command(), *call_args)
|
||||
app = Application.objects.get(name=APP_NAME)
|
||||
self.assertEqual(app.redirect_uris, URI_OLD)
|
||||
with self.assertRaises(ApplicationAccess.DoesNotExist):
|
||||
assert app.redirect_uris == URI_OLD
|
||||
with pytest.raises(ApplicationAccess.DoesNotExist):
|
||||
ApplicationAccess.objects.get(application_id=app.id)
|
||||
|
||||
# Make sure we can call again with no changes
|
||||
call_args = base_call_args + [URI_OLD]
|
||||
call_command(Command(), *call_args)
|
||||
app = Application.objects.get(name=APP_NAME)
|
||||
self.assertEqual(app.redirect_uris, URI_OLD)
|
||||
with self.assertRaises(ApplicationAccess.DoesNotExist):
|
||||
assert app.redirect_uris == URI_OLD
|
||||
with pytest.raises(ApplicationAccess.DoesNotExist):
|
||||
ApplicationAccess.objects.get(application_id=app.id)
|
||||
|
||||
# Make sure calling with new URI changes URI, but does not add access
|
||||
call_args = base_call_args + [URI_NEW]
|
||||
call_command(Command(), *call_args)
|
||||
app = Application.objects.get(name=APP_NAME)
|
||||
self.assertEqual(app.redirect_uris, URI_NEW)
|
||||
with self.assertRaises(ApplicationAccess.DoesNotExist):
|
||||
assert app.redirect_uris == URI_NEW
|
||||
with pytest.raises(ApplicationAccess.DoesNotExist):
|
||||
ApplicationAccess.objects.get(application_id=app.id)
|
||||
|
||||
# Make sure calling with scopes adds access
|
||||
call_args = base_call_args + [URI_NEW, "--scopes", ",".join(SCOPES_X)]
|
||||
call_command(Command(), *call_args)
|
||||
app = Application.objects.get(name=APP_NAME)
|
||||
self.assertEqual(app.redirect_uris, URI_NEW)
|
||||
assert app.redirect_uris == URI_NEW
|
||||
access = ApplicationAccess.objects.get(application_id=app.id)
|
||||
self.assertEqual(access.scopes, SCOPES_X)
|
||||
assert access.scopes == SCOPES_X
|
||||
|
||||
# Make sure calling with new scopes changes them
|
||||
call_args = base_call_args + [URI_NEW, "--scopes", ",".join(SCOPES_Y)]
|
||||
call_command(Command(), *call_args)
|
||||
app = Application.objects.get(name=APP_NAME)
|
||||
self.assertEqual(app.redirect_uris, URI_NEW)
|
||||
assert app.redirect_uris == URI_NEW
|
||||
access = ApplicationAccess.objects.get(application_id=app.id)
|
||||
self.assertEqual(access.scopes, SCOPES_Y)
|
||||
assert access.scopes == SCOPES_Y
|
||||
|
||||
@ddt.data(
|
||||
(None, None, None, None, False, None),
|
||||
@@ -122,31 +121,31 @@ class TestCreateDotApplication(TestCase):
|
||||
call_command(Command(), *call_args)
|
||||
|
||||
apps = Application.objects.filter(name='testing_application')
|
||||
self.assertEqual(1, len(apps))
|
||||
assert 1 == len(apps)
|
||||
application = apps[0]
|
||||
self.assertEqual('testing_application', application.name)
|
||||
self.assertEqual(self.user, application.user)
|
||||
self.assertEqual(grant_type, application.authorization_grant_type)
|
||||
self.assertEqual(client_type, application.client_type)
|
||||
self.assertEqual('', application.redirect_uris)
|
||||
self.assertEqual(skip_auth, application.skip_authorization)
|
||||
assert 'testing_application' == application.name
|
||||
assert self.user == application.user
|
||||
assert grant_type == application.authorization_grant_type
|
||||
assert client_type == application.client_type
|
||||
assert '' == application.redirect_uris
|
||||
assert skip_auth == application.skip_authorization
|
||||
|
||||
if client_id:
|
||||
self.assertEqual(client_id, application.client_id)
|
||||
assert client_id == application.client_id
|
||||
if client_secret:
|
||||
self.assertEqual(client_secret, application.client_secret)
|
||||
assert client_secret == application.client_secret
|
||||
|
||||
if scopes:
|
||||
app_access_list = ApplicationAccess.objects.filter(application_id=application.id)
|
||||
self.assertEqual(1, len(app_access_list))
|
||||
assert 1 == len(app_access_list)
|
||||
app_access = app_access_list[0]
|
||||
self.assertEqual(scopes.split(','), app_access.scopes)
|
||||
assert scopes.split(',') == app_access.scopes
|
||||
|
||||
# When called a second time with the same arguments, the command should
|
||||
# exit gracefully without creating a second application.
|
||||
call_command(Command(), *call_args)
|
||||
apps = Application.objects.filter(name='testing_application')
|
||||
self.assertEqual(1, len(apps))
|
||||
assert 1 == len(apps)
|
||||
if scopes:
|
||||
app_access_list = ApplicationAccess.objects.filter(application_id=application.id)
|
||||
self.assertEqual(1, len(app_access_list))
|
||||
assert 1 == len(app_access_list)
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestGenerateJwtSigningKey(TestCase):
|
||||
message in log_entry[0][0]
|
||||
for log_entry in mock_log.call_args_list
|
||||
)
|
||||
self.assertEqual(log_message_exists, expected_to_exist)
|
||||
assert log_message_exists == expected_to_exist
|
||||
|
||||
def _assert_key_output(self, output_stream, filename, strip_key_prefix):
|
||||
expected_in_output = [
|
||||
@@ -53,23 +53,23 @@ class TestGenerateJwtSigningKey(TestCase):
|
||||
'{}JWT_PUBLIC_SIGNING_JWK_SET'.format('' if strip_key_prefix else 'COMMON_'),
|
||||
]
|
||||
for expected in expected_in_output:
|
||||
self.assertIn(expected, output_stream.getvalue())
|
||||
assert expected in output_stream.getvalue()
|
||||
|
||||
with open(filename) as file_obj: # pylint: disable=open-builtin
|
||||
output_from_yaml = yaml.safe_load(file_obj)
|
||||
for expected in expected_in_output:
|
||||
self.assertIn(expected, output_from_yaml['JWT_AUTH'])
|
||||
assert expected in output_from_yaml['JWT_AUTH']
|
||||
|
||||
def _assert_presence_of_old_keys(self, mock_log, add_previous_public_keys):
|
||||
self._assert_log_message(mock_log, 'Old JWT_PUBLIC_SIGNING_JWK_SET', expected_to_exist=add_previous_public_keys)
|
||||
|
||||
def _assert_presence_of_key_id(self, mock_log, output_stream, provide_key_id, key_id_size):
|
||||
if provide_key_id:
|
||||
self.assertIn(TEST_KEY_IDENTIFIER, output_stream.getvalue())
|
||||
assert TEST_KEY_IDENTIFIER in output_stream.getvalue()
|
||||
else:
|
||||
self.assertNotIn(TEST_KEY_IDENTIFIER, output_stream.getvalue())
|
||||
assert TEST_KEY_IDENTIFIER not in output_stream.getvalue()
|
||||
key_id = mock_log.call_args_list[0][0][1]
|
||||
self.assertEqual(len(key_id), key_id_size or 8)
|
||||
assert len(key_id) == (key_id_size or 8)
|
||||
|
||||
@ddt.data(
|
||||
dict(add_previous_public_keys=True, provide_key_id=False, key_id_size=None, strip_key_prefix=True),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
OAuth Dispatch test mixins
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
import jwt
|
||||
from django.conf import settings
|
||||
from jwkest.jwk import KEYS
|
||||
@@ -100,7 +100,7 @@ class AccessTokenMixin(object):
|
||||
# now we should assert that the claim is indeed
|
||||
# expired
|
||||
if should_be_expired:
|
||||
with self.assertRaises(ExpiredSignatureError):
|
||||
with pytest.raises(ExpiredSignatureError):
|
||||
_decode_jwt(verify_expiration=True)
|
||||
|
||||
return payload
|
||||
|
||||
@@ -35,14 +35,14 @@ class TestOAuthDispatchAPI(TestCase):
|
||||
|
||||
def _assert_stored_token(self, stored_token_value, expected_token_user, expected_client):
|
||||
stored_access_token = AccessToken.objects.get(token=stored_token_value)
|
||||
self.assertEqual(stored_access_token.user.id, expected_token_user.id)
|
||||
self.assertEqual(stored_access_token.application.client_id, expected_client.client_id)
|
||||
self.assertEqual(stored_access_token.application.user.id, expected_client.user.id)
|
||||
assert stored_access_token.user.id == expected_token_user.id
|
||||
assert stored_access_token.application.client_id == expected_client.client_id
|
||||
assert stored_access_token.application.user.id == expected_client.user.id
|
||||
|
||||
def test_create_token_success(self):
|
||||
token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
|
||||
self.assertTrue(token['access_token'])
|
||||
self.assertTrue(token['refresh_token'])
|
||||
assert token['access_token']
|
||||
assert token['refresh_token']
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
u'token_type': u'Bearer',
|
||||
|
||||
@@ -43,9 +43,9 @@ class ClientCredentialsTest(mixins.AccessTokenMixin, TestCase):
|
||||
}
|
||||
|
||||
response = self.client.post(reverse('access_token'), data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
content = json.loads(response.content.decode('utf-8'))
|
||||
access_token = content['access_token']
|
||||
self.assertEqual(content['scope'], data['scope'])
|
||||
assert content['scope'] == data['scope']
|
||||
self.assert_valid_jwt_access_token(access_token, self.user, scopes)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
Tests for DOT Adapter
|
||||
"""
|
||||
|
||||
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
import pytest
|
||||
|
||||
import ddt
|
||||
import six
|
||||
@@ -56,9 +56,8 @@ class DOTAdapterTestCase(TestCase):
|
||||
"""
|
||||
Make sure unicode representation of RestrictedApplication is correct
|
||||
"""
|
||||
self.assertEqual(six.text_type(self.restricted_app), u"<RestrictedApplication '{name}'>".format(
|
||||
name=self.restricted_client.name
|
||||
))
|
||||
assert six.text_type(self.restricted_app) == "<RestrictedApplication '{name}'>"\
|
||||
.format(name=self.restricted_client.name)
|
||||
|
||||
@ddt.data(
|
||||
('confidential', models.Application.CLIENT_CONFIDENTIAL),
|
||||
@@ -67,9 +66,9 @@ class DOTAdapterTestCase(TestCase):
|
||||
@ddt.unpack
|
||||
def test_create_client(self, client_name, client_type):
|
||||
client = getattr(self, '{}_client'.format(client_name))
|
||||
self.assertIsInstance(client, models.Application)
|
||||
self.assertEqual(client.client_id, '{}-client-id'.format(client_name))
|
||||
self.assertEqual(client.client_type, client_type)
|
||||
assert isinstance(client, models.Application)
|
||||
assert client.client_id == '{}-client-id'.format(client_name)
|
||||
assert client.client_type == client_type
|
||||
|
||||
def test_get_client(self):
|
||||
"""
|
||||
@@ -80,11 +79,11 @@ class DOTAdapterTestCase(TestCase):
|
||||
redirect_uris=DUMMY_REDIRECT_URL,
|
||||
client_type=models.Application.CLIENT_CONFIDENTIAL
|
||||
)
|
||||
self.assertIsInstance(client, models.Application)
|
||||
self.assertEqual(client.client_type, models.Application.CLIENT_CONFIDENTIAL)
|
||||
assert isinstance(client, models.Application)
|
||||
assert client.client_type == models.Application.CLIENT_CONFIDENTIAL
|
||||
|
||||
def test_get_client_not_found(self):
|
||||
with self.assertRaises(models.Application.DoesNotExist):
|
||||
with pytest.raises(models.Application.DoesNotExist):
|
||||
self.adapter.get_client(client_id='not-found')
|
||||
|
||||
def test_get_client_for_token(self):
|
||||
@@ -92,7 +91,7 @@ class DOTAdapterTestCase(TestCase):
|
||||
user=self.user,
|
||||
application=self.public_client,
|
||||
)
|
||||
self.assertEqual(self.adapter.get_client_for_token(token), self.public_client)
|
||||
assert self.adapter.get_client_for_token(token) == self.public_client
|
||||
|
||||
def test_get_access_token(self):
|
||||
token = self.adapter.create_access_token_for_test(
|
||||
@@ -101,7 +100,7 @@ class DOTAdapterTestCase(TestCase):
|
||||
user=self.user,
|
||||
expires=now() + timedelta(days=30),
|
||||
)
|
||||
self.assertEqual(self.adapter.get_access_token(token_string='token-id'), token)
|
||||
assert self.adapter.get_access_token(token_string='token-id') == token
|
||||
|
||||
def test_get_restricted_access_token(self):
|
||||
"""
|
||||
@@ -116,4 +115,4 @@ class DOTAdapterTestCase(TestCase):
|
||||
)
|
||||
|
||||
readback_token = self.adapter.get_access_token(token_string='expired-token-id')
|
||||
self.assertTrue(RestrictedApplication.verify_access_token_as_expired(readback_token))
|
||||
assert RestrictedApplication.verify_access_token_as_expired(readback_token)
|
||||
|
||||
@@ -42,17 +42,11 @@ class AuthenticateTestCase(TestCase):
|
||||
|
||||
def test_authenticate_with_username(self):
|
||||
user = self.validator._authenticate(username='darkhelmet', password='12345')
|
||||
self.assertEqual(
|
||||
self.user,
|
||||
user
|
||||
)
|
||||
assert self.user == user
|
||||
|
||||
def test_authenticate_with_email(self):
|
||||
user = self.validator._authenticate(username='darkhelmet@spaceball_one.org', password='12345')
|
||||
self.assertEqual(
|
||||
self.user,
|
||||
user
|
||||
)
|
||||
assert self.user == user
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@@ -73,15 +67,15 @@ class CustomValidationTestCase(TestCase):
|
||||
self.request_factory = RequestFactory()
|
||||
|
||||
def test_active_user_validates(self):
|
||||
self.assertTrue(self.user.is_active)
|
||||
assert self.user.is_active
|
||||
request = self.request_factory.get('/')
|
||||
self.assertTrue(self.validator.validate_user('darkhelmet', '12345', client=None, request=request))
|
||||
assert self.validator.validate_user('darkhelmet', '12345', client=None, request=request)
|
||||
|
||||
def test_inactive_user_validates(self):
|
||||
self.user.is_active = False
|
||||
self.user.save()
|
||||
request = self.request_factory.get('/')
|
||||
self.assertTrue(self.validator.validate_user('darkhelmet', '12345', client=None, request=request))
|
||||
assert self.validator.validate_user('darkhelmet', '12345', client=None, request=request)
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@@ -137,11 +131,11 @@ class CustomAuthorizationViewTestCase(TestCase):
|
||||
|
||||
def test_no_reprompting(self):
|
||||
response = self._get_authorize(scope='profile')
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertTrue(response.url.startswith(DUMMY_REDIRECT_URL))
|
||||
assert response.status_code == 302
|
||||
assert response.url.startswith(DUMMY_REDIRECT_URL)
|
||||
|
||||
def test_prompting_with_new_scope(self):
|
||||
response = self._get_authorize(scope='email')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
self.assertContains(response, settings.OAUTH2_PROVIDER['SCOPES']['email'])
|
||||
self.assertNotContains(response, settings.OAUTH2_PROVIDER['SCOPES']['profile'])
|
||||
|
||||
@@ -20,7 +20,7 @@ class TestClientFactory(TestCase):
|
||||
def test_client_factory(self):
|
||||
actual_application = factories.ApplicationFactory(user=self.user)
|
||||
expected_application = Application.objects.get(user=self.user)
|
||||
self.assertEqual(actual_application, expected_application)
|
||||
assert actual_application == expected_application
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled")
|
||||
@@ -33,7 +33,7 @@ class TestAccessTokenFactory(TestCase):
|
||||
application = factories.ApplicationFactory(user=self.user)
|
||||
actual_access_token = factories.AccessTokenFactory(user=self.user, application=application)
|
||||
expected_access_token = AccessToken.objects.get(user=self.user)
|
||||
self.assertEqual(actual_access_token, expected_access_token)
|
||||
assert actual_access_token == expected_access_token
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled")
|
||||
@@ -49,4 +49,4 @@ class TestRefreshTokenFactory(TestCase):
|
||||
user=self.user, application=application, access_token=access_token
|
||||
)
|
||||
expected_refresh_token = RefreshToken.objects.get(user=self.user, access_token=access_token)
|
||||
self.assertEqual(actual_refresh_token, expected_refresh_token)
|
||||
assert actual_refresh_token == expected_refresh_token
|
||||
|
||||
@@ -93,8 +93,8 @@ class TestCreateJWTs(AccessTokenMixin, TestCase):
|
||||
jwt_token, self.user, self.default_scopes, aud=aud, secret=secret,
|
||||
)
|
||||
self.assertDictContainsSubset(additional_claims, token_payload)
|
||||
self.assertEqual(user_email_verified, token_payload['email_verified'])
|
||||
self.assertEqual(token_payload['roles'], mock_create_roles.return_value)
|
||||
assert user_email_verified == token_payload['email_verified']
|
||||
assert token_payload['roles'] == mock_create_roles.return_value
|
||||
|
||||
def test_scopes(self):
|
||||
"""
|
||||
@@ -115,6 +115,6 @@ class TestCreateJWTs(AccessTokenMixin, TestCase):
|
||||
jwt_scopes_payload = self.assert_valid_jwt_access_token(
|
||||
jwt_scopes, self.user, scopes, aud=aud, secret=secret,
|
||||
)
|
||||
self.assertEqual(jwt_payload['scopes'], self.default_scopes)
|
||||
self.assertEqual(jwt_scopes_payload['scopes'], scopes)
|
||||
self.assertEqual(jwt_scopes_payload['user_id'], self.user.id)
|
||||
assert jwt_payload['scopes'] == self.default_scopes
|
||||
assert jwt_scopes_payload['scopes'] == scopes
|
||||
assert jwt_scopes_payload['user_id'] == self.user.id
|
||||
|
||||
@@ -29,7 +29,5 @@ class ApplicationModelScopesTestCase(TestCase):
|
||||
""" Verify the settings backend returns the expected available scopes. """
|
||||
application_access = ApplicationAccessFactory(scopes=application_scopes)
|
||||
scopes = ApplicationModelScopes()
|
||||
self.assertEqual(
|
||||
set(scopes.get_available_scopes(application_access.application)),
|
||||
set(list(settings.OAUTH2_DEFAULT_SCOPES.keys()) + expected_additional_scopes),
|
||||
)
|
||||
assert set(scopes.get_available_scopes(application_access.application)) == \
|
||||
set((list(settings.OAUTH2_DEFAULT_SCOPES.keys()) + expected_additional_scopes))
|
||||
|
||||
@@ -66,13 +66,13 @@ class AccessTokenLoginMixin(object):
|
||||
"""
|
||||
Asserts that oauth assigned access_token is valid and usable
|
||||
"""
|
||||
self.assertEqual(self.login_with_access_token(access_token=access_token).status_code, 204)
|
||||
assert self.login_with_access_token(access_token=access_token).status_code == 204
|
||||
|
||||
def _assert_access_token_invalidated(self, access_token=None):
|
||||
"""
|
||||
Asserts that oauth assigned access_token is not valid
|
||||
"""
|
||||
self.assertEqual(self.login_with_access_token(access_token=access_token).status_code, 401)
|
||||
assert self.login_with_access_token(access_token=access_token).status_code == 401
|
||||
|
||||
|
||||
@unittest.skipUnless(OAUTH_PROVIDER_ENABLED, 'OAuth2 not enabled')
|
||||
@@ -177,10 +177,10 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
"""
|
||||
client = getattr(self, client_attr)
|
||||
response = self._post_request(self.user, client, token_type=token_type, headers=headers or {})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
self.assertIn('expires_in', data)
|
||||
self.assertEqual(data['token_type'], 'JWT')
|
||||
assert 'expires_in' in data
|
||||
assert data['token_type'] == 'JWT'
|
||||
self.assert_valid_jwt_access_token(
|
||||
data['access_token'],
|
||||
self.user,
|
||||
@@ -192,29 +192,26 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
def test_access_token_fields(self, client_attr):
|
||||
client = getattr(self, client_attr)
|
||||
response = self._post_request(self.user, client)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
self.assertIn('access_token', data)
|
||||
self.assertIn('expires_in', data)
|
||||
self.assertIn('scope', data)
|
||||
self.assertIn('token_type', data)
|
||||
assert 'access_token' in data
|
||||
assert 'expires_in' in data
|
||||
assert 'scope' in data
|
||||
assert 'token_type' in data
|
||||
|
||||
def test_restricted_non_jwt_access_token_fields(self):
|
||||
response = self._post_request(self.user, self.restricted_dot_app)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
self.assertIn('access_token', data)
|
||||
self.assertIn('expires_in', data)
|
||||
self.assertIn('scope', data)
|
||||
self.assertIn('token_type', data)
|
||||
assert 'access_token' in data
|
||||
assert 'expires_in' in data
|
||||
assert 'scope' in data
|
||||
assert 'token_type' in data
|
||||
|
||||
# Verify token expiration.
|
||||
self.assertEqual(data['expires_in'] < 0, True)
|
||||
assert (data['expires_in'] < 0) is True
|
||||
access_token = dot_models.AccessToken.objects.get(token=data['access_token'])
|
||||
self.assertEqual(
|
||||
models.RestrictedApplication.verify_access_token_as_expired(access_token),
|
||||
True
|
||||
)
|
||||
assert models.RestrictedApplication.verify_access_token_as_expired(access_token) is True
|
||||
|
||||
@ddt.data('dot_app')
|
||||
def test_jwt_access_token_from_parameter(self, client_attr):
|
||||
@@ -236,7 +233,7 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
@patch('edx_django_utils.monitoring.set_custom_attribute')
|
||||
def test_access_token_attributes(self, token_type, expected_token_type, mock_set_custom_attribute):
|
||||
response = self._post_request(self.user, self.dot_app, token_type=token_type)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
expected_calls = [
|
||||
call('oauth_token_type', expected_token_type),
|
||||
call('oauth_grant_type', 'password'),
|
||||
@@ -250,7 +247,7 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
'grant_type': grant_type.replace('-', '_'),
|
||||
}
|
||||
bad_response = self.client.post(self.url, invalid_body)
|
||||
self.assertEqual(bad_response.status_code, 401)
|
||||
assert bad_response.status_code == 401
|
||||
expected_calls = [
|
||||
call('oauth_token_type', 'no_token_type_supplied'),
|
||||
call('oauth_grant_type', 'password'),
|
||||
@@ -262,12 +259,12 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
Verify that we get a restricted JWT that is not expired.
|
||||
"""
|
||||
response = self._post_request(self.user, self.restricted_dot_app, token_type='jwt')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
|
||||
self.assertIn('expires_in', data)
|
||||
assert 'expires_in' in data
|
||||
assert data['expires_in'] > 0
|
||||
self.assertEqual(data['token_type'], 'JWT')
|
||||
assert data['token_type'] == 'JWT'
|
||||
self.assert_valid_jwt_access_token(
|
||||
data['access_token'],
|
||||
self.user,
|
||||
@@ -284,14 +281,14 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
"""
|
||||
|
||||
response = self._post_request(self.user, self.restricted_dot_app)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
|
||||
self.assertIn('expires_in', data)
|
||||
self.assertIn('access_token', data)
|
||||
assert 'expires_in' in data
|
||||
assert 'access_token' in data
|
||||
|
||||
# the payload should indicate that the token is expired
|
||||
self.assertLess(data['expires_in'], 0)
|
||||
assert data['expires_in'] < 0
|
||||
|
||||
# try submitting this expired access_token to an API,
|
||||
# and assert that it fails
|
||||
@@ -299,9 +296,9 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
|
||||
def test_dot_access_token_provides_refresh_token(self):
|
||||
response = self._post_request(self.user, self.dot_app)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
self.assertIn('refresh_token', data)
|
||||
assert 'refresh_token' in data
|
||||
|
||||
@ddt.data(dot_models.Application.GRANT_CLIENT_CREDENTIALS, dot_models.Application.GRANT_PASSWORD)
|
||||
def test_jwt_access_token_scopes_and_filters(self, grant_type):
|
||||
@@ -325,7 +322,7 @@ class TestAccessTokenView(AccessTokenLoginMixin, mixins.AccessTokenMixin, _Dispa
|
||||
assert 'test:filter' in filters
|
||||
|
||||
response = self._post_request(self.user, dot_app, token_type='jwt', scope=scopes)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
self.assert_valid_jwt_access_token(
|
||||
data['access_token'],
|
||||
@@ -359,7 +356,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut
|
||||
self.oauth_client = client
|
||||
self._setup_provider_response(success=True)
|
||||
response = self._post_request(self.user, client)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
@@ -470,7 +467,8 @@ class TestAuthorizationView(_DispatchingViewTestCase):
|
||||
Check that django-oauth-toolkit gives an appropriate authorization response.
|
||||
"""
|
||||
# django-oauth-toolkit tries to redirect to the user's redirect URL
|
||||
self.assertEqual(response.status_code, 404) # We used a non-existent redirect url.
|
||||
assert response.status_code == 404
|
||||
# We used a non-existent redirect url.
|
||||
expected_redirect_prefix = u'{}?'.format(DUMMY_REDIRECT_URL)
|
||||
self._assert_startswith(self._redirect_destination(response), expected_redirect_prefix)
|
||||
|
||||
@@ -478,7 +476,7 @@ class TestAuthorizationView(_DispatchingViewTestCase):
|
||||
"""
|
||||
Assert that the string starts with the specified prefix.
|
||||
"""
|
||||
self.assertTrue(string.startswith(prefix), u'{} does not start with {}'.format(string, prefix))
|
||||
assert string.startswith(prefix), u'{} does not start with {}'.format(string, prefix)
|
||||
|
||||
@staticmethod
|
||||
def _redirect_destination(response):
|
||||
@@ -516,10 +514,10 @@ class TestViewDispatch(TestCase):
|
||||
_msg_base = u'{view} is not a view: {reason}'
|
||||
msg_not_callable = _msg_base.format(view=view_candidate, reason=u'it is not callable')
|
||||
msg_no_request = _msg_base.format(view=view_candidate, reason=u'it has no request argument')
|
||||
self.assertTrue(hasattr(view_candidate, '__call__'), msg_not_callable)
|
||||
assert hasattr(view_candidate, '__call__'), msg_not_callable
|
||||
args = view_candidate.__code__.co_varnames
|
||||
self.assertTrue(args, msg_no_request)
|
||||
self.assertEqual(args[0], 'request')
|
||||
assert args, msg_no_request
|
||||
assert args[0] == 'request'
|
||||
|
||||
def _post_request(self, client_id):
|
||||
"""
|
||||
@@ -535,19 +533,19 @@ class TestViewDispatch(TestCase):
|
||||
|
||||
def test_dispatching_post_to_dot(self):
|
||||
request = self._post_request('dot-id')
|
||||
self.assertEqual(self.view.select_backend(request), self.dot_adapter.backend)
|
||||
assert self.view.select_backend(request) == self.dot_adapter.backend
|
||||
|
||||
def test_dispatching_get_to_dot(self):
|
||||
request = self._get_request('dot-id')
|
||||
self.assertEqual(self.view.select_backend(request), self.dot_adapter.backend)
|
||||
assert self.view.select_backend(request) == self.dot_adapter.backend
|
||||
|
||||
def test_dispatching_with_no_client(self):
|
||||
request = self._post_request('')
|
||||
self.assertEqual(self.view.select_backend(request), self.dot_adapter.backend)
|
||||
assert self.view.select_backend(request) == self.dot_adapter.backend
|
||||
|
||||
def test_dispatching_with_invalid_client(self):
|
||||
request = self._post_request('abcesdfljh')
|
||||
self.assertEqual(self.view.select_backend(request), self.dot_adapter.backend)
|
||||
assert self.view.select_backend(request) == self.dot_adapter.backend
|
||||
|
||||
def test_get_view_for_dot(self):
|
||||
view_object = views.AccessTokenView()
|
||||
@@ -613,14 +611,14 @@ class TestRevokeTokenView(AccessTokenLoginMixin, _DispatchingViewTestCase): # p
|
||||
self.access_token_url,
|
||||
self.access_token_post_body_with_refresh_token(refresh_token)
|
||||
)
|
||||
self.assertEqual(response.status_code, expected_status_code)
|
||||
assert response.status_code == expected_status_code
|
||||
|
||||
def revoke_token(self, token):
|
||||
"""
|
||||
Revokes the passed access or refresh token
|
||||
"""
|
||||
response = self.client.post(self.revoke_token_url, self.revoke_token_post_body(token))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_revoke_refresh_token_dot(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user