feat: update saml management command (#37330)
The SAML management command has been refactored from an auto-update tool to a comprehensive report-only audit system. The changes introduce a new --run-checks option that provides detailed reporting on SAML configuration issues without making any automatic changes.
This commit is contained in:
@@ -6,6 +6,7 @@ Management commands for third_party_auth
|
||||
import logging
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from edx_django_utils.monitoring import set_custom_attribute
|
||||
|
||||
from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
|
||||
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
|
||||
@@ -18,34 +19,24 @@ class Command(BaseCommand):
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
|
||||
parser.add_argument(
|
||||
'--fix-references',
|
||||
'--run-checks',
|
||||
action='store_true',
|
||||
help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--site-id',
|
||||
type=int,
|
||||
help='Only fix configurations for a specific site ID (to be used with --fix-references)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be changed, but do not make any changes.'
|
||||
help="Run checks on SAMLProviderConfig configurations and report potential issues"
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
should_pull_saml_metadata = options.get('pull', False)
|
||||
should_fix_references = options.get('fix_references', False)
|
||||
dry_run = options.get('dry_run', False)
|
||||
|
||||
if not should_pull_saml_metadata and not should_fix_references:
|
||||
raise CommandError("Command must be used with '--pull' or '--fix-references' option.")
|
||||
should_run_checks = options.get('run_checks', False)
|
||||
|
||||
if should_pull_saml_metadata:
|
||||
self._handle_pull_metadata()
|
||||
return
|
||||
|
||||
if should_fix_references:
|
||||
self._handle_fix_references(options, dry_run=dry_run)
|
||||
if should_run_checks:
|
||||
self._handle_run_checks()
|
||||
return
|
||||
|
||||
raise CommandError("Command must be used with '--pull' or '--run-checks' option.")
|
||||
|
||||
def _handle_pull_metadata(self):
|
||||
"""
|
||||
@@ -76,45 +67,139 @@ class Command(BaseCommand):
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_fix_references(self, options, dry_run=False):
|
||||
"""Handle the --fix-references option for fixing outdated SAML configuration references."""
|
||||
site_id = options.get('site_id')
|
||||
updated_count = 0
|
||||
error_count = 0
|
||||
def _handle_run_checks(self):
|
||||
"""
|
||||
Handle the --run-checks option for checking SAMLProviderConfig configuration issues.
|
||||
|
||||
This is a report-only command. It identifies potential configuration problems such as:
|
||||
- Outdated SAMLConfiguration references (provider pointing to old config version)
|
||||
- Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration
|
||||
- Slug mismatches (except 'default' slugs) # noqa: E501
|
||||
- SAMLProviderConfig objects with null SAMLConfiguration references (informational)
|
||||
|
||||
Includes observability attributes for monitoring.
|
||||
"""
|
||||
# Set custom attributes for monitoring the check operation
|
||||
# .. custom_attribute_name: saml_management_command.operation
|
||||
# .. custom_attribute_description: Records current SAML operation ('run_checks').
|
||||
set_custom_attribute('saml_management_command.operation', 'run_checks')
|
||||
|
||||
metrics = self._check_provider_configurations()
|
||||
self._report_check_summary(metrics)
|
||||
|
||||
def _check_provider_configurations(self):
|
||||
"""
|
||||
Check each provider configuration for potential issues.
|
||||
Returns a dictionary of metrics about the found issues.
|
||||
"""
|
||||
outdated_count = 0
|
||||
site_mismatch_count = 0
|
||||
slug_mismatch_count = 0
|
||||
null_config_count = 0
|
||||
error_count = 0
|
||||
total_providers = 0
|
||||
|
||||
# Filter by site if specified
|
||||
provider_configs = SAMLProviderConfig.objects.current_set()
|
||||
if site_id:
|
||||
provider_configs = provider_configs.filter(site_id=site_id)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS("SAML Configuration Check Report"))
|
||||
self.stdout.write("=" * 50)
|
||||
self.stdout.write("")
|
||||
|
||||
for provider_config in provider_configs:
|
||||
if provider_config.saml_configuration:
|
||||
try:
|
||||
current_config = SAMLConfiguration.current(
|
||||
provider_config.site_id,
|
||||
provider_config.saml_configuration.slug
|
||||
)
|
||||
total_providers += 1
|
||||
provider_info = (
|
||||
f"Provider (id={provider_config.id}, name={provider_config.name}, "
|
||||
f"slug={provider_config.slug}, site_id={provider_config.site_id})"
|
||||
)
|
||||
|
||||
if current_config and current_config.id != provider_config.saml_configuration_id:
|
||||
if not provider_config.saml_configuration:
|
||||
self.stdout.write(
|
||||
f"[INFO] {provider_info} has no SAML configuration because "
|
||||
"a matching default was not found."
|
||||
)
|
||||
null_config_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
current_config = SAMLConfiguration.current(
|
||||
provider_config.saml_configuration.site_id,
|
||||
provider_config.saml_configuration.slug
|
||||
)
|
||||
|
||||
# Check for outdated configuration references
|
||||
if current_config:
|
||||
if current_config.id != provider_config.saml_configuration_id:
|
||||
self.stdout.write(
|
||||
f"Provider '{provider_config.slug}' (site {provider_config.site_id}) "
|
||||
f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})"
|
||||
f"[WARNING] {provider_info} "
|
||||
f"has outdated SAML config (id={provider_config.saml_configuration_id} which "
|
||||
f"should be updated to the current SAML config (id={current_config.id})."
|
||||
)
|
||||
outdated_count += 1
|
||||
|
||||
if not dry_run:
|
||||
provider_config.saml_configuration = current_config
|
||||
provider_config.save()
|
||||
updated_count += 1
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
self.stderr.write(
|
||||
f"Error processing provider '{provider_config.slug}': {e}"
|
||||
if provider_config.saml_configuration.site_id != provider_config.site_id:
|
||||
config_site_id = provider_config.saml_configuration.site_id
|
||||
provider_site_id = provider_config.site_id
|
||||
self.stdout.write(
|
||||
f"[WARNING] {provider_info} "
|
||||
f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) "
|
||||
"does not match the provider's site_id."
|
||||
)
|
||||
error_count += 1
|
||||
site_mismatch_count += 1
|
||||
|
||||
style = self.style.SUCCESS
|
||||
if dry_run:
|
||||
msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered."
|
||||
saml_configuration_slug = provider_config.saml_configuration.slug
|
||||
provider_config_slug = provider_config.slug
|
||||
|
||||
if saml_configuration_slug not in (provider_config_slug, 'default'):
|
||||
self.stdout.write(
|
||||
f"[WARNING] {provider_info} "
|
||||
f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') "
|
||||
"does not match the provider's slug."
|
||||
)
|
||||
slug_mismatch_count += 1
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
self.stderr.write(f"[ERROR] Error processing {provider_info}: {e}")
|
||||
error_count += 1
|
||||
|
||||
metrics = {
|
||||
'total_providers': {'count': total_providers, 'requires_attention': False},
|
||||
'outdated_count': {'count': outdated_count, 'requires_attention': True},
|
||||
'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True},
|
||||
'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True},
|
||||
'null_config_count': {'count': null_config_count, 'requires_attention': False},
|
||||
'error_count': {'count': error_count, 'requires_attention': True},
|
||||
}
|
||||
|
||||
for key, metric_data in metrics.items():
|
||||
# .. custom_attribute_name: saml_management_command.{key}
|
||||
# .. custom_attribute_description: Records metrics from SAML configuration checks.
|
||||
set_custom_attribute(f'saml_management_command.{key}', metric_data['count'])
|
||||
|
||||
return metrics
|
||||
|
||||
def _report_check_summary(self, metrics):
|
||||
"""
|
||||
Print a summary of the check results and set the total_requiring_attention custom attribute.
|
||||
"""
|
||||
total_requiring_attention = sum(
|
||||
metric_data['count'] for metric_data in metrics.values()
|
||||
if metric_data['requires_attention']
|
||||
)
|
||||
|
||||
# .. custom_attribute_name: saml_management_command.total_requiring_attention
|
||||
# .. custom_attribute_description: The total number of configuration issues requiring attention.
|
||||
set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:"))
|
||||
self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}")
|
||||
self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}")
|
||||
|
||||
if total_requiring_attention > 0:
|
||||
self.stdout.write("\nIssues requiring attention:")
|
||||
self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}")
|
||||
self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}")
|
||||
self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}")
|
||||
self.stdout.write(f" Errors: {metrics['error_count']['count']}")
|
||||
self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}")
|
||||
else:
|
||||
msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered."
|
||||
self.stdout.write(style(msg))
|
||||
self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!"))
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
from io import StringIO
|
||||
|
||||
from unittest import mock
|
||||
from ddt import ddt, data, unpack
|
||||
from ddt import ddt
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
@@ -18,8 +18,6 @@ from requests.models import Response
|
||||
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms
|
||||
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
|
||||
|
||||
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
|
||||
|
||||
|
||||
def mock_get(status_code=200):
|
||||
"""
|
||||
@@ -64,6 +62,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
self.stdout = StringIO()
|
||||
self.site = Site.objects.get_current()
|
||||
self.other_site = Site.objects.create(domain='other.example.com', name='Other Site')
|
||||
|
||||
# We are creating SAMLConfiguration instance here so that there is always at-least one
|
||||
# disabled saml configuration instance, this is done to verify that disabled configurations are
|
||||
@@ -82,9 +81,9 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
|
||||
)
|
||||
|
||||
def _setup_test_configs_for_fix_references(self):
|
||||
def _setup_test_configs_for_run_checks(self):
|
||||
"""
|
||||
Helper method to create SAML configurations for fix-references tests.
|
||||
Helper method to create SAML configurations for run-checks tests.
|
||||
|
||||
Returns tuple of (old_config, new_config, provider_config)
|
||||
|
||||
@@ -108,7 +107,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
entity_id='https://updated.example.com'
|
||||
)
|
||||
|
||||
# Create a provider config that references the old config for fix-references tests
|
||||
# Create a provider config that references the old config for run-checks tests
|
||||
test_provider_config = SAMLProviderConfigFactory.create(
|
||||
site=self.site,
|
||||
slug='test-provider',
|
||||
@@ -148,14 +147,10 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
|
||||
This test would fail with an error if ValueError is raised.
|
||||
"""
|
||||
# Call `saml` command without any argument so that it raises a CommandError
|
||||
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
|
||||
# Call `saml` command without any arguments so that it raises a CommandError
|
||||
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--run-checks' option."):
|
||||
call_command("saml")
|
||||
|
||||
# Call `saml` command without any argument so that it raises a CommandError
|
||||
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
|
||||
call_command("saml", pull=False)
|
||||
|
||||
def test_no_saml_configuration(self):
|
||||
"""
|
||||
Test that management command completes without errors and logs correct information when no
|
||||
@@ -334,59 +329,144 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@data(
|
||||
(True, '[DRY RUN]', 'should not update provider configs'),
|
||||
(False, '', 'should create new provider config for new version')
|
||||
)
|
||||
@unpack
|
||||
def test_fix_references(self, dry_run, expected_output_marker, test_description):
|
||||
def _run_checks_command(self):
|
||||
"""
|
||||
Test the --fix-references command with and without --dry-run option.
|
||||
|
||||
Args:
|
||||
dry_run (bool): Whether to run with --dry-run flag
|
||||
expected_output_marker (str): Expected marker in output
|
||||
test_description (str): Description of what the test should do
|
||||
Helper method to run the --run-checks command and return output.
|
||||
"""
|
||||
old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references()
|
||||
new_config_id = new_config.id
|
||||
original_config_id = old_config.id
|
||||
|
||||
out = StringIO()
|
||||
if dry_run:
|
||||
call_command('saml', '--fix-references', '--dry-run', stdout=out)
|
||||
else:
|
||||
call_command('saml', '--fix-references', stdout=out)
|
||||
call_command('saml', '--run-checks', stdout=out)
|
||||
return out.getvalue()
|
||||
|
||||
output = out.getvalue()
|
||||
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
|
||||
def test_run_checks_outdated_configs(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test the --run-checks command identifies outdated configurations.
|
||||
"""
|
||||
old_config, new_config, test_provider_config = self._setup_test_configs_for_run_checks()
|
||||
|
||||
output = self._run_checks_command()
|
||||
|
||||
self.assertIn('[WARNING]', output)
|
||||
self.assertIn('test-provider', output)
|
||||
if expected_output_marker:
|
||||
self.assertIn(expected_output_marker, output)
|
||||
self.assertIn(
|
||||
f'id={old_config.id} which should be updated to the current SAML config (id={new_config.id})',
|
||||
output
|
||||
)
|
||||
self.assertIn('CHECK SUMMARY:', output)
|
||||
self.assertIn('Providers checked: 2', output)
|
||||
self.assertIn('Outdated: 1', output)
|
||||
|
||||
test_provider_config.refresh_from_db()
|
||||
# Check key observability calls
|
||||
expected_calls = [
|
||||
mock.call('saml_management_command.operation', 'run_checks'),
|
||||
mock.call('saml_management_command.total_providers', 2),
|
||||
mock.call('saml_management_command.outdated_count', 1),
|
||||
mock.call('saml_management_command.site_mismatch_count', 0),
|
||||
mock.call('saml_management_command.slug_mismatch_count', 1),
|
||||
mock.call('saml_management_command.null_config_count', 1),
|
||||
mock.call('saml_management_command.error_count', 0),
|
||||
mock.call('saml_management_command.total_requiring_attention', 2),
|
||||
]
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
if dry_run:
|
||||
# For dry run, ensure the provider config was NOT updated
|
||||
self.assertEqual(
|
||||
test_provider_config.saml_configuration_id,
|
||||
original_config_id,
|
||||
"Provider config should not be updated in dry run mode"
|
||||
)
|
||||
else:
|
||||
# For actual run, check that a new provider config was created
|
||||
new_provider = SAMLProviderConfig.objects.filter(
|
||||
site=self.site,
|
||||
slug='test-provider',
|
||||
saml_configuration_id=new_config_id
|
||||
).exclude(id=test_provider_config.id).first()
|
||||
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
|
||||
def test_run_checks_site_mismatches(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test the --run-checks command identifies site ID mismatches.
|
||||
"""
|
||||
config = SAMLConfigurationFactory.create(
|
||||
site=self.other_site,
|
||||
slug='test-config',
|
||||
entity_id='https://example.com'
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_provider, "New provider config should be created")
|
||||
self.assertEqual(new_provider.saml_configuration_id, new_config_id)
|
||||
SAMLProviderConfigFactory.create(
|
||||
site=self.site,
|
||||
slug='test-provider',
|
||||
saml_configuration=config
|
||||
)
|
||||
|
||||
# Original provider config should still reference the old config
|
||||
self.assertEqual(
|
||||
test_provider_config.saml_configuration_id,
|
||||
original_config_id,
|
||||
"Original provider config should still reference old config"
|
||||
)
|
||||
output = self._run_checks_command()
|
||||
|
||||
self.assertIn('[WARNING]', output)
|
||||
self.assertIn('test-provider', output)
|
||||
self.assertIn('does not match the provider\'s site_id', output)
|
||||
|
||||
# Check observability calls
|
||||
expected_calls = [
|
||||
mock.call('saml_management_command.operation', 'run_checks'),
|
||||
mock.call('saml_management_command.total_providers', 2),
|
||||
mock.call('saml_management_command.outdated_count', 0),
|
||||
mock.call('saml_management_command.site_mismatch_count', 1),
|
||||
mock.call('saml_management_command.slug_mismatch_count', 1),
|
||||
mock.call('saml_management_command.null_config_count', 1),
|
||||
mock.call('saml_management_command.error_count', 0),
|
||||
mock.call('saml_management_command.total_requiring_attention', 2),
|
||||
]
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
|
||||
def test_run_checks_slug_mismatches(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test the --run-checks command identifies slug mismatches.
|
||||
"""
|
||||
config = SAMLConfigurationFactory.create(
|
||||
site=self.site,
|
||||
slug='config-slug',
|
||||
entity_id='https://example.com'
|
||||
)
|
||||
|
||||
SAMLProviderConfigFactory.create(
|
||||
site=self.site,
|
||||
slug='provider-slug',
|
||||
saml_configuration=config
|
||||
)
|
||||
|
||||
output = self._run_checks_command()
|
||||
|
||||
self.assertIn('[WARNING]', output)
|
||||
self.assertIn('provider-slug', output)
|
||||
self.assertIn('does not match the provider\'s slug', output)
|
||||
|
||||
# Check observability calls
|
||||
expected_calls = [
|
||||
mock.call('saml_management_command.operation', 'run_checks'),
|
||||
mock.call('saml_management_command.total_providers', 2),
|
||||
mock.call('saml_management_command.outdated_count', 0),
|
||||
mock.call('saml_management_command.site_mismatch_count', 0),
|
||||
mock.call('saml_management_command.slug_mismatch_count', 1),
|
||||
mock.call('saml_management_command.null_config_count', 1),
|
||||
mock.call('saml_management_command.error_count', 0),
|
||||
mock.call('saml_management_command.total_requiring_attention', 1),
|
||||
]
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
|
||||
def test_run_checks_null_configurations(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test the --run-checks command identifies providers with null configurations.
|
||||
"""
|
||||
SAMLProviderConfigFactory.create(
|
||||
site=self.site,
|
||||
slug='null-provider',
|
||||
saml_configuration=None
|
||||
)
|
||||
|
||||
output = self._run_checks_command()
|
||||
|
||||
self.assertIn('[INFO]', output)
|
||||
self.assertIn('null-provider', output)
|
||||
self.assertIn('has no SAML configuration because a matching default was not found', output)
|
||||
|
||||
# Check observability calls
|
||||
expected_calls = [
|
||||
mock.call('saml_management_command.operation', 'run_checks'),
|
||||
mock.call('saml_management_command.total_providers', 2),
|
||||
mock.call('saml_management_command.outdated_count', 0),
|
||||
mock.call('saml_management_command.site_mismatch_count', 0),
|
||||
mock.call('saml_management_command.slug_mismatch_count', 0),
|
||||
mock.call('saml_management_command.null_config_count', 2),
|
||||
mock.call('saml_management_command.error_count', 0),
|
||||
mock.call('saml_management_command.total_requiring_attention', 0),
|
||||
]
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
@@ -153,9 +153,12 @@ class TestSAMLConfigurationSignalHandlers(TestCase):
|
||||
|
||||
current_provider = self._get_current_provider(provider_slug)
|
||||
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.enabled', True)
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.new_config_id', new_saml_config.id)
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.slug', signal_saml_slug)
|
||||
expected_calls = [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', new_saml_config.id),
|
||||
call('saml_config_signal.slug', signal_saml_slug),
|
||||
]
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
|
||||
if is_provider_updated:
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 1)
|
||||
|
||||
Reference in New Issue
Block a user