diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index 56cb138889..b126a2b649 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -25,14 +25,19 @@ class Command(BaseCommand): log = logging.getLogger('third_party_auth.tasks') log.propagate = False log.addHandler(log_handler) - num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() + total, skipped, attempted, updated, failed, failure_messages = fetch_saml_metadata() self.stdout.write( - "\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format( - num_changed=num_changed, num_failed=num_failed, num_total=num_total + "\nDone." + "\n{total} provider(s) found in database." + "\n{skipped} skipped and {attempted} attempted." + "\n{updated} updated and {failed} failed.\n".format( + total=total, + skipped=skipped, attempted=attempted, + updated=updated, failed=failed, ) ) - if num_failed > 0: + if failed > 0: raise CommandError( "Command finished with the following exceptions:\n\n{failures}".format( failures="\n\n".join(failure_messages) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index c3002d4fde..eec4af3fa7 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -92,10 +92,9 @@ class TestSAMLCommand(TestCase): Test that management command completes without errors and logs correct information when no saml configurations are enabled/present. """ - # Capture command output log for testing. + expected = "\nDone.\n1 provider(s) found in database.\n1 skipped and 0 attempted.\n0 updated and 0 failed.\n" call_command("saml", pull=True, stdout=self.stdout) - - self.assertIn('Done. Fetched 0 total. 0 were updated and 0 failed.', self.stdout.getvalue()) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get", mock_get()) def test_fetch_saml_metadata(self): @@ -106,10 +105,9 @@ class TestSAMLCommand(TestCase): # Create enabled configurations self.__create_saml_configurations__() - # Capture command output log for testing. + expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n1 updated and 0 failed.\n" call_command("saml", pull=True, stdout=self.stdout) - - self.assertIn('Done. Fetched 1 total. 1 were updated and 0 failed.', self.stdout.getvalue()) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get", mock_get(status_code=404)) def test_fetch_saml_metadata_failure(self): @@ -120,11 +118,11 @@ class TestSAMLCommand(TestCase): # Create enabled configurations self.__create_saml_configurations__() - with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"): - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n" - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + with self.assertRaisesRegexp(CommandError, r"HTTPError: 404 Client Error"): + call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get", mock_get(status_code=200)) def test_fetch_multiple_providers_data(self): @@ -162,11 +160,31 @@ class TestSAMLCommand(TestCase): } ) + expected = '\n3 provider(s) found in database.\n0 skipped and 3 attempted.\n2 updated and 1 failed.\n' with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"): - # Capture command output log for testing. call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) - self.assertIn('Done. Fetched 3 total. 2 were updated and 1 failed.', self.stdout.getvalue()) + # Now add a fourth configuration, and indicate that it should not be included in the update + self.__create_saml_configurations__( + saml_config={ + "site__domain": "fourth.testserver.fake", + }, + saml_provider_config={ + "site__domain": "fourth.testserver.fake", + "idp_slug": "fourth-test-shib", + "automatic_refresh_enabled": False, + # Note: This invalid entity id will not be present in the refresh set + "entity_id": "https://idp.testshib.org/idp/fourth-shibboleth", + "metadata_source": "https://www.testshib.org/metadata/fourth/testshib-providers.xml", + } + ) + + # Four configurations -- one will be skipped and three attempted, with similar results. + expected = '\nDone.\n4 provider(s) found in database.\n1 skipped and 3 attempted.\n0 updated and 1 failed.\n' + with self.assertRaisesRegexp(CommandError, r"MetadataParseError: Can't find EntityDescriptor for entityID"): + call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get") def test_saml_request_exceptions(self, mocked_get): @@ -178,27 +196,23 @@ class TestSAMLCommand(TestCase): mocked_get.side_effect = exceptions.SSLError - with self.assertRaisesRegexp(CommandError, "SSLError:"): - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n" - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + with self.assertRaisesRegexp(CommandError, "SSLError:"): + call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) mocked_get.side_effect = exceptions.ConnectionError with self.assertRaisesRegexp(CommandError, "ConnectionError:"): - # Capture command output log for testing. call_command("saml", pull=True, stdout=self.stdout) - - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + self.assertIn(expected, self.stdout.getvalue()) mocked_get.side_effect = exceptions.HTTPError with self.assertRaisesRegexp(CommandError, "HTTPError:"): - # Capture command output log for testing. call_command("saml", pull=True, stdout=self.stdout) - - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get", mock_get(status_code=200)) def test_saml_parse_exceptions(self): @@ -219,11 +233,11 @@ class TestSAMLCommand(TestCase): } ) - with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"): - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + expected = "\nDone.\n2 provider(s) found in database.\n1 skipped and 1 attempted.\n0 updated and 1 failed.\n" - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + with self.assertRaisesRegexp(CommandError, "MetadataParseError: Can't find EntityDescriptor for entityID"): + call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) @mock.patch("requests.get") def test_xml_parse_exceptions(self, mocked_get): @@ -239,8 +253,8 @@ class TestSAMLCommand(TestCase): # create enabled configuration self.__create_saml_configurations__() - with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"): - # Capture command output log for testing. - call_command("saml", pull=True, stdout=self.stdout) + expected = "\nDone.\n1 provider(s) found in database.\n0 skipped and 1 attempted.\n0 updated and 1 failed.\n" - self.assertIn('Done. Fetched 1 total. 0 were updated and 1 failed.', self.stdout.getvalue()) + with self.assertRaisesRegexp(CommandError, "XMLSyntaxError:"): + call_command("saml", pull=True, stdout=self.stdout) + self.assertIn(expected, self.stdout.getvalue()) diff --git a/common/djangoapps/third_party_auth/migrations/0006_samlproviderconfig_automatic_refresh_enabled.py b/common/djangoapps/third_party_auth/migrations/0006_samlproviderconfig_automatic_refresh_enabled.py new file mode 100644 index 0000000000..1a44b112c1 --- /dev/null +++ b/common/djangoapps/third_party_auth/migrations/0006_samlproviderconfig_automatic_refresh_enabled.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('third_party_auth', '0005_add_site_field'), + ] + + operations = [ + migrations.AddField( + model_name='samlproviderconfig', + name='automatic_refresh_enabled', + field=models.BooleanField(default=True, help_text=b"When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.", verbose_name=b'Enable automatic metadata refresh'), + ), + ] diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index a4f84a0e32..d72533def0 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -347,6 +347,9 @@ class SAMLProviderConfig(ProviderConfig): attr_email = models.CharField( max_length=128, blank=True, verbose_name="Email Attribute", help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.") + automatic_refresh_enabled = models.BooleanField( + default=True, verbose_name="Enable automatic metadata refresh", + help_text="When checked, the SAML provider's metadata will be included in the automatic refresh job, if configured.") debug_mode = models.BooleanField( default=False, verbose_name="Debug Mode", help_text=( diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py index 2678bb78c5..d1a138769b 100644 --- a/common/djangoapps/third_party_auth/tasks.py +++ b/common/djangoapps/third_party_auth/tasks.py @@ -33,27 +33,42 @@ def fetch_saml_metadata(): It's OK to run this whether or not SAML is enabled. Return value: - tuple(num_changed, num_failed, num_total, failure_messages) - num_changed: Number of providers that are either new or whose metadata has changed + tuple(num_skipped, num_attempted, num_updated, num_failed, failure_messages) + num_total: Total number of providers found in the database + num_skipped: Number of providers skipped for various reasons (see L52) + num_attempted: Number of providers whose metadata was fetched + num_updated: Number of providers that are either new or whose metadata has changed num_failed: Number of providers that could not be updated - num_total: Total number of providers whose metadata was fetched failure_messages: List of error messages for the providers that could not be updated """ - num_changed = 0 - failure_messages = [] # First make a list of all the metadata XML URLs: + saml_providers = SAMLProviderConfig.key_values('idp_slug', flat=True) + num_total = len(saml_providers) + num_skipped = 0 url_map = {} - for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True): + for idp_slug in saml_providers: config = SAMLProviderConfig.current(idp_slug) - if not config.enabled or not SAMLConfiguration.is_enabled(config.site): + + # Skip SAML provider configurations which do not qualify for fetching + if any([ + not config.enabled, + not config.automatic_refresh_enabled, + not SAMLConfiguration.is_enabled(config.site) + ]): + num_skipped += 1 continue + url = config.metadata_source if url not in url_map: url_map[url] = [] if config.entity_id not in url_map[url]: url_map[url].append(config.entity_id) - # Now fetch the metadata: + + # Now attempt to fetch the metadata for the remaining SAML providers: + num_attempted = len(url_map) + num_updated = 0 + failure_messages = [] # We return the length of this array for num_failed for url, entity_ids in url_map.items(): try: log.info("Fetching %s", url) @@ -75,7 +90,7 @@ def fetch_saml_metadata(): changed = _update_data(entity_id, public_key, sso_url, expires_at) if changed: log.info(u"→ Created new record for SAMLProviderData") - num_changed += 1 + num_updated += 1 else: log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.") except (exceptions.SSLError, exceptions.HTTPError, exceptions.RequestException, MetadataParseError) as error: @@ -109,7 +124,8 @@ def fetch_saml_metadata(): ) ) - return num_changed, len(failure_messages), len(url_map), failure_messages + # Return counts for total, skipped, attempted, updated, and failed, along with any failure messages + return num_total, num_skipped, num_attempted, num_updated, len(failure_messages), failure_messages def _parse_metadata_xml(xml, entity_id): diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 13a5e2e32c..99a1150f05 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -149,11 +149,13 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): kwargs.setdefault('attr_email', 'urn:oid:1.3.6.1.4.1.5923.1.1.1.6') # eduPersonPrincipalName self.configure_saml_provider(**kwargs) self.assertTrue(httpretty.is_enabled()) - num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() + num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata() + self.assertEqual(num_total, 1) + self.assertEqual(num_skipped, 0) + self.assertEqual(num_attempted, 1) + self.assertEqual(num_updated, 1) self.assertEqual(num_failed, 0) self.assertEqual(len(failure_messages), 0) - self.assertEqual(num_changed, 1) - self.assertEqual(num_total, 1) def _freeze_time(self, timestamp): """ Mock the current time for SAML, so we can replay canned requests/responses """ @@ -177,12 +179,14 @@ class TestShibIntegrationTest(IntegrationTestMixin, testutil.SAMLTestCase): if fetch_metadata: self.assertTrue(httpretty.is_enabled()) - num_changed, num_failed, num_total, failure_messages = fetch_saml_metadata() + num_total, num_skipped, num_attempted, num_updated, num_failed, failure_messages = fetch_saml_metadata() if assert_metadata_updates: + self.assertEqual(num_total, 1) + self.assertEqual(num_skipped, 0) + self.assertEqual(num_attempted, 1) + self.assertEqual(num_updated, 1) self.assertEqual(num_failed, 0) self.assertEqual(len(failure_messages), 0) - self.assertEqual(num_changed, 1) - self.assertEqual(num_total, 1) def do_provider_login(self, provider_redirect_url): """ Mocked: the user logs in to TestShib and then gets redirected back """