fix: allowing for provider config fields to be provided on create/updates

This commit is contained in:
Alexander Sheehan
2022-05-22 23:20:20 -04:00
parent ccb635733e
commit 47693769e0
2 changed files with 41 additions and 14 deletions

View File

@@ -93,6 +93,19 @@ class SAMLProviderDataTests(APITestCase):
assert len(results) == 1
assert results[0]['sso_url'] == SINGLE_PROVIDER_DATA['sso_url']
def test_get_one_provider_data_with_pk_success(self):
# GET auth/saml/v0/providerdata/<provider data ID>/?enterprise_customer_uuid=id
url_base = reverse('saml_provider_data-list')
query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID}
url = f'{url_base}{self.saml_provider_data.id}/?{urlencode(query_kwargs)}'
response = self.client.get(url, format='json')
assert response.status_code == status.HTTP_200_OK
assert response.data.get('id') == self.saml_provider_data.id
assert response.data.get('entity_id') == self.saml_provider_data.entity_id
assert response.data.get('sso_url') == self.saml_provider_data.sso_url
assert response.data.get('public_key') == self.saml_provider_data.public_key
def test_create_one_provider_data_success(self):
# POST auth/saml/v0/providerdata/ -d data
url = reverse('saml_provider_data-list')

View File

@@ -1,8 +1,9 @@
"""
Viewset for auth/saml/v0/samlproviderdata
"""
from datetime import datetime
import logging
from requests.exceptions import SSLError, MissingSchema
from requests.exceptions import SSLError, MissingSchema, HTTPError
from django.http import Http404
from django.shortcuts import get_object_or_404
@@ -71,6 +72,9 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id))
except SAMLProviderConfig.DoesNotExist:
raise Http404('No matching SAML provider found.') # lint-amnesty, pylint: disable=raise-missing-from
provider_data_id = self.request.parser_context.get('kwargs').get('pk')
if provider_data_id:
return SAMLProviderData.objects.filter(id=provider_data_id)
return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id)
@property
@@ -105,22 +109,32 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
"""
entity_id = request.POST.get('entity_id')
metadata_url = request.POST.get('metadata_url')
sso_url = request.POST.get('sso_url')
public_key = request.POST.get('public_key')
if not entity_id:
return Response('entity_id is required!', status.HTTP_400_BAD_REQUEST)
if not metadata_url:
return Response('metadata_url is required!', status.HTTP_400_BAD_REQUEST)
return Response('entity_id is required', status.HTTP_400_BAD_REQUEST)
if not metadata_url and not (sso_url and public_key):
return Response('either metadata_url or sso and public key are required', status.HTTP_400_BAD_REQUEST)
if metadata_url and (sso_url or public_key):
return Response(
'either metadata_url or sso and public key can be provided, not both', status.HTTP_400_BAD_REQUEST
)
# part 1: fetch information from remote metadata based on metadataUrl in samlproviderconfig
try:
xml = fetch_metadata_xml(metadata_url)
except (SSLError, MissingSchema) as ex:
msg = f'Could not verify provider metadata url. Exc type: {type(ex).__name__}'
log.warning(msg)
return Response(msg, status.HTTP_406_NOT_ACCEPTABLE)
if metadata_url:
# part 1: fetch information from remote metadata based on metadataUrl in samlproviderconfig
try:
xml = fetch_metadata_xml(metadata_url)
except (SSLError, MissingSchema, HTTPError) as ex:
msg = f'Could not verify provider metadata url. Exc type: {type(ex).__name__}'
log.warning(msg)
return Response(msg, status.HTTP_406_NOT_ACCEPTABLE)
# part 2: create/update samlproviderdata
log.info("Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
# part 2: create/update samlproviderdata
log.info("Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
else:
now = datetime.now()
expires_at = now.replace(year=now.year + 10)
changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at)
if changed:
str_message = f" Created new record for SAMLProviderData for entityID {entity_id}"