fix: accounting for non-unique entity ID on idp configs + fixing provider data bulk update bug
This commit is contained in:
@@ -315,7 +315,12 @@ class UserMappingViewAPITests(TpaAPITestCase):
|
||||
self._verify_response(response, expect_code, expect_data)
|
||||
|
||||
def test_user_mappings_only_return_requested_idp_mapping_by_provider_id(self):
|
||||
testshib2 = self.configure_saml_provider(name='TestShib2', enabled=True, slug='testshib2')
|
||||
testshib2 = self.configure_saml_provider(
|
||||
name='TestShib2',
|
||||
enabled=True,
|
||||
slug='testshib2',
|
||||
entity_id='entity-id-user-mapping'
|
||||
)
|
||||
username = 'testshib2user'
|
||||
user = UserFactory.create(
|
||||
username=username,
|
||||
|
||||
@@ -37,8 +37,8 @@ class TestRemoveSocialAuthUsersCommand(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.provider_hogwarts = SAMLProviderConfigFactory.create(slug='hogwarts')
|
||||
self.provider_durmstrang = SAMLProviderConfigFactory.create(slug='durmstrang')
|
||||
self.provider_hogwarts = SAMLProviderConfigFactory.create(slug='hogwarts', entity_id='entity-id-hogwarts')
|
||||
self.provider_durmstrang = SAMLProviderConfigFactory.create(slug='durmstrang', entity_id='entity-id-durmstrang')
|
||||
|
||||
self.user_fleur = UserFactory(username='fleur') # no social auth
|
||||
self.user_harry = UserFactory(username='harry') # social auth for Hogwarts
|
||||
|
||||
@@ -65,7 +65,14 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
# disabled saml configuration instance, this is done to verify that disabled configurations are
|
||||
# not processed.
|
||||
SAMLConfigurationFactory.create(enabled=False, site__domain='testserver.fake', site__name='testserver.fake')
|
||||
SAMLProviderConfigFactory.create(site__domain='testserver.fake', site__name='testserver.fake')
|
||||
SAMLProviderConfigFactory.create(
|
||||
site__domain='testserver.fake',
|
||||
site__name='testserver.fake',
|
||||
slug='test-shib',
|
||||
name='TestShib College',
|
||||
entity_id='https://idp.testshib.org/idp/shibboleth',
|
||||
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
|
||||
)
|
||||
|
||||
def __create_saml_configurations__(self, saml_config=None, saml_provider_config=None):
|
||||
"""
|
||||
@@ -74,13 +81,17 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
SAMLConfigurationFactory.create(enabled=True, **(
|
||||
saml_config or {
|
||||
'site__domain': 'testserver.fake',
|
||||
'site__name': 'testserver.fake'
|
||||
'site__name': 'testserver.fake',
|
||||
}
|
||||
))
|
||||
SAMLProviderConfigFactory.create(enabled=True, **(
|
||||
saml_provider_config or {
|
||||
'site__domain': 'testserver.fake',
|
||||
'site__name': 'testserver.fake'
|
||||
'site__name': 'testserver.fake',
|
||||
'slug': 'test-shib',
|
||||
'name': 'TestShib College',
|
||||
'entity_id': 'https://idp.testshib.org/idp/shibboleth',
|
||||
'metadata_source': 'https://www.testshib.org/metadata/testshib-providers.xml',
|
||||
}
|
||||
))
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from config_models.models import ConfigurationModel, cache
|
||||
from django.conf import settings
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db import models, IntegrityError
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from organizations.models import Organization
|
||||
@@ -581,55 +581,68 @@ class SAMLProviderConfig(ProviderConfig):
|
||||
prefix = 'saml'
|
||||
display_name = models.CharField(
|
||||
max_length=35, blank=True,
|
||||
help_text=_("A configuration nickname."))
|
||||
help_text=_("A configuration nickname.")
|
||||
)
|
||||
backend_name = models.CharField(
|
||||
max_length=50, default='tpa-saml', blank=True,
|
||||
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend.")
|
||||
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend."
|
||||
)
|
||||
entity_id = models.CharField(
|
||||
max_length=255, verbose_name="Entity ID", blank=True,
|
||||
help_text="Example: https://idp.testshib.org/idp/shibboleth")
|
||||
help_text="Example: https://idp.testshib.org/idp/shibboleth"
|
||||
)
|
||||
metadata_source = models.CharField(
|
||||
max_length=255, blank=True,
|
||||
help_text=(
|
||||
"URL to this provider's XML metadata. Should be an HTTPS URL. "
|
||||
"Example: https://www.testshib.org/metadata/testshib-providers.xml"
|
||||
))
|
||||
)
|
||||
)
|
||||
attr_user_permanent_id = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="User ID Attribute",
|
||||
help_text=(
|
||||
"URN of the SAML attribute that we can use as a unique, "
|
||||
"persistent user ID. Leave blank for default."
|
||||
))
|
||||
)
|
||||
)
|
||||
attr_full_name = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="Full Name Attribute",
|
||||
help_text="URN of SAML attribute containing the user's full name. Leave blank for default.")
|
||||
help_text="URN of SAML attribute containing the user's full name. Leave blank for default."
|
||||
)
|
||||
default_full_name = models.CharField(
|
||||
max_length=255, blank=True, verbose_name="Default Value for Full Name",
|
||||
help_text="Default value for full name to be used if not present in SAML response.")
|
||||
help_text="Default value for full name to be used if not present in SAML response."
|
||||
)
|
||||
attr_first_name = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="First Name Attribute",
|
||||
help_text="URN of SAML attribute containing the user's first name. Leave blank for default.")
|
||||
help_text="URN of SAML attribute containing the user's first name. Leave blank for default."
|
||||
)
|
||||
default_first_name = models.CharField(
|
||||
max_length=255, blank=True, verbose_name="Default Value for First Name",
|
||||
help_text="Default value for first name to be used if not present in SAML response.")
|
||||
help_text="Default value for first name to be used if not present in SAML response."
|
||||
)
|
||||
attr_last_name = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="Last Name Attribute",
|
||||
help_text="URN of SAML attribute containing the user's last name. Leave blank for default.")
|
||||
help_text="URN of SAML attribute containing the user's last name. Leave blank for default."
|
||||
)
|
||||
default_last_name = models.CharField(
|
||||
max_length=255, blank=True, verbose_name="Default Value for Last Name",
|
||||
help_text="Default value for last name to be used if not present in SAML response.")
|
||||
attr_username = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="Username Hint Attribute",
|
||||
help_text="URN of SAML attribute to use as a suggested username for this user. Leave blank for default.")
|
||||
help_text="URN of SAML attribute to use as a suggested username for this user. Leave blank for default."
|
||||
)
|
||||
default_username = models.CharField(
|
||||
max_length=255, blank=True, verbose_name="Default Value for Username",
|
||||
help_text="Default value for username to be used if not present in SAML response.")
|
||||
help_text="Default value for username to be used if not present in SAML response."
|
||||
)
|
||||
attr_email = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="Email Attribute",
|
||||
help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.")
|
||||
default_email = models.CharField(
|
||||
max_length=255, blank=True, verbose_name="Default Value for Email",
|
||||
help_text="Default value for email to be used if not present in SAML response.")
|
||||
help_text="Default value for email to be used if not present in SAML response."
|
||||
)
|
||||
automatic_refresh_enabled = models.BooleanField(
|
||||
default=True, verbose_name="Enable automatic metadata refresh",
|
||||
help_text="When checked, the SAML provider's metadata will be included "
|
||||
@@ -698,7 +711,8 @@ class SAMLProviderConfig(ProviderConfig):
|
||||
'the relevant values from the SAML response. Custom provider types, as selected '
|
||||
'in the "Identity Provider Type" field, may make use of the information stored '
|
||||
'in this field for additional configuration.'
|
||||
))
|
||||
)
|
||||
)
|
||||
archived = models.BooleanField(default=False)
|
||||
saml_configuration = models.ForeignKey(
|
||||
SAMLConfiguration,
|
||||
@@ -718,6 +732,23 @@ class SAMLProviderConfig(ProviderConfig):
|
||||
verbose_name = "Provider Configuration (SAML IdP)"
|
||||
verbose_name_plural = "Provider Configuration (SAML IdPs)"
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
# Disallowing any new entries that have the same entity ID as an existing provider config unless the slug
|
||||
# matches.
|
||||
# This both allows for the old architecture to create new rows on save but also prevents enterprise users from
|
||||
# creating configs that share entity ID's with other enterprises
|
||||
# One consequence of this is that once a provider configuration is created, the slug is essentially locked in
|
||||
# and unchangeable. But I blame that on bad old architecture.
|
||||
existing_provider_configs = SAMLProviderConfig.objects.filter(
|
||||
entity_id=self.entity_id,
|
||||
archived=False,
|
||||
).exclude(slug=self.slug)
|
||||
# If any exist, raise an integrity error
|
||||
if existing_provider_configs:
|
||||
exc_str = f'Entity ID: {self.entity_id} already in use'
|
||||
raise IntegrityError(exc_str)
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
def get_url_params(self):
|
||||
""" Get a dict of GET parameters to append to login links for this provider """
|
||||
return {'idp': self.slug}
|
||||
|
||||
@@ -14,6 +14,18 @@ class SAMLProviderConfigSerializer(serializers.ModelSerializer): # lint-amnesty
|
||||
model = SAMLProviderConfig
|
||||
fields = '__all__'
|
||||
|
||||
def validate(self, data):
|
||||
"""
|
||||
Validate that no provider config exists with a different slug and same entity ID
|
||||
"""
|
||||
# If there are any existing provider configs that match the payload's entity ID, don't match the slug and
|
||||
# are not archived, raise a validation error. We do this to prevent provider configs from sharing entity ID's
|
||||
# which link a provider config to provider data (SAML certificates). An entity ID therefore, is uniquely linked
|
||||
# to a single slug/provider config (which in the case of enterprise provider slug == customer slug).
|
||||
if SAMLProviderConfig.objects.filter(entity_id=data['entity_id'], archived=False).exclude(slug=data['slug']):
|
||||
raise serializers.ValidationError(f"Entity ID: {data['entity_id']} already taken")
|
||||
return data
|
||||
|
||||
def create(self, validated_data):
|
||||
"""
|
||||
Overwriting create in order to get a SAMLConfiguration object from id.
|
||||
@@ -27,7 +39,6 @@ class SAMLProviderConfigSerializer(serializers.ModelSerializer): # lint-amnesty
|
||||
return SAMLProviderConfig.objects.create(**validated_data)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
|
||||
if 'saml_config_id' in validated_data:
|
||||
saml_configuration = SAMLConfiguration.objects.current_set().get(id=validated_data['saml_config_id'])
|
||||
del validated_data['saml_config_id']
|
||||
|
||||
@@ -34,10 +34,13 @@ SINGLE_PROVIDER_CONFIG_2 = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2'
|
||||
SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2'
|
||||
SINGLE_PROVIDER_CONFIG_2['display_name'] = 'display-name'
|
||||
SINGLE_PROVIDER_CONFIG_2['entity_id'] = 'id-2'
|
||||
|
||||
SINGLE_PROVIDER_CONFIG_3 = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
SINGLE_PROVIDER_CONFIG_3['name'] = 'name-of-config-3'
|
||||
SINGLE_PROVIDER_CONFIG_3['slug'] = 'test-slug-3'
|
||||
SINGLE_PROVIDER_CONFIG_3['entity_id'] = 'id-3'
|
||||
|
||||
|
||||
ENTERPRISE_ID = str(uuid4())
|
||||
ENTERPRISE_ID_NON_EXISTENT = str(uuid4())
|
||||
@@ -195,7 +198,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
"""
|
||||
url = reverse('saml_provider_config-list')
|
||||
provider_config_no_country = {
|
||||
'entity_id': 'id',
|
||||
'entity_id': 'id2',
|
||||
'metadata_source': 'http://test.url',
|
||||
'name': 'name-of-config-no-country',
|
||||
'enabled': 'true',
|
||||
@@ -214,7 +217,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
"""
|
||||
url = reverse('saml_provider_config-list')
|
||||
provider_config_blank_country = {
|
||||
'entity_id': 'id',
|
||||
'entity_id': 'id-empty-country-urn',
|
||||
'metadata_source': 'http://test.url',
|
||||
'name': 'name-of-config-blank-country',
|
||||
'enabled': 'true',
|
||||
@@ -256,3 +259,120 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_3['slug'])
|
||||
assert provider_config.saml_configuration == self.samlconfiguration
|
||||
|
||||
def test_unique_entity_id_constraint_with_different_slug(self):
|
||||
"""
|
||||
Test that a config cannot be created with an entity ID if another config already exists with that entity ID and
|
||||
a different slug
|
||||
"""
|
||||
url = reverse('saml_provider_config-list')
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
data['slug'] = 'some-other-slug'
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert len(SAMLProviderConfig.objects.all()) == 1
|
||||
assert str(response.data.get('non_field_errors')[0]) == f"Entity ID: {data['entity_id']} already taken"
|
||||
|
||||
def test_unique_entity_id_constraint_with_same_slug(self):
|
||||
"""
|
||||
Test that a config can be created/edited using the same entity ID as an existing config as long as it shares an
|
||||
entity ID.
|
||||
"""
|
||||
url = reverse('saml_provider_config-list')
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
data['name'] = 'some-other-name'
|
||||
|
||||
response = self.client.post(url, data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert len(SAMLProviderConfig.objects.all()) == 2
|
||||
assert response.data.get('name') == 'some-other-name'
|
||||
|
||||
def test_api_deleting_provider_configs(self):
|
||||
"""
|
||||
Test deleting a provider config.
|
||||
"""
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
url = reverse('saml_provider_config-list')
|
||||
data = {}
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
|
||||
response = self.client.delete(
|
||||
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}'
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(SAMLProviderConfig.objects.all()) == 1
|
||||
assert SAMLProviderConfig.objects.first().archived
|
||||
|
||||
def test_api_deleting_config_then_using_deleted_entity_id(self):
|
||||
"""
|
||||
Test deleting a config then creating a new config with the entity ID of the deleted config
|
||||
"""
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
url = reverse('saml_provider_config-list')
|
||||
data = {}
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
|
||||
response = self.client.delete(
|
||||
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}'
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(SAMLProviderConfig.objects.all()) == 1
|
||||
assert SAMLProviderConfig.objects.first().archived
|
||||
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
data['entity_id'] = SINGLE_PROVIDER_CONFIG['entity_id']
|
||||
data['slug'] = 'idk-something-else'
|
||||
|
||||
response = self.client.post(url, data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert len(SAMLProviderConfig.objects.all()) == 2
|
||||
|
||||
def test_using_an_edited_configs_entity_id_after_deleting(self):
|
||||
"""
|
||||
Test that editing an existing config then removing it still allows new configs to use the deleted config's
|
||||
entity ID
|
||||
"""
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
url = reverse('saml_provider_config-list')
|
||||
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG)
|
||||
data['saml_config_id'] = self.samlconfiguration.id
|
||||
data['name'] = 'a new name'
|
||||
|
||||
response = self.client.patch(
|
||||
url + f'{str(self.samlproviderconfig.id)}/?enterprise_customer_uuid={ENTERPRISE_ID}',
|
||||
data,
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(SAMLProviderConfig.objects.all()) == 2
|
||||
|
||||
data = {}
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
response = self.client.delete(
|
||||
url + f'{str(response.data.get("id"))}/?enterprise_customer_uuid={ENTERPRISE_ID}'
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(SAMLProviderConfig.objects.all()) == 2
|
||||
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG_3)
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID
|
||||
data['saml_config_id'] = self.samlconfiguration.id
|
||||
|
||||
response = self.client.post(url, data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert len(SAMLProviderConfig.objects.all()) == 3
|
||||
|
||||
@@ -3,6 +3,7 @@ Viewset for auth/saml/v0/samlproviderconfig
|
||||
"""
|
||||
|
||||
from django.shortcuts import get_list_or_404
|
||||
from django.db.utils import IntegrityError
|
||||
from edx_rbac.mixins import PermissionRequiredMixin
|
||||
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
|
||||
from rest_framework import permissions, viewsets, status
|
||||
@@ -59,7 +60,10 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
enterprise_customer__uuid=self.requested_enterprise_uuid
|
||||
)
|
||||
slug_list = [idp.provider_id for idp in enterprise_customer_idps]
|
||||
return [config for config in SAMLProviderConfig.objects.current_set() if config.provider_id in slug_list]
|
||||
saml_config_ids = [
|
||||
config.id for config in SAMLProviderConfig.objects.current_set() if config.provider_id in slug_list
|
||||
]
|
||||
return SAMLProviderConfig.objects.filter(id__in=saml_config_ids)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
saml_provider_config = self.get_object()
|
||||
@@ -76,7 +80,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
provider_id=provider_config_provider_id,
|
||||
)
|
||||
enterprise_saml_provider.delete()
|
||||
saml_provider_config.delete()
|
||||
SAMLProviderConfig.objects.filter(id=saml_provider_config.id).update(archived=True, enabled=False)
|
||||
return Response(data=config_id, status=status.HTTP_200_OK)
|
||||
|
||||
@property
|
||||
@@ -116,9 +120,12 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
raise ValidationError(f'Enterprise customer not found at uuid: {customer_uuid}') # lint-amnesty, pylint: disable=raise-missing-from
|
||||
|
||||
# Create the samlproviderconfig model first
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
try:
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
except IntegrityError as exc:
|
||||
return Response(str(exc), status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Associate the enterprise customer with the provider
|
||||
association_obj = EnterpriseCustomerIdentityProvider(
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
Provides factories for third_party_auth models.
|
||||
"""
|
||||
|
||||
|
||||
import factory
|
||||
from factory import SubFactory
|
||||
from factory.django import DjangoModelFactory
|
||||
from faker import Factory as FakerFactory
|
||||
|
||||
from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory
|
||||
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
|
||||
from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory
|
||||
|
||||
FAKER = FakerFactory.create()
|
||||
|
||||
|
||||
class SAMLConfigurationFactory(DjangoModelFactory):
|
||||
@@ -32,8 +35,8 @@ class SAMLProviderConfigFactory(DjangoModelFactory):
|
||||
site = SubFactory(SiteFactory)
|
||||
|
||||
enabled = True
|
||||
slug = "test-shib"
|
||||
name = "TestShib College"
|
||||
slug = factory.LazyAttribute(lambda x: FAKER.slug())
|
||||
name = factory.LazyAttribute(lambda x: FAKER.company())
|
||||
|
||||
entity_id = "https://idp.testshib.org/idp/shibboleth"
|
||||
metadata_source = "https://www.testshib.org/metadata/testshib-providers.xml"
|
||||
entity_id = factory.LazyAttribute(lambda x: FAKER.uri())
|
||||
metadata_source = factory.LazyAttribute(lambda x: FAKER.uri())
|
||||
|
||||
@@ -85,7 +85,7 @@ class RegistryTest(testutil.TestCase):
|
||||
self.enable_saml()
|
||||
provider_count = 5
|
||||
for i in range(provider_count):
|
||||
self.configure_saml_provider(enabled=True, slug="saml-slug-%s" % i)
|
||||
self.configure_saml_provider(enabled=True, slug=f"saml-slug-{i}", entity_id=f"saml-entity-id-{i}")
|
||||
|
||||
with CaptureQueriesContext(connections[DEFAULT_DB_ALIAS]) as cq:
|
||||
enabled_slugs = {p.slug for p in provider.Registry.enabled()}
|
||||
|
||||
@@ -172,16 +172,25 @@ def get_user_from_email(details):
|
||||
|
||||
def create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at):
|
||||
"""
|
||||
Placeholder
|
||||
Method to bulk update or create provider data entries
|
||||
"""
|
||||
fetched_at = now()
|
||||
new_records_created = False
|
||||
# Create a data record for each of the public keys provided
|
||||
for key in public_keys:
|
||||
_, created = SAMLProviderData.objects.update_or_create(
|
||||
public_key=key, entity_id=entity_id,
|
||||
defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
|
||||
)
|
||||
existing_data_objects = SAMLProviderData.objects.filter(public_key=key, entity_id=entity_id)
|
||||
if len(existing_data_objects) > 1:
|
||||
for obj in existing_data_objects:
|
||||
obj.sso_url = sso_url
|
||||
obj.expires_at = expires_at
|
||||
obj.fetched_at = fetched_at
|
||||
SAMLProviderData.objects.bulk_update(existing_data_objects, ['sso_url', 'expires_at', 'fetched_at'])
|
||||
return True
|
||||
else:
|
||||
_, created = SAMLProviderData.objects.update_or_create(
|
||||
public_key=key, entity_id=entity_id,
|
||||
defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
|
||||
)
|
||||
if created:
|
||||
new_records_created = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user