fix: SAML provider config references to use current SAML configuration versions (#36954)
Introduces temporary rollout toggle ENABLE_SAML_CONFIG_SIGNAL_HANDLERS which controls whether SAML configuration signal handlers are active. When enabled (True), signal handlers will automatically update SAMLProviderConfig references when the associated SAMLConfiguration is updated. When disabled (False), SAMLProviderConfigs point to outdated SAMLConfiguration. Warning: Disabling this toggle may result in SAMLProviderConfig instances pointing to outdated SAMLConfiguration records. Use the management command 'saml --fix-references' to fix outdated references.
This commit is contained in:
@@ -9,6 +9,9 @@ class ThirdPartyAuthConfig(AppConfig): # lint-amnesty, pylint: disable=missing-
|
||||
verbose_name = "Third-party authentication"
|
||||
|
||||
def ready(self):
|
||||
# Import signal handlers to register them
|
||||
from .signals import handlers # noqa: F401 pylint: disable=unused-import
|
||||
|
||||
# To override the settings before loading social_django.
|
||||
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH', False):
|
||||
self._enable_third_party_auth()
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
|
||||
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -16,13 +17,41 @@ class Command(BaseCommand):
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
|
||||
parser.add_argument(
|
||||
'--fix-references',
|
||||
action='store_true',
|
||||
help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--site-id',
|
||||
type=int,
|
||||
help='Only fix configurations for a specific site ID (to be used with --fix-references)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be changed, but do not make any changes.'
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
should_pull_saml_metadata = options.get('pull', False)
|
||||
should_fix_references = options.get('fix_references', False)
|
||||
dry_run = options.get('dry_run', False)
|
||||
|
||||
if not should_pull_saml_metadata:
|
||||
raise CommandError("Command can only be used with '--pull' option.")
|
||||
if not should_pull_saml_metadata and not should_fix_references:
|
||||
raise CommandError("Command must be used with '--pull' or '--fix-references' option.")
|
||||
|
||||
if should_pull_saml_metadata:
|
||||
self._handle_pull_metadata()
|
||||
|
||||
if should_fix_references:
|
||||
self._handle_fix_references(options, dry_run=dry_run)
|
||||
|
||||
def _handle_pull_metadata(self):
|
||||
"""
|
||||
Handle the --pull option to fetch and update SAML metadata from external providers.
|
||||
This sets up logging and calls the fetch_saml_metadata task.
|
||||
"""
|
||||
log_handler = logging.StreamHandler(self.stdout)
|
||||
log_handler.setLevel(logging.DEBUG)
|
||||
log = logging.getLogger('common.djangoapps.third_party_auth.tasks')
|
||||
@@ -46,3 +75,46 @@ class Command(BaseCommand):
|
||||
failures="\n\n".join(failure_messages)
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_fix_references(self, options, dry_run=False):
|
||||
"""Handle the --fix-references option for fixing outdated SAML configuration references."""
|
||||
site_id = options.get('site_id')
|
||||
updated_count = 0
|
||||
error_count = 0
|
||||
|
||||
# Filter by site if specified
|
||||
provider_configs = SAMLProviderConfig.objects.current_set()
|
||||
if site_id:
|
||||
provider_configs = provider_configs.filter(site_id=site_id)
|
||||
|
||||
for provider_config in provider_configs:
|
||||
if provider_config.saml_configuration:
|
||||
try:
|
||||
current_config = SAMLConfiguration.current(
|
||||
provider_config.site_id,
|
||||
provider_config.saml_configuration.slug
|
||||
)
|
||||
|
||||
if current_config and current_config.id != provider_config.saml_configuration_id:
|
||||
self.stdout.write(
|
||||
f"Provider '{provider_config.slug}' (site {provider_config.site_id}) "
|
||||
f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})"
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
provider_config.saml_configuration = current_config
|
||||
provider_config.save()
|
||||
updated_count += 1
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
self.stderr.write(
|
||||
f"Error processing provider '{provider_config.slug}': {e}"
|
||||
)
|
||||
error_count += 1
|
||||
|
||||
style = self.style.SUCCESS
|
||||
if dry_run:
|
||||
msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered."
|
||||
else:
|
||||
msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered."
|
||||
self.stdout.write(style(msg))
|
||||
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
from io import StringIO
|
||||
|
||||
from unittest import mock
|
||||
from ddt import ddt, data, unpack
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from requests import exceptions
|
||||
@@ -16,6 +18,8 @@ from requests.models import Response
|
||||
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms
|
||||
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
|
||||
|
||||
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
|
||||
|
||||
|
||||
def mock_get(status_code=200):
|
||||
"""
|
||||
@@ -45,6 +49,7 @@ def mock_get(status_code=200):
|
||||
|
||||
|
||||
@skip_unless_lms
|
||||
@ddt
|
||||
class TestSAMLCommand(CacheIsolationTestCase):
|
||||
"""
|
||||
Test django management command for fetching saml metadata.
|
||||
@@ -58,12 +63,17 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
super().setUp()
|
||||
|
||||
self.stdout = StringIO()
|
||||
self.site = Site.objects.get_current()
|
||||
|
||||
# We are creating SAMLConfiguration instance here so that there is always at-least one
|
||||
# disabled saml configuration instance, this is done to verify that disabled configurations are
|
||||
# not processed.
|
||||
SAMLConfigurationFactory.create(enabled=False, site__domain='testserver.fake', site__name='testserver.fake')
|
||||
SAMLProviderConfigFactory.create(
|
||||
self.saml_config = SAMLConfigurationFactory.create(
|
||||
enabled=False,
|
||||
site__domain='testserver.fake',
|
||||
site__name='testserver.fake'
|
||||
)
|
||||
self.provider_config = SAMLProviderConfigFactory.create(
|
||||
site__domain='testserver.fake',
|
||||
site__name='testserver.fake',
|
||||
slug='test-shib',
|
||||
@@ -72,6 +82,44 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
|
||||
)
|
||||
|
||||
def _setup_test_configs_for_fix_references(self):
|
||||
"""
|
||||
Helper method to create SAML configurations for fix-references tests.
|
||||
|
||||
Returns tuple of (old_config, new_config, provider_config)
|
||||
|
||||
Using a separate method keeps test data isolated. Including these configs in
|
||||
setUp would create 3 provider configs for all tests, breaking tests that expect
|
||||
specific provider counts or try to access non-existent test XML files.
|
||||
"""
|
||||
# Create a SAML config that will be outdated after the new config is created
|
||||
old_config = SAMLConfigurationFactory.create(
|
||||
enabled=False,
|
||||
site=self.site,
|
||||
slug='test-config',
|
||||
entity_id='https://old.example.com'
|
||||
)
|
||||
|
||||
# Create newer config with same slug
|
||||
new_config = SAMLConfigurationFactory.create(
|
||||
enabled=True,
|
||||
site=self.site,
|
||||
slug='test-config',
|
||||
entity_id='https://updated.example.com'
|
||||
)
|
||||
|
||||
# Create a provider config that references the old config for fix-references tests
|
||||
test_provider_config = SAMLProviderConfigFactory.create(
|
||||
site=self.site,
|
||||
slug='test-provider',
|
||||
name='Test Provider',
|
||||
entity_id='https://test.provider/idp/shibboleth',
|
||||
metadata_source='https://test.provider/metadata.xml',
|
||||
saml_configuration=old_config
|
||||
)
|
||||
|
||||
return old_config, new_config, test_provider_config
|
||||
|
||||
def __create_saml_configurations__(self, saml_config=None, saml_provider_config=None):
|
||||
"""
|
||||
Helper method to create SAMLConfiguration and AMLProviderConfig.
|
||||
@@ -101,11 +149,11 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
This test would fail with an error if ValueError is raised.
|
||||
"""
|
||||
# Call `saml` command without any argument so that it raises a CommandError
|
||||
with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."):
|
||||
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
|
||||
call_command("saml")
|
||||
|
||||
# Call `saml` command without any argument so that it raises a CommandError
|
||||
with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."):
|
||||
with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."):
|
||||
call_command("saml", pull=False)
|
||||
|
||||
def test_no_saml_configuration(self):
|
||||
@@ -285,3 +333,60 @@ class TestSAMLCommand(CacheIsolationTestCase):
|
||||
with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"):
|
||||
call_command("saml", pull=True, stdout=self.stdout)
|
||||
assert expected in self.stdout.getvalue()
|
||||
|
||||
@data(
|
||||
(True, '[DRY RUN]', 'should not update provider configs'),
|
||||
(False, '', 'should create new provider config for new version')
|
||||
)
|
||||
@unpack
|
||||
def test_fix_references(self, dry_run, expected_output_marker, test_description):
|
||||
"""
|
||||
Test the --fix-references command with and without --dry-run option.
|
||||
|
||||
Args:
|
||||
dry_run (bool): Whether to run with --dry-run flag
|
||||
expected_output_marker (str): Expected marker in output
|
||||
test_description (str): Description of what the test should do
|
||||
"""
|
||||
old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references()
|
||||
new_config_id = new_config.id
|
||||
original_config_id = old_config.id
|
||||
|
||||
out = StringIO()
|
||||
if dry_run:
|
||||
call_command('saml', '--fix-references', '--dry-run', stdout=out)
|
||||
else:
|
||||
call_command('saml', '--fix-references', stdout=out)
|
||||
|
||||
output = out.getvalue()
|
||||
|
||||
self.assertIn('test-provider', output)
|
||||
if expected_output_marker:
|
||||
self.assertIn(expected_output_marker, output)
|
||||
|
||||
test_provider_config.refresh_from_db()
|
||||
|
||||
if dry_run:
|
||||
# For dry run, ensure the provider config was NOT updated
|
||||
self.assertEqual(
|
||||
test_provider_config.saml_configuration_id,
|
||||
original_config_id,
|
||||
"Provider config should not be updated in dry run mode"
|
||||
)
|
||||
else:
|
||||
# For actual run, check that a new provider config was created
|
||||
new_provider = SAMLProviderConfig.objects.filter(
|
||||
site=self.site,
|
||||
slug='test-provider',
|
||||
saml_configuration_id=new_config_id
|
||||
).exclude(id=test_provider_config.id).first()
|
||||
|
||||
self.assertIsNotNone(new_provider, "New provider config should be created")
|
||||
self.assertEqual(new_provider.saml_configuration_id, new_config_id)
|
||||
|
||||
# Original provider config should still reference the old config
|
||||
self.assertEqual(
|
||||
test_provider_config.saml_configuration_id,
|
||||
original_config_id,
|
||||
"Original provider config should still reference old config"
|
||||
)
|
||||
|
||||
1
common/djangoapps/third_party_auth/signals/__init__.py
Normal file
1
common/djangoapps/third_party_auth/signals/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Signal handlers for third_party_auth app
|
||||
57
common/djangoapps/third_party_auth/signals/handlers.py
Normal file
57
common/djangoapps/third_party_auth/signals/handlers.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Signal handlers for third_party_auth app.
|
||||
"""
|
||||
|
||||
from django.db.models.signals import post_save
|
||||
from django.dispatch import receiver
|
||||
from edx_django_utils.monitoring import set_custom_attribute
|
||||
|
||||
from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig
|
||||
from common.djangoapps.third_party_auth.toggles import ENABLE_SAML_CONFIG_SIGNAL_HANDLERS
|
||||
|
||||
|
||||
@receiver(post_save, sender=SAMLConfiguration)
|
||||
def update_saml_provider_configs_on_configuration_change(sender, instance, created, **kwargs):
|
||||
"""
|
||||
Signal handler to create a new SAMLProviderConfig when SAMLConfiguration is updated.
|
||||
|
||||
When a SAMLConfiguration is updated and a new version is created, this handler
|
||||
generates a corresponding SAMLProviderConfig that references the latest
|
||||
configuration version, ensuring all providers remain aligned with the most
|
||||
current settings.
|
||||
"""
|
||||
# .. custom_attribute_name: saml_config_signal.enabled
|
||||
# .. custom_attribute_description: Tracks whether the SAML config signal handler is enabled.
|
||||
set_custom_attribute('saml_config_signal.enabled', ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled())
|
||||
|
||||
# .. custom_attribute_name: saml_config_signal.new_config_id
|
||||
# .. custom_attribute_description: Records the ID of the new SAML configuration instance.
|
||||
set_custom_attribute('saml_config_signal.new_config_id', instance.id)
|
||||
|
||||
# .. custom_attribute_name: saml_config_signal.slug
|
||||
# .. custom_attribute_description: Records the slug of the SAML configuration instance.
|
||||
set_custom_attribute('saml_config_signal.slug', instance.slug)
|
||||
|
||||
if ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled():
|
||||
try:
|
||||
# Find all existing SAMLProviderConfig instances (current_set) that should be
|
||||
# pointing to this slug but are pointing to an older version
|
||||
existing_providers = SAMLProviderConfig.objects.current_set().filter(
|
||||
site_id=instance.site_id,
|
||||
saml_configuration__slug=instance.slug
|
||||
).exclude(saml_configuration_id=instance.id)
|
||||
|
||||
updated_count = 0
|
||||
for provider_config in existing_providers:
|
||||
provider_config.saml_configuration = instance
|
||||
provider_config.save()
|
||||
updated_count += 1
|
||||
|
||||
# .. custom_attribute_name: saml_config_signal.updated_count
|
||||
# .. custom_attribute_description: The number of SAMLProviderConfig records updated to point to the new configuration.
|
||||
set_custom_attribute('saml_config_signal.updated_count', updated_count)
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# .. custom_attribute_name: saml_config_signal.error_message
|
||||
# .. custom_attribute_description: Records any error message that occurs during SAML provider config updates.
|
||||
set_custom_attribute('saml_config_signal.error_message', str(e))
|
||||
@@ -0,0 +1 @@
|
||||
# This file marks the directory as a Python package.
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Tests for SAML configuration signal handlers.
|
||||
"""
|
||||
|
||||
import ddt
|
||||
from unittest import mock
|
||||
from unittest.mock import call
|
||||
from django.test import TestCase, override_settings
|
||||
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
class TestSAMLConfigurationSignalHandlers(TestCase):
|
||||
"""
|
||||
Test effects of SAML configuration signal handlers.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.saml_config = SAMLConfigurationFactory(
|
||||
slug='test-config',
|
||||
entity_id='https://test.example.com',
|
||||
org_info_str='{"en-US": {"url": "http://test.com", "displayname": "Test", "name": "test"}}'
|
||||
)
|
||||
|
||||
@ddt.data(
|
||||
# Case 1: Tests behavior when SAML config signal handlers are disabled
|
||||
# Verifies that basic attributes are set but no provider updates are attempted
|
||||
{
|
||||
'enabled': False,
|
||||
'simulate_error': False,
|
||||
'description': 'handlers disabled',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', False),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
],
|
||||
'expected_call_count': 3,
|
||||
},
|
||||
# Case 2: Tests behavior when SAML config signal handlers are enabled
|
||||
# Verifies that attributes are set and provider updates are attempted successfully
|
||||
{
|
||||
'enabled': True,
|
||||
'simulate_error': False,
|
||||
'description': 'handlers enabled',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
call('saml_config_signal.updated_count', 0),
|
||||
],
|
||||
'expected_call_count': 4,
|
||||
},
|
||||
# Case 3: Tests error handling when signal handlers are enabled but encounter an exception
|
||||
# Verifies that error information is properly captured when provider updates fail
|
||||
{
|
||||
'enabled': True,
|
||||
'simulate_error': True,
|
||||
'description': 'handlers enabled with exception',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
],
|
||||
'expected_call_count': 4, # includes error_message call
|
||||
'error_message': 'Test error',
|
||||
},
|
||||
)
|
||||
@ddt.unpack
|
||||
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
|
||||
def test_saml_config_signal_handlers(
|
||||
self, mock_set_custom_attribute, enabled, simulate_error,
|
||||
description, expected_calls, expected_call_count, error_message=None):
|
||||
"""
|
||||
Test SAML configuration signal handlers under different conditions.
|
||||
"""
|
||||
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=enabled):
|
||||
if simulate_error:
|
||||
# Simulate an exception in the provider config update logic
|
||||
with mock.patch(
|
||||
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
|
||||
side_effect=Exception(error_message)
|
||||
):
|
||||
self.saml_config.entity_id = 'https://updated.example.com'
|
||||
self.saml_config.save()
|
||||
else:
|
||||
self.saml_config.entity_id = 'https://updated.example.com'
|
||||
self.saml_config.save()
|
||||
|
||||
expected_calls_with_id = []
|
||||
for call_obj in expected_calls:
|
||||
args = list(call_obj[1])
|
||||
if args[1] == 'CONFIG_ID':
|
||||
args[1] = self.saml_config.id
|
||||
expected_calls_with_id.append(call(args[0], args[1]))
|
||||
|
||||
# Verify expected calls were made
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls_with_id, any_order=False)
|
||||
|
||||
# Verify total call count
|
||||
assert mock_set_custom_attribute.call_count == expected_call_count, (
|
||||
f"Expected {expected_call_count} calls for {description}, "
|
||||
f"got {mock_set_custom_attribute.call_count}"
|
||||
)
|
||||
|
||||
# If error is expected, verify error message was logged
|
||||
if error_message:
|
||||
mock_set_custom_attribute.assert_any_call(
|
||||
'saml_config_signal.error_message',
|
||||
mock.ANY
|
||||
)
|
||||
error_calls = [
|
||||
call for call in mock_set_custom_attribute.mock_calls
|
||||
if call[1][0] == 'saml_config_signal.error_message'
|
||||
]
|
||||
assert error_message in error_calls[0][1][1], (
|
||||
f"Expected '{error_message}' in error message for {description}, "
|
||||
f"got: {error_calls[0][1][1]}"
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
Togglable settings for Third Party Auth
|
||||
"""
|
||||
|
||||
from edx_toggles.toggles import WaffleFlag
|
||||
from edx_toggles.toggles import WaffleFlag, SettingToggle
|
||||
|
||||
THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth'
|
||||
|
||||
@@ -18,6 +18,26 @@ THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth'
|
||||
APPLE_USER_MIGRATION_FLAG = WaffleFlag(f'{THIRD_PARTY_AUTH_NAMESPACE}.apple_user_migration', __name__)
|
||||
|
||||
|
||||
# .. toggle_name: ENABLE_SAML_CONFIG_SIGNAL_HANDLERS
|
||||
# .. toggle_implementation: SettingToggle
|
||||
# .. toggle_default: False
|
||||
# .. toggle_description: Controls whether SAML configuration signal handlers are active.
|
||||
# When enabled (True), signal handlers will automatically update SAMLProviderConfig
|
||||
# references when the associated SAMLConfiguration is updated.
|
||||
# When disabled (False), SAMLProviderConfigs point to outdated SAMLConfiguration.
|
||||
# .. toggle_use_cases: temporary
|
||||
# .. toggle_creation_date: 2025-07-03
|
||||
# .. toggle_target_removal_date: 2026-01-01
|
||||
# .. toggle_warning: Disabling this toggle may result in SAMLProviderConfig instances
|
||||
# pointing to outdated SAMLConfiguration records. Use the management command
|
||||
# 'saml --fix-references' to fix outdated references.
|
||||
ENABLE_SAML_CONFIG_SIGNAL_HANDLERS = SettingToggle(
|
||||
"ENABLE_SAML_CONFIG_SIGNAL_HANDLERS",
|
||||
default=False,
|
||||
module_name=__name__
|
||||
)
|
||||
|
||||
|
||||
def is_apple_user_migration_enabled():
|
||||
"""
|
||||
Returns a boolean if Apple users migration is in process.
|
||||
|
||||
Reference in New Issue
Block a user