fix: SAML provider config references to use current SAML configuration versions (#36954)

Introduces temporary rollout toggle ENABLE_SAML_CONFIG_SIGNAL_HANDLERS
which controls whether SAML configuration signal handlers are active.
When enabled (True), signal handlers will automatically update SAMLProviderConfig
references when the associated SAMLConfiguration is updated.
When disabled (False), SAMLProviderConfigs point to outdated SAMLConfiguration.

Warning: Disabling this toggle may result in SAMLProviderConfig instances
pointing to outdated SAMLConfiguration records.

Use the management command 'saml --fix-references' to fix outdated references.
This commit is contained in:
Krish Tyagi
2025-08-12 19:04:34 +05:30
committed by GitHub
parent 472801b774
commit 14cdbc855d
8 changed files with 383 additions and 7 deletions

View File

@@ -9,6 +9,9 @@ class ThirdPartyAuthConfig(AppConfig): # lint-amnesty, pylint: disable=missing-
verbose_name = "Third-party authentication"
def ready(self):
# Import signal handlers to register them
from .signals import handlers # noqa: F401 pylint: disable=unused-import
# To override the settings before loading social_django.
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH', False):
self._enable_third_party_auth()

View File

@@ -8,6 +8,7 @@ import logging
from django.core.management.base import BaseCommand, CommandError
from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
class Command(BaseCommand):
@@ -16,13 +17,41 @@ class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
parser.add_argument(
'--fix-references',
action='store_true',
help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions"
)
parser.add_argument(
'--site-id',
type=int,
help='Only fix configurations for a specific site ID (to be used with --fix-references)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Show what would be changed, but do not make any changes.'
)
def handle(self, *args, **options):
should_pull_saml_metadata = options.get('pull', False)
should_fix_references = options.get('fix_references', False)
dry_run = options.get('dry_run', False)
if not should_pull_saml_metadata:
raise CommandError("Command can only be used with '--pull' option.")
if not should_pull_saml_metadata and not should_fix_references:
raise CommandError("Command must be used with '--pull' or '--fix-references' option.")
if should_pull_saml_metadata:
self._handle_pull_metadata()
if should_fix_references:
self._handle_fix_references(options, dry_run=dry_run)
def _handle_pull_metadata(self):
"""
Handle the --pull option to fetch and update SAML metadata from external providers.
This sets up logging and calls the fetch_saml_metadata task.
"""
log_handler = logging.StreamHandler(self.stdout)
log_handler.setLevel(logging.DEBUG)
log = logging.getLogger('common.djangoapps.third_party_auth.tasks')
@@ -46,3 +75,46 @@ class Command(BaseCommand):
failures="\n\n".join(failure_messages)
)
)
def _handle_fix_references(self, options, dry_run=False):
"""Handle the --fix-references option for fixing outdated SAML configuration references."""
site_id = options.get('site_id')
updated_count = 0
error_count = 0
# Filter by site if specified
provider_configs = SAMLProviderConfig.objects.current_set()
if site_id:
provider_configs = provider_configs.filter(site_id=site_id)
for provider_config in provider_configs:
if provider_config.saml_configuration:
try:
current_config = SAMLConfiguration.current(
provider_config.site_id,
provider_config.saml_configuration.slug
)
if current_config and current_config.id != provider_config.saml_configuration_id:
self.stdout.write(
f"Provider '{provider_config.slug}' (site {provider_config.site_id}) "
f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})"
)
if not dry_run:
provider_config.saml_configuration = current_config
provider_config.save()
updated_count += 1
except Exception as e: # pylint: disable=broad-except
self.stderr.write(
f"Error processing provider '{provider_config.slug}': {e}"
)
error_count += 1
style = self.style.SUCCESS
if dry_run:
msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered."
else:
msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered."
self.stdout.write(style(msg))

View File

