diff --git a/common/djangoapps/third_party_auth/apps.py b/common/djangoapps/third_party_auth/apps.py index 9745523592..7f6b82cfde 100644 --- a/common/djangoapps/third_party_auth/apps.py +++ b/common/djangoapps/third_party_auth/apps.py @@ -9,6 +9,9 @@ class ThirdPartyAuthConfig(AppConfig): # lint-amnesty, pylint: disable=missing- verbose_name = "Third-party authentication" def ready(self): + # Import signal handlers to register them + from .signals import handlers # noqa: F401 pylint: disable=unused-import + # To override the settings before loading social_django. if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH', False): self._enable_third_party_auth() diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index bb44dc5d02..e3891cefa0 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -8,6 +8,7 @@ import logging from django.core.management.base import BaseCommand, CommandError from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata +from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration class Command(BaseCommand): @@ -16,13 +17,41 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs") + parser.add_argument( + '--fix-references', + action='store_true', + help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions" + ) + parser.add_argument( + '--site-id', + type=int, + help='Only fix configurations for a specific site ID (to be used with --fix-references)' + ) + parser.add_argument( + '--dry-run', + action='store_true', + help='Show what would be changed, but do not make any changes.' + ) def handle(self, *args, **options): should_pull_saml_metadata = options.get('pull', False) + should_fix_references = options.get('fix_references', False) + dry_run = options.get('dry_run', False) - if not should_pull_saml_metadata: - raise CommandError("Command can only be used with '--pull' option.") + if not should_pull_saml_metadata and not should_fix_references: + raise CommandError("Command must be used with '--pull' or '--fix-references' option.") + if should_pull_saml_metadata: + self._handle_pull_metadata() + + if should_fix_references: + self._handle_fix_references(options, dry_run=dry_run) + + def _handle_pull_metadata(self): + """ + Handle the --pull option to fetch and update SAML metadata from external providers. + This sets up logging and calls the fetch_saml_metadata task. + """ log_handler = logging.StreamHandler(self.stdout) log_handler.setLevel(logging.DEBUG) log = logging.getLogger('common.djangoapps.third_party_auth.tasks') @@ -46,3 +75,46 @@ class Command(BaseCommand): failures="\n\n".join(failure_messages) ) ) + + def _handle_fix_references(self, options, dry_run=False): + """Handle the --fix-references option for fixing outdated SAML configuration references.""" + site_id = options.get('site_id') + updated_count = 0 + error_count = 0 + + # Filter by site if specified + provider_configs = SAMLProviderConfig.objects.current_set() + if site_id: + provider_configs = provider_configs.filter(site_id=site_id) + + for provider_config in provider_configs: + if provider_config.saml_configuration: + try: + current_config = SAMLConfiguration.current( + provider_config.site_id, + provider_config.saml_configuration.slug + ) + + if current_config and current_config.id != provider_config.saml_configuration_id: + self.stdout.write( + f"Provider '{provider_config.slug}' (site {provider_config.site_id}) " + f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})" + ) + + if not dry_run: + provider_config.saml_configuration = current_config + provider_config.save() + updated_count += 1 + + except Exception as e: # pylint: disable=broad-except + self.stderr.write( + f"Error processing provider '{provider_config.slug}': {e}" + ) + error_count += 1 + + style = self.style.SUCCESS + if dry_run: + msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered." + else: + msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered." + self.stdout.write(style(msg)) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index a60ab6ca9b..168d88ae3b 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -8,6 +8,8 @@ import os from io import StringIO from unittest import mock +from ddt import ddt, data, unpack +from django.contrib.sites.models import Site from django.core.management import call_command from django.core.management.base import CommandError from requests import exceptions @@ -16,6 +18,8 @@ from requests.models import Response from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory +from common.djangoapps.third_party_auth.models import SAMLProviderConfig + def mock_get(status_code=200): """ @@ -45,6 +49,7 @@ def mock_get(status_code=200): @skip_unless_lms +@ddt class TestSAMLCommand(CacheIsolationTestCase): """ Test django management command for fetching saml metadata. @@ -58,12 +63,17 @@ class TestSAMLCommand(CacheIsolationTestCase): super().setUp() self.stdout = StringIO() + self.site = Site.objects.get_current() # We are creating SAMLConfiguration instance here so that there is always at-least one # disabled saml configuration instance, this is done to verify that disabled configurations are # not processed. - SAMLConfigurationFactory.create(enabled=False, site__domain='testserver.fake', site__name='testserver.fake') - SAMLProviderConfigFactory.create( + self.saml_config = SAMLConfigurationFactory.create( + enabled=False, + site__domain='testserver.fake', + site__name='testserver.fake' + ) + self.provider_config = SAMLProviderConfigFactory.create( site__domain='testserver.fake', site__name='testserver.fake', slug='test-shib', @@ -72,6 +82,44 @@ class TestSAMLCommand(CacheIsolationTestCase): metadata_source='https://www.testshib.org/metadata/testshib-providers.xml', ) + def _setup_test_configs_for_fix_references(self): + """ + Helper method to create SAML configurations for fix-references tests. + + Returns tuple of (old_config, new_config, provider_config) + + Using a separate method keeps test data isolated. Including these configs in + setUp would create 3 provider configs for all tests, breaking tests that expect + specific provider counts or try to access non-existent test XML files. + """ + # Create a SAML config that will be outdated after the new config is created + old_config = SAMLConfigurationFactory.create( + enabled=False, + site=self.site, + slug='test-config', + entity_id='https://old.example.com' + ) + + # Create newer config with same slug + new_config = SAMLConfigurationFactory.create( + enabled=True, + site=self.site, + slug='test-config', + entity_id='https://updated.example.com' + ) + + # Create a provider config that references the old config for fix-references tests + test_provider_config = SAMLProviderConfigFactory.create( + site=self.site, + slug='test-provider', + name='Test Provider', + entity_id='https://test.provider/idp/shibboleth', + metadata_source='https://test.provider/metadata.xml', + saml_configuration=old_config + ) + + return old_config, new_config, test_provider_config + def __create_saml_configurations__(self, saml_config=None, saml_provider_config=None): """ Helper method to create SAMLConfiguration and AMLProviderConfig. @@ -101,11 +149,11 @@ class TestSAMLCommand(CacheIsolationTestCase): This test would fail with an error if ValueError is raised. """ # Call `saml` command without any argument so that it raises a CommandError - with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."): + with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."): call_command("saml") # Call `saml` command without any argument so that it raises a CommandError - with self.assertRaisesMessage(CommandError, "Command can only be used with '--pull' option."): + with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."): call_command("saml", pull=False) def test_no_saml_configuration(self): @@ -285,3 +333,60 @@ class TestSAMLCommand(CacheIsolationTestCase): with self.assertRaisesRegex(CommandError, "XMLSyntaxError:"): call_command("saml", pull=True, stdout=self.stdout) assert expected in self.stdout.getvalue() + + @data( + (True, '[DRY RUN]', 'should not update provider configs'), + (False, '', 'should create new provider config for new version') + ) + @unpack + def test_fix_references(self, dry_run, expected_output_marker, test_description): + """ + Test the --fix-references command with and without --dry-run option. + + Args: + dry_run (bool): Whether to run with --dry-run flag + expected_output_marker (str): Expected marker in output + test_description (str): Description of what the test should do + """ + old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references() + new_config_id = new_config.id + original_config_id = old_config.id + + out = StringIO() + if dry_run: + call_command('saml', '--fix-references', '--dry-run', stdout=out) + else: + call_command('saml', '--fix-references', stdout=out) + + output = out.getvalue() + + self.assertIn('test-provider', output) + if expected_output_marker: + self.assertIn(expected_output_marker, output) + + test_provider_config.refresh_from_db() + + if dry_run: + # For dry run, ensure the provider config was NOT updated + self.assertEqual( + test_provider_config.saml_configuration_id, + original_config_id, + "Provider config should not be updated in dry run mode" + ) + else: + # For actual run, check that a new provider config was created + new_provider = SAMLProviderConfig.objects.filter( + site=self.site, + slug='test-provider', + saml_configuration_id=new_config_id + ).exclude(id=test_provider_config.id).first() + + self.assertIsNotNone(new_provider, "New provider config should be created") + self.assertEqual(new_provider.saml_configuration_id, new_config_id) + + # Original provider config should still reference the old config + self.assertEqual( + test_provider_config.saml_configuration_id, + original_config_id, + "Original provider config should still reference old config" + ) diff --git a/common/djangoapps/third_party_auth/signals/__init__.py b/common/djangoapps/third_party_auth/signals/__init__.py new file mode 100644 index 0000000000..cf255a847f --- /dev/null +++ b/common/djangoapps/third_party_auth/signals/__init__.py @@ -0,0 +1 @@ +# Signal handlers for third_party_auth app diff --git a/common/djangoapps/third_party_auth/signals/handlers.py b/common/djangoapps/third_party_auth/signals/handlers.py new file mode 100644 index 0000000000..dc83a32162 --- /dev/null +++ b/common/djangoapps/third_party_auth/signals/handlers.py @@ -0,0 +1,57 @@ +""" +Signal handlers for third_party_auth app. +""" + +from django.db.models.signals import post_save +from django.dispatch import receiver +from edx_django_utils.monitoring import set_custom_attribute + +from common.djangoapps.third_party_auth.models import SAMLConfiguration, SAMLProviderConfig +from common.djangoapps.third_party_auth.toggles import ENABLE_SAML_CONFIG_SIGNAL_HANDLERS + + +@receiver(post_save, sender=SAMLConfiguration) +def update_saml_provider_configs_on_configuration_change(sender, instance, created, **kwargs): + """ + Signal handler to create a new SAMLProviderConfig when SAMLConfiguration is updated. + + When a SAMLConfiguration is updated and a new version is created, this handler + generates a corresponding SAMLProviderConfig that references the latest + configuration version, ensuring all providers remain aligned with the most + current settings. + """ + # .. custom_attribute_name: saml_config_signal.enabled + # .. custom_attribute_description: Tracks whether the SAML config signal handler is enabled. + set_custom_attribute('saml_config_signal.enabled', ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled()) + + # .. custom_attribute_name: saml_config_signal.new_config_id + # .. custom_attribute_description: Records the ID of the new SAML configuration instance. + set_custom_attribute('saml_config_signal.new_config_id', instance.id) + + # .. custom_attribute_name: saml_config_signal.slug + # .. custom_attribute_description: Records the slug of the SAML configuration instance. + set_custom_attribute('saml_config_signal.slug', instance.slug) + + if ENABLE_SAML_CONFIG_SIGNAL_HANDLERS.is_enabled(): + try: + # Find all existing SAMLProviderConfig instances (current_set) that should be + # pointing to this slug but are pointing to an older version + existing_providers = SAMLProviderConfig.objects.current_set().filter( + site_id=instance.site_id, + saml_configuration__slug=instance.slug + ).exclude(saml_configuration_id=instance.id) + + updated_count = 0 + for provider_config in existing_providers: + provider_config.saml_configuration = instance + provider_config.save() + updated_count += 1 + + # .. custom_attribute_name: saml_config_signal.updated_count + # .. custom_attribute_description: The number of SAMLProviderConfig records updated to point to the new configuration. + set_custom_attribute('saml_config_signal.updated_count', updated_count) + + except Exception as e: # pylint: disable=broad-except + # .. custom_attribute_name: saml_config_signal.error_message + # .. custom_attribute_description: Records any error message that occurs during SAML provider config updates. + set_custom_attribute('saml_config_signal.error_message', str(e)) diff --git a/common/djangoapps/third_party_auth/signals/tests/__init__.py b/common/djangoapps/third_party_auth/signals/tests/__init__.py new file mode 100644 index 0000000000..0145fcaed0 --- /dev/null +++ b/common/djangoapps/third_party_auth/signals/tests/__init__.py @@ -0,0 +1 @@ +# This file marks the directory as a Python package. diff --git a/common/djangoapps/third_party_auth/signals/tests/test_handlers.py b/common/djangoapps/third_party_auth/signals/tests/test_handlers.py new file mode 100644 index 0000000000..8c534ce06c --- /dev/null +++ b/common/djangoapps/third_party_auth/signals/tests/test_handlers.py @@ -0,0 +1,117 @@ +""" +Tests for SAML configuration signal handlers. +""" + +import ddt +from unittest import mock +from unittest.mock import call +from django.test import TestCase, override_settings +from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory + + +@ddt.ddt +class TestSAMLConfigurationSignalHandlers(TestCase): + """ + Test effects of SAML configuration signal handlers. + """ + def setUp(self): + self.saml_config = SAMLConfigurationFactory( + slug='test-config', + entity_id='https://test.example.com', + org_info_str='{"en-US": {"url": "http://test.com", "displayname": "Test", "name": "test"}}' + ) + + @ddt.data( + # Case 1: Tests behavior when SAML config signal handlers are disabled + # Verifies that basic attributes are set but no provider updates are attempted + { + 'enabled': False, + 'simulate_error': False, + 'description': 'handlers disabled', + 'expected_calls': [ + call('saml_config_signal.enabled', False), + call('saml_config_signal.new_config_id', 'CONFIG_ID'), + call('saml_config_signal.slug', 'test-config'), + ], + 'expected_call_count': 3, + }, + # Case 2: Tests behavior when SAML config signal handlers are enabled + # Verifies that attributes are set and provider updates are attempted successfully + { + 'enabled': True, + 'simulate_error': False, + 'description': 'handlers enabled', + 'expected_calls': [ + call('saml_config_signal.enabled', True), + call('saml_config_signal.new_config_id', 'CONFIG_ID'), + call('saml_config_signal.slug', 'test-config'), + call('saml_config_signal.updated_count', 0), + ], + 'expected_call_count': 4, + }, + # Case 3: Tests error handling when signal handlers are enabled but encounter an exception + # Verifies that error information is properly captured when provider updates fail + { + 'enabled': True, + 'simulate_error': True, + 'description': 'handlers enabled with exception', + 'expected_calls': [ + call('saml_config_signal.enabled', True), + call('saml_config_signal.new_config_id', 'CONFIG_ID'), + call('saml_config_signal.slug', 'test-config'), + ], + 'expected_call_count': 4, # includes error_message call + 'error_message': 'Test error', + }, + ) + @ddt.unpack + @mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute') + def test_saml_config_signal_handlers( + self, mock_set_custom_attribute, enabled, simulate_error, + description, expected_calls, expected_call_count, error_message=None): + """ + Test SAML configuration signal handlers under different conditions. + """ + with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=enabled): + if simulate_error: + # Simulate an exception in the provider config update logic + with mock.patch( + 'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set', + side_effect=Exception(error_message) + ): + self.saml_config.entity_id = 'https://updated.example.com' + self.saml_config.save() + else: + self.saml_config.entity_id = 'https://updated.example.com' + self.saml_config.save() + + expected_calls_with_id = [] + for call_obj in expected_calls: + args = list(call_obj[1]) + if args[1] == 'CONFIG_ID': + args[1] = self.saml_config.id + expected_calls_with_id.append(call(args[0], args[1])) + + # Verify expected calls were made + mock_set_custom_attribute.assert_has_calls(expected_calls_with_id, any_order=False) + + # Verify total call count + assert mock_set_custom_attribute.call_count == expected_call_count, ( + f"Expected {expected_call_count} calls for {description}, " + f"got {mock_set_custom_attribute.call_count}" + ) + + # If error is expected, verify error message was logged + if error_message: + mock_set_custom_attribute.assert_any_call( + 'saml_config_signal.error_message', + mock.ANY + ) + error_calls = [ + call for call in mock_set_custom_attribute.mock_calls + if call[1][0] == 'saml_config_signal.error_message' + ] + assert error_message in error_calls[0][1][1], ( + f"Expected '{error_message}' in error message for {description}, " + f"got: {error_calls[0][1][1]}" + ) diff --git a/common/djangoapps/third_party_auth/toggles.py b/common/djangoapps/third_party_auth/toggles.py index 53c4edd295..d8f77f0b1c 100644 --- a/common/djangoapps/third_party_auth/toggles.py +++ b/common/djangoapps/third_party_auth/toggles.py @@ -2,7 +2,7 @@ Togglable settings for Third Party Auth """ -from edx_toggles.toggles import WaffleFlag +from edx_toggles.toggles import WaffleFlag, SettingToggle THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth' @@ -18,6 +18,26 @@ THIRD_PARTY_AUTH_NAMESPACE = 'thirdpartyauth' APPLE_USER_MIGRATION_FLAG = WaffleFlag(f'{THIRD_PARTY_AUTH_NAMESPACE}.apple_user_migration', __name__) +# .. toggle_name: ENABLE_SAML_CONFIG_SIGNAL_HANDLERS +# .. toggle_implementation: SettingToggle +# .. toggle_default: False +# .. toggle_description: Controls whether SAML configuration signal handlers are active. +# When enabled (True), signal handlers will automatically update SAMLProviderConfig +# references when the associated SAMLConfiguration is updated. +# When disabled (False), SAMLProviderConfigs point to outdated SAMLConfiguration. +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2025-07-03 +# .. toggle_target_removal_date: 2026-01-01 +# .. toggle_warning: Disabling this toggle may result in SAMLProviderConfig instances +# pointing to outdated SAMLConfiguration records. Use the management command +# 'saml --fix-references' to fix outdated references. +ENABLE_SAML_CONFIG_SIGNAL_HANDLERS = SettingToggle( + "ENABLE_SAML_CONFIG_SIGNAL_HANDLERS", + default=False, + module_name=__name__ +) + + def is_apple_user_migration_enabled(): """ Returns a boolean if Apple users migration is in process.