BUG: fixes for saml provider config/data lookup

This commit is contained in:
Talia
2020-07-30 10:05:11 -04:00
parent e4f28debb7
commit 2b956c54a0
6 changed files with 47 additions and 9 deletions

View File

@@ -17,6 +17,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE
from third_party_auth.tests.samlutils import set_jwt_cookie
from third_party_auth.models import SAMLProviderConfig
from third_party_auth.tests import testutil
from third_party_auth.utils import convert_saml_slug_provider_id
# country here refers to the URN provided by a user's IDP
SINGLE_PROVIDER_CONFIG = {
@@ -71,7 +72,7 @@ class SAMLProviderConfigTests(APITestCase):
# for GET to work, we need an association present
EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=self.samlproviderconfig.slug,
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
enterprise_customer_id=ENTERPRISE_ID
)
urlbase = reverse('saml_provider_config-list')
@@ -140,7 +141,7 @@ class SAMLProviderConfigTests(APITestCase):
# check association has also been created
self.assertTrue(
EnterpriseCustomerIdentityProvider.objects.filter(
provider_id=provider_config.slug
provider_id=convert_saml_slug_provider_id(provider_config.slug)
).exists(),
'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
)
@@ -162,7 +163,7 @@ class SAMLProviderConfigTests(APITestCase):
# check association has NOT been created
self.assertFalse(
EnterpriseCustomerIdentityProvider.objects.filter(
provider_id=SINGLE_PROVIDER_CONFIG_2['slug']
provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])
).exists(),
'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
)

View File

@@ -15,6 +15,7 @@ from third_party_auth.utils import validate_uuid4_string
from ..models import SAMLProviderConfig
from .serializers import SAMLProviderConfigSerializer
from ..utils import convert_saml_slug_provider_id
class SAMLProviderMixin(object):
@@ -57,7 +58,8 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
EnterpriseCustomerIdentityProvider,
enterprise_customer__uuid=self.requested_enterprise_uuid
)
return SAMLProviderConfig.objects.current_set().filter(slug=enterprise_customer_idp.provider_id)
return SAMLProviderConfig.objects.current_set().filter(
slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id))
@property
def requested_enterprise_uuid(self):
@@ -103,7 +105,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
# Associate the enterprise customer with the provider
association_obj = EnterpriseCustomerIdentityProvider(
enterprise_customer=enterprise_customer,
provider_id=serializer.data['slug']
provider_id=convert_saml_slug_provider_id(serializer.data['slug'])
)
association_obj.save()

View File

@@ -16,6 +16,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE
from third_party_auth.tests import testutil
from third_party_auth.models import SAMLProviderData, SAMLProviderConfig
from third_party_auth.tests.samlutils import set_jwt_cookie
from third_party_auth.utils import convert_saml_slug_provider_id
SINGLE_PROVIDER_CONFIG = {
'entity_id': 'http://entity-id-1',
@@ -68,7 +69,7 @@ class SAMLProviderDataTests(APITestCase):
fetched_at=SINGLE_PROVIDER_DATA['fetched_at']
)
cls.enterprise_customer_idp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=cls.saml_provider_config.slug,
provider_id=convert_saml_slug_provider_id(cls.saml_provider_config.slug),
enterprise_customer_id=ENTERPRISE_ID
)

View File

@@ -11,7 +11,7 @@ from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import ParseError
from enterprise.models import EnterpriseCustomerIdentityProvider
from third_party_auth.utils import validate_uuid4_string
from third_party_auth.utils import validate_uuid4_string, convert_saml_slug_provider_id
from ..models import SAMLProviderConfig, SAMLProviderData
from .serializers import SAMLProviderDataSerializer
@@ -54,7 +54,8 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
enterprise_customer__uuid=self.requested_enterprise_uuid
)
try:
saml_provider = SAMLProviderConfig.objects.current_set().get(slug=enterprise_customer_idp.provider_id)
saml_provider = SAMLProviderConfig.objects.current_set().get(
slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id))
except SAMLProviderConfig.DoesNotExist:
raise Http404('No matching SAML provider found.')
return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id)

View File

@@ -9,7 +9,7 @@ from django.conf import settings
from student.tests.factories import UserFactory
from third_party_auth.tests.testutil import TestCase
from third_party_auth.utils import user_exists
from third_party_auth.utils import user_exists, convert_saml_slug_provider_id
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@@ -38,3 +38,18 @@ class TestUtils(TestCase):
self.assertTrue(
user_exists({'username': 'TesT_User'})
)
def test_convert_saml_slug_provider_id(self):
"""
Verify saml provider id/slug map to each other correctly.
"""
provider_names = {'saml-samltest': 'samltest', 'saml-example': 'example'}
for provider_id in provider_names:
# provider_id -> slug
self.assertEqual(
convert_saml_slug_provider_id(provider_id), provider_names[provider_id]
)
# slug -> provider_id
self.assertEqual(
convert_saml_slug_provider_id(provider_names[provider_id]), provider_id
)

View File

@@ -30,6 +30,24 @@ def user_exists(details):
return False
def convert_saml_slug_provider_id(provider):
"""
Provider id is stored with the backend type prefixed to it (ie "saml-")
Slug is stored without this prefix.
This just converts between them whenever you expect the opposite of what you currently have.
Arguments:
provider (string): provider_id or slug
Returns:
(string): Opposite of what you inputted (slug -> provider_id; provider_id -> slug)
"""
if provider.startswith('saml-'):
return provider[5:]
else:
return 'saml-' + provider
def validate_uuid4_string(uuid_string):
"""
Returns True if valid uuid4 string, or False