diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py index 3d14636d5d..aa48e0d094 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py @@ -17,6 +17,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE from third_party_auth.tests.samlutils import set_jwt_cookie from third_party_auth.models import SAMLProviderConfig from third_party_auth.tests import testutil +from third_party_auth.utils import convert_saml_slug_provider_id # country here refers to the URN provided by a user's IDP SINGLE_PROVIDER_CONFIG = { @@ -71,7 +72,7 @@ class SAMLProviderConfigTests(APITestCase): # for GET to work, we need an association present EnterpriseCustomerIdentityProvider.objects.get_or_create( - provider_id=self.samlproviderconfig.slug, + provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug), enterprise_customer_id=ENTERPRISE_ID ) urlbase = reverse('saml_provider_config-list') @@ -140,7 +141,7 @@ class SAMLProviderConfigTests(APITestCase): # check association has also been created self.assertTrue( EnterpriseCustomerIdentityProvider.objects.filter( - provider_id=provider_config.slug + provider_id=convert_saml_slug_provider_id(provider_config.slug) ).exists(), 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association' ) @@ -162,7 +163,7 @@ class SAMLProviderConfigTests(APITestCase): # check association has NOT been created self.assertFalse( EnterpriseCustomerIdentityProvider.objects.filter( - provider_id=SINGLE_PROVIDER_CONFIG_2['slug'] + provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug']) ).exists(), 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association' ) diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/views.py b/common/djangoapps/third_party_auth/samlproviderconfig/views.py index 26de21607e..605864a350 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/views.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/views.py @@ -15,6 +15,7 @@ from third_party_auth.utils import validate_uuid4_string from ..models import SAMLProviderConfig from .serializers import SAMLProviderConfigSerializer +from ..utils import convert_saml_slug_provider_id class SAMLProviderMixin(object): @@ -57,7 +58,8 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view EnterpriseCustomerIdentityProvider, enterprise_customer__uuid=self.requested_enterprise_uuid ) - return SAMLProviderConfig.objects.current_set().filter(slug=enterprise_customer_idp.provider_id) + return SAMLProviderConfig.objects.current_set().filter( + slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id)) @property def requested_enterprise_uuid(self): @@ -103,7 +105,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view # Associate the enterprise customer with the provider association_obj = EnterpriseCustomerIdentityProvider( enterprise_customer=enterprise_customer, - provider_id=serializer.data['slug'] + provider_id=convert_saml_slug_provider_id(serializer.data['slug']) ) association_obj.save() diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py index 7ba16ab72c..99eac25301 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py @@ -16,6 +16,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE from third_party_auth.tests import testutil from third_party_auth.models import SAMLProviderData, SAMLProviderConfig from third_party_auth.tests.samlutils import set_jwt_cookie +from third_party_auth.utils import convert_saml_slug_provider_id SINGLE_PROVIDER_CONFIG = { 'entity_id': 'http://entity-id-1', @@ -68,7 +69,7 @@ class SAMLProviderDataTests(APITestCase): fetched_at=SINGLE_PROVIDER_DATA['fetched_at'] ) cls.enterprise_customer_idp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create( - provider_id=cls.saml_provider_config.slug, + provider_id=convert_saml_slug_provider_id(cls.saml_provider_config.slug), enterprise_customer_id=ENTERPRISE_ID ) diff --git a/common/djangoapps/third_party_auth/samlproviderdata/views.py b/common/djangoapps/third_party_auth/samlproviderdata/views.py index e569cfe4a1..eacab0a07e 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/views.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/views.py @@ -11,7 +11,7 @@ from rest_framework.authentication import SessionAuthentication from rest_framework.exceptions import ParseError from enterprise.models import EnterpriseCustomerIdentityProvider -from third_party_auth.utils import validate_uuid4_string +from third_party_auth.utils import validate_uuid4_string, convert_saml_slug_provider_id from ..models import SAMLProviderConfig, SAMLProviderData from .serializers import SAMLProviderDataSerializer @@ -54,7 +54,8 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi enterprise_customer__uuid=self.requested_enterprise_uuid ) try: - saml_provider = SAMLProviderConfig.objects.current_set().get(slug=enterprise_customer_idp.provider_id) + saml_provider = SAMLProviderConfig.objects.current_set().get( + slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id)) except SAMLProviderConfig.DoesNotExist: raise Http404('No matching SAML provider found.') return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id) diff --git a/common/djangoapps/third_party_auth/tests/test_utils.py b/common/djangoapps/third_party_auth/tests/test_utils.py index c0e21f1b47..7d18e64e7e 100644 --- a/common/djangoapps/third_party_auth/tests/test_utils.py +++ b/common/djangoapps/third_party_auth/tests/test_utils.py @@ -9,7 +9,7 @@ from django.conf import settings from student.tests.factories import UserFactory from third_party_auth.tests.testutil import TestCase -from third_party_auth.utils import user_exists +from third_party_auth.utils import user_exists, convert_saml_slug_provider_id @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @@ -38,3 +38,18 @@ class TestUtils(TestCase): self.assertTrue( user_exists({'username': 'TesT_User'}) ) + + def test_convert_saml_slug_provider_id(self): + """ + Verify saml provider id/slug map to each other correctly. + """ + provider_names = {'saml-samltest': 'samltest', 'saml-example': 'example'} + for provider_id in provider_names: + # provider_id -> slug + self.assertEqual( + convert_saml_slug_provider_id(provider_id), provider_names[provider_id] + ) + # slug -> provider_id + self.assertEqual( + convert_saml_slug_provider_id(provider_names[provider_id]), provider_id + ) diff --git a/common/djangoapps/third_party_auth/utils.py b/common/djangoapps/third_party_auth/utils.py index ee525f0f33..457ff47726 100644 --- a/common/djangoapps/third_party_auth/utils.py +++ b/common/djangoapps/third_party_auth/utils.py @@ -30,6 +30,24 @@ def user_exists(details): return False +def convert_saml_slug_provider_id(provider): + """ + Provider id is stored with the backend type prefixed to it (ie "saml-") + Slug is stored without this prefix. + This just converts between them whenever you expect the opposite of what you currently have. + + Arguments: + provider (string): provider_id or slug + + Returns: + (string): Opposite of what you inputted (slug -> provider_id; provider_id -> slug) + """ + if provider.startswith('saml-'): + return provider[5:] + else: + return 'saml-' + provider + + def validate_uuid4_string(uuid_string): """ Returns True if valid uuid4 string, or False