SAML Configuration API endpoint + Public flag

This commit is contained in:
Talia
2020-09-04 14:57:47 -04:00
parent a06f4dc09d
commit 9cbd1907ea
11 changed files with 251 additions and 4 deletions

View File

@@ -0,0 +1,18 @@
# Generated by Django 2.2.15 on 2020-09-02 15:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('third_party_auth', '0002_auto_20200721_1650'),
]
operations = [
migrations.AddField(
model_name='samlconfiguration',
name='is_public',
field=models.BooleanField(default=False, help_text='When checked, customers will be able to choose this SAML Configuration in the admin portal.', verbose_name='Allow customers to see and use this SAML configuration'),
),
]

View File

@@ -457,6 +457,14 @@ class SAMLConfiguration(ConfigurationModel):
"Valid keys that can be set here include: SECURITY_CONFIG and SP_EXTRA"
),
)
is_public = models.BooleanField(
default=False,
verbose_name=u"Allow customers to see and use this SAML configuration",
help_text=(
u"When checked, customers will be able to choose this SAML Configuration "
"in the admin portal."
),
)
class Meta(object):
app_label = "third_party_auth"

View File

@@ -0,0 +1,13 @@
"""
Serializer for SAMLConfiguration
"""
from rest_framework import serializers
from third_party_auth.models import SAMLConfiguration
class SAMLConfigurationSerializer(serializers.ModelSerializer):
class Meta:
model = SAMLConfiguration
fields = ('id', 'slug')

View File

