fix: Update SAMLProviderConfig for site-specific configurations (#37294)
Fixes minor bugs in new SAMLProviderConfig signal handlers.
This commit is contained in:
@@ -37,9 +37,9 @@ def update_saml_provider_configs_on_configuration_change(sender, instance, creat
|
||||
# Find all existing SAMLProviderConfig instances (current_set) that should be
|
||||
# pointing to this slug but are pointing to an older version
|
||||
existing_providers = SAMLProviderConfig.objects.current_set().filter(
|
||||
site_id=instance.site_id,
|
||||
saml_configuration__site_id=instance.site_id,
|
||||
saml_configuration__slug=instance.slug
|
||||
).exclude(saml_configuration_id=instance.id)
|
||||
).exclude(saml_configuration_id=instance.id).exclude(saml_configuration_id__isnull=True)
|
||||
|
||||
updated_count = 0
|
||||
for provider_config in existing_providers:
|
||||
|
||||
@@ -6,7 +6,9 @@ import ddt
|
||||
from unittest import mock
|
||||
from unittest.mock import call
|
||||
from django.test import TestCase, override_settings
|
||||
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory
|
||||
from django.contrib.sites.models import Site
|
||||
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
|
||||
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
@@ -21,97 +23,181 @@ class TestSAMLConfigurationSignalHandlers(TestCase):
|
||||
org_info_str='{"en-US": {"url": "http://test.com", "displayname": "Test", "name": "test"}}'
|
||||
)
|
||||
|
||||
@ddt.data(
|
||||
# Case 1: Tests behavior when SAML config signal handlers are disabled
|
||||
# Verifies that basic attributes are set but no provider updates are attempted
|
||||
{
|
||||
'enabled': False,
|
||||
'simulate_error': False,
|
||||
'description': 'handlers disabled',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', False),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
],
|
||||
'expected_call_count': 3,
|
||||
},
|
||||
# Case 2: Tests behavior when SAML config signal handlers are enabled
|
||||
# Verifies that attributes are set and provider updates are attempted successfully
|
||||
{
|
||||
'enabled': True,
|
||||
'simulate_error': False,
|
||||
'description': 'handlers enabled',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
call('saml_config_signal.updated_count', 0),
|
||||
],
|
||||
'expected_call_count': 4,
|
||||
},
|
||||
# Case 3: Tests error handling when signal handlers are enabled but encounter an exception
|
||||
# Verifies that error information is properly captured when provider updates fail
|
||||
{
|
||||
'enabled': True,
|
||||
'simulate_error': True,
|
||||
'description': 'handlers enabled with exception',
|
||||
'expected_calls': [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
],
|
||||
'expected_call_count': 4, # includes error_message call
|
||||
'error_message': 'Test error',
|
||||
},
|
||||
)
|
||||
@ddt.unpack
|
||||
self.site1 = Site.objects.get_or_create(domain='test-site1.com', name='Site 1')[0]
|
||||
self.site2 = Site.objects.get_or_create(domain='test-site2.com', name='Site 2')[0]
|
||||
|
||||
# Existing SAML config used by provider update tests
|
||||
self.existing_saml_config = SAMLConfigurationFactory(
|
||||
site=self.site1,
|
||||
slug='slug',
|
||||
entity_id='https://existing.example.com'
|
||||
)
|
||||
|
||||
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
|
||||
def test_saml_config_signal_handlers(
|
||||
self, mock_set_custom_attribute, enabled, simulate_error,
|
||||
description, expected_calls, expected_call_count, error_message=None):
|
||||
def test_saml_config_signal_handlers_disabled(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test SAML configuration signal handlers under different conditions.
|
||||
Test behavior when SAML config signal handlers are disabled.
|
||||
|
||||
Verifies that basic attributes are set but no provider updates are attempted.
|
||||
"""
|
||||
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=enabled):
|
||||
if simulate_error:
|
||||
# Simulate an exception in the provider config update logic
|
||||
with mock.patch(
|
||||
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
|
||||
side_effect=Exception(error_message)
|
||||
):
|
||||
self.saml_config.entity_id = 'https://updated.example.com'
|
||||
self.saml_config.save()
|
||||
else:
|
||||
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=False):
|
||||
self.saml_config.entity_id = 'https://updated.example.com'
|
||||
self.saml_config.save()
|
||||
|
||||
expected_calls = [
|
||||
call('saml_config_signal.enabled', False),
|
||||
call('saml_config_signal.new_config_id', self.saml_config.id),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
]
|
||||
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
assert mock_set_custom_attribute.call_count == 3
|
||||
|
||||
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
|
||||
def test_saml_config_signal_handlers_with_error(self, mock_set_custom_attribute):
|
||||
"""
|
||||
Test error handling when signal handlers encounter an exception.
|
||||
|
||||
Verifies that error information is properly captured when provider updates fail.
|
||||
"""
|
||||
error_message = "Test error"
|
||||
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True):
|
||||
# Simulate an exception in the provider config update logic
|
||||
with mock.patch(
|
||||
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
|
||||
side_effect=Exception(error_message)
|
||||
):
|
||||
self.saml_config.entity_id = 'https://updated.example.com'
|
||||
self.saml_config.save()
|
||||
|
||||
expected_calls_with_id = []
|
||||
for call_obj in expected_calls:
|
||||
args = list(call_obj[1])
|
||||
if args[1] == 'CONFIG_ID':
|
||||
args[1] = self.saml_config.id
|
||||
expected_calls_with_id.append(call(args[0], args[1]))
|
||||
expected_calls = [
|
||||
call('saml_config_signal.enabled', True),
|
||||
call('saml_config_signal.new_config_id', self.saml_config.id),
|
||||
call('saml_config_signal.slug', 'test-config'),
|
||||
]
|
||||
|
||||
# Verify expected calls were made
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls_with_id, any_order=False)
|
||||
mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
|
||||
assert mock_set_custom_attribute.call_count == 4
|
||||
|
||||
# Verify total call count
|
||||
assert mock_set_custom_attribute.call_count == expected_call_count, (
|
||||
f"Expected {expected_call_count} calls for {description}, "
|
||||
f"got {mock_set_custom_attribute.call_count}"
|
||||
# Verify error message was logged
|
||||
mock_set_custom_attribute.assert_any_call(
|
||||
'saml_config_signal.error_message',
|
||||
mock.ANY
|
||||
)
|
||||
error_calls = [
|
||||
call for call in mock_set_custom_attribute.mock_calls
|
||||
if call[1][0] == 'saml_config_signal.error_message'
|
||||
]
|
||||
assert error_message in error_calls[0][1][1], (
|
||||
f"Expected '{error_message}' in error message, "
|
||||
f"got: {error_calls[0][1][1]}"
|
||||
)
|
||||
|
||||
# If error is expected, verify error message was logged
|
||||
if error_message:
|
||||
mock_set_custom_attribute.assert_any_call(
|
||||
'saml_config_signal.error_message',
|
||||
mock.ANY
|
||||
)
|
||||
error_calls = [
|
||||
call for call in mock_set_custom_attribute.mock_calls
|
||||
if call[1][0] == 'saml_config_signal.error_message'
|
||||
]
|
||||
assert error_message in error_calls[0][1][1], (
|
||||
f"Expected '{error_message}' in error message for {description}, "
|
||||
f"got: {error_calls[0][1][1]}"
|
||||
)
|
||||
def _get_current_provider(self, slug):
|
||||
"""
|
||||
Helper to get current version of provider by slug.
|
||||
"""
|
||||
return SAMLProviderConfig.objects.current_set().get(slug=slug)
|
||||
|
||||
def _get_site(self, site_id):
|
||||
"""
|
||||
Helper to get site by ID (1 = site1, 2 = site2).
|
||||
"""
|
||||
if site_id == 1:
|
||||
return self.site1
|
||||
elif site_id == 2:
|
||||
return self.site2
|
||||
else:
|
||||
raise ValueError(f"Unexpected site_id: {site_id}.")
|
||||
|
||||
@ddt.data(
|
||||
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug, is_provider_updated
|
||||
# All tests: provider's saml_configuration has site_id=1, slug='slug'
|
||||
# Signal matches provider's saml config and should update
|
||||
(1, 'slug', 1, 'slug', True), # Same site, same slug
|
||||
(2, 'slug', 1, 'slug', True), # Cross-site provider, matching saml config
|
||||
(1, 'provider-slug', 1, 'slug', True), # Different provider slug, matching saml config
|
||||
# Signal does not match provider's saml config and should not update
|
||||
(1, 'slug', 2, 'slug', False), # Different saml config site
|
||||
(2, 'slug', 2, 'slug', False), # Different saml config site (cross-site)
|
||||
(1, 'provider-slug', 1, 'provider-slug', False), # Different saml config slug
|
||||
(2, 'provider-slug', 1, 'provider-slug', False), # Different saml config slug (cross-site)
|
||||
)
|
||||
@ddt.unpack
|
||||
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
|
||||
@override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True)
|
||||
def test_saml_provider_config_updates(self, provider_site_id, provider_slug,
|
||||
signal_saml_site_id, signal_saml_slug, is_provider_updated,
|
||||
mock_set_custom_attribute):
|
||||
"""
|
||||
Test SAML provider config updates under different scenarios.
|
||||
|
||||
Tests that providers are updated only when the signal's SAML configuration
|
||||
matches the provider's existing SAML configuration (by site and slug).
|
||||
"""
|
||||
provider_site = self._get_site(provider_site_id)
|
||||
signal_saml_site = self._get_site(signal_saml_site_id)
|
||||
|
||||
provider = SAMLProviderConfigFactory(
|
||||
slug=provider_slug,
|
||||
site=provider_site,
|
||||
saml_configuration=self.existing_saml_config
|
||||
)
|
||||
original_config_id = provider.saml_configuration_id
|
||||
|
||||
new_saml_config = SAMLConfigurationFactory(
|
||||
site=signal_saml_site,
|
||||
slug=signal_saml_slug,
|
||||
entity_id='https://new.example.com'
|
||||
)
|
||||
|
||||
current_provider = self._get_current_provider(provider_slug)
|
||||
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.enabled', True)
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.new_config_id', new_saml_config.id)
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.slug', signal_saml_slug)
|
||||
|
||||
if is_provider_updated:
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 1)
|
||||
self.assertEqual(current_provider.saml_configuration_id, new_saml_config.id,
|
||||
"Provider should be updated when signal SAML config matches")
|
||||
else:
|
||||
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 0)
|
||||
self.assertEqual(current_provider.saml_configuration_id, original_config_id,
|
||||
"Provider should NOT be updated when signal SAML config doesn't match")
|
||||
|
||||
@ddt.data(
|
||||
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug
|
||||
# All tests: provider's saml config is None and should never be updated
|
||||
(1, 'slug', 1, 'default'),
|
||||
(1, 'default', 1, 'default'),
|
||||
(2, 'slug', 1, 'default'),
|
||||
)
|
||||
@ddt.unpack
|
||||
@override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True)
|
||||
def test_saml_provider_with_null_config_not_updated(self, provider_site_id, provider_slug,
|
||||
signal_saml_site_id, signal_saml_slug):
|
||||
"""
|
||||
Test that providers with NULL SAML configuration are never updated by signal handler.
|
||||
|
||||
This is critical for fallback authentication scenarios where providers
|
||||
intentionally have no SAML configuration.
|
||||
"""
|
||||
provider_site = self._get_site(provider_site_id)
|
||||
signal_saml_site = self._get_site(signal_saml_site_id)
|
||||
|
||||
null_provider = SAMLProviderConfigFactory(
|
||||
slug=provider_slug,
|
||||
site=provider_site,
|
||||
saml_configuration=None
|
||||
)
|
||||
|
||||
new_saml_config = SAMLConfigurationFactory(
|
||||
site=signal_saml_site,
|
||||
slug=signal_saml_slug,
|
||||
entity_id='https://new.example.com'
|
||||
)
|
||||
|
||||
current_provider = self._get_current_provider(provider_slug)
|
||||
self.assertIsNone(current_provider.saml_configuration_id,
|
||||
"Provider with NULL SAML config should never be updated")
|
||||
|
||||
Reference in New Issue
Block a user