fix: allowing for multiple idp data configs

This commit is contained in:
Alexander Sheehan
2022-06-03 17:45:27 -04:00
parent 47693769e0
commit 8d6e041d7e
7 changed files with 95 additions and 49 deletions

View File

@@ -108,8 +108,12 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
""" Do we have cached metadata for this SAML provider? """
if not inst.is_active:
return None # N/A
data = SAMLProviderData.current(inst.entity_id)
return bool(data and data.is_valid())
records = SAMLProviderData.objects.filter(entity_id=inst.entity_id)
for record in records:
if record.is_valid():
return True
return False
has_data.short_description = 'Metadata Ready'
has_data.boolean = True

View File

@@ -784,16 +784,23 @@ class SAMLProviderConfig(ProviderConfig):
conf['attr_defaults'][field] = default
# Now get the data fetched automatically from the metadata.xml:
data = SAMLProviderData.current(self.entity_id)
if not data or not data.is_valid():
data_records = SAMLProviderData.objects.filter(entity_id=self.entity_id)
public_keys = []
for record in data_records:
if record.is_valid():
public_keys.append(record.public_key)
sso_url = record.sso_url
if not public_keys:
log.error(
'No SAMLProviderData found for provider "%s" with entity id "%s" and IdP slug "%s". '
'Run "manage.py saml pull" to fix or debug.',
self.name, self.entity_id, self.slug
)
raise AuthNotConfigured(provider_name=self.name)
conf['x509cert'] = data.public_key
conf['url'] = data.sso_url
conf['x509certMulti'] = {'signing': public_keys}
conf['x509cert'] = ''
conf['url'] = sso_url
# Add SAMLConfiguration appropriate for this IdP
conf['saml_sp_configuration'] = (

View File

@@ -203,7 +203,7 @@ class SAMLProviderDataTests(APITestCase):
POST auth/saml/v0/provider_data/sync_provider_data -d data
"""
mock_fetch.return_value = '<?xml><a>tag</a>'
public_key = 'askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj'
public_key = ['askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj']
sso_url = 'https://fake-test.id'
expires_at = datetime.now()
mock_parse.return_value = (public_key, sso_url, expires_at)
@@ -219,11 +219,11 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.post(url, data)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == " Created new record for SAMLProviderData for entityID http://entity-id-1"
assert response.data == " Created new record(s) for SAMLProviderData for entityID http://entity-id-1"
assert SAMLProviderData.objects.count() == orig_count + 1
# should only update this time
response = self.client.post(url, data)
assert response.status_code == status.HTTP_200_OK
assert response.data == (" Updated existing SAMLProviderData for entityID http://entity-id-1")
assert response.data == (" Updated existing SAMLProviderData record(s) for entityID http://entity-id-1")
assert SAMLProviderData.objects.count() == orig_count + 1

View File

@@ -18,7 +18,7 @@ from rest_framework.response import Response
from common.djangoapps.third_party_auth.utils import (
convert_saml_slug_provider_id,
create_or_update_saml_provider_data,
create_or_update_bulk_saml_provider_data,
fetch_metadata_xml,
parse_metadata_xml,
validate_uuid4_string
@@ -110,12 +110,12 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
entity_id = request.POST.get('entity_id')
metadata_url = request.POST.get('metadata_url')
sso_url = request.POST.get('sso_url')
public_key = request.POST.get('public_key')
public_keys = request.POST.get('public_key')
if not entity_id:
return Response('entity_id is required', status.HTTP_400_BAD_REQUEST)
if not metadata_url and not (sso_url and public_key):
if not metadata_url and not (sso_url and public_keys):
return Response('either metadata_url or sso and public key are required', status.HTTP_400_BAD_REQUEST)
if metadata_url and (sso_url or public_key):
if metadata_url and (sso_url or public_keys):
return Response(
'either metadata_url or sso and public key can be provided, not both', status.HTTP_400_BAD_REQUEST
)
@@ -131,18 +131,18 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
# part 2: create/update samlproviderdata
log.info("Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
else:
now = datetime.now()
expires_at = now.replace(year=now.year + 10)
changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at)
changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at)
if changed:
str_message = f" Created new record for SAMLProviderData for entityID {entity_id}"
str_message = f" Created new record(s) for SAMLProviderData for entityID {entity_id}"
log.info(str_message)
response = str_message
http_status = status.HTTP_201_CREATED
else:
str_message = f" Updated existing SAMLProviderData for entityID {entity_id}"
str_message = f" Updated existing SAMLProviderData record(s) for entityID {entity_id}"
log.info(str_message)
response = str_message
http_status = status.HTTP_200_OK

View File

@@ -14,7 +14,7 @@ from requests import exceptions
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
from common.djangoapps.third_party_auth.utils import (
MetadataParseError,
create_or_update_saml_provider_data,
create_or_update_bulk_saml_provider_data,
parse_metadata_xml,
)
@@ -87,8 +87,8 @@ def fetch_saml_metadata():
for entity_id in entity_ids:
log.info("Processing IdP with entityID %s", entity_id)
public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at)
public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at)
if changed:
log.info(f"→ Created new record for SAMLProviderData for entityID {entity_id}")
num_updated += 1

View File

@@ -157,8 +157,44 @@ class TestUtils(TestCase):
</md:EntityDescriptor>
'''
xml = etree.fromstring(xml_text, parser)
public_key, sso_url, _ = parse_metadata_xml(xml, entity_id)
assert public_key == 'abc+hkIuUktxkg='
public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
assert public_keys == ['abc+hkIuUktxkg=']
assert sso_url == 'https://idp/SSOService.php'
def test_parse_metadata_uses_multiple_signing_cert(self):
entity_id = 'http://testid'
parser = etree.XMLParser(remove_comments=True)
xml_text = '''<?xml version="1.0"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="http://testid">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>abc+hkIuUktxkg=</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>xyz+ayylmao=</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:KeyDescriptor use="encryption">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>blachabc+hkIuUktxkg=blaal;skdjf;ksd</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp/SSOService.php"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
'''
xml = etree.fromstring(xml_text, parser)
public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
assert public_keys == ['abc+hkIuUktxkg=', 'xyz+ayylmao=']
assert sso_url == 'https://idp/SSOService.php'
def test_parse_metadata_with_use_attribute_missing(self):
@@ -179,6 +215,6 @@ class TestUtils(TestCase):
</md:EntityDescriptor>
'''
xml = etree.fromstring(xml_text, parser)
public_key, sso_url, _ = parse_metadata_xml(xml, entity_id)
assert public_key == 'abc+hkIuUktxkg='
public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
assert public_keys == ['abc+hkIuUktxkg=']
assert sso_url == 'https://idp/SSOService.php'

View File

@@ -101,18 +101,24 @@ def parse_metadata_xml(xml, entity_id):
# Now we just need to get the public_key and sso_url
# We want the use='signing' cert, not the 'encryption' one
public_key = sso_desc.findtext("./{}[@use='signing']//{}".format(
# There may be multiple signing certs returned by the server so create one record per signing cert found.
certs = sso_desc.findall("./{}[@use='signing']//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
if not certs:
# it's possible that there is just one keyDescription with no use attribute
# that is a shortcut for both signing and encryption combined. So we can use that as fallback.
public_key = sso_desc.findtext("./{}//{}".format(
certs = sso_desc.findall("./{}//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
if not certs:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
public_keys = []
for key in certs:
public_keys.append(key.text.replace(" ", ""))
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
@@ -120,7 +126,7 @@ def parse_metadata_xml(xml, entity_id):
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.") # lint-amnesty, pylint: disable=raise-missing-from
return public_key, sso_url, expires_at
return public_keys, sso_url, expires_at
def user_exists(details):
@@ -164,29 +170,22 @@ def get_user_from_email(details):
return None
def create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at):
def create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
Return value:
False if nothing has changed and existing data's "fetched at" timestamp is just updated.
True if a new record was created. (Either this is a new provider or something changed.)
Placeholder
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
return False
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
new_records_created = False
# Create a data record for each of the public keys provided
for key in public_keys:
_, created = SAMLProviderData.objects.update_or_create(
public_key=key, entity_id=entity_id,
defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
)
return True
if created:
new_records_created = True
return new_records_created
def convert_saml_slug_provider_id(provider): # lint-amnesty, pylint: disable=redefined-outer-name