From 47693769e083841c0769273a404b8acdec559816 Mon Sep 17 00:00:00 2001 From: Alexander Sheehan Date: Sun, 22 May 2022 23:20:20 -0400 Subject: [PATCH] fix: allowing for provider config fields to be provided on create/updates --- .../tests/test_samlproviderdata.py | 13 ++++++ .../samlproviderdata/views.py | 42 ++++++++++++------- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py index 7607ee5dd9..8854d320ae 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py @@ -93,6 +93,19 @@ class SAMLProviderDataTests(APITestCase): assert len(results) == 1 assert results[0]['sso_url'] == SINGLE_PROVIDER_DATA['sso_url'] + def test_get_one_provider_data_with_pk_success(self): + # GET auth/saml/v0/providerdata//?enterprise_customer_uuid=id + url_base = reverse('saml_provider_data-list') + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} + url = f'{url_base}{self.saml_provider_data.id}/?{urlencode(query_kwargs)}' + + response = self.client.get(url, format='json') + assert response.status_code == status.HTTP_200_OK + assert response.data.get('id') == self.saml_provider_data.id + assert response.data.get('entity_id') == self.saml_provider_data.entity_id + assert response.data.get('sso_url') == self.saml_provider_data.sso_url + assert response.data.get('public_key') == self.saml_provider_data.public_key + def test_create_one_provider_data_success(self): # POST auth/saml/v0/providerdata/ -d data url = reverse('saml_provider_data-list') diff --git a/common/djangoapps/third_party_auth/samlproviderdata/views.py b/common/djangoapps/third_party_auth/samlproviderdata/views.py index 59439fcd38..09ad65dfa8 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/views.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/views.py @@ -1,8 +1,9 @@ """ Viewset for auth/saml/v0/samlproviderdata """ +from datetime import datetime import logging -from requests.exceptions import SSLError, MissingSchema +from requests.exceptions import SSLError, MissingSchema, HTTPError from django.http import Http404 from django.shortcuts import get_object_or_404 @@ -71,6 +72,9 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id)) except SAMLProviderConfig.DoesNotExist: raise Http404('No matching SAML provider found.') # lint-amnesty, pylint: disable=raise-missing-from + provider_data_id = self.request.parser_context.get('kwargs').get('pk') + if provider_data_id: + return SAMLProviderData.objects.filter(id=provider_data_id) return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id) @property @@ -105,22 +109,32 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi """ entity_id = request.POST.get('entity_id') metadata_url = request.POST.get('metadata_url') + sso_url = request.POST.get('sso_url') + public_key = request.POST.get('public_key') if not entity_id: - return Response('entity_id is required!', status.HTTP_400_BAD_REQUEST) - if not metadata_url: - return Response('metadata_url is required!', status.HTTP_400_BAD_REQUEST) + return Response('entity_id is required', status.HTTP_400_BAD_REQUEST) + if not metadata_url and not (sso_url and public_key): + return Response('either metadata_url or sso and public key are required', status.HTTP_400_BAD_REQUEST) + if metadata_url and (sso_url or public_key): + return Response( + 'either metadata_url or sso and public key can be provided, not both', status.HTTP_400_BAD_REQUEST + ) - # part 1: fetch information from remote metadata based on metadataUrl in samlproviderconfig - try: - xml = fetch_metadata_xml(metadata_url) - except (SSLError, MissingSchema) as ex: - msg = f'Could not verify provider metadata url. Exc type: {type(ex).__name__}' - log.warning(msg) - return Response(msg, status.HTTP_406_NOT_ACCEPTABLE) + if metadata_url: + # part 1: fetch information from remote metadata based on metadataUrl in samlproviderconfig + try: + xml = fetch_metadata_xml(metadata_url) + except (SSLError, MissingSchema, HTTPError) as ex: + msg = f'Could not verify provider metadata url. Exc type: {type(ex).__name__}' + log.warning(msg) + return Response(msg, status.HTTP_406_NOT_ACCEPTABLE) - # part 2: create/update samlproviderdata - log.info("Processing IdP with entityID %s", entity_id) - public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id) + # part 2: create/update samlproviderdata + log.info("Processing IdP with entityID %s", entity_id) + public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id) + else: + now = datetime.now() + expires_at = now.replace(year=now.year + 10) changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at) if changed: str_message = f" Created new record for SAMLProviderData for entityID {entity_id}"