diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index 238599e2cb..56cb138889 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -25,9 +25,16 @@ class Command(BaseCommand): log = logging.getLogger('third_party_auth.tasks') log.propagate = False log.addHandler(log_handler) - num_changed, num_failed, num_total = fetch_saml_metadata() + num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() self.stdout.write( "\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format( num_changed=num_changed, num_failed=num_failed, num_total=num_total ) ) + + if num_failed > 0: + raise CommandError( + "Command finished with the following exceptions:\n\n{failures}".format( + failures="\n\n".join(failure_messages) + ) + ) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index 59f28f5817..c3002d4fde 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -12,6 +12,7 @@ from django.core.management.base import CommandError from django.conf import settings from django.utils.six import StringIO +from requests import exceptions from requests.models import Response from third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory @@ -119,10 +120,11 @@ class TestSAMLCommand(TestCase): # Create enabled configurations self.__create_saml_configurations__() - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) @mock.patch("requests.get", mock_get(status_code=200)) def test_fetch_multiple_providers_data(self): @@ -160,7 +162,85 @@ class TestSAMLCommand(TestCase): } ) - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) - self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue()) + self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue()) + + @mock.patch("requests.get") + def test_saml_request_exceptions(self, mocked_get): + """ + Test that management command errors out in case of fatal exceptions instead of failing silently. + """ + # Create enabled configurations + self.__create_saml_configurations__() + + mocked_get.side_effect = exceptions.SSLError + + with self.assertRaisesRegexp(CommandError, "SSLError:"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) + + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + + mocked_get.side_effect = exceptions.ConnectionError + + with self.assertRaisesRegexp(CommandError, "ConnectionError:"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) + + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + + mocked_get.side_effect = exceptions.HTTPError + + with self.assertRaisesRegexp(CommandError, "HTTPError:"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) + + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + + @mock.patch("requests.get", mock_get(status_code=200)) + def test_saml_parse_exceptions(self): + """ + Test that management command errors out in case of fatal exceptions instead of failing silently. + """ + # Create enabled configurations, this configuration will raise MetadataParseError. + self.__create_saml_configurations__( + saml_config={ + "site__domain": "third.testserver.fake", + }, + saml_provider_config={ + "site__domain": "third.testserver.fake", + "idp_slug": "third-test-shib", + # Note: This entity id will not be present in returned response and will cause failed update. + "entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth", + "metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml", + } + ) + + with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) + + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + + @mock.patch("requests.get") + def test_xml_parse_exceptions(self, mocked_get): + """ + Test that management command errors out in case of fatal exceptions instead of failing silently. + """ + response = Response() + response._content = "" # pylint: disable=protected-access + response.status_code = 200 + + mocked_get.return_value = response + + # create enabled configuration + self.__create_saml_configurations__() + + with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"): + # Capture command output log for testing. + call_command("saml", pull=True, stdout=self.stdout) + + self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py index 3363d2614d..2678bb78c5 100644 --- a/common/djangoapps/third_party_auth/tasks.py +++ b/common/djangoapps/third_party_auth/tasks.py @@ -10,6 +10,7 @@ import pytz import logging from lxml import etree import requests +from requests import exceptions from onelogin.saml2.utils import OneLogin_Saml2_Utils from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData @@ -32,12 +33,14 @@ def fetch_saml_metadata(): It's OK to run this whether or not SAML is enabled. Return value: - tuple(num_changed, num_failed, num_total) + tuple(num_changed, num_failed, num_total, failure_messages) num_changed: Number of providers that are either new or whose metadata has changed num_failed: Number of providers that could not be updated num_total: Total number of providers whose metadata was fetched + failure_messages: List of error messages for the providers that could not be updated """ - num_changed, num_failed = 0, 0 + num_changed = 0 + failure_messages = [] # First make a list of all the metadata XML URLs: url_map = {} @@ -75,10 +78,38 @@ def fetch_saml_metadata(): num_changed += 1 else: log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.") - except Exception as err: # pylint: disable=broad-except - log.exception(err.message) - num_failed += 1 - return (num_changed, num_failed, len(url_map)) + except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error: + # Catch and process exception in case of errors during fetching and processing saml metadata. + # Here is a description of each exception. + # SSLError is raised in case of errors caused by SSL (e.g. SSL cer verification failure etc.) + # HTTPError is raised in case of unexpected status code (e.g. 500 error etc.) + # RequestException is the base exception for any request related error that "requests" lib raises. + # MetadataParseError is raised if there is error in the fetched meta data (e.g. missing @entityID etc.) + + log.exception(error.message) + failure_messages.append( + "{error_type}: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format( + error_type=type(error).__name__, + error_message=error.message, + url=url, + entity_ids="\n".join( + ["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)], + ) + ) + ) + except etree.XMLSyntaxError as error: + log.exception(error.message) + failure_messages.append( + "XMLSyntaxError: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format( + error_message=str(error.error_log), + url=url, + entity_ids="\n".join( + ["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)], + ) + ) + ) + + return num_changed, len(failure_messages), len(url_map), failure_messages def _parse_metadata_xml(xml, entity_id): diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index e8e874dea7..13a5e2e32c 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -149,8 +149,9 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName self.configure_saml_provider(**kwargs) self.assertTrue(httpretty.is_enabled()) - num_changed, num_failed, num_total = fetch_saml_metadata() + num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() self.assertEqual(num_failed, 0) + self.assertEqual(len(failure_messages), 0) self.assertEqual(num_changed, 1) self.assertEqual(num_total, 1) @@ -176,9 +177,10 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): if fetch_metadata: self.assertTrue(httpretty.is_enabled()) - num_changed, num_failed, num_total = fetch_saml_metadata() + num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() if assert_metadata_updates: self.assertEqual(num_failed, 0) + self.assertEqual(len(failure_messages), 0) self.assertEqual(num_changed, 1) self.assertEqual(num_total, 1)