Add SAML metadata refresh control flag
mattdrayer: Change model fieldname, revise code, fix bad tests.
This commit is contained in:
@@ -25,14 +25,19 @@ class Command(BaseCommand):
|
||||
log = logging.getLogger('third_party_auth.tasks')
|
||||
log.propagate = False
|
||||
log.addHandler(log_handler)
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
total, skipped, attempted, updated, failed, failure_messages = fetch_saml_metadata()
|
||||
self.stdout.write(
|
||||
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
|
||||
num_changed=num_changed, num_failed=num_failed, num_total=num_total
|
||||
"\nDone."
|
||||
"\n{total} provider(s) found in database."
|
||||
"\n{skipped} skipped and {attempted} attempted."
|
||||
"\n{updated} updated and {failed} failed.\n".format(
|
||||
total=total,
|
||||
skipped=skipped, attempted=attempted,
|
||||
updated=updated, failed=failed,
|
||||
)
|
||||
)
|
||||
|
||||
if num_failed > 0:
|
||||
if failed > 0:
|
||||
raise CommandError(
|
||||
"Command finished with the following exceptions:\n\n{failures}".format(
|
||||
failures="\n\n".join(failure_messages)
|
||||
|
||||
@@ -92,10 +92,9 @@ class TestSAMLCommand(TestCase):
|
||||
Test that management command completes without errors and logs correct information when no
|
||||
saml configurations are enabled/present.
|
||||
"""
|
||||
# Capture command output log for testing.
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n"
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 0 total. 0 were updated and 0 failed.', self.stdout.getvalue())
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get())
|
||||
def test_fetch_saml_metadata(self):
|
||||
@@ -106,10 +105,9 @@ class TestSAMLCommand(TestCase):
|
||||
# Create enabled configurations
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
# Capture command output log for testing.
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n"
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 1 were updated and 0 failed.', self.stdout.getvalue())
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=404))
|
||||
def test_fetch_saml_metadata_failure(self):
|
||||
@@ -120,11 +118,11 @@ class TestSAMLCommand(TestCase):
|
||||
# Create enabled configurations
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_fetch_multiple_providers_data(self):
|
||||
@@ -162,11 +160,31 @@ class TestSAMLCommand(TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n'
|
||||
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue())
|
||||
# Now add a fourth configuration, and indicate that it should not be included in the update
|
||||
self.__create_saml_configurations__(
|
||||
saml_config={
|
||||
"site__domain": "fourth.testserver.fake",
|
||||
},
|
||||
saml_provider_config={
|
||||
"site__domain": "fourth.testserver.fake",
|
||||
"idp_slug": "fourth-test-shib",
|
||||
"automatic_refresh_enabled": False,
|
||||
# Note: This invalid entity id will not be present in the refresh set
|
||||
"entity_id": "https://idp.testshib.org/idp/fourth-shibboleth",
|
||||
"metadata_source": "https://www.testshib.org/metadata/fourth/testshib-providers.xml",
|
||||
}
|
||||
)
|
||||
|
||||
# Four configurations -- one will be skipped and three attempted, with similar results.
|
||||
expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n'
|
||||
with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_saml_request_exceptions(self, mocked_get):
|
||||
@@ -178,27 +196,23 @@ class TestSAMLCommand(TestCase):
|
||||
|
||||
mocked_get.side_effect = exceptions.SSLError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "SSLError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
with self.assertRaisesRegexp(CommandError, "SSLError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
mocked_get.side_effect = exceptions.ConnectionError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "ConnectionError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
mocked_get.side_effect = exceptions.HTTPError
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "HTTPError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get", mock_get(status_code=200))
|
||||
def test_saml_parse_exceptions(self):
|
||||
@@ -219,11 +233,11 @@ class TestSAMLCommand(TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
expected = "\nDone.\n2 provider(s) found in database.\n1 skipped and 1 attempted.\n0 updated and 1 failed.\n"
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_xml_parse_exceptions(self, mocked_get):
|
||||
@@ -239,8 +253,8 @@ class TestSAMLCommand(TestCase):
|
||||
# create enabled configuration
|
||||
self.__create_saml_configurations__()
|
||||
|
||||
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
|
||||
# Capture command output log for testing.
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n"
|
||||
|
||||
self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue())
|
||||
with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
self.assertIn(expected, self.stdout.getvalue())
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('third_party_auth', '0005_add_site_field'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='samlproviderconfig',
|
||||
name='automatic_refresh_enabled',
|
||||
field=models.BooleanField(default=True, help_text=b"When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.", verbose_name=b'Enable automatic metadata refresh'),
|
||||
),
|
||||
]
|
||||
@@ -347,6 +347,9 @@ class SAMLProviderConfig(ProviderConfig):
|
||||
attr_email = models.CharField(
|
||||
max_length=128, blank=True, verbose_name="Email Attribute",
|
||||
help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.")
|
||||
automatic_refresh_enabled = models.BooleanField(
|
||||
default=True, verbose_name="Enable automatic metadata refresh",
|
||||
help_text="When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.")
|
||||
debug_mode = models.BooleanField(
|
||||
default=False, verbose_name="Debug Mode",
|
||||
help_text=(
|
||||
|
||||
@@ -33,27 +33,42 @@ def fetch_saml_metadata():
|
||||
It's OK to run this whether or not SAML is enabled.
|
||||
|
||||
Return value:
|
||||
tuple(num_changed, num_failed, num_total, failure_messages)
|
||||
num_changed: Number of providers that are either new or whose metadata has changed
|
||||
tuple(num_skipped, num_attempted, num_updated, num_failed, failure_messages)
|
||||
num_total: Total number of providers found in the database
|
||||
num_skipped: Number of providers skipped for various reasons (see L52)
|
||||
num_attempted: Number of providers whose metadata was fetched
|
||||
num_updated: Number of providers that are either new or whose metadata has changed
|
||||
num_failed: Number of providers that could not be updated
|
||||
num_total: Total number of providers whose metadata was fetched
|
||||
failure_messages: List of error messages for the providers that could not be updated
|
||||
"""
|
||||
num_changed = 0
|
||||
failure_messages = []
|
||||
|
||||
# First make a list of all the metadata XML URLs:
|
||||
saml_providers = SAMLProviderConfig.key_values('idp_slug', flat=True)
|
||||
num_total = len(saml_providers)
|
||||
num_skipped = 0
|
||||
url_map = {}
|
||||
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
|
||||
for idp_slug in saml_providers:
|
||||
config = SAMLProviderConfig.current(idp_slug)
|
||||
if not config.enabled or not SAMLConfiguration.is_enabled(config.site):
|
||||
|
||||
# Skip SAML provider configurations which do not qualify for fetching
|
||||
if any([
|
||||
not config.enabled,
|
||||
not config.automatic_refresh_enabled,
|
||||
not SAMLConfiguration.is_enabled(config.site)
|
||||
]):
|
||||
num_skipped += 1
|
||||
continue
|
||||
|
||||
url = config.metadata_source
|
||||
if url not in url_map:
|
||||
url_map[url] = []
|
||||
if config.entity_id not in url_map[url]:
|
||||
url_map[url].append(config.entity_id)
|
||||
# Now fetch the metadata:
|
||||
|
||||
# Now attempt to fetch the metadata for the remaining SAML providers:
|
||||
num_attempted = len(url_map)
|
||||
num_updated = 0
|
||||
failure_messages = [] # We return the length of this array for num_failed
|
||||
for url, entity_ids in url_map.items():
|
||||
try:
|
||||
log.info("Fetching %s", url)
|
||||
@@ -75,7 +90,7 @@ def fetch_saml_metadata():
|
||||
changed = _update_data(entity_id, public_key, sso_url, expires_at)
|
||||
if changed:
|
||||
log.info(u"→ Created new record for SAMLProviderData")
|
||||
num_changed += 1
|
||||
num_updated += 1
|
||||
else:
|
||||
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
|
||||
except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error:
|
||||
@@ -109,7 +124,8 @@ def fetch_saml_metadata():
|
||||
)
|
||||
)
|
||||
|
||||
return num_changed, len(failure_messages), len(url_map), failure_messages
|
||||
# Return counts for total, skipped, attempted, updated, and failed, along with any failure messages
|
||||
return num_total, num_skipped, num_attempted, num_updated, len(failure_messages), failure_messages
|
||||
|
||||
|
||||
def _parse_metadata_xml(xml, entity_id):
|
||||
|
||||
@@ -149,11 +149,13 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
|
||||
kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName
|
||||
self.configure_saml_provider(**kwargs)
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
|
||||
self.assertEqual(num_total, 1)
|
||||
self.assertEqual(num_skipped, 0)
|
||||
self.assertEqual(num_attempted, 1)
|
||||
self.assertEqual(num_updated, 1)
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(len(failure_messages), 0)
|
||||
self.assertEqual(num_changed, 1)
|
||||
self.assertEqual(num_total, 1)
|
||||
|
||||
def _freeze_time(self, timestamp):
|
||||
""" Mock the current time for SAML, so we can replay canned requests/responses """
|
||||
@@ -177,12 +179,14 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase):
|
||||
|
||||
if fetch_metadata:
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata()
|
||||
num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata()
|
||||
if assert_metadata_updates:
|
||||
self.assertEqual(num_total, 1)
|
||||
self.assertEqual(num_skipped, 0)
|
||||
self.assertEqual(num_attempted, 1)
|
||||
self.assertEqual(num_updated, 1)
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(len(failure_messages), 0)
|
||||
self.assertEqual(num_changed, 1)
|
||||
self.assertEqual(num_total, 1)
|
||||
|
||||
def do_provider_login(self, provider_redirect_url):
|
||||
""" Mocked: the user logs in to TestShib and then gets redirected back """
|
||||
|
||||
Reference in New Issue
Block a user