@@ -8,6 +8,8 @@ import os
from io import StringIO
from unittest import mock
from ddt import ddt, data, unpack
from django.contrib.sites.models import Site
from django.core.management import call_command
from django.core.management.base import CommandError
from requests import exceptions
@@ -16,6 +18,8 @@ from requests.models import Response
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
def mock_get(status_code=200):
"""
@@ -45,6 +49,7 @@ def mock_get(status_code=200):
@skip_unless_lms
@ddt
class TestSAMLCommand(CacheIsolationTestCase):
"""
Test django management command for fetching saml metadata.
@@ -58,12 +63,17 @@ class TestSAMLCommand(CacheIsolationTestCase):
super().setUp()
self.stdout = StringIO()
self.site = Site.objects.get_current()
# We are creating SAMLConfiguration instance here so that there is always at-least one
# disabled saml configuration instance, this is done to verify that disabled configurations are
# not processed.
SAMLConfigurationFactory.create(enabled=False, site__domain='testserver.fake', site__name='testserver.fake')
SAMLProviderConfigFactory.create(
self.saml_config = SAMLConfigurationFactory.create(
enabled=False,
site__domain='testserver.fake',
site__name='testserver.fake'
)
self.provider_config = SAMLProviderConfigFactory.create(
site__domain='testserver.fake',
site__name='testserver.fake',
slug='test-shib',
@@ -72,6 +82,44 @@ class TestSAMLCommand(CacheIsolationTestCase):
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
)
def _setup_test_configs_for_fix_references(self):
"""
Helper method to create SAML configurations for fix-references tests.
Returns tuple of (old_config, new_config, provider_config)
Using a separate method keeps test data isolated. Including these configs in
setUp would create 3 provider configs for all tests, breaking tests that expect
specific provider counts or try to access non-existent test XML files.
"""
# Create a SAML config that will be outdated after the new config is created
old_config = SAMLConfigurationFactory.create(
enabled=False,
site=self.site,
slug='test-config',
entity_id='https://old.example.com'
)
# Create newer config with same slug
new_config = SAMLConfigurationFactory.create(
enabled=True,
site=self.site,
slug='test-config',
entity_id='https://updated.example.com'
)
# Create a provider config that references the old config for fix-references tests
test_provider_config = SAMLProviderConfigFactory.create(
site=self.site,
slug='test-provider',
name='Test Provider',
entity_id='https://test.provider/idp/shibboleth',
metadata_source='https://test.provider/metadata.xml',
saml_configuration=old_config
)
return old_config, new_config, test_provider_config
def __create_saml_configurations__(self, saml_config=None, saml_provider_config=None):
"""
Helper method to create SAMLConfiguration and AMLProviderConfig.
@@ -101,11 +149,11 @@ class TestSAMLCommand(CacheIsolationTestCase):
This test would fail with an error if ValueError is raised.
"""
# Call `saml` command without any argument so that it raises a CommandError
with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."):
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
call_command("saml")
# Call `saml` command without any argument so that it raises a CommandError
with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."):
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
call_command("saml", pull=False)
def test_no_saml_configuration(self):
@@ -285,3 +333,60 @@ class TestSAMLCommand(CacheIsolationTestCase):
with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"):
call_command("saml", pull=True, stdout=self.stdout)
assert expected in self.stdout.getvalue()
@data(
(True, '[DRY RUN]', 'should not update provider configs'),
(False, '', 'should create new provider config for new version')
)
@unpack
def test_fix_references(self, dry_run, expected_output_marker, test_description):
"""
Test the --fix-references command with and without --dry-run option.
Args:
dry_run (bool): Whether to run with --dry-run flag
expected_output_marker (str): Expected marker in output
test_description (str): Description of what the test should do
"""
old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references()
new_config_id = new_config.id
original_config_id = old_config.id
out = StringIO()
if dry_run:
call_command('saml', '--fix-references', '--dry-run', stdout=out)
else:
call_command('saml', '--fix-references', stdout=out)
output = out.getvalue()
self.assertIn('test-provider', output)
if expected_output_marker:
self.assertIn(expected_output_marker, output)
test_provider_config.refresh_from_db()
if dry_run:
# For dry run, ensure the provider config was NOT updated
self.assertEqual(
test_provider_config.saml_configuration_id,
original_config_id,
"Provider config should not be updated in dry run mode"
)
else:
# For actual run, check that a new provider config was created
new_provider = SAMLProviderConfig.objects.filter(
site=self.site,
slug='test-provider',
saml_configuration_id=new_config_id
).exclude(id=test_provider_config.id).first()
self.assertIsNotNone(new_provider, "New provider config should be created")
self.assertEqual(new_provider.saml_configuration_id, new_config_id)
# Original provider config should still reference the old config
self.assertEqual(
test_provider_config.saml_configuration_id,
original_config_id,
"Original provider config should still reference old config"
)

View File

@@ -0,0 +1 @@
# Signal handlers for third_party_auth app

View File

@@ -0,0 +1,57 @@
"""
Signal handlers for third_party_auth app.
"""
from django.db.models.signals import post_save
from django.dispatch import receiver
from edx_django_utils.monitoring import set_custom_attribute
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
from common.djangoapps.third_party_auth.toggles import ENABLE_SAML_CONFIG_SIGNAL_HANDLERS
@receiver(post_save, sender=SAMLConfiguration)
def update_saml_provider_configs_on_configuration_change(sender, instance, created, **kwargs):
"""
Signal handler to create a new SAMLProviderConfig when SAMLConfiguration is updated.
When a SAMLConfiguration is updated and a new version is created, this handler
generates a corresponding SAMLProviderConfig that references the latest
configuration version, ensuring all providers remain aligned with the most
current settings.
"""
# .. custom_attribute_name: saml_config_signal.enabled
# .. custom_attribute_description: Tracks whether the SAML config signal handler is enabled.
set_custom_attribute('saml_config_signal.enabled', ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled())
# .. custom_attribute_name: saml_config_signal.new_config_id
# .. custom_attribute_description: Records the ID of the new SAML configuration instance.
set_custom_attribute('saml_config_signal.new_config_id', instance.id)
# .. custom_attribute_name: saml_config_signal.slug
# .. custom_attribute_description: Records the slug of the SAML configuration instance.
set_custom_attribute('saml_config_signal.slug', instance.slug)
if ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled():
try:
# Find all existing SAMLProviderConfig instances (current_set) that should be
# pointing to this slug but are pointing to an older version
existing_providers = SAMLProviderConfig.objects.current_set().filter(
site_id=instance.site_id,
saml_configuration__slug=instance.slug
).exclude(saml_configuration_id=instance.id)
updated_count = 0
for provider_config in existing_providers:
provider_config.saml_configuration = instance
provider_config.save()
updated_count += 1
# .. custom_attribute_name: saml_config_signal.updated_count
# .. custom_attribute_description: The number of SAMLProviderConfig records updated to point to the new configuration.
set_custom_attribute('saml_config_signal.updated_count', updated_count)
except Exception as e: # pylint: disable=broad-except
# .. custom_attribute_name: saml_config_signal.error_message
# .. custom_attribute_description: Records any error message that occurs during SAML provider config updates.
set_custom_attribute('saml_config_signal.error_message', str(e))

View File

@@ -0,0 +1 @@
# This file marks the directory as a Python package.

View File

@@ -0,0 +1,117 @@
"""
Tests for SAML configuration signal handlers.
"""
import ddt
from unittest import mock
from unittest.mock import call
from django.test import TestCase, override_settings
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory
@ddt.ddt
class TestSAMLConfigurationSignalHandlers(TestCase):
"""
Test effects of SAML configuration signal handlers.
"""
def setUp(self):
self.saml_config = SAMLConfigurationFactory(
slug='test-config',
entity_id='https://test.example.com',
org_info_str='{"en-US": {"url": "http://test.com", "displayname": "Test", "name": "test"}}'
)
@ddt.data(
# Case 1: Tests behavior when SAML config signal handlers are disabled
# Verifies that basic attributes are set but no provider updates are attempted
{
'enabled': False,
'simulate_error': False,
'description': 'handlers disabled',
'expected_calls': [
call('saml_config_signal.enabled', False),
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.slug', 'test-config'),
],
'expected_call_count': 3,
},
# Case 2: Tests behavior when SAML config signal handlers are enabled
# Verifies that attributes are set and provider updates are attempted successfully
{
'enabled': True,
'simulate_error': False,
'description': 'handlers enabled',
'expected_calls': [
call('saml_config_signal.enabled', True),
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.slug', 'test-config'),
call('saml_config_signal.updated_count', 0),
],
'expected_call_count': 4,
},
# Case 3: Tests error handling when signal handlers are enabled but encounter an exception
# Verifies that error information is properly captured when provider updates fail
{
'enabled': True,
'simulate_error': True,
'description': 'handlers enabled with exception',
'expected_calls': [
call('saml_config_signal.enabled', True),
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.slug', 'test-config'),
],
'expected_call_count': 4, # includes error_message call
'error_message': 'Test error',
},
)
@ddt.unpack
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
def test_saml_config_signal_handlers(
self, mock_set_custom_attribute, enabled, simulate_error,
description, expected_calls, expected_call_count, error_message=None):
"""
Test SAML configuration signal handlers under different conditions.
"""
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=enabled):
if simulate_error:
# Simulate an exception in the provider config update logic
with mock.patch(
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
side_effect=Exception(error_message)
):
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()
else:
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()
expected_calls_with_id = []
for call_obj in expected_calls:
args = list(call_obj[1])
if args[1] == 'CONFIG_ID':
args[1] = self.saml_config.id
expected_calls_with_id.append(call(args[0], args[1]))
# Verify expected calls were made
mock_set_custom_attribute.assert_has_calls(expected_calls_with_id, any_order=False)
# Verify total call count
assert mock_set_custom_attribute.call_count == expected_call_count, (
f"Expected {expected_call_count} calls for {description}, "
f"got {mock_set_custom_attribute.call_count}"
)
# If error is expected, verify error message was logged
if error_message:
mock_set_custom_attribute.assert_any_call(
'saml_config_signal.error_message',
mock.ANY
)
error_calls = [
call for call in mock_set_custom_attribute.mock_calls
if call[1][0] == 'saml_config_signal.error_message'
]
assert error_message in error_calls[0][1][1], (
f"Expected '{error_message}' in error message for {description}, "
f"got: {error_calls[0][1][1]}"
)

