Files
edx-platform/common/djangoapps/auth_exchange/tests/test_views.py
2015-05-27 07:35:31 -04:00

156 lines
6.0 KiB
Python

# pylint: disable=no-member
"""
Tests for OAuth token exchange views
"""
from datetime import timedelta
import json
import mock
import unittest
from django.conf import settings
from django.core.urlresolvers import reverse
from django.test import TestCase
import httpretty
import provider.constants
from provider import scope
from provider.oauth2.models import AccessToken, Client
from auth_exchange.tests.utils import AccessTokenExchangeTestMixin
from student.tests.factories import UserFactory
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeView
"""
def setUp(self):
super(AccessTokenExchangeViewTest, self).setUp()
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
def _assert_error(self, data, expected_error, expected_error_description):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response["Content-Type"], "application/json")
self.assertEqual(
json.loads(response.content),
{"error": expected_error, "error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_success(self, data, expected_scopes):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/json")
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
self.assertEqual(content["token_type"], "Bearer")
self.assertLessEqual(
timedelta(seconds=int(content["expires_in"])),
provider.constants.EXPIRE_DELTA_PUBLIC
)
self.assertEqual(content["scope"], " ".join(expected_scopes))
token = AccessToken.objects.get(token=content["access_token"])
self.assertEqual(token.user, self.user)
self.assertEqual(token.client, self.oauth_client)
self.assertEqual(scope.to_names(token.scope), expected_scopes)
def test_single_access_token(self):
def extract_token(response):
"""
Returns the access token from the response payload.
"""
return json.loads(response.content)["access_token"]
self._setup_provider_response(success=True)
for single_access_token in [True, False]:
with mock.patch(
"auth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
single_access_token
):
first_response = self.client.post(self.url, self.data)
second_response = self.client.post(self.url, self.data)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
def test_get_method(self):
response = self.client.get(self.url, self.data)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.content),
{
"error": "invalid_request",
"error_description": "Only POST requests allowed.",
}
)
def test_invalid_provider(self):
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
response = self.client.post(url, self.data)
self.assertEqual(response.status_code, 404)
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestFacebook(
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinFacebook,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Facebook
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestGoogle(
AccessTokenExchangeViewTest,
ThirdPartyOAuthTestMixinGoogle,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Google
"""
pass
@unittest.skipUnless(settings.FEATURES.get("ENABLE_OAUTH2_PROVIDER"), "OAuth2 not enabled")
class TestLoginWithAccessTokenView(TestCase):
"""
Tests for LoginWithAccessTokenView
"""
def setUp(self):
super(TestLoginWithAccessTokenView, self).setUp()
self.user = UserFactory()
self.oauth2_client = Client.objects.create(client_type=provider.constants.CONFIDENTIAL)
def _verify_response(self, access_token, expected_status_code, expected_num_cookies):
"""
Calls the login_with_access_token endpoint and verifies the response given the expected values.
"""
url = reverse("login_with_access_token")
response = self.client.post(url, HTTP_AUTHORIZATION="Bearer {0}".format(access_token))
self.assertEqual(response.status_code, expected_status_code)
self.assertEqual(len(response.cookies), expected_num_cookies)
def test_success(self):
access_token = AccessToken.objects.create(
token="test_access_token",
client=self.oauth2_client,
user=self.user,
)
self._verify_response(access_token, expected_status_code=204, expected_num_cookies=1)
self.assertEqual(len(self.client.cookies), 1)
self.assertEqual(self.client.session['_auth_user_id'], self.user.id)
def test_unauthenticated(self):
self._verify_response("invalid_token", expected_status_code=401, expected_num_cookies=0)
self.assertEqual(len(self.client.cookies), 0)
self.assertNotIn("session_key", self.client.session)