BUG: fixes for saml provider config/data lookup
This commit is contained in:
@@ -17,6 +17,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE
|
||||
from third_party_auth.tests.samlutils import set_jwt_cookie
|
||||
from third_party_auth.models import SAMLProviderConfig
|
||||
from third_party_auth.tests import testutil
|
||||
from third_party_auth.utils import convert_saml_slug_provider_id
|
||||
|
||||
# country here refers to the URN provided by a user's IDP
|
||||
SINGLE_PROVIDER_CONFIG = {
|
||||
@@ -71,7 +72,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
# for GET to work, we need an association present
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=self.samlproviderconfig.slug,
|
||||
provider_id=convert_saml_slug_provider_id(self.samlproviderconfig.slug),
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
urlbase = reverse('saml_provider_config-list')
|
||||
@@ -140,7 +141,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
# check association has also been created
|
||||
self.assertTrue(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=provider_config.slug
|
||||
provider_id=convert_saml_slug_provider_id(provider_config.slug)
|
||||
).exists(),
|
||||
'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
@@ -162,7 +163,7 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
# check association has NOT been created
|
||||
self.assertFalse(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=SINGLE_PROVIDER_CONFIG_2['slug']
|
||||
provider_id=convert_saml_slug_provider_id(SINGLE_PROVIDER_CONFIG_2['slug'])
|
||||
).exists(),
|
||||
'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from third_party_auth.utils import validate_uuid4_string
|
||||
|
||||
from ..models import SAMLProviderConfig
|
||||
from .serializers import SAMLProviderConfigSerializer
|
||||
from ..utils import convert_saml_slug_provider_id
|
||||
|
||||
|
||||
class SAMLProviderMixin(object):
|
||||
@@ -57,7 +58,8 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
EnterpriseCustomerIdentityProvider,
|
||||
enterprise_customer__uuid=self.requested_enterprise_uuid
|
||||
)
|
||||
return SAMLProviderConfig.objects.current_set().filter(slug=enterprise_customer_idp.provider_id)
|
||||
return SAMLProviderConfig.objects.current_set().filter(
|
||||
slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id))
|
||||
|
||||
@property
|
||||
def requested_enterprise_uuid(self):
|
||||
@@ -103,7 +105,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
# Associate the enterprise customer with the provider
|
||||
association_obj = EnterpriseCustomerIdentityProvider(
|
||||
enterprise_customer=enterprise_customer,
|
||||
provider_id=serializer.data['slug']
|
||||
provider_id=convert_saml_slug_provider_id(serializer.data['slug'])
|
||||
)
|
||||
association_obj.save()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_LEARNER_ROLE
|
||||
from third_party_auth.tests import testutil
|
||||
from third_party_auth.models import SAMLProviderData, SAMLProviderConfig
|
||||
from third_party_auth.tests.samlutils import set_jwt_cookie
|
||||
from third_party_auth.utils import convert_saml_slug_provider_id
|
||||
|
||||
SINGLE_PROVIDER_CONFIG = {
|
||||
'entity_id': 'http://entity-id-1',
|
||||
@@ -68,7 +69,7 @@ class SAMLProviderDataTests(APITestCase):
|
||||
fetched_at=SINGLE_PROVIDER_DATA['fetched_at']
|
||||
)
|
||||
cls.enterprise_customer_idp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=cls.saml_provider_config.slug,
|
||||
provider_id=convert_saml_slug_provider_id(cls.saml_provider_config.slug),
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.exceptions import ParseError
|
||||
|
||||
from enterprise.models import EnterpriseCustomerIdentityProvider
|
||||
from third_party_auth.utils import validate_uuid4_string
|
||||
from third_party_auth.utils import validate_uuid4_string, convert_saml_slug_provider_id
|
||||
|
||||
from ..models import SAMLProviderConfig, SAMLProviderData
|
||||
from .serializers import SAMLProviderDataSerializer
|
||||
@@ -54,7 +54,8 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
|
||||
enterprise_customer__uuid=self.requested_enterprise_uuid
|
||||
)
|
||||
try:
|
||||
saml_provider = SAMLProviderConfig.objects.current_set().get(slug=enterprise_customer_idp.provider_id)
|
||||
saml_provider = SAMLProviderConfig.objects.current_set().get(
|
||||
slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id))
|
||||
except SAMLProviderConfig.DoesNotExist:
|
||||
raise Http404('No matching SAML provider found.')
|
||||
return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id)
|
||||
|
||||
@@ -9,7 +9,7 @@ from django.conf import settings
|
||||
|
||||
from student.tests.factories import UserFactory
|
||||
from third_party_auth.tests.testutil import TestCase
|
||||
from third_party_auth.utils import user_exists
|
||||
from third_party_auth.utils import user_exists, convert_saml_slug_provider_id
|
||||
|
||||
|
||||
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
|
||||
@@ -38,3 +38,18 @@ class TestUtils(TestCase):
|
||||
self.assertTrue(
|
||||
user_exists({'username': 'TesT_User'})
|
||||
)
|
||||
|
||||
def test_convert_saml_slug_provider_id(self):
|
||||
"""
|
||||
Verify saml provider id/slug map to each other correctly.
|
||||
"""
|
||||
provider_names = {'saml-samltest': 'samltest', 'saml-example': 'example'}
|
||||
for provider_id in provider_names:
|
||||
# provider_id -> slug
|
||||
self.assertEqual(
|
||||
convert_saml_slug_provider_id(provider_id), provider_names[provider_id]
|
||||
)
|
||||
# slug -> provider_id
|
||||
self.assertEqual(
|
||||
convert_saml_slug_provider_id(provider_names[provider_id]), provider_id
|
||||
)
|
||||
|
||||
@@ -30,6 +30,24 @@ def user_exists(details):
|
||||
return False
|
||||
|
||||
|
||||
def convert_saml_slug_provider_id(provider):
|
||||
"""
|
||||
Provider id is stored with the backend type prefixed to it (ie "saml-")
|
||||
Slug is stored without this prefix.
|
||||
This just converts between them whenever you expect the opposite of what you currently have.
|
||||
|
||||
Arguments:
|
||||
provider (string): provider_id or slug
|
||||
|
||||
Returns:
|
||||
(string): Opposite of what you inputted (slug -> provider_id; provider_id -> slug)
|
||||
"""
|
||||
if provider.startswith('saml-'):
|
||||
return provider[5:]
|
||||
else:
|
||||
return 'saml-' + provider
|
||||
|
||||
|
||||
def validate_uuid4_string(uuid_string):
|
||||
"""
|
||||
Returns True if valid uuid4 string, or False
|
||||
|
||||
Reference in New Issue
Block a user