diff --git a/common/djangoapps/third_party_auth/admin.py b/common/djangoapps/third_party_auth/admin.py
index 35f43ad129..6b0ca785fd 100644
--- a/common/djangoapps/third_party_auth/admin.py
+++ b/common/djangoapps/third_party_auth/admin.py
@@ -108,8 +108,12 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
""" Do we have cached metadata for this SAML provider? """
if not inst.is_active:
return None # N/A
- data = SAMLProviderData.current(inst.entity_id)
- return bool(data and data.is_valid())
+ records = SAMLProviderData.objects.filter(entity_id=inst.entity_id)
+ for record in records:
+ if record.is_valid():
+ return True
+ return False
+
has_data.short_description = 'Metadata Ready'
has_data.boolean = True
diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py
index f0a37616a3..4d1c1f84a9 100644
--- a/common/djangoapps/third_party_auth/models.py
+++ b/common/djangoapps/third_party_auth/models.py
@@ -784,16 +784,23 @@ class SAMLProviderConfig(ProviderConfig):
conf['attr_defaults'][field] = default
# Now get the data fetched automatically from the metadata.xml:
- data = SAMLProviderData.current(self.entity_id)
- if not data or not data.is_valid():
+ data_records = SAMLProviderData.objects.filter(entity_id=self.entity_id)
+ public_keys = []
+ for record in data_records:
+ if record.is_valid():
+ public_keys.append(record.public_key)
+ sso_url = record.sso_url
+ if not public_keys:
log.error(
'No SAMLProviderData found for provider "%s" with entity id "%s" and IdP slug "%s". '
'Run "manage.py saml pull" to fix or debug.',
self.name, self.entity_id, self.slug
)
raise AuthNotConfigured(provider_name=self.name)
- conf['x509cert'] = data.public_key
- conf['url'] = data.sso_url
+
+ conf['x509certMulti'] = {'signing': public_keys}
+ conf['x509cert'] = ''
+ conf['url'] = sso_url
# Add SAMLConfiguration appropriate for this IdP
conf['saml_sp_configuration'] = (
diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py
index 8854d320ae..5dbf7a803e 100644
--- a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py
+++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py
@@ -203,7 +203,7 @@ class SAMLProviderDataTests(APITestCase):
POST auth/saml/v0/provider_data/sync_provider_data -d data
"""
mock_fetch.return_value = 'tag'
- public_key = 'askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj'
+ public_key = ['askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj']
sso_url = 'https://fake-test.id'
expires_at = datetime.now()
mock_parse.return_value = (public_key, sso_url, expires_at)
@@ -219,11 +219,11 @@ class SAMLProviderDataTests(APITestCase):
response = self.client.post(url, data)
assert response.status_code == status.HTTP_201_CREATED
- assert response.data == " Created new record for SAMLProviderData for entityID http://entity-id-1"
+ assert response.data == " Created new record(s) for SAMLProviderData for entityID http://entity-id-1"
assert SAMLProviderData.objects.count() == orig_count + 1
# should only update this time
response = self.client.post(url, data)
assert response.status_code == status.HTTP_200_OK
- assert response.data == (" Updated existing SAMLProviderData for entityID http://entity-id-1")
+ assert response.data == (" Updated existing SAMLProviderData record(s) for entityID http://entity-id-1")
assert SAMLProviderData.objects.count() == orig_count + 1
diff --git a/common/djangoapps/third_party_auth/samlproviderdata/views.py b/common/djangoapps/third_party_auth/samlproviderdata/views.py
index 09ad65dfa8..f61b237c12 100644
--- a/common/djangoapps/third_party_auth/samlproviderdata/views.py
+++ b/common/djangoapps/third_party_auth/samlproviderdata/views.py
@@ -18,7 +18,7 @@ from rest_framework.response import Response
from common.djangoapps.third_party_auth.utils import (
convert_saml_slug_provider_id,
- create_or_update_saml_provider_data,
+ create_or_update_bulk_saml_provider_data,
fetch_metadata_xml,
parse_metadata_xml,
validate_uuid4_string
@@ -110,12 +110,12 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
entity_id = request.POST.get('entity_id')
metadata_url = request.POST.get('metadata_url')
sso_url = request.POST.get('sso_url')
- public_key = request.POST.get('public_key')
+ public_keys = request.POST.get('public_key')
if not entity_id:
return Response('entity_id is required', status.HTTP_400_BAD_REQUEST)
- if not metadata_url and not (sso_url and public_key):
+ if not metadata_url and not (sso_url and public_keys):
return Response('either metadata_url or sso and public key are required', status.HTTP_400_BAD_REQUEST)
- if metadata_url and (sso_url or public_key):
+ if metadata_url and (sso_url or public_keys):
return Response(
'either metadata_url or sso and public key can be provided, not both', status.HTTP_400_BAD_REQUEST
)
@@ -131,18 +131,18 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi
# part 2: create/update samlproviderdata
log.info("Processing IdP with entityID %s", entity_id)
- public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
+ public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
else:
now = datetime.now()
expires_at = now.replace(year=now.year + 10)
- changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at)
+ changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at)
if changed:
- str_message = f" Created new record for SAMLProviderData for entityID {entity_id}"
+ str_message = f" Created new record(s) for SAMLProviderData for entityID {entity_id}"
log.info(str_message)
response = str_message
http_status = status.HTTP_201_CREATED
else:
- str_message = f" Updated existing SAMLProviderData for entityID {entity_id}"
+ str_message = f" Updated existing SAMLProviderData record(s) for entityID {entity_id}"
log.info(str_message)
response = str_message
http_status = status.HTTP_200_OK
diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py
index 88b118a689..d702932f4d 100644
--- a/common/djangoapps/third_party_auth/tasks.py
+++ b/common/djangoapps/third_party_auth/tasks.py
@@ -14,7 +14,7 @@ from requests import exceptions
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
from common.djangoapps.third_party_auth.utils import (
MetadataParseError,
- create_or_update_saml_provider_data,
+ create_or_update_bulk_saml_provider_data,
parse_metadata_xml,
)
@@ -87,8 +87,8 @@ def fetch_saml_metadata():
for entity_id in entity_ids:
log.info("Processing IdP with entityID %s", entity_id)
- public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
- changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at)
+ public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id)
+ changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at)
if changed:
log.info(f"→ Created new record for SAMLProviderData for entityID {entity_id}")
num_updated += 1
diff --git a/common/djangoapps/third_party_auth/tests/test_utils.py b/common/djangoapps/third_party_auth/tests/test_utils.py
index 2d971e69d0..5d35d92b46 100644
--- a/common/djangoapps/third_party_auth/tests/test_utils.py
+++ b/common/djangoapps/third_party_auth/tests/test_utils.py
@@ -157,8 +157,44 @@ class TestUtils(TestCase):
'''
xml = etree.fromstring(xml_text, parser)
- public_key, sso_url, _ = parse_metadata_xml(xml, entity_id)
- assert public_key == 'abc+hkIuUktxkg='
+ public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
+ assert public_keys == ['abc+hkIuUktxkg=']
+ assert sso_url == 'https://idp/SSOService.php'
+
+ def test_parse_metadata_uses_multiple_signing_cert(self):
+ entity_id = 'http://testid'
+ parser = etree.XMLParser(remove_comments=True)
+ xml_text = '''
+
+
+
+
+
+ abc+hkIuUktxkg=
+
+
+
+
+
+
+ xyz+ayylmao=
+
+
+
+
+
+
+ blachabc+hkIuUktxkg=blaal;skdjf;ksd
+
+
+
+
+
+
+ '''
+ xml = etree.fromstring(xml_text, parser)
+ public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
+ assert public_keys == ['abc+hkIuUktxkg=', 'xyz+ayylmao=']
assert sso_url == 'https://idp/SSOService.php'
def test_parse_metadata_with_use_attribute_missing(self):
@@ -179,6 +215,6 @@ class TestUtils(TestCase):
'''
xml = etree.fromstring(xml_text, parser)
- public_key, sso_url, _ = parse_metadata_xml(xml, entity_id)
- assert public_key == 'abc+hkIuUktxkg='
+ public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id)
+ assert public_keys == ['abc+hkIuUktxkg=']
assert sso_url == 'https://idp/SSOService.php'
diff --git a/common/djangoapps/third_party_auth/utils.py b/common/djangoapps/third_party_auth/utils.py
index 8517af0328..0cb6981030 100644
--- a/common/djangoapps/third_party_auth/utils.py
+++ b/common/djangoapps/third_party_auth/utils.py
@@ -101,18 +101,24 @@ def parse_metadata_xml(xml, entity_id):
# Now we just need to get the public_key and sso_url
# We want the use='signing' cert, not the 'encryption' one
- public_key = sso_desc.findtext("./{}[@use='signing']//{}".format(
+ # There may be multiple signing certs returned by the server so create one record per signing cert found.
+ certs = sso_desc.findall("./{}[@use='signing']//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
- if not public_key:
+
+ if not certs:
# it's possible that there is just one keyDescription with no use attribute
# that is a shortcut for both signing and encryption combined. So we can use that as fallback.
- public_key = sso_desc.findtext("./{}//{}".format(
+ certs = sso_desc.findall("./{}//{}".format(
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
- if not public_key:
+ if not certs:
raise MetadataParseError("Public Key missing. Expected an ")
- public_key = public_key.replace(" ", "")
+
+ public_keys = []
+ for key in certs:
+ public_keys.append(key.text.replace(" ", ""))
+
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
@@ -120,7 +126,7 @@ def parse_metadata_xml(xml, entity_id):
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.") # lint-amnesty, pylint: disable=raise-missing-from
- return public_key, sso_url, expires_at
+ return public_keys, sso_url, expires_at
def user_exists(details):
@@ -164,29 +170,22 @@ def get_user_from_email(details):
return None
-def create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at):
+def create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at):
"""
- Update/Create the SAMLProviderData for the given entity ID.
- Return value:
- False if nothing has changed and existing data's "fetched at" timestamp is just updated.
- True if a new record was created. (Either this is a new provider or something changed.)
+ Placeholder
"""
- data_obj = SAMLProviderData.current(entity_id)
fetched_at = now()
- if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
- data_obj.expires_at = expires_at
- data_obj.fetched_at = fetched_at
- data_obj.save()
- return False
- else:
- SAMLProviderData.objects.create(
- entity_id=entity_id,
- fetched_at=fetched_at,
- expires_at=expires_at,
- sso_url=sso_url,
- public_key=public_key,
+ new_records_created = False
+ # Create a data record for each of the public keys provided
+ for key in public_keys:
+ _, created = SAMLProviderData.objects.update_or_create(
+ public_key=key, entity_id=entity_id,
+ defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at},
)
- return True
+ if created:
+ new_records_created = True
+
+ return new_records_created
def convert_saml_slug_provider_id(provider): # lint-amnesty, pylint: disable=redefined-outer-name