Update saml --pull command to raise error when it fails.
This commit is contained in:
@@ -25,9 +25,16 @@ class Command(BaseCommand):
|
||||
log = logging.getLogger('third_party_auth.tasks')
|
||||
log.propagate = False
|
||||
log.addHandler(log_handler)
|
||||
num_changed, num_failed, num_total = fetch_saml_metadata()
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
self.stdout.write(
|
||||
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
|
||||
num_changed=num_changed, num_failed=num_failed, num_total=num_total
|
||||
)
|
||||
)
|
||||
|
||||
if num_failed > 0:
|
||||
raise CommandError(
|
||||
"Command finished with the following exceptions:\n\n{failures}".format(
|
||||
failures="\n\n".join(failure_messages)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from django.core.management.base import CommandError
|
||||
from django.conf import settings
|
||||
from django.utils.six import StringIO
|
||||
|
||||
from requests import exceptions
|
||||
from requests.models import Response
|
||||
|
||||
from third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
|
||||
@@ -119,10 +120,11 @@ class TestSAMLCommand(TestCase):
|
||||
# Create enabled configurations
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_fetch_multiple_providers_data(self):
|
||||
@@ -160,7 +162,85 @@ class TestSAMLCommand(TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
|
||||
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_saml_request_exceptions(self, mocked_get):
|
||||
"""
|
||||
Test that management command errors out in case of fatal exceptions instead of failing silently.
|
||||
"""
|
||||
# Create enabled configurations
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
mocked_get.side_effect = exceptions.SSLError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "SSLError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
mocked_get.side_effect = exceptions.ConnectionError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "ConnectionError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
mocked_get.side_effect = exceptions.HTTPError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "HTTPError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_saml_parse_exceptions(self):
|
||||
"""
|
||||
Test that management command errors out in case of fatal exceptions instead of failing silently.
|
||||
"""
|
||||
# Create enabled configurations, this configuration will raise MetadataParseError.
|
||||
self.__create_saml_configurations__(
|
||||
saml_config={
|
||||
"site__domain": "third.testserver.fake",
|
||||
},
|
||||
saml_provider_config={
|
||||
"site__domain": "third.testserver.fake",
|
||||
"idp_slug": "third-test-shib",
|
||||
# Note: This entity id will not be present in returned response and will cause failed update.
|
||||
"entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth",
|
||||
"metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml",
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_xml_parse_exceptions(self, mocked_get):
|
||||
"""
|
||||
Test that management command errors out in case of fatal exceptions instead of failing silently.
|
||||
"""
|
||||
response = Response()
|
||||
response._content = "" # pylint: disable=protected-access
|
||||
response.status_code = 200
|
||||
|
||||
mocked_get.return_value = response
|
||||
|
||||
# create enabled configuration
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytz
|
||||
import logging
|
||||
from lxml import etree
|
||||
import requests
|
||||
from requests import exceptions
|
||||
from onelogin.saml2.utils import OneLogin_Saml2_Utils
|
||||
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
|
||||
|
||||
@@ -32,12 +33,14 @@ def fetch_saml_metadata():
|
||||
It's OK to run this whether or not SAML is enabled.
|
||||
|
||||
Return value:
|
||||
tuple(num_changed, num_failed, num_total)
|
||||
tuple(num_changed, num_failed, num_total, failure_messages)
|
||||
num_changed: Number of providers that are either new or whose metadata has changed
|
||||
num_failed: Number of providers that could not be updated
|
||||
num_total: Total number of providers whose metadata was fetched
|
||||
failure_messages: List of error messages for the providers that could not be updated
|
||||
"""
|
||||
num_changed, num_failed = 0, 0
|
||||
num_changed = 0
|
||||
failure_messages = []
|
||||
|
||||
# First make a list of all the metadata XML URLs:
|
||||
url_map = {}
|
||||
@@ -75,10 +78,38 @@ def fetch_saml_metadata():
|
||||
num_changed += 1
|
||||
else:
|
||||
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
log.exception(err.message)
|
||||
num_failed += 1
|
||||
return (num_changed, num_failed, len(url_map))
|
||||
except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error:
|
||||
# Catch and process exception in case of errors during fetching and processing saml metadata.
|
||||
# Here is a description of each exception.
|
||||
# SSLError is raised in case of errors caused by SSL (e.g. SSL cer verification failure etc.)
|
||||
# HTTPError is raised in case of unexpected status code (e.g. 500 error etc.)
|
||||
# RequestException is the base exception for any request related error that "requests" lib raises.
|
||||
# MetadataParseError is raised if there is error in the fetched meta data (e.g. missing @entityID etc.)
|
||||
|
||||
log.exception(error.message)
|
||||
failure_messages.append(
|
||||
"{error_type}: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
|
||||
error_type=type(error).__name__,
|
||||
error_message=error.message,
|
||||
url=url,
|
||||
entity_ids="\n".join(
|
||||
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
|
||||
)
|
||||
)
|
||||
)
|
||||
except etree.XMLSyntaxError as error:
|
||||
log.exception(error.message)
|
||||
failure_messages.append(
|
||||
"XMLSyntaxError: {error_message}\nMetadata Source: {url}\nEntity IDs: \n{entity_ids}.".format(
|
||||
error_message=str(error.error_log),
|
||||
url=url,
|
||||
entity_ids="\n".join(
|
||||
["\t{}: {}".format(count, item) for count, item in enumerate(entity_ids, start=1)],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return num_changed, len(failure_messages), len(url_map), failure_messages
|
||||
|
||||
|
||||
def _parse_metadata_xml(xml, entity_id):
|
||||
|
||||
@@ -149,8 +149,9 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
|
||||
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
|
||||
self.configure_saml_provider(**kwargs)
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
num_changed, num_failed, num_total = fetch_saml_metadata()
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(len(failure_messages), 0)
|
||||
self.assertEqual(num_changed, 1)
|
||||
self.assertEqual(num_total, 1)
|
||||
|
||||
@@ -176,9 +177,10 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
|
||||
|
||||
if fetch_metadata:
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
num_changed, num_failed, num_total = fetch_saml_metadata()
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
if assert_metadata_updates:
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(len(failure_messages), 0)
|
||||
self.assertEqual(num_changed, 1)
|
||||
self.assertEqual(num_total, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user