fix: accounting for non-unique entity ID on idp configs + fixing provider data bulk update bug

This commit is contained in:
Alexander Sheehan
2022-06-08 14:58:26 -07:00
parent 290236390b
commit fd6b726a68
10 changed files with 238 additions and 41 deletions

View File

@@ -315,7 +315,12 @@ class UserMappingViewAPITests(TpaAPITestCase):
self._verify_response(response, expect_code, expect_data)
def test_user_mappings_only_return_requested_idp_mapping_by_provider_id(self):
testshib2 = self.configure_saml_provider(name='TestShib2', enabled=True, slug='testshib2')
testshib2 = self.configure_saml_provider(
name='TestShib2',
enabled=True,
slug='testshib2',
entity_id='entity-id-user-mapping'
)
username = 'testshib2user'
user = UserFactory.create(
username=username,

View File

@@ -37,8 +37,8 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
def setUp(self):
super().setUp()
self.provider_hogwarts = SAMLProviderConfigFactory.create(slug='hogwarts')
self.provider_durmstrang = SAMLProviderConfigFactory.create(slug='durmstrang')
self.provider_hogwarts = SAMLProviderConfigFactory.create(slug='hogwarts', entity_id='entity-id-hogwarts')
self.provider_durmstrang = SAMLProviderConfigFactory.create(slug='durmstrang', entity_id='entity-id-durmstrang')
self.user_fleur = UserFactory(username='fleur') # no social auth
self.user_harry = UserFactory(username='harry') # social auth for Hogwarts

View File

@@ -65,7 +65,14 @@ class TestSAMLCommand(CacheIsolationTestCase):
# disabled saml configuration instance, this is done to verify that disabled configurations are
# not processed.
SAMLConfigurationFactory.create(enabled=False, site__domain='testserver.fake', site__name='testserver.fake')
SAMLProviderConfigFactory.create(site__domain='testserver.fake', site__name='testserver.fake')
SAMLProviderConfigFactory.create(
site__domain='testserver.fake',
site__name='testserver.fake',
slug='test-shib',
name='TestShib College',
entity_id='https://idp.testshib.org/idp/shibboleth',
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
)
def __create_saml_configurations__(self, saml_config=None, saml_provider_config=None):
"""
@@ -74,13 +81,17 @@ class TestSAMLCommand(CacheIsolationTestCase):
SAMLConfigurationFactory.create(enabled=True, **(
saml_config or {
'site__domain': 'testserver.fake',
'site__name': 'testserver.fake'
'site__name': 'testserver.fake',
}
))
SAMLProviderConfigFactory.create(enabled=True, **(
saml_provider_config or {
'site__domain': 'testserver.fake',
'site__name': 'testserver.fake'
'site__name': 'testserver.fake',
'slug': 'test-shib',
'name': 'TestShib College',
'entity_id': 'https://idp.testshib.org/idp/shibboleth',
'metadata_source': 'https://www.testshib.org/metadata/testshib-providers.xml',
}
))

View File

@@ -12,7 +12,7 @@ from config_models.models import ConfigurationModel, cache
from django.conf import settings
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.db import models
from django.db import models, IntegrityError
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from organizations.models import Organization
@@ -581,55 +581,68 @@ class SAMLProviderConfig(ProviderConfig):
prefix = 'saml'
display_name = models.CharField(
max_length=35, blank=True,
help_text=_("A configuration nickname."))
help_text=_("A configuration nickname.")
)
backend_name = models.CharField(
max_length=50, default='tpa-saml', blank=True,
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend.")
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend."
)
entity_id = models.CharField(
max_length=255, verbose_name="Entity ID", blank=True,
help_text="Example: https://idp.testshib.org/idp/shibboleth")
help_text="Example: https://idp.testshib.org/idp/shibboleth"
)
metadata_source = models.CharField(
max_length=255, blank=True,
help_text=(
"URL to this provider's XML metadata. Should be an HTTPS URL. "
"Example: https://www.testshib.org/metadata/testshib-providers.xml"
))
)
)
attr_user_permanent_id = models.CharField(
max_length=128, blank=True, verbose_name="User ID Attribute",
help_text=(
"URN of the SAML attribute that we can use as a unique, "
"persistent user ID. Leave blank for default."
))
)
)
attr_full_name = models.CharField(
max_length=128, blank=True, verbose_name="Full Name Attribute",
help_text="URN of SAML attribute containing the user's full name. Leave blank for default.")
help_text="URN of SAML attribute containing the user's full name. Leave blank for default."
)
default_full_name = models.CharField(
max_length=255, blank=True, verbose_name="Default Value for Full Name",
help_text="Default value for full name to be used if not present in SAML response.")
help_text="Default value for full name to be used if not present in SAML response."
)
attr_first_name = models.CharField(
max_length=128, blank=True, verbose_name="First Name Attribute",
help_text="URN of SAML attribute containing the user's first name. Leave blank for default.")
help_text="URN of SAML attribute containing the user's first name. Leave blank for default."
)
default_first_name = models.CharField(
max_length=255, blank=True, verbose_name="Default Value for First Name",
help_text="Default value for first name to be used if not present in SAML response.")
help_text="Default value for first name to be used if not present in SAML response."
)
attr_last_name = models.CharField(
max_length=128, blank=True, verbose_name="Last Name Attribute",
help_text="URN of SAML attribute containing the user's last name. Leave blank for default.")
help_text="URN of SAML attribute containing the user's last name. Leave blank for default."
)
default_last_name = models.CharField(
max_length=255, blank=True, verbose_name="Default Value for Last Name",
help_text="Default value for last name to be used if not present in SAML response.")
attr_username = models.CharField(
max_length=128, blank=True, verbose_name="Username Hint Attribute",
help_text="URN of SAML attribute to use as a suggested username for this user. Leave blank for default.")
help_text="URN of SAML attribute to use as a suggested username for this user. Leave blank for default."
)
default_username = models.CharField(
max_length=255, blank=True, verbose_name="Default Value for Username",
help_text="Default value for username to be used if not present in SAML response.")
help_text="Default value for username to be used if not present in SAML response."
)
attr_email = models.CharField(
max_length=128, blank=True, verbose_name="Email Attribute",
help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.")
default_email = models.CharField(
max_length=255, blank=True, verbose_name="Default Value for Email",
help_text="Default value for email to be used if not present in SAML response.")
help_text="Default value for email to be used if not present in SAML response."
)
automatic_refresh_enabled = models.BooleanField(
default=True, verbose_name="Enable automatic metadata refresh",
help_text="When checked, the SAML provider's metadata will be included "
@@ -698,7 +711,8 @@ class SAMLProviderConfig(ProviderConfig):
'the relevant values from the SAML response. Custom provider types, as selected '
'in the "Identity Provider Type" field, may make use of the information stored '
'in this field for additional configuration.'
))
)
)
archived = models.BooleanField(default=False)
saml_configuration = models.ForeignKey(
SAMLConfiguration,
@@ -718,6 +732,23 @@ class SAMLProviderConfig(ProviderConfig):
verbose_name = "Provider Configuration (SAML IdP)"
verbose_name_plural = "Provider Configuration (SAML IdPs)"
def save(self, *args, **kwargs):
# Disallowing any new entries that have the same entity ID as an existing provider config unless the slug
# matches.
# This both allows for the old architecture to create new rows on save but also prevents enterprise users from
# creating configs that share entity ID's with other enterprises
# One consequence of this is that once a provider configuration is created, the slug is essentially locked in
# and unchangeable. But I blame that on bad old architecture.
existing_provider_configs = SAMLProviderConfig.objects.filter(
entity_id=self.entity_id,
archived=False,
).exclude(slug=self.slug)
# If any exist, raise an integrity error
if existing_provider_configs:
exc_str = f'Entity ID: {self.entity_id} already in use'
raise IntegrityError(exc_str)
super().save(*args, **kwargs)
def get_url_params(self):
""" Get a dict of GET parameters to append to login links for this provider """
return {'idp': self.slug}

View File

@@ -14,6 +14,18 @@ class SAMLProviderConfigSerializer(serializers.ModelSerializer): # lint-amnesty
model = SAMLProviderConfig
fields = '__all__'
def validate(self, data):
"""
Validate that no provider config exists with a different slug and same entity ID
"""
# If there are any existing provider configs that match the payload's entity ID, don't match the slug and
# are not archived, raise a validation error. We do this to prevent provider configs from sharing entity ID's
# which link a provider config to provider data (SAML certificates). An entity ID therefore, is uniquely linked
# to a single slug/provider config (which in the case of enterprise provider slug == customer slug).
if SAMLProviderConfig.objects.filter(entity_id=data['entity_id'], archived=False).exclude(slug=data['slug']):
raise serializers.ValidationError(f"Entity ID: {data['entity_id']} already taken")
return data
def create(self, validated_data):
"""
Overwriting create in order to get a SAMLConfiguration object from id.
@@ -27,7 +39,6 @@ class SAMLProviderConfigSerializer(serializers.ModelSerializer): # lint-amnesty
return SAMLProviderConfig.objects.create(**validated_data)
def update(self, instance, validated_data):
if 'saml_config_id' in validated_data:
saml_configuration = SAMLConfiguration.objects.current_set().get(id=validated_data['saml_config_id'])
del validated_data['saml_config_id']

View File

@@ -34,10 +34,13 @@ SINGLE_PROVIDER_CONFIG_2 = copy.copy(SINGLE_PROVIDER_CONFIG)
SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2'
SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2'
SINGLE_PROVIDER_CONFIG_2['display_name'] = 'display-name'
SINGLE_PROVIDER_CONFIG_2['entity_id'] = 'id-2'
SINGLE_PROVIDER_CONFIG_3 = copy.copy(SINGLE_PROVIDER_CONFIG)
SINGLE_PROVIDER_CONFIG_3['name'] = 'name-of-config-3'
SINGLE_PROVIDER_CONFIG_3['slug'] = 'test-slug-3'
SINGLE_PROVIDER_CONFIG_3['entity_id'] = 'id-3'
ENTERPRISE_ID = str(uuid4())
ENTERPRISE_ID_NON_EXISTENT = str(uuid4())
@@ -195,7 +198,7 @@ class SAMLProviderConfigTests(APITestCase):
"""
url = reverse('saml_provider_config-list')
provider_config_no_country = {
'entity_id': 'id',
'entity_id': 'id2',
'metadata_source': 'http://test.url',
'name': 'name-of-config-no-country',
'enabled': 'true',
@@ -214,7 +217,7 @@ class SAMLProviderConfigTests(APITestCase):
"""
url = reverse('saml_provider_config-list')
provider_config_blank_country = {
'entity_id': 'id',
'entity_id': 'id-empty-country-urn',
'metadata_source': 'http://test.url',
'name': 'name-of-config-blank-country',
'enabled': 'true',
@@ -256,3 +259,120 @@ class SAMLProviderConfigTests(APITestCase):
assert response.status_code == status.HTTP_201_CREATED
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug'])
assert provider_config.saml_configuration == self.samlconfiguration
def test_unique_entity_id_constraint_with_different_slug(self):
"""
Test that a config cannot be created with an entity ID if another config already exists with that entity ID and
a different slug
"""
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
data['slug'] = 'some-other-slug'
response = self.client.post(url, data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert len(SAMLProviderConfig.objects.all()) == 1
assert str(response.data.get('non_field_errors')[0]) == f"Entity ID: {data['entity_id']} already taken"
def test_unique_entity_id_constraint_with_same_slug(self):
"""
Test that a config can be created/edited using the same entity ID as an existing config as long as it shares an
entity ID.
"""
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
data['name'] = 'some-other-name'
response = self.client.post(url, data)
assert response.status_code == status.HTTP_201_CREATED
assert len(SAMLProviderConfig.objects.all()) == 2
assert response.data.get('name') == 'some-other-name'
def test_api_deleting_provider_configs(self):
"""
Test deleting a provider config.
"""
EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
enterprise_customer_id=ENTERPRISE_ID
)
url = reverse('saml_provider_config-list')
data = {}
data['enterprise_customer_uuid'] = ENTERPRISE_ID
response = self.client.delete(
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}'
)
assert response.status_code == status.HTTP_200_OK
assert len(SAMLProviderConfig.objects.all()) == 1
assert SAMLProviderConfig.objects.first().archived
def test_api_deleting_config_then_using_deleted_entity_id(self):
"""
Test deleting a config then creating a new config with the entity ID of the deleted config
"""
EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
enterprise_customer_id=ENTERPRISE_ID
)
url = reverse('saml_provider_config-list')
data = {}
data['enterprise_customer_uuid'] = ENTERPRISE_ID
response = self.client.delete(
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}'
)
assert response.status_code == status.HTTP_200_OK
assert len(SAMLProviderConfig.objects.all()) == 1
assert SAMLProviderConfig.objects.first().archived
data = copy.copy(SINGLE_PROVIDER_CONFIG)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
data['entity_id'] = SINGLE_PROVIDER_CONFIG['entity_id']
data['slug'] = 'idk-something-else'
response = self.client.post(url, data)
assert response.status_code == status.HTTP_201_CREATED
assert len(SAMLProviderConfig.objects.all()) == 2
def test_using_an_edited_configs_entity_id_after_deleting(self):
"""
Test that editing an existing config then removing it still allows new configs to use the deleted config's
entity ID
"""
EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
enterprise_customer_id=ENTERPRISE_ID
)
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG)
data['saml_config_id'] = self.samlconfiguration.id
data['name'] = 'a new name'
response = self.client.patch(
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}',
data,
)
assert response.status_code == status.HTTP_200_OK
assert len(SAMLProviderConfig.objects.all()) == 2
data = {}
data['enterprise_customer_uuid'] = ENTERPRISE_ID
response = self.client.delete(
url + f'{str(response.data.get("id"))}/?enterprise_customer_uuid={ENTERPRISE_ID}'
)
assert response.status_code == status.HTTP_200_OK
assert len(SAMLProviderConfig.objects.all()) == 2
data = copy.copy(SINGLE_PROVIDER_CONFIG_3)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
data['saml_config_id'] = self.samlconfiguration.id
response = self.client.post(url, data)
assert response.status_code == status.HTTP_201_CREATED
assert len(SAMLProviderConfig.objects.all()) == 3