View File

@@ -2,7 +2,7 @@
Togglable settings for Third Party Auth
"""
from edx_toggles.toggles import WaffleFlag
from edx_toggles.toggles import WaffleFlag, SettingToggle
THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth'
@@ -18,6 +18,26 @@ THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth'
APPLE_USER_MIGRATION_FLAG = WaffleFlag(f'{THIRD_PARTY_AUTH_NAMESPACE}.apple_user_migration', __name__)
# .. toggle_name: ENABLE_SAML_CONFIG_SIGNAL_HANDLERS
# .. toggle_implementation: SettingToggle
# .. toggle_default: False
# .. toggle_description: Controls whether SAML configuration signal handlers are active.
# When enabled (True), signal handlers will automatically update SAMLProviderConfig
# references when the associated SAMLConfiguration is updated.
# When disabled (False), SAMLProviderConfigs point to outdated SAMLConfiguration.
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2025-07-03
# .. toggle_target_removal_date: 2026-01-01
# .. toggle_warning: Disabling this toggle may result in SAMLProviderConfig instances
# pointing to outdated SAMLConfiguration records. Use the management command
# 'saml --fix-references' to fix outdated references.
ENABLE_SAML_CONFIG_SIGNAL_HANDLERS = SettingToggle(
"ENABLE_SAML_CONFIG_SIGNAL_HANDLERS",
default=False,
module_name=__name__
)
def is_apple_user_migration_enabled():
"""
Returns a boolean if Apple users migration is in process.