diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index 1550c54eba..12224c7bbf 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -363,11 +363,14 @@ class SAMLConfiguration(ConfigurationModel): return self.public_key if name == "SP_PRIVATE_KEY": return self.private_key - if name == "TECHNICAL_CONTACT": - return {"givenName": "Technical Support", "emailAddress": settings.TECH_SUPPORT_EMAIL} - if name == "SUPPORT_CONTACT": - return {"givenName": "SAML Support", "emailAddress": settings.TECH_SUPPORT_EMAIL} other_config = json.loads(self.other_config_str) + if name in ("TECHNICAL_CONTACT", "SUPPORT_CONTACT"): + contact = { + "givenName": "{} Support".format(settings.PLATFORM_NAME), + "emailAddress": settings.TECH_SUPPORT_EMAIL + } + contact.update(other_config.get(name, {})) + return contact return other_config[name] # SECURITY_CONFIG, SP_EXTRA, or similar extra settings diff --git a/common/djangoapps/third_party_auth/tests/test_views.py b/common/djangoapps/third_party_auth/tests/test_views.py index 8e88629801..583efddb99 100644 --- a/common/djangoapps/third_party_auth/tests/test_views.py +++ b/common/djangoapps/third_party_auth/tests/test_views.py @@ -28,22 +28,42 @@ class SAMLMetadataTest(SAMLTestCase): @ddt.data('saml_key', 'saml_key_alt') # Test two slightly different key pair export formats def test_metadata(self, key_name): - self.enable_saml( - private_key=self._get_private_key(key_name), - public_key=self._get_public_key(key_name), - entity_id="https://saml.example.none", - ) + self.enable_saml() doc = self._fetch_metadata() # Check the ACS URL: acs_node = doc.find(".//{}".format(etree.QName(SAML_XML_NS, 'AssertionConsumerService'))) self.assertIsNotNone(acs_node) self.assertEqual(acs_node.attrib['Location'], 'http://example.none/auth/complete/tpa-saml/') + def test_default_contact_info(self): + self.enable_saml() + self.check_metadata_contacts( + xml=self._fetch_metadata(), + tech_name="edX Support", + tech_email="technical@example.com", + support_name="edX Support", + support_email="technical@example.com" + ) + + def test_custom_contact_info(self): + self.enable_saml( + other_config_str=( + '{' + '"TECHNICAL_CONTACT": {"givenName": "Jane Tech", "emailAddress": "jane@example.com"},' + '"SUPPORT_CONTACT": {"givenName": "Joe Support", "emailAddress": "joe@example.com"}' + '}' + ) + ) + self.check_metadata_contacts( + xml=self._fetch_metadata(), + tech_name="Jane Tech", + tech_email="jane@example.com", + support_name="Joe Support", + support_email="joe@example.com" + ) + def test_signed_metadata(self): self.enable_saml( - private_key=self._get_private_key(), - public_key=self._get_public_key(), - entity_id="https://saml.example.none", other_config_str='{"SECURITY_CONFIG": {"signMetadata": true} }', ) doc = self._fetch_metadata() @@ -62,3 +82,19 @@ class SAMLMetadataTest(SAMLTestCase): self.fail('SAML metadata must be valid XML') self.assertEqual(metadata_doc.tag, etree.QName(SAML_XML_NS, 'EntityDescriptor')) return metadata_doc + + def check_metadata_contacts(self, xml, tech_name, tech_email, support_name, support_email): + """ Validate that the contact info in the metadata has the expected values """ + technical_node = xml.find(".//{}[@contactType='technical']".format(etree.QName(SAML_XML_NS, 'ContactPerson'))) + self.assertIsNotNone(technical_node) + tech_name_node = technical_node.find(etree.QName(SAML_XML_NS, 'GivenName')) + self.assertEqual(tech_name_node.text, tech_name) + tech_email_node = technical_node.find(etree.QName(SAML_XML_NS, 'EmailAddress')) + self.assertEqual(tech_email_node.text, tech_email) + + support_node = xml.find(".//{}[@contactType='support']".format(etree.QName(SAML_XML_NS, 'ContactPerson'))) + self.assertIsNotNone(support_node) + support_name_node = support_node.find(etree.QName(SAML_XML_NS, 'GivenName')) + self.assertEqual(support_name_node.text, support_name) + support_email_node = support_node.find(etree.QName(SAML_XML_NS, 'EmailAddress')) + self.assertEqual(support_email_node.text, support_email) diff --git a/common/djangoapps/third_party_auth/tests/testutil.py b/common/djangoapps/third_party_auth/tests/testutil.py index 5d1a1f38c2..323e57142b 100644 --- a/common/djangoapps/third_party_auth/tests/testutil.py +++ b/common/djangoapps/third_party_auth/tests/testutil.py @@ -114,6 +114,15 @@ class SAMLTestCase(TestCase): with open(os.path.join(os.path.dirname(__file__), 'data', filename)) as f: return f.read() + def enable_saml(self, **kwargs): + """ Enable SAML support (via SAMLConfiguration, not for any particular provider) """ + if 'private_key' not in kwargs: + kwargs['private_key'] = self._get_private_key() + if 'public_key' not in kwargs: + kwargs['public_key'] = self._get_public_key() + kwargs.setdefault('entity_id', "https://saml.example.none") + super(SAMLTestCase, self).enable_saml(**kwargs) + @contextmanager def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=None, username=None):