View File

@@ -3,6 +3,7 @@ Viewset for auth/saml/v0/samlproviderconfig
"""
from django.shortcuts import get_list_or_404
from django.db.utils import IntegrityError
from edx_rbac.mixins import PermissionRequiredMixin
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from rest_framework import permissions, viewsets, status
@@ -59,7 +60,10 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
enterprise_customer__uuid=self.requested_enterprise_uuid
)
slug_list = [idp.provider_id for idp in enterprise_customer_idps]
return [config for config in SAMLProviderConfig.objects.current_set() if config.provider_id in slug_list]
saml_config_ids = [
config.id for config in SAMLProviderConfig.objects.current_set() if config.provider_id in slug_list
]
return SAMLProviderConfig.objects.filter(id__in=saml_config_ids)
def destroy(self, request, *args, **kwargs):
saml_provider_config = self.get_object()
@@ -76,7 +80,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
provider_id=provider_config_provider_id,
)
enterprise_saml_provider.delete()
saml_provider_config.delete()
SAMLProviderConfig.objects.filter(id=saml_provider_config.id).update(archived=True, enabled=False)
return Response(data=config_id, status=status.HTTP_200_OK)
@property
@@ -116,9 +120,12 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
raise ValidationError(f'Enterprise customer not found at uuid: {customer_uuid}') # lint-amnesty, pylint: disable=raise-missing-from
# Create the samlproviderconfig model first
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
try:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
except IntegrityError as exc:
return Response(str(exc), status=status.HTTP_400_BAD_REQUEST)
# Associate the enterprise customer with the provider
association_obj = EnterpriseCustomerIdentityProvider(

View File

@@ -2,12 +2,15 @@
Provides factories for third_party_auth models.
"""
import factory
from factory import SubFactory
from factory.django import DjangoModelFactory
from faker import Factory as FakerFactory
from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory
FAKER = FakerFactory.create()
class SAMLConfigurationFactory(DjangoModelFactory):
@@ -32,8 +35,8 @@ class SAMLProviderConfigFactory(DjangoModelFactory):
site = SubFactory(SiteFactory)
enabled = True
slug = "test-shib"
name = "TestShib College"
slug = factory.LazyAttribute(lambda x: FAKER.slug())
name = factory.LazyAttribute(lambda x: FAKER.company())
entity_id = "https://idp.testshib.org/idp/shibboleth"
metadata_source = "https://www.testshib.org/metadata/testshib-providers.xml"
entity_id = factory.LazyAttribute(lambda x: FAKER.uri())
metadata_source = factory.LazyAttribute(lambda x: FAKER.uri())