@@ -0,0 +1,112 @@
"""
Tests for SAMLConfiguration endpoints
"""
import unittest
from django.urls import reverse
from django.contrib.sites.models import Site
from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.test import APITestCase
from third_party_auth.models import SAMLConfiguration
from third_party_auth.tests import testutil
SAML_CONFIGURATIONS = [
{
'site': 1,
'slug': 'testing',
'private_key': 'TestingKey',
'public_key': 'TestingKey',
'entity_id': 'example.com',
'is_public': True,
},
{
'site': 2,
'slug': 'testing2',
'private_key': 'TestingKey2',
'public_key': 'TestingKey2',
'entity_id': 'edx.example.com',
'is_public': True,
},
]
PRIV_CONFIGURATIONS = [
{
'site': 1,
'slug': 'testing3',
'private_key': 'TestingKey',
'public_key': 'TestingKey',
'entity_id': 'example.com',
'is_public': False,
},
]
TEST_PASSWORD = 'testpwd'
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class SAMLConfigurationTests(APITestCase):
"""
API Tests for SAMLConfiguration objects retrieval.
The skip annotation above exists because we currently cannot run this test in
the cms mode in CI builds, where the third_party_auth application is not loaded
"""
@classmethod
def setUpTestData(cls):
super(SAMLConfigurationTests, cls).setUpTestData()
cls.user = User.objects.create_user(username='testuser', password=TEST_PASSWORD)
cls.site, _ = Site.objects.get_or_create(domain='example.com')
for config in SAML_CONFIGURATIONS:
cls.samlconfiguration = SAMLConfiguration.objects.get_or_create(
site=cls.site,
slug=config['slug'],
private_key=config['private_key'],
public_key=config['public_key'],
entity_id=config['entity_id'],
is_public=config['is_public']
)
for config in PRIV_CONFIGURATIONS:
cls.samlconfiguration = SAMLConfiguration.objects.get_or_create(
site=cls.site,
slug=config['slug'],
private_key=config['private_key'],
public_key=config['public_key'],
entity_id=config['entity_id'],
is_public=config['is_public']
)
def setUp(self):
super(SAMLConfigurationTests, self).setUp()
self.client.login(username=self.user.username, password=TEST_PASSWORD)
def test_get_saml_configurations_successful(self):
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# We ultimately just need ids and slugs, so let's just check those.
results = response.data['results']
self.assertEqual(results[0]['id'], SAML_CONFIGURATIONS[0]['site'])
self.assertEqual(results[0]['slug'], SAML_CONFIGURATIONS[0]['slug'])
self.assertEqual(results[1]['id'], SAML_CONFIGURATIONS[1]['site'])
self.assertEqual(results[1]['slug'], SAML_CONFIGURATIONS[1]['slug'])
def test_get_saml_configurations_noprivate(self):
# Verify we have 3 saml configuration objects: 2 public, 1 private.
total_object_count = SAMLConfiguration.objects.count()
self.assertEqual(total_object_count, 3)
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
# We should only see 2 results, since 1 out of 3 are private
# and our queryset only returns public configurations.
results = response.data['results']
self.assertEqual(len(results), 2)
def test_unauthenticated_user_get_saml_configurations(self):
self.client.logout()
url = reverse('saml_configuration-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View File

@@ -0,0 +1,11 @@
"""
Viewset for auth/saml/v0/samlconfiguration/
"""
from rest_framework import routers
from .views import SAMLConfigurationViewSet
saml_configuration_router = routers.DefaultRouter()
saml_configuration_router.register(r'saml_configuration', SAMLConfigurationViewSet, basename="saml_configuration")
urlpatterns = saml_configuration_router.urls

View File

@@ -0,0 +1,31 @@
"""
Viewset for auth/saml/v0/saml_configuration
"""
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from rest_framework import permissions, viewsets
from rest_framework.authentication import SessionAuthentication
from ..models import SAMLConfiguration
from .serializers import SAMLConfigurationSerializer
class SAMLConfigurationMixin(object):
authentication_classes = (JwtAuthentication, SessionAuthentication,)
permission_classes = (permissions.IsAuthenticated,)
serializer_class = SAMLConfigurationSerializer
class SAMLConfigurationViewSet(SAMLConfigurationMixin, viewsets.ModelViewSet):
"""
A View to handle SAMLConfiguration GETs
Usage:
GET /auth/saml/v0/saml_configuration/
"""
def get_queryset(self):
"""
Find and return all saml configurations that are listed as public.
"""
return SAMLConfiguration.objects.current_set().filter(is_public=True)

View File

@@ -4,10 +4,39 @@ Serializer for SAMLProviderConfig
from rest_framework import serializers
from third_party_auth.models import SAMLProviderConfig
from third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
class SAMLProviderConfigSerializer(serializers.ModelSerializer):
saml_config_id = serializers.IntegerField(required=False)
class Meta:
model = SAMLProviderConfig
fields = '__all__'
def create(self, validated_data):
"""
Overwriting create in order to get a SAMLConfiguration object from id.
"""
if 'saml_config_id' in validated_data:
saml_configuration = SAMLConfiguration.objects.current_set().get(id=validated_data['saml_config_id'])
del validated_data['saml_config_id']
validated_data['saml_configuration'] = saml_configuration
return SAMLProviderConfig.objects.create(**validated_data)
def update(self, instance, validated_data):
if 'saml_config_id' in validated_data:
saml_configuration = SAMLConfiguration.objects.current_set().get(id=validated_data['saml_config_id'])
del validated_data['saml_config_id']
validated_data['saml_configuration'] = saml_configuration
for modifiable_field in validated_data:
setattr(
instance,
modifiable_field,
validated_data.get(modifiable_field, getattr(instance, modifiable_field))
)
instance.save()
return instance

View File

@@ -15,7 +15,7 @@ from rest_framework.test import APITestCase
from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomer
from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE
from third_party_auth.tests.samlutils import set_jwt_cookie
from third_party_auth.models import SAMLProviderConfig
from third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
from third_party_auth.tests import testutil
from third_party_auth.utils import convert_saml_slug_provider_id
@@ -26,13 +26,17 @@ SINGLE_PROVIDER_CONFIG = {
'name': 'name-of-config',
'enabled': 'true',
'slug': 'test-slug',
'country': 'https://example.customer.com/countrycode'
'country': 'https://example.customer.com/countrycode',
}
SINGLE_PROVIDER_CONFIG_2 = copy.copy(SINGLE_PROVIDER_CONFIG)
SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2'
SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2'
SINGLE_PROVIDER_CONFIG_3 = copy.copy(SINGLE_PROVIDER_CONFIG)
SINGLE_PROVIDER_CONFIG_3['name'] = 'name-of-config-3'
SINGLE_PROVIDER_CONFIG_3['slug'] = 'test-slug-3'
ENTERPRISE_ID = str(uuid4())
ENTERPRISE_ID_NON_EXISTENT = str(uuid4())
@@ -60,6 +64,11 @@ class SAMLProviderConfigTests(APITestCase):
slug=SINGLE_PROVIDER_CONFIG['slug'],
country=SINGLE_PROVIDER_CONFIG['country'],
)
cls.samlconfiguration, _ = SAMLConfiguration.objects.get_or_create(
enabled=True,
site=cls.site,
slug='edxSideTest',
)
def setUp(self):
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID)])
@@ -233,3 +242,18 @@ class SAMLProviderConfigTests(APITestCase):
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_create_one_config_with_samlconfiguration(self):
"""
POST auth/saml/v0/provider_config/ -d data
"""
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG_3)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
data['saml_config_id'] = self.samlconfiguration.id
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug'])
self.assertEqual(provider_config.saml_configuration, self.samlconfiguration)

View File

@@ -18,5 +18,6 @@ urlpatterns = [
url(r'^auth/idp_redirect/(?P<provider_slug>[\w-]+)', IdPRedirectView.as_view(), name="idp_redirect"),
url(r'^auth/', include('social_django.urls', namespace='social')),
url(r'^auth/saml/v0/', include('third_party_auth.samlproviderconfig.urls')),
url(r'^auth/saml/v0/', include('third_party_auth.samlproviderdata.urls'))
url(r'^auth/saml/v0/', include('third_party_auth.samlproviderdata.urls')),
url(r'^auth/saml/v0/', include('third_party_auth.saml_configuration.urls')),
]