diff --git a/common/djangoapps/third_party_auth/admin.py b/common/djangoapps/third_party_auth/admin.py index 35f43ad129..6b0ca785fd 100644 --- a/common/djangoapps/third_party_auth/admin.py +++ b/common/djangoapps/third_party_auth/admin.py @@ -108,8 +108,12 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): """ Do we have cached metadata for this SAML provider? """ if not inst.is_active: return None # N/A - data = SAMLProviderData.current(inst.entity_id) - return bool(data and data.is_valid()) + records = SAMLProviderData.objects.filter(entity_id=inst.entity_id) + for record in records: + if record.is_valid(): + return True + return False + has_data.short_description = 'Metadata Ready' has_data.boolean = True diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index f0a37616a3..4d1c1f84a9 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -784,16 +784,23 @@ class SAMLProviderConfig(ProviderConfig): conf['attr_defaults'][field] = default # Now get the data fetched automatically from the metadata.xml: - data = SAMLProviderData.current(self.entity_id) - if not data or not data.is_valid(): + data_records = SAMLProviderData.objects.filter(entity_id=self.entity_id) + public_keys = [] + for record in data_records: + if record.is_valid(): + public_keys.append(record.public_key) + sso_url = record.sso_url + if not public_keys: log.error( 'No SAMLProviderData found for provider "%s" with entity id "%s" and IdP slug "%s". ' 'Run "manage.py saml pull" to fix or debug.', self.name, self.entity_id, self.slug ) raise AuthNotConfigured(provider_name=self.name) - conf['x509cert'] = data.public_key - conf['url'] = data.sso_url + + conf['x509certMulti'] = {'signing': public_keys} + conf['x509cert'] = '' + conf['url'] = sso_url # Add SAMLConfiguration appropriate for this IdP conf['saml_sp_configuration'] = ( diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py index 8854d320ae..5dbf7a803e 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py @@ -203,7 +203,7 @@ class SAMLProviderDataTests(APITestCase): POST auth/saml/v0/provider_data/sync_provider_data -d data """ mock_fetch.return_value = 'tag' - public_key = 'askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj' + public_key = ['askdjf;sakdjfs;adkfjas;dkfjas;dkfjas;dlkfj'] sso_url = 'https://fake-test.id' expires_at = datetime.now() mock_parse.return_value = (public_key, sso_url, expires_at) @@ -219,11 +219,11 @@ class SAMLProviderDataTests(APITestCase): response = self.client.post(url, data) assert response.status_code == status.HTTP_201_CREATED - assert response.data == " Created new record for SAMLProviderData for entityID http://entity-id-1" + assert response.data == " Created new record(s) for SAMLProviderData for entityID http://entity-id-1" assert SAMLProviderData.objects.count() == orig_count + 1 # should only update this time response = self.client.post(url, data) assert response.status_code == status.HTTP_200_OK - assert response.data == (" Updated existing SAMLProviderData for entityID http://entity-id-1") + assert response.data == (" Updated existing SAMLProviderData record(s) for entityID http://entity-id-1") assert SAMLProviderData.objects.count() == orig_count + 1 diff --git a/common/djangoapps/third_party_auth/samlproviderdata/views.py b/common/djangoapps/third_party_auth/samlproviderdata/views.py index 09ad65dfa8..f61b237c12 100644 --- a/common/djangoapps/third_party_auth/samlproviderdata/views.py +++ b/common/djangoapps/third_party_auth/samlproviderdata/views.py @@ -18,7 +18,7 @@ from rest_framework.response import Response from common.djangoapps.third_party_auth.utils import ( convert_saml_slug_provider_id, - create_or_update_saml_provider_data, + create_or_update_bulk_saml_provider_data, fetch_metadata_xml, parse_metadata_xml, validate_uuid4_string @@ -110,12 +110,12 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi entity_id = request.POST.get('entity_id') metadata_url = request.POST.get('metadata_url') sso_url = request.POST.get('sso_url') - public_key = request.POST.get('public_key') + public_keys = request.POST.get('public_key') if not entity_id: return Response('entity_id is required', status.HTTP_400_BAD_REQUEST) - if not metadata_url and not (sso_url and public_key): + if not metadata_url and not (sso_url and public_keys): return Response('either metadata_url or sso and public key are required', status.HTTP_400_BAD_REQUEST) - if metadata_url and (sso_url or public_key): + if metadata_url and (sso_url or public_keys): return Response( 'either metadata_url or sso and public key can be provided, not both', status.HTTP_400_BAD_REQUEST ) @@ -131,18 +131,18 @@ class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, vi # part 2: create/update samlproviderdata log.info("Processing IdP with entityID %s", entity_id) - public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id) + public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id) else: now = datetime.now() expires_at = now.replace(year=now.year + 10) - changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at) + changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at) if changed: - str_message = f" Created new record for SAMLProviderData for entityID {entity_id}" + str_message = f" Created new record(s) for SAMLProviderData for entityID {entity_id}" log.info(str_message) response = str_message http_status = status.HTTP_201_CREATED else: - str_message = f" Updated existing SAMLProviderData for entityID {entity_id}" + str_message = f" Updated existing SAMLProviderData record(s) for entityID {entity_id}" log.info(str_message) response = str_message http_status = status.HTTP_200_OK diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py index 88b118a689..d702932f4d 100644 --- a/common/djangoapps/third_party_auth/tasks.py +++ b/common/djangoapps/third_party_auth/tasks.py @@ -14,7 +14,7 @@ from requests import exceptions from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig from common.djangoapps.third_party_auth.utils import ( MetadataParseError, - create_or_update_saml_provider_data, + create_or_update_bulk_saml_provider_data, parse_metadata_xml, ) @@ -87,8 +87,8 @@ def fetch_saml_metadata(): for entity_id in entity_ids: log.info("Processing IdP with entityID %s", entity_id) - public_key, sso_url, expires_at = parse_metadata_xml(xml, entity_id) - changed = create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at) + public_keys, sso_url, expires_at = parse_metadata_xml(xml, entity_id) + changed = create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at) if changed: log.info(f"→ Created new record for SAMLProviderData for entityID {entity_id}") num_updated += 1 diff --git a/common/djangoapps/third_party_auth/tests/test_utils.py b/common/djangoapps/third_party_auth/tests/test_utils.py index 2d971e69d0..5d35d92b46 100644 --- a/common/djangoapps/third_party_auth/tests/test_utils.py +++ b/common/djangoapps/third_party_auth/tests/test_utils.py @@ -157,8 +157,44 @@ class TestUtils(TestCase): ''' xml = etree.fromstring(xml_text, parser) - public_key, sso_url, _ = parse_metadata_xml(xml, entity_id) - assert public_key == 'abc+hkIuUktxkg=' + public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id) + assert public_keys == ['abc+hkIuUktxkg='] + assert sso_url == 'https://idp/SSOService.php' + + def test_parse_metadata_uses_multiple_signing_cert(self): + entity_id = 'http://testid' + parser = etree.XMLParser(remove_comments=True) + xml_text = ''' + + + + + + abc+hkIuUktxkg= + + + + + + + xyz+ayylmao= + + + + + + + blachabc+hkIuUktxkg=blaal;skdjf;ksd + + + + + + + ''' + xml = etree.fromstring(xml_text, parser) + public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id) + assert public_keys == ['abc+hkIuUktxkg=', 'xyz+ayylmao='] assert sso_url == 'https://idp/SSOService.php' def test_parse_metadata_with_use_attribute_missing(self): @@ -179,6 +215,6 @@ class TestUtils(TestCase): ''' xml = etree.fromstring(xml_text, parser) - public_key, sso_url, _ = parse_metadata_xml(xml, entity_id) - assert public_key == 'abc+hkIuUktxkg=' + public_keys, sso_url, _ = parse_metadata_xml(xml, entity_id) + assert public_keys == ['abc+hkIuUktxkg='] assert sso_url == 'https://idp/SSOService.php' diff --git a/common/djangoapps/third_party_auth/utils.py b/common/djangoapps/third_party_auth/utils.py index 8517af0328..0cb6981030 100644 --- a/common/djangoapps/third_party_auth/utils.py +++ b/common/djangoapps/third_party_auth/utils.py @@ -101,18 +101,24 @@ def parse_metadata_xml(xml, entity_id): # Now we just need to get the public_key and sso_url # We want the use='signing' cert, not the 'encryption' one - public_key = sso_desc.findtext("./{}[@use='signing']//{}".format( + # There may be multiple signing certs returned by the server so create one record per signing cert found. + certs = sso_desc.findall("./{}[@use='signing']//{}".format( etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate" )) - if not public_key: + + if not certs: # it's possible that there is just one keyDescription with no use attribute # that is a shortcut for both signing and encryption combined. So we can use that as fallback. - public_key = sso_desc.findtext("./{}//{}".format( + certs = sso_desc.findall("./{}//{}".format( etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate" )) - if not public_key: + if not certs: raise MetadataParseError("Public Key missing. Expected an ") - public_key = public_key.replace(" ", "") + + public_keys = [] + for key in certs: + public_keys.append(key.text.replace(" ", "")) + binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService"))) sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements} try: @@ -120,7 +126,7 @@ def parse_metadata_xml(xml, entity_id): sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect'] except KeyError: raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.") # lint-amnesty, pylint: disable=raise-missing-from - return public_key, sso_url, expires_at + return public_keys, sso_url, expires_at def user_exists(details): @@ -164,29 +170,22 @@ def get_user_from_email(details): return None -def create_or_update_saml_provider_data(entity_id, public_key, sso_url, expires_at): +def create_or_update_bulk_saml_provider_data(entity_id, public_keys, sso_url, expires_at): """ - Update/Create the SAMLProviderData for the given entity ID. - Return value: - False if nothing has changed and existing data's "fetched at" timestamp is just updated. - True if a new record was created. (Either this is a new provider or something changed.) + Placeholder """ - data_obj = SAMLProviderData.current(entity_id) fetched_at = now() - if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url): - data_obj.expires_at = expires_at - data_obj.fetched_at = fetched_at - data_obj.save() - return False - else: - SAMLProviderData.objects.create( - entity_id=entity_id, - fetched_at=fetched_at, - expires_at=expires_at, - sso_url=sso_url, - public_key=public_key, + new_records_created = False + # Create a data record for each of the public keys provided + for key in public_keys: + _, created = SAMLProviderData.objects.update_or_create( + public_key=key, entity_id=entity_id, + defaults={'sso_url': sso_url, 'expires_at': expires_at, 'fetched_at': fetched_at}, ) - return True + if created: + new_records_created = True + + return new_records_created def convert_saml_slug_provider_id(provider): # lint-amnesty, pylint: disable=redefined-outer-name