View File

@@ -85,7 +85,7 @@ class RegistryTest(testutil.TestCase):
self.enable_saml()
provider_count = 5
for i in range(provider_count):
self.configure_saml_provider(enabled=True, slug="saml-slug-%s" % i)
self.configure_saml_provider(enabled=True, slug=f"saml-slug-{i}", entity_id=f"saml-entity-id-{i}")
with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as cq:
enabled_slugs = {p.slug for p in provider.Registry.enabled()}

View File

@@ -172,16 +172,25 @@ def get_user_from_email(details):
def create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at):
"""
Placeholder
Method to bulk update or create provider data entries
"""
fetched_at = now()
new_records_created = False
# Create a data record for each of the public keys provided
for key in public_keys:
_, created = SAMLProviderData.objects.update_or_create(
public_key=key, entity_id=entity_id,
defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
)
existing_data_objects = SAMLProviderData.objects.filter(public_key=key, entity_id=entity_id)
if len(existing_data_objects) > 1:
for obj in existing_data_objects:
obj.sso_url = sso_url
obj.expires_at = expires_at
obj.fetched_at = fetched_at
SAMLProviderData.objects.bulk_update(existing_data_objects, ['sso_url', 'expires_at', 'fetched_at'])
return True
else:
_, created = SAMLProviderData.objects.update_or_create(
public_key=key, entity_id=entity_id,
defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
)
if created:
new_records_created = True