From 19f82258aa9714d74f662bdc8c0286ed6a691979 Mon Sep 17 00:00:00 2001 From: Binod Pant Date: Mon, 20 Jul 2020 12:35:04 -0400 Subject: [PATCH] ENT-3160 Automate association to customer on SAMLProviderConfig creation (#24519) * create links ProviderConfig to EnterpriseCustomer * lint * remove extraneous print * don't create samlprovider unless enterprise found, update a test to use valid uuid and fail request * fix test for correct status code as was intended --- .../tests/test_samlproviderconfig.py | 56 ++++++++++++++++--- .../samlproviderconfig/views.py | 35 ++++++++++-- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py index 3b9d045d84..22b4cedfe5 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py @@ -31,6 +31,7 @@ SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2' SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2' ENTERPRISE_ID = str(uuid4()) +ENTERPRISE_ID_NON_EXISTENT = str(uuid4()) @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled') @@ -52,11 +53,8 @@ class SAMLProviderConfigTests(APITestCase): site=cls.site) cls.samlproviderconfig, _ = SAMLProviderConfig.objects.get_or_create( entity_id=SINGLE_PROVIDER_CONFIG['entity_id'], - metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source'] - ) - cls.enterprisecustomeridp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create( - provider_id=cls.samlproviderconfig.id, - enterprise_customer_id=ENTERPRISE_ID + metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source'], + slug=SINGLE_PROVIDER_CONFIG['slug'] ) def setUp(self): @@ -67,6 +65,12 @@ class SAMLProviderConfigTests(APITestCase): """ GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id """ + + # for GET to work, we need an association present + EnterpriseCustomerIdentityProvider.objects.get_or_create( + provider_id=self.samlproviderconfig.slug, + enterprise_customer_id=ENTERPRISE_ID + ) urlbase = reverse('saml_provider_config-list') query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) @@ -94,16 +98,22 @@ class SAMLProviderConfigTests(APITestCase): def test_get_one_config_by_enterprise_uuid_not_found(self): """ - GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id + GET auth/saml/v0/provider_config/?enterprise_customer_uuid=valid-but-nonexistent-uuid """ + + # the user must actually be authorized for this enterprise + # since we are testing auth passes but association to samlproviderconfig is not found + set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID_NON_EXISTENT)]) + self.client.force_authenticate(user=self.user) + urlbase = reverse('saml_provider_config-list') - query_kwargs = {'enterprise_customer_uuid': 'abc-notfound'} + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID_NON_EXISTENT} url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) orig_count = SAMLProviderConfig.objects.count() response = self.client.get(url, format='json') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(SAMLProviderConfig.objects.count(), orig_count) def test_create_one_config(self): @@ -122,6 +132,36 @@ class SAMLProviderConfigTests(APITestCase): provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug']) self.assertEqual(provider_config.name, 'name-of-config-2') + # check association has also been created + self.assertTrue( + EnterpriseCustomerIdentityProvider.objects.filter( + provider_id=provider_config.slug + ).exists(), + 'Cannot find EnterpriseCustomer-->SAMLProviderConfig association' + ) + + def test_create_one_config_fail_non_existent_enterprise_uuid(self): + """ + POST auth/saml/v0/provider_config/ -d data + """ + url = reverse('saml_provider_config-list') + data = copy.copy(SINGLE_PROVIDER_CONFIG_2) + data['enterprise_customer_uuid'] = ENTERPRISE_ID_NON_EXISTENT + orig_count = SAMLProviderConfig.objects.count() + + response = self.client.post(url, data) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(SAMLProviderConfig.objects.count(), orig_count) + + # check association has NOT been created + self.assertFalse( + EnterpriseCustomerIdentityProvider.objects.filter( + provider_id=SINGLE_PROVIDER_CONFIG_2['slug'] + ).exists(), + 'Did not expect to find EnterpriseCustomer-->SAMLProviderConfig association' + ) + def test_create_one_config_with_absent_enterprise_uuid(self): """ POST auth/saml/v0/provider_config/ -d data diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/views.py b/common/djangoapps/third_party_auth/samlproviderconfig/views.py index 8293290902..fa41e23174 100644 --- a/common/djangoapps/third_party_auth/samlproviderconfig/views.py +++ b/common/djangoapps/third_party_auth/samlproviderconfig/views.py @@ -5,11 +5,12 @@ Viewset for auth/saml/v0/samlproviderconfig from django.shortcuts import get_object_or_404 from edx_rbac.mixins import PermissionRequiredMixin from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication -from rest_framework import permissions, viewsets +from rest_framework import permissions, viewsets, status +from rest_framework.response import Response from rest_framework.authentication import SessionAuthentication -from rest_framework.exceptions import ParseError +from rest_framework.exceptions import ParseError, ValidationError -from enterprise.models import EnterpriseCustomerIdentityProvider +from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomer from third_party_auth.utils import validate_uuid4_string from ..models import SAMLProviderConfig @@ -56,7 +57,7 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view EnterpriseCustomerIdentityProvider, enterprise_customer__uuid=self.requested_enterprise_uuid ) - return SAMLProviderConfig.objects.filter(pk=enterprise_customer_idp.provider_id) + return SAMLProviderConfig.objects.filter(slug=enterprise_customer_idp.provider_id) @property def requested_enterprise_uuid(self): @@ -82,3 +83,29 @@ class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, view can access these endpoints, we have to sort out the operator role use case """ return self.requested_enterprise_uuid + + def create(self, request, *args, **kwargs): + """ + Process POST /auth/saml/v0/provider_config/ {postData} + """ + + customer_uuid = self.requested_enterprise_uuid + try: + enterprise_customer = EnterpriseCustomer.objects.get(pk=customer_uuid) + except EnterpriseCustomer.DoesNotExist: + raise ValidationError('Enterprise customer not found at uuid: {}'.format(customer_uuid)) + + # Create the samlproviderconfig model first + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + + # Associate the enterprise customer with the provider + association_obj = EnterpriseCustomerIdentityProvider( + enterprise_customer=enterprise_customer, + provider_id=serializer.data['slug'] + ) + association_obj.save() + + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)