Merge pull request #18237 from open-craft/cliff/sso-discovery-upstream
Add endpoint to get SAML providers for a user.
This commit is contained in:
@@ -4,15 +4,16 @@ Tests for the Third Party Auth REST API
|
||||
import unittest
|
||||
|
||||
import ddt
|
||||
from django.urls import reverse
|
||||
import six
|
||||
from django.conf import settings
|
||||
from django.http import QueryDict
|
||||
from django.test.utils import override_settings
|
||||
from django.urls import reverse
|
||||
from mock import patch
|
||||
from provider.constants import CONFIDENTIAL
|
||||
from provider.oauth2.models import Client, AccessToken
|
||||
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
|
||||
from rest_framework.test import APITestCase
|
||||
from django.conf import settings
|
||||
from django.test.utils import override_settings
|
||||
from social_django.models import UserSocialAuth
|
||||
|
||||
from student.tests.factories import UserFactory
|
||||
@@ -29,6 +30,7 @@ ALICE_USERNAME = "alice"
|
||||
CARL_USERNAME = "carl"
|
||||
STAFF_USERNAME = "staff"
|
||||
ADMIN_USERNAME = "admin"
|
||||
NONEXISTENT_USERNAME = "nobody"
|
||||
# These users will be created and linked to third party accounts:
|
||||
LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME)
|
||||
PASSWORD = "edx"
|
||||
@@ -62,9 +64,10 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
|
||||
make_staff = (username == STAFF_USERNAME) or make_superuser
|
||||
user = UserFactory.create(
|
||||
username=username,
|
||||
email='{}@example.com'.format(username),
|
||||
password=PASSWORD,
|
||||
is_staff=make_staff,
|
||||
is_superuser=make_superuser
|
||||
is_superuser=make_superuser,
|
||||
)
|
||||
UserSocialAuth.objects.create(
|
||||
user=user,
|
||||
@@ -77,15 +80,13 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
|
||||
uid='{}:remote_{}'.format(testshib.slug, username),
|
||||
)
|
||||
# Create another user not linked to any providers:
|
||||
UserFactory.create(username=CARL_USERNAME, password=PASSWORD)
|
||||
UserFactory.create(username=CARL_USERNAME, email='{}@example.com'.format(CARL_USERNAME), password=PASSWORD)
|
||||
|
||||
|
||||
@override_settings(EDX_API_KEY=VALID_API_KEY)
|
||||
@ddt.ddt
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class UserViewAPITests(TpaAPITestCase):
|
||||
class UserViewsMixin(object):
|
||||
"""
|
||||
Test the Third Party Auth User REST API
|
||||
Generic TestCase to exercise the v1 and v2 UserViews.
|
||||
"""
|
||||
|
||||
def expected_active(self, username):
|
||||
@@ -124,7 +125,7 @@ class UserViewAPITests(TpaAPITestCase):
|
||||
@ddt.unpack
|
||||
def test_list_connected_providers(self, request_user, target_user, expect_result):
|
||||
self.client.login(username=request_user, password=PASSWORD)
|
||||
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
|
||||
url = self.make_url({'username': target_user})
|
||||
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, expect_result)
|
||||
@@ -140,14 +141,87 @@ class UserViewAPITests(TpaAPITestCase):
|
||||
(None, ALICE_USERNAME, 403),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result):
|
||||
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
|
||||
def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
|
||||
url = self.make_url({'username': target_user})
|
||||
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
|
||||
self.assertEqual(response.status_code, expect_result)
|
||||
if expect_result == 200:
|
||||
self.assertIn("active", response.data)
|
||||
self.assertItemsEqual(response.data["active"], self.expected_active(target_user))
|
||||
|
||||
@ddt.data(
|
||||
(True, ALICE_USERNAME, 200, True),
|
||||
(True, CARL_USERNAME, 200, False),
|
||||
(False, ALICE_USERNAME, 200, True),
|
||||
(False, CARL_USERNAME, 403, None),
|
||||
)
|
||||
@ddt.unpack
|
||||
def test_allow_unprivileged_response(self, allow_unprivileged, requesting_user, expect, include_remote_id):
|
||||
self.client.login(username=requesting_user, password=PASSWORD)
|
||||
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
|
||||
url = self.make_url({'username': ALICE_USERNAME})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, expect)
|
||||
if response.status_code == 200:
|
||||
self.assertGreater(len(response.data['active']), 0)
|
||||
for provider_data in response.data['active']:
|
||||
self.assertEqual(include_remote_id, 'remote_id' in provider_data)
|
||||
|
||||
def test_allow_query_by_email(self):
|
||||
self.client.login(username=ALICE_USERNAME, password=PASSWORD)
|
||||
url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertGreater(len(response.data['active']), 0)
|
||||
|
||||
def test_throttling(self):
|
||||
# Default throttle is 10/min. Make 11 requests to verify
|
||||
throttling_user = UserFactory.create(password=PASSWORD)
|
||||
self.client.login(username=throttling_user.username, password=PASSWORD)
|
||||
url = self.make_url({'username': ALICE_USERNAME})
|
||||
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
|
||||
for _ in range(10):
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
|
||||
@override_settings(EDX_API_KEY=VALID_API_KEY)
|
||||
@ddt.ddt
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class UserViewAPITests(UserViewsMixin, TpaAPITestCase):
|
||||
"""
|
||||
Test the Third Party Auth User REST API
|
||||
"""
|
||||
|
||||
def make_url(self, identifier):
|
||||
"""
|
||||
Return the view URL, with the identifier provided
|
||||
"""
|
||||
return reverse(
|
||||
'third_party_auth_users_api',
|
||||
kwargs={'username': identifier.values()[0]}
|
||||
)
|
||||
|
||||
|
||||
@override_settings(EDX_API_KEY=VALID_API_KEY)
|
||||
@ddt.ddt
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
class UserViewV2APITests(UserViewsMixin, TpaAPITestCase):
|
||||
"""
|
||||
Test the Third Party Auth User REST API
|
||||
"""
|
||||
|
||||
def make_url(self, identifier):
|
||||
"""
|
||||
Return the view URL, with the identifier provided
|
||||
"""
|
||||
return '?'.join([
|
||||
reverse('third_party_auth_users_api_v2'),
|
||||
six.moves.urllib.parse.urlencode(identifier)
|
||||
])
|
||||
|
||||
|
||||
@override_settings(EDX_API_KEY=VALID_API_KEY)
|
||||
@ddt.ddt
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from django.conf import settings
|
||||
from django.conf.urls import url
|
||||
|
||||
from .views import UserMappingView, UserView
|
||||
from .views import UserMappingView, UserView, UserViewV2
|
||||
|
||||
|
||||
PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?'
|
||||
@@ -14,6 +14,7 @@ urlpatterns = [
|
||||
UserView.as_view(),
|
||||
name='third_party_auth_users_api',
|
||||
),
|
||||
url(r'^v0/users/', UserViewV2.as_view(), name='third_party_auth_users_api_v2'),
|
||||
url(
|
||||
r'^v0/providers/{provider_pattern}/users$'.format(provider_pattern=PROVIDER_PATTERN),
|
||||
UserMappingView.as_view(),
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
Third Party Auth REST API views
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.db.models import Q
|
||||
from django.http import Http404
|
||||
from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser
|
||||
from rest_framework import exceptions, status
|
||||
from rest_framework import exceptions, status, throttling
|
||||
from rest_framework.generics import ListAPIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
@@ -20,13 +24,132 @@ from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
|
||||
from third_party_auth.provider import Registry
|
||||
|
||||
|
||||
class UserView(APIView):
|
||||
class ProviderBaseThrottle(throttling.UserRateThrottle):
|
||||
"""
|
||||
Base throttle for provider queries
|
||||
"""
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Only throttle unprivileged requests.
|
||||
"""
|
||||
if view.is_unprivileged_query(request, view.get_identifier_for_requested_user(request)):
|
||||
return super(ProviderBaseThrottle, self).allow_request(request, view)
|
||||
return True
|
||||
|
||||
|
||||
class ProviderBurstThrottle(ProviderBaseThrottle):
|
||||
"""
|
||||
Maximum number of provider requests in a quick burst.
|
||||
"""
|
||||
rate = settings.TPA_PROVIDER_BURST_THROTTLE # Default '10/min'
|
||||
|
||||
|
||||
class ProviderSustainedThrottle(ProviderBaseThrottle):
|
||||
"""
|
||||
Maximum number of provider requests over time.
|
||||
"""
|
||||
rate = settings.TPA_PROVIDER_SUSTAINED_THROTTLE # Default '50/day'
|
||||
|
||||
|
||||
class BaseUserView(APIView):
|
||||
"""
|
||||
Common core of UserView and UserViewV2
|
||||
"""
|
||||
identifier = namedtuple('identifier', ['kind', 'value'])
|
||||
identifier_kinds = ['email', 'username']
|
||||
|
||||
authentication_classes = (
|
||||
# Users may want to view/edit the providers used for authentication before they've
|
||||
# activated their account, so we allow inactive users.
|
||||
OAuth2AuthenticationAllowInactiveUser,
|
||||
SessionAuthenticationAllowInactiveUser,
|
||||
)
|
||||
throttle_classes = [ProviderSustainedThrottle, ProviderBurstThrottle]
|
||||
|
||||
def do_get(self, request, identifier):
|
||||
"""
|
||||
Fulfill the request, now that the identifier has been specified.
|
||||
"""
|
||||
is_unprivileged = self.is_unprivileged_query(request, identifier)
|
||||
|
||||
if is_unprivileged:
|
||||
if not getattr(settings, 'ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False):
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
try:
|
||||
user = User.objects.get(**{identifier.kind: identifier.value})
|
||||
except User.DoesNotExist:
|
||||
return Response(status=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
providers = pipeline.get_provider_user_states(user)
|
||||
|
||||
active_providers = [
|
||||
self.get_provider_data(assoc, is_unprivileged)
|
||||
for assoc in providers if assoc.has_account
|
||||
]
|
||||
|
||||
# In the future this can be trivially modified to return the inactive/disconnected providers as well.
|
||||
|
||||
return Response({
|
||||
"active": active_providers
|
||||
})
|
||||
|
||||
def get_provider_data(self, assoc, is_unprivileged):
|
||||
"""
|
||||
Return the data for the specified provider.
|
||||
|
||||
If the request is unprivileged, do not return the remote ID of the user.
|
||||
"""
|
||||
provider_data = {
|
||||
"provider_id": assoc.provider.provider_id,
|
||||
"name": assoc.provider.name,
|
||||
}
|
||||
if not is_unprivileged:
|
||||
provider_data["remote_id"] = assoc.remote_id
|
||||
return provider_data
|
||||
|
||||
def is_unprivileged_query(self, request, identifier):
|
||||
"""
|
||||
Return True if a non-superuser requests information about another user.
|
||||
|
||||
Params must be a dict that includes only one of 'username' or 'email'
|
||||
"""
|
||||
if identifier.kind not in self.identifier_kinds:
|
||||
# This is already checked before we get here, so raise a 500 error
|
||||
# if the check fails.
|
||||
raise ValueError("Identifier kind {} not in {}".format(identifier.kind, self.identifier_kinds))
|
||||
|
||||
self_request = False
|
||||
if identifier == self.identifier('username', request.user.username):
|
||||
self_request = True
|
||||
elif identifier.kind == 'email' and getattr(identifier, 'value', object()) == request.user.email:
|
||||
# AnonymousUser does not have an email attribute, so fall back to
|
||||
# something that will never compare equal to the provided email.
|
||||
self_request = True
|
||||
if self_request:
|
||||
# We can always ask for our own provider
|
||||
return False
|
||||
# We are querying permissions for a user other than the current user.
|
||||
if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
|
||||
# The user does not have elevated permissions.
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class UserView(BaseUserView):
|
||||
"""
|
||||
List the third party auth accounts linked to the specified user account.
|
||||
|
||||
[DEPRECATED]
|
||||
|
||||
This view uses heuristics to guess whether the provided identifier is a
|
||||
username or email address. Instead, use /api/third_party_auth/v0/users/
|
||||
and specify ?username=foo or ?email=foo@exmaple.com.
|
||||
|
||||
**Example Request**
|
||||
|
||||
GET /api/third_party_auth/v0/users/{username}
|
||||
GET /api/third_party_auth/v0/users/{email@example.com}
|
||||
|
||||
**Response Values**
|
||||
|
||||
@@ -45,18 +168,11 @@ class UserView(APIView):
|
||||
is what is used to link the user to their edX account during
|
||||
login.
|
||||
"""
|
||||
authentication_classes = (
|
||||
# Users may want to view/edit the providers used for authentication before they've
|
||||
# activated their account, so we allow inactive users.
|
||||
OAuth2AuthenticationAllowInactiveUser,
|
||||
SessionAuthenticationAllowInactiveUser,
|
||||
)
|
||||
|
||||
def get(self, request, username):
|
||||
"""Create, read, or update enrollment information for a user.
|
||||
"""Read provider information for a user.
|
||||
|
||||
HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and
|
||||
updates of the current enrollment for a particular course.
|
||||
Allows reading the list of providers for a specified user.
|
||||
|
||||
Args:
|
||||
request (Request): The HTTP GET request
|
||||
@@ -66,34 +182,80 @@ class UserView(APIView):
|
||||
JSON serialized list of the providers linked to this user.
|
||||
|
||||
"""
|
||||
if request.user.username != username:
|
||||
# We are querying permissions for a user other than the current user.
|
||||
if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
|
||||
# Return a 403 (Unauthorized) without validating 'username', so that we
|
||||
# do not let users probe the existence of other user accounts.
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
identifier = self.get_identifier_for_requested_user(request)
|
||||
return self.do_get(request, identifier)
|
||||
|
||||
try:
|
||||
user = User.objects.get(username=username)
|
||||
except User.DoesNotExist:
|
||||
return Response(status=status.HTTP_404_NOT_FOUND)
|
||||
def get_identifier_for_requested_user(self, _request):
|
||||
"""
|
||||
Return an identifier namedtuple for the requested user.
|
||||
"""
|
||||
if u'@' in self.kwargs[u'username']:
|
||||
id_kind = u'email'
|
||||
else:
|
||||
id_kind = u'username'
|
||||
return self.identifier(id_kind, self.kwargs[u'username'])
|
||||
|
||||
providers = pipeline.get_provider_user_states(user)
|
||||
|
||||
active_providers = [
|
||||
{
|
||||
"provider_id": assoc.provider.provider_id,
|
||||
"name": assoc.provider.name,
|
||||
"remote_id": assoc.remote_id,
|
||||
}
|
||||
for assoc in providers if assoc.has_account
|
||||
]
|
||||
# TODO: When removing deprecated UserView, rename this view to UserView.
|
||||
class UserViewV2(BaseUserView):
|
||||
"""
|
||||
List the third party auth accounts linked to the specified user account.
|
||||
|
||||
# In the future this can be trivially modified to return the inactive/disconnected providers as well.
|
||||
**Example Request**
|
||||
|
||||
return Response({
|
||||
"active": active_providers
|
||||
})
|
||||
GET /api/third_party_auth/v0/users/?username={username}
|
||||
GET /api/third_party_auth/v0/users/?email={email@example.com}
|
||||
|
||||
**Response Values**
|
||||
|
||||
If the request for information about the user is successful, an HTTP 200 "OK" response
|
||||
is returned.
|
||||
|
||||
The HTTP 200 response has the following values.
|
||||
|
||||
* active: A list of all the third party auth providers currently linked
|
||||
to the given user's account. Each object in this list has the
|
||||
following attributes:
|
||||
|
||||
* provider_id: The unique identifier of this provider (string)
|
||||
* name: The name of this provider (string)
|
||||
* remote_id: The ID of the user according to the provider. This ID
|
||||
is what is used to link the user to their edX account during
|
||||
login.
|
||||
"""
|
||||
|
||||
def get(self, request):
|
||||
"""
|
||||
Read provider information for a user.
|
||||
|
||||
Allows reading the list of providers for a specified user.
|
||||
|
||||
Args:
|
||||
request (Request): The HTTP GET request
|
||||
|
||||
Request Parameters:
|
||||
Must provide one of 'email' or 'username'. If both are provided,
|
||||
the username will be ignored.
|
||||
|
||||
Return:
|
||||
JSON serialized list of the providers linked to this user.
|
||||
|
||||
"""
|
||||
identifier = self.get_identifier_for_requested_user(request)
|
||||
return self.do_get(request, identifier)
|
||||
|
||||
def get_identifier_for_requested_user(self, request):
|
||||
"""
|
||||
Return an identifier namedtuple for the requested user.
|
||||
"""
|
||||
identifier = None
|
||||
for id_kind in self.identifier_kinds:
|
||||
if id_kind in request.GET:
|
||||
identifier = self.identifier(id_kind, request.GET[id_kind])
|
||||
break
|
||||
if identifier is None:
|
||||
raise exceptions.ValidationError(u"Must provide one of {}".format(self.identifier_kinds))
|
||||
return identifier
|
||||
|
||||
|
||||
class UserMappingView(ListAPIView):
|
||||
@@ -195,7 +357,7 @@ class UserMappingView(ListAPIView):
|
||||
# When using multi-IdP backend, we only retrieve the ones that are for current IdP.
|
||||
# test if the current provider has a slug
|
||||
uid = self.provider.get_social_auth_uid('uid')
|
||||
if uid is not 'uid':
|
||||
if uid != 'uid':
|
||||
# if yes, we add a filter for the slug on uid column
|
||||
query_set = query_set.filter(uid__startswith=uid[:-3])
|
||||
|
||||
@@ -207,13 +369,13 @@ class UserMappingView(ListAPIView):
|
||||
if usernames:
|
||||
usernames = ','.join(usernames)
|
||||
usernames = set(usernames.split(',')) if usernames else set()
|
||||
if len(usernames):
|
||||
if usernames:
|
||||
query = query | Q(user__username__in=usernames)
|
||||
|
||||
if remote_ids:
|
||||
remote_ids = ','.join(remote_ids)
|
||||
remote_ids = set(remote_ids.split(',')) if remote_ids else set()
|
||||
if len(remote_ids):
|
||||
if remote_ids:
|
||||
query = query | Q(uid__in=[self.provider.get_social_auth_uid(remote_id) for remote_id in remote_ids])
|
||||
|
||||
return query_set.filter(query)
|
||||
|
||||
@@ -707,6 +707,9 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
|
||||
# dict with an arbitrary 'secret_key' and a 'url'.
|
||||
THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS = AUTH_TOKENS.get('THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS', {})
|
||||
|
||||
# Whether to allow unprivileged users to discover SSO providers for arbitrary usernames.
|
||||
ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY = ENV_TOKENS.get('ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False)
|
||||
|
||||
##### OAUTH2 Provider ##############
|
||||
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
|
||||
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
|
||||
@@ -722,6 +725,10 @@ if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
|
||||
OAUTH_ID_TOKEN_EXPIRATION = ENV_TOKENS.get('OAUTH_ID_TOKEN_EXPIRATION', OAUTH_ID_TOKEN_EXPIRATION)
|
||||
OAUTH_DELETE_EXPIRED = ENV_TOKENS.get('OAUTH_DELETE_EXPIRED', OAUTH_DELETE_EXPIRED)
|
||||
|
||||
##### THIRD_PARTY_AUTH #############
|
||||
TPA_PROVIDER_BURST_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_BURST_THROTTLE', TPA_PROVIDER_BURST_THROTTLE)
|
||||
TPA_PROVIDER_SUSTAINED_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_SUSTAINED_THROTTLE', TPA_PROVIDER_SUSTAINED_THROTTLE)
|
||||
|
||||
##### ADVANCED_SECURITY_CONFIG #####
|
||||
ADVANCED_SECURITY_CONFIG = ENV_TOKENS.get('ADVANCED_SECURITY_CONFIG', {})
|
||||
|
||||
|
||||
@@ -521,6 +521,10 @@ OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application'
|
||||
OAUTH_DELETE_EXPIRED = True
|
||||
OAUTH_ID_TOKEN_EXPIRATION = 60 * 60
|
||||
|
||||
################################## THIRD_PARTY_AUTH CONFIGURATION #############################
|
||||
TPA_PROVIDER_BURST_THROTTLE = '10/min'
|
||||
TPA_PROVIDER_SUSTAINED_THROTTLE = '50/hr'
|
||||
|
||||
################################## TEMPLATE CONFIGURATION #####################################
|
||||
# Mako templating
|
||||
import tempfile
|
||||
|
||||
Reference in New Issue
Block a user