ENT-3160 Automate association to customer on SAMLProviderConfig creation (#24519)
* create links ProviderConfig to EnterpriseCustomer * lint * remove extraneous print * don't create samlprovider unless enterprise found, update a test to use valid uuid and fail request * fix test for correct status code as was intended
This commit is contained in:
@@ -31,6 +31,7 @@ SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2'
|
||||
SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2'
|
||||
|
||||
ENTERPRISE_ID = str(uuid4())
|
||||
ENTERPRISE_ID_NON_EXISTENT = str(uuid4())
|
||||
|
||||
|
||||
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
|
||||
@@ -52,11 +53,8 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
site=cls.site)
|
||||
cls.samlproviderconfig, _ = SAMLProviderConfig.objects.get_or_create(
|
||||
entity_id=SINGLE_PROVIDER_CONFIG['entity_id'],
|
||||
metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source']
|
||||
)
|
||||
cls.enterprisecustomeridp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=cls.samlproviderconfig.id,
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source'],
|
||||
slug=SINGLE_PROVIDER_CONFIG['slug']
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@@ -67,6 +65,12 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
"""
|
||||
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id
|
||||
"""
|
||||
|
||||
# for GET to work, we need an association present
|
||||
EnterpriseCustomerIdentityProvider.objects.get_or_create(
|
||||
provider_id=self.samlproviderconfig.slug,
|
||||
enterprise_customer_id=ENTERPRISE_ID
|
||||
)
|
||||
urlbase = reverse('saml_provider_config-list')
|
||||
query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID}
|
||||
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
|
||||
@@ -94,16 +98,22 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
|
||||
def test_get_one_config_by_enterprise_uuid_not_found(self):
|
||||
"""
|
||||
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id
|
||||
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=valid-but-nonexistent-uuid
|
||||
"""
|
||||
|
||||
# the user must actually be authorized for this enterprise
|
||||
# since we are testing auth passes but association to samlproviderconfig is not found
|
||||
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)])
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
urlbase = reverse('saml_provider_config-list')
|
||||
query_kwargs = {'enterprise_customer_uuid': 'abc-notfound'}
|
||||
query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID_NON_EXISTENT}
|
||||
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
|
||||
orig_count = SAMLProviderConfig.objects.count()
|
||||
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
|
||||
|
||||
def test_create_one_config(self):
|
||||
@@ -122,6 +132,36 @@ class SAMLProviderConfigTests(APITestCase):
|
||||
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug'])
|
||||
self.assertEqual(provider_config.name, 'name-of-config-2')
|
||||
|
||||
# check association has also been created
|
||||
self.assertTrue(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=provider_config.slug
|
||||
).exists(),
|
||||
'Cannot find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
|
||||
def test_create_one_config_fail_non_existent_enterprise_uuid(self):
|
||||
"""
|
||||
POST auth/saml/v0/provider_config/ -d data
|
||||
"""
|
||||
url = reverse('saml_provider_config-list')
|
||||
data = copy.copy(SINGLE_PROVIDER_CONFIG_2)
|
||||
data['enterprise_customer_uuid'] = ENTERPRISE_ID_NON_EXISTENT
|
||||
orig_count = SAMLProviderConfig.objects.count()
|
||||
|
||||
response = self.client.post(url, data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
|
||||
|
||||
# check association has NOT been created
|
||||
self.assertFalse(
|
||||
EnterpriseCustomerIdentityProvider.objects.filter(
|
||||
provider_id=SINGLE_PROVIDER_CONFIG_2['slug']
|
||||
).exists(),
|
||||
'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association'
|
||||
)
|
||||
|
||||
def test_create_one_config_with_absent_enterprise_uuid(self):
|
||||
"""
|
||||
POST auth/saml/v0/provider_config/ -d data
|
||||
|
||||
@@ -5,11 +5,12 @@ Viewset for auth/saml/v0/samlproviderconfig
|
||||
from django.shortcuts import get_object_or_404
|
||||
from edx_rbac.mixins import PermissionRequiredMixin
|
||||
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
|
||||
from rest_framework import permissions, viewsets
|
||||
from rest_framework import permissions, viewsets, status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.exceptions import ParseError
|
||||
from rest_framework.exceptions import ParseError, ValidationError
|
||||
|
||||
from enterprise.models import EnterpriseCustomerIdentityProvider
|
||||
from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomer
|
||||
from third_party_auth.utils import validate_uuid4_string
|
||||
|
||||
from ..models import SAMLProviderConfig
|
||||
@@ -56,7 +57,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
EnterpriseCustomerIdentityProvider,
|
||||
enterprise_customer__uuid=self.requested_enterprise_uuid
|
||||
)
|
||||
return SAMLProviderConfig.objects.filter(pk=enterprise_customer_idp.provider_id)
|
||||
return SAMLProviderConfig.objects.filter(slug=enterprise_customer_idp.provider_id)
|
||||
|
||||
@property
|
||||
def requested_enterprise_uuid(self):
|
||||
@@ -82,3 +83,29 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view
|
||||
can access these endpoints, we have to sort out the operator role use case
|
||||
"""
|
||||
return self.requested_enterprise_uuid
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
"""
|
||||
Process POST /auth/saml/v0/provider_config/ {postData}
|
||||
"""
|
||||
|
||||
customer_uuid = self.requested_enterprise_uuid
|
||||
try:
|
||||
enterprise_customer = EnterpriseCustomer.objects.get(pk=customer_uuid)
|
||||
except EnterpriseCustomer.DoesNotExist:
|
||||
raise ValidationError('Enterprise customer not found at uuid: {}'.format(customer_uuid))
|
||||
|
||||
# Create the samlproviderconfig model first
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
|
||||
# Associate the enterprise customer with the provider
|
||||
association_obj = EnterpriseCustomerIdentityProvider(
|
||||
enterprise_customer=enterprise_customer,
|
||||
provider_id=serializer.data['slug']
|
||||
)
|
||||
association_obj.save()
|
||||
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
Reference in New Issue
Block a user