feat: update saml management command (#37330)

The SAML management command has been refactored from
an auto-update tool to a comprehensive report-only audit system.
The changes introduce a new --run-checks option that provides
detailed reporting on SAML configuration issues without making
any automatic changes.
This commit is contained in:
Krish Tyagi
2025-09-17 18:32:55 +05:30
committed by GitHub
parent 68d68203a2
commit 1eb387b11b
3 changed files with 279 additions and 111 deletions

View File

@@ -6,6 +6,7 @@ Management commands for third_party_auth
import logging
from django.core.management.base import BaseCommand, CommandError
from edx_django_utils.monitoring import set_custom_attribute
from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
@@ -18,34 +19,24 @@ class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
parser.add_argument(
'--fix-references',
'--run-checks',
action='store_true',
help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions"
)
parser.add_argument(
'--site-id',
type=int,
help='Only fix configurations for a specific site ID (to be used with --fix-references)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Show what would be changed, but do not make any changes.'
help="Run checks on SAMLProviderConfig configurations and report potential issues"
)
def handle(self, *args, **options):
should_pull_saml_metadata = options.get('pull', False)
should_fix_references = options.get('fix_references', False)
dry_run = options.get('dry_run', False)
if not should_pull_saml_metadata and not should_fix_references:
raise CommandError("Command must be used with '--pull' or '--fix-references' option.")
should_run_checks = options.get('run_checks', False)
if should_pull_saml_metadata:
self._handle_pull_metadata()
return
if should_fix_references:
self._handle_fix_references(options, dry_run=dry_run)
if should_run_checks:
self._handle_run_checks()
return
raise CommandError("Command must be used with '--pull' or '--run-checks' option.")
def _handle_pull_metadata(self):
"""
@@ -76,45 +67,139 @@ class Command(BaseCommand):
)
)
def _handle_fix_references(self, options, dry_run=False):
"""Handle the --fix-references option for fixing outdated SAML configuration references."""
site_id = options.get('site_id')
updated_count = 0
error_count = 0
def _handle_run_checks(self):
"""
Handle the --run-checks option for checking SAMLProviderConfig configuration issues.
This is a report-only command. It identifies potential configuration problems such as:
- Outdated SAMLConfiguration references (provider pointing to old config version)
- Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration
- Slug mismatches (except 'default' slugs) # noqa: E501
- SAMLProviderConfig objects with null SAMLConfiguration references (informational)
Includes observability attributes for monitoring.
"""
# Set custom attributes for monitoring the check operation
# .. custom_attribute_name: saml_management_command.operation
# .. custom_attribute_description: Records current SAML operation ('run_checks').
set_custom_attribute('saml_management_command.operation', 'run_checks')
metrics = self._check_provider_configurations()
self._report_check_summary(metrics)
def _check_provider_configurations(self):
"""
Check each provider configuration for potential issues.
Returns a dictionary of metrics about the found issues.
"""
outdated_count = 0
site_mismatch_count = 0
slug_mismatch_count = 0
null_config_count = 0
error_count = 0
total_providers = 0
# Filter by site if specified
provider_configs = SAMLProviderConfig.objects.current_set()
if site_id:
provider_configs = provider_configs.filter(site_id=site_id)
self.stdout.write(self.style.SUCCESS("SAML Configuration Check Report"))
self.stdout.write("=" * 50)
self.stdout.write("")
for provider_config in provider_configs:
if provider_config.saml_configuration:
try:
current_config = SAMLConfiguration.current(
provider_config.site_id,
provider_config.saml_configuration.slug
)
total_providers += 1
provider_info = (
f"Provider (id={provider_config.id}, name={provider_config.name}, "
f"slug={provider_config.slug}, site_id={provider_config.site_id})"
)
if current_config and current_config.id != provider_config.saml_configuration_id:
if not provider_config.saml_configuration:
self.stdout.write(
f"[INFO] {provider_info} has no SAML configuration because "
"a matching default was not found."
)
null_config_count += 1
continue
try:
current_config = SAMLConfiguration.current(
provider_config.saml_configuration.site_id,
provider_config.saml_configuration.slug
)
# Check for outdated configuration references
if current_config:
if current_config.id != provider_config.saml_configuration_id:
self.stdout.write(
f"Provider '{provider_config.slug}' (site {provider_config.site_id}) "
f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})"
f"[WARNING] {provider_info} "
f"has outdated SAML config (id={provider_config.saml_configuration_id} which "
f"should be updated to the current SAML config (id={current_config.id})."
)
outdated_count += 1
if not dry_run:
provider_config.saml_configuration = current_config
provider_config.save()
updated_count += 1
except Exception as e: # pylint: disable=broad-except
self.stderr.write(
f"Error processing provider '{provider_config.slug}': {e}"
if provider_config.saml_configuration.site_id != provider_config.site_id:
config_site_id = provider_config.saml_configuration.site_id
provider_site_id = provider_config.site_id
self.stdout.write(
f"[WARNING] {provider_info} "
f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) "
"does not match the provider's site_id."
)
error_count += 1
site_mismatch_count += 1
style = self.style.SUCCESS
if dry_run:
msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered."
saml_configuration_slug = provider_config.saml_configuration.slug
provider_config_slug = provider_config.slug
if saml_configuration_slug not in (provider_config_slug, 'default'):
self.stdout.write(
f"[WARNING] {provider_info} "
f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') "
"does not match the provider's slug."
)
slug_mismatch_count += 1
except Exception as e: # pylint: disable=broad-except
self.stderr.write(f"[ERROR] Error processing {provider_info}: {e}")
error_count += 1
metrics = {
'total_providers': {'count': total_providers, 'requires_attention': False},
'outdated_count': {'count': outdated_count, 'requires_attention': True},
'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True},
'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True},
'null_config_count': {'count': null_config_count, 'requires_attention': False},
'error_count': {'count': error_count, 'requires_attention': True},
}
for key, metric_data in metrics.items():
# .. custom_attribute_name: saml_management_command.{key}
# .. custom_attribute_description: Records metrics from SAML configuration checks.
set_custom_attribute(f'saml_management_command.{key}', metric_data['count'])
return metrics
def _report_check_summary(self, metrics):
"""
Print a summary of the check results and set the total_requiring_attention custom attribute.
"""
total_requiring_attention = sum(
metric_data['count'] for metric_data in metrics.values()
if metric_data['requires_attention']
)
# .. custom_attribute_name: saml_management_command.total_requiring_attention
# .. custom_attribute_description: The total number of configuration issues requiring attention.
set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention)
self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:"))
self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}")
self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}")
if total_requiring_attention > 0:
self.stdout.write("\nIssues requiring attention:")
self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}")
self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}")
self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}")
self.stdout.write(f" Errors: {metrics['error_count']['count']}")
self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}")
else:
msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered."
self.stdout.write(style(msg))
self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!"))

View File

@@ -8,7 +8,7 @@ import os
from io import StringIO
from unittest import mock
from ddt import ddt, data, unpack
from ddt import ddt
from django.contrib.sites.models import Site
from django.core.management import call_command
from django.core.management.base import CommandError
@@ -18,8 +18,6 @@ from requests.models import Response
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
def mock_get(status_code=200):
"""
@@ -64,6 +62,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
self.stdout = StringIO()
self.site = Site.objects.get_current()
self.other_site = Site.objects.create(domain='other.example.com', name='Other Site')
# We are creating SAMLConfiguration instance here so that there is always at-least one
# disabled saml configuration instance, this is done to verify that disabled configurations are
@@ -82,9 +81,9 @@ class TestSAMLCommand(CacheIsolationTestCase):
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
)
def _setup_test_configs_for_fix_references(self):
def _setup_test_configs_for_run_checks(self):
"""
Helper method to create SAML configurations for fix-references tests.
Helper method to create SAML configurations for run-checks tests.
Returns tuple of (old_config, new_config, provider_config)
@@ -108,7 +107,7 @@ class TestSAMLCommand(CacheIsolationTestCase):
entity_id='https://updated.example.com'
)
# Create a provider config that references the old config for fix-references tests
# Create a provider config that references the old config for run-checks tests
test_provider_config = SAMLProviderConfigFactory.create(
site=self.site,
slug='test-provider',
@@ -148,14 +147,10 @@ class TestSAMLCommand(CacheIsolationTestCase):
This test would fail with an error if ValueError is raised.
"""
# Call `saml` command without any argument so that it raises a CommandError
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
# Call `saml` command without any arguments so that it raises a CommandError
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--run-checks' option."):
call_command("saml")
# Call `saml` command without any argument so that it raises a CommandError
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
call_command("saml", pull=False)
def test_no_saml_configuration(self):
"""
Test that management command completes without errors and logs correct information when no
@@ -334,59 +329,144 @@ class TestSAMLCommand(CacheIsolationTestCase):
call_command("saml", pull=True, stdout=self.stdout)
assert expected in self.stdout.getvalue()
@data(
(True, '[DRY RUN]', 'should not update provider configs'),
(False, '', 'should create new provider config for new version')
)
@unpack
def test_fix_references(self, dry_run, expected_output_marker, test_description):
def _run_checks_command(self):
"""
Test the --fix-references command with and without --dry-run option.
Args:
dry_run (bool): Whether to run with --dry-run flag
expected_output_marker (str): Expected marker in output
test_description (str): Description of what the test should do
Helper method to run the --run-checks command and return output.
"""
old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references()
new_config_id = new_config.id
original_config_id = old_config.id
out = StringIO()
if dry_run:
call_command('saml', '--fix-references', '--dry-run', stdout=out)
else:
call_command('saml', '--fix-references', stdout=out)
call_command('saml', '--run-checks', stdout=out)
return out.getvalue()
output = out.getvalue()
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
def test_run_checks_outdated_configs(self, mock_set_custom_attribute):
"""
Test the --run-checks command identifies outdated configurations.
"""
old_config, new_config, test_provider_config = self._setup_test_configs_for_run_checks()
output = self._run_checks_command()
self.assertIn('[WARNING]', output)
self.assertIn('test-provider', output)
if expected_output_marker:
self.assertIn(expected_output_marker, output)
self.assertIn(
f'id={old_config.id} which should be updated to the current SAML config (id={new_config.id})',
output
)
self.assertIn('CHECK SUMMARY:', output)
self.assertIn('Providers checked: 2', output)
self.assertIn('Outdated: 1', output)
test_provider_config.refresh_from_db()
# Check key observability calls
expected_calls = [
mock.call('saml_management_command.operation', 'run_checks'),
mock.call('saml_management_command.total_providers', 2),
mock.call('saml_management_command.outdated_count', 1),
mock.call('saml_management_command.site_mismatch_count', 0),
mock.call('saml_management_command.slug_mismatch_count', 1),
mock.call('saml_management_command.null_config_count', 1),
mock.call('saml_management_command.error_count', 0),
mock.call('saml_management_command.total_requiring_attention', 2),
]
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
if dry_run:
# For dry run, ensure the provider config was NOT updated
self.assertEqual(
test_provider_config.saml_configuration_id,
original_config_id,
"Provider config should not be updated in dry run mode"
)
else:
# For actual run, check that a new provider config was created
new_provider = SAMLProviderConfig.objects.filter(
site=self.site,
slug='test-provider',
saml_configuration_id=new_config_id
).exclude(id=test_provider_config.id).first()
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
def test_run_checks_site_mismatches(self, mock_set_custom_attribute):
"""
Test the --run-checks command identifies site ID mismatches.
"""
config = SAMLConfigurationFactory.create(
site=self.other_site,
slug='test-config',
entity_id='https://example.com'
)
self.assertIsNotNone(new_provider, "New provider config should be created")
self.assertEqual(new_provider.saml_configuration_id, new_config_id)
SAMLProviderConfigFactory.create(
site=self.site,
slug='test-provider',
saml_configuration=config
)
# Original provider config should still reference the old config
self.assertEqual(
test_provider_config.saml_configuration_id,
original_config_id,
"Original provider config should still reference old config"
)
output = self._run_checks_command()
self.assertIn('[WARNING]', output)
self.assertIn('test-provider', output)
self.assertIn('does not match the provider\'s site_id', output)
# Check observability calls
expected_calls = [
mock.call('saml_management_command.operation', 'run_checks'),
mock.call('saml_management_command.total_providers', 2),
mock.call('saml_management_command.outdated_count', 0),
mock.call('saml_management_command.site_mismatch_count', 1),
mock.call('saml_management_command.slug_mismatch_count', 1),
mock.call('saml_management_command.null_config_count', 1),
mock.call('saml_management_command.error_count', 0),
mock.call('saml_management_command.total_requiring_attention', 2),
]
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
def test_run_checks_slug_mismatches(self, mock_set_custom_attribute):
"""
Test the --run-checks command identifies slug mismatches.
"""
config = SAMLConfigurationFactory.create(
site=self.site,
slug='config-slug',
entity_id='https://example.com'
)
SAMLProviderConfigFactory.create(
site=self.site,
slug='provider-slug',
saml_configuration=config
)
output = self._run_checks_command()
self.assertIn('[WARNING]', output)
self.assertIn('provider-slug', output)
self.assertIn('does not match the provider\'s slug', output)
# Check observability calls
expected_calls = [
mock.call('saml_management_command.operation', 'run_checks'),
mock.call('saml_management_command.total_providers', 2),
mock.call('saml_management_command.outdated_count', 0),
mock.call('saml_management_command.site_mismatch_count', 0),
mock.call('saml_management_command.slug_mismatch_count', 1),
mock.call('saml_management_command.null_config_count', 1),
mock.call('saml_management_command.error_count', 0),
mock.call('saml_management_command.total_requiring_attention', 1),
]
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
@mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
def test_run_checks_null_configurations(self, mock_set_custom_attribute):
"""
Test the --run-checks command identifies providers with null configurations.
"""
SAMLProviderConfigFactory.create(
site=self.site,
slug='null-provider',
saml_configuration=None
)
output = self._run_checks_command()
self.assertIn('[INFO]', output)
self.assertIn('null-provider', output)
self.assertIn('has no SAML configuration because a matching default was not found', output)
# Check observability calls
expected_calls = [
mock.call('saml_management_command.operation', 'run_checks'),
mock.call('saml_management_command.total_providers', 2),
mock.call('saml_management_command.outdated_count', 0),
mock.call('saml_management_command.site_mismatch_count', 0),
mock.call('saml_management_command.slug_mismatch_count', 0),
mock.call('saml_management_command.null_config_count', 2),
mock.call('saml_management_command.error_count', 0),
mock.call('saml_management_command.total_requiring_attention', 0),
]
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)

View File

@@ -153,9 +153,12 @@ class TestSAMLConfigurationSignalHandlers(TestCase):
current_provider = self._get_current_provider(provider_slug)
mock_set_custom_attribute.assert_any_call('saml_config_signal.enabled', True)
mock_set_custom_attribute.assert_any_call('saml_config_signal.new_config_id', new_saml_config.id)
mock_set_custom_attribute.assert_any_call('saml_config_signal.slug', signal_saml_slug)
expected_calls = [
call('saml_config_signal.enabled', True),
call('saml_config_signal.new_config_id', new_saml_config.id),
call('saml_config_signal.slug', signal_saml_slug),
]
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
if is_provider_updated:
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 1)