156 lines
6.0 KiB
Python
156 lines
6.0 KiB
Python
# pylint: disable=no-member
|
|
"""
|
|
Tests for OAuth token exchange views
|
|
"""
|
|
from datetime import timedelta
|
|
import json
|
|
import mock
|
|
import unittest
|
|
|
|
from django.conf import settings
|
|
from django.core.urlresolvers import reverse
|
|
from django.test import TestCase
|
|
import httpretty
|
|
import provider.constants
|
|
from provider import scope
|
|
from provider.oauth2.models import AccessToken, Client
|
|
|
|
from auth_exchange.tests.utils import AccessTokenExchangeTestMixin
|
|
from student.tests.factories import UserFactory
|
|
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
|
|
|
|
|
|
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
|
|
"""
|
|
Mixin that defines test cases for AccessTokenExchangeView
|
|
"""
|
|
def setUp(self):
|
|
super(AccessTokenExchangeViewTest, self).setUp()
|
|
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
|
|
|
|
def _assert_error(self, data, expected_error, expected_error_description):
|
|
response = self.client.post(self.url, data)
|
|
self.assertEqual(response.status_code, 400)
|
|
self.assertEqual(response["Content-Type"], "application/json")
|
|
self.assertEqual(
|
|
json.loads(response.content),
|
|
{"error": expected_error, "error_description": expected_error_description}
|
|
)
|
|
self.assertNotIn("partial_pipeline", self.client.session)
|
|
|
|
def _assert_success(self, data, expected_scopes):
|
|
response = self.client.post(self.url, data)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response["Content-Type"], "application/json")
|
|
content = json.loads(response.content)
|
|
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
|
|
self.assertEqual(content["token_type"], "Bearer")
|
|
self.assertLessEqual(
|
|
timedelta(seconds=int(content["expires_in"])),
|
|
provider.constants.EXPIRE_DELTA_PUBLIC
|
|
)
|
|
self.assertEqual(content["scope"], " ".join(expected_scopes))
|
|
token = AccessToken.objects.get(token=content["access_token"])
|
|
self.assertEqual(token.user, self.user)
|
|
self.assertEqual(token.client, self.oauth_client)
|
|
self.assertEqual(scope.to_names(token.scope), expected_scopes)
|
|
|
|
def test_single_access_token(self):
|
|
def extract_token(response):
|
|
"""
|
|
Returns the access token from the response payload.
|
|
"""
|
|
return json.loads(response.content)["access_token"]
|
|
|
|
self._setup_provider_response(success=True)
|
|
for single_access_token in [True, False]:
|
|
with mock.patch(
|
|
"auth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
|
|
single_access_token
|
|
):
|
|
first_response = self.client.post(self.url, self.data)
|
|
second_response = self.client.post(self.url, self.data)
|
|
self.assertEqual(
|
|
extract_token(first_response) == extract_token(second_response),
|
|
single_access_token
|
|
)
|
|
|
|
def test_get_method(self):
|
|
response = self.client.get(self.url, self.data)
|
|
self.assertEqual(response.status_code, 400)
|
|
self.assertEqual(
|
|
json.loads(response.content),
|
|
{
|
|
"error": "invalid_request",
|
|
"error_description": "Only POST requests allowed.",
|
|
}
|
|
)
|
|
|
|
def test_invalid_provider(self):
|
|
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
|
|
response = self.client.post(url, self.data)
|
|
self.assertEqual(response.status_code, 404)
|
|
|
|
|
|
# This is necessary because cms does not implement third party auth
|
|
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
|
@httpretty.activate
|
|
class AccessTokenExchangeViewTestFacebook(
|
|
AccessTokenExchangeViewTest,
|
|
ThirdPartyOAuthTestMixinFacebook,
|
|
TestCase
|
|
):
|
|
"""
|
|
Tests for AccessTokenExchangeView used with Facebook
|
|
"""
|
|
pass
|
|
|
|
|
|
# This is necessary because cms does not implement third party auth
|
|
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
|
|
@httpretty.activate
|
|
class AccessTokenExchangeViewTestGoogle(
|
|
AccessTokenExchangeViewTest,
|
|
ThirdPartyOAuthTestMixinGoogle,
|
|
TestCase
|
|
):
|
|
"""
|
|
Tests for AccessTokenExchangeView used with Google
|
|
"""
|
|
pass
|
|
|
|
|
|
@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled")
|
|
class TestLoginWithAccessTokenView(TestCase):
|
|
"""
|
|
Tests for LoginWithAccessTokenView
|
|
"""
|
|
def setUp(self):
|
|
super(TestLoginWithAccessTokenView, self).setUp()
|
|
self.user = UserFactory()
|
|
self.oauth2_client = Client.objects.create(client_type=provider.constants.CONFIDENTIAL)
|
|
|
|
def _verify_response(self, access_token, expected_status_code, expected_num_cookies):
|
|
"""
|
|
Calls the login_with_access_token endpoint and verifies the response given the expected values.
|
|
"""
|
|
url = reverse("login_with_access_token")
|
|
response = self.client.post(url, HTTP_AUTHORIZATION="Bearer {0}".format(access_token))
|
|
self.assertEqual(response.status_code, expected_status_code)
|
|
self.assertEqual(len(response.cookies), expected_num_cookies)
|
|
|
|
def test_success(self):
|
|
access_token = AccessToken.objects.create(
|
|
token="test_access_token",
|
|
client=self.oauth2_client,
|
|
user=self.user,
|
|
)
|
|
self._verify_response(access_token, expected_status_code=204, expected_num_cookies=1)
|
|
self.assertEqual(len(self.client.cookies), 1)
|
|
self.assertEqual(self.client.session['_auth_user_id'], self.user.id)
|
|
|
|
def test_unauthenticated(self):
|
|
self._verify_response("invalid_token", expected_status_code=401, expected_num_cookies=0)
|
|
self.assertEqual(len(self.client.cookies), 0)
|
|
self.assertNotIn("session_key", self.client.session)
|