From 605a328f2496cf810d0c2841207ccb4e71b1608c Mon Sep 17 00:00:00 2001 From: Alexander Sheehan Date: Mon, 27 Jun 2022 11:25:12 -0400 Subject: [PATCH] fix: accounting for only current configs when checking for uniqueness --- common/djangoapps/third_party_auth/models.py | 2 +- .../samlproviderconfig/serializers.py | 5 ++- .../third_party_auth/tests/test_models.py | 40 +++++++++++++++++++ .../program_enrollments/api/reading.py | 8 +++- .../api/tests/test_reading.py | 11 ++++- 5 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 common/djangoapps/third_party_auth/tests/test_models.py diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index f5c5133cf3..65109a4928 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -739,7 +739,7 @@ class SAMLProviderConfig(ProviderConfig): # creating configs that share entity ID's with other enterprises # One consequence of this is that once a provider configuration is created, the slug is essentially locked in # and unchangeable. But I blame that on bad old architecture. - existing_provider_configs = SAMLProviderConfig.objects.filter( + existing_provider_configs = SAMLProviderConfig.objects.current_set().filter( entity_id=self.entity_id, archived=False, ).exclude(slug=self.slug) diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py b/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py index 6e3cdbe592..97ed8f1976 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py @@ -22,7 +22,10 @@ class SAMLProviderConfigSerializer(serializers.ModelSerializer): # lint-amnesty # are not archived, raise a validation error. We do this to prevent provider configs from sharing entity ID's # which link a provider config to provider data (SAML certificates). An entity ID therefore, is uniquely linked # to a single slug/provider config (which in the case of enterprise provider slug == customer slug). - if SAMLProviderConfig.objects.filter(entity_id=data['entity_id'], archived=False).exclude(slug=data['slug']): + if SAMLProviderConfig.objects.current_set().filter( + entity_id=data['entity_id'], + archived=False, + ).exclude(slug=data['slug']): raise serializers.ValidationError(f"Entity ID: {data['entity_id']} already taken") return data diff --git a/common/djangoapps/third_party_auth/tests/test_models.py b/common/djangoapps/third_party_auth/tests/test_models.py new file mode 100644 index 0000000000..4d9c1c946e --- /dev/null +++ b/common/djangoapps/third_party_auth/tests/test_models.py @@ -0,0 +1,40 @@ +""" +Tests for third_party_auth/models.py. +""" +import pytest +from django.test import TestCase +from django.db.utils import IntegrityError + +from .factories import SAMLProviderConfigFactory +from ..models import SAMLProviderConfig + + +class TestSamlProviderConfigModel(TestCase): + """ + Test model operations for the saml provider config model. + """ + + def setUp(self): + super().setUp() + self.saml_provider_config = SAMLProviderConfigFactory() + + def test_unique_entity_id_enforcement_for_non_current_configs(self): + """ + Test that the unique entity ID enforcement does not apply to noncurrent configs + """ + assert len(SAMLProviderConfig.objects.all()) == 1 + old_entity_id = self.saml_provider_config.entity_id + self.saml_provider_config.entity_id = f'{self.saml_provider_config.entity_id}-ayylmao' + self.saml_provider_config.save() + + # check that we now have two records, one non-current + assert len(SAMLProviderConfig.objects.all()) == 2 + assert len(SAMLProviderConfig.objects.current_set()) == 1 + + # Make sure we can use that old entity id + SAMLProviderConfigFactory(entity_id=old_entity_id) + + # Now if we try and create a new model using a current entity ID then it should throw the integrity error + with pytest.raises(IntegrityError): + bad_config = SAMLProviderConfig(entity_id=self.saml_provider_config.entity_id) + bad_config.save() diff --git a/lms/djangoapps/program_enrollments/api/reading.py b/lms/djangoapps/program_enrollments/api/reading.py index 2b2984e870..b84b83845d 100644 --- a/lms/djangoapps/program_enrollments/api/reading.py +++ b/lms/djangoapps/program_enrollments/api/reading.py @@ -564,6 +564,12 @@ def get_saml_providers_for_organization(organization): return list(provider_configs) +def remove_prefix(text, prefix): + if text.startswith(prefix): + return text[len(prefix):] + return text + + def get_provider_slug(provider_config): """ Returns slug identifying a SAML provider. @@ -573,7 +579,7 @@ def get_provider_slug(provider_config): Returns: str """ - return provider_config.provider_id.strip('saml-') + return remove_prefix(provider_config.provider_id, 'saml-') def is_course_staff_enrollment(program_course_enrollment): diff --git a/lms/djangoapps/program_enrollments/api/tests/test_reading.py b/lms/djangoapps/program_enrollments/api/tests/test_reading.py index f1fcbe87a1..d1fe0937b3 100644 --- a/lms/djangoapps/program_enrollments/api/tests/test_reading.py +++ b/lms/djangoapps/program_enrollments/api/tests/test_reading.py @@ -47,7 +47,8 @@ from ..reading import ( get_program_course_enrollment, get_program_enrollment, get_users_by_external_keys, - is_course_staff_enrollment + is_course_staff_enrollment, + get_provider_slug, ) User = get_user_model() @@ -803,3 +804,11 @@ class IsCourseStaffEnrollmentTest(TestCase): id=program_course_enrollment_id ) assert is_course_staff == is_course_staff_enrollment(program_course_enrollment) + + def test_get_provider_slug_correctly_strips(self): + list_of_providers = [] + for num_provider in range(1000): + list_of_providers.append(SAMLProviderConfigFactory(entity_id=str(num_provider))) + + for provider in list_of_providers: + assert provider.slug == get_provider_slug(provider)