diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index 6d244d96ed..98877ddb75 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -827,7 +827,7 @@ class SAMLProviderConfig(ProviderConfig): return other_settings[name] raise KeyError - def get_config(self): + def get_config(self, backend): """ Return a SAMLIdentityProvider instance for use by SAMLAuthBackend. @@ -887,7 +887,7 @@ class SAMLProviderConfig(ProviderConfig): SAMLConfiguration.current(self.site.id, 'default') ) idp_class = get_saml_idp_class(self.identity_provider_type) - return idp_class(self.slug, **conf) + return idp_class(backend, self.slug, **conf) class SAMLProviderData(models.Model): diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index 8e78f9e36f..ce5375ec95 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -2,7 +2,6 @@ Slightly customized python-social-auth backend for SAML 2.0 support """ - import logging from copy import deepcopy @@ -14,7 +13,7 @@ from django.utils.datastructures import MultiValueDictKeyError from django_countries import countries from onelogin.saml2.settings import OneLogin_Saml2_Settings from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider -from social_core.exceptions import AuthForbidden, AuthMissingParameter +from social_core.exceptions import AuthForbidden, AuthMissingParameter, AuthInvalidParameter from openedx.core.djangoapps.theming.helpers import get_current_request from common.djangoapps.third_party_auth.exceptions import IncorrectConfigurationException @@ -34,7 +33,7 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method def get_idp(self, idp_name): """ Given the name of an IdP, get a SAMLIdentityProvider instance """ from .models import SAMLProviderConfig - return SAMLProviderConfig.current(idp_name).get_config() + return SAMLProviderConfig.current(idp_name).get_config(self) def setting(self, name, default=None): """ Get a setting, from SAMLConfiguration """ @@ -102,7 +101,7 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method """ try: return super().get_user_id(details, response) - except (KeyError, IndexError) as ex: + except (KeyError, IndexError, AuthInvalidParameter) as ex: # Add AuthInvalidParameter here log.warning( '[THIRD_PARTY_AUTH] Error in SAML authentication flow. ' 'Provider: {idp_name}, Message: {message}'.format( @@ -179,7 +178,6 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method auth_inst = super()._create_saml_auth(idp) from .models import SAMLProviderConfig if SAMLProviderConfig.current(idp.name).debug_mode: - def wrap_with_logging(method_name, action_description, xml_getter, request_data, next_url): """ Wrap the request and response handlers to add debug mode logging """ method = getattr(auth_inst, method_name) @@ -192,6 +190,7 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method action_description, idp.name, request_data, next_url, xml_getter() ) return result + setattr(auth_inst, method_name, wrapped_method) request_data = self.strategy.request_data() @@ -226,21 +225,47 @@ class EdXSAMLIdentityProvider(SAMLIdentityProvider): }) return details - def get_attr(self, attributes, conf_key, default_attribute): + def get_attr( + self, + attributes: dict[str, str | list[str] | None], + conf_key: str, + default_attributes: tuple[str, ...], + *, + validate_defaults: bool = False, + ): """ - Internal helper method. - Get the attribute 'default_attribute' out of the attributes, - unless self.conf[conf_key] overrides the default by specifying - another attribute to use. + This override is compatible with the new social-core base class + (which passes a tuple of default_attributes) and preserves the + 'attr_defaults' fallback logic. """ - key = self.conf.get(conf_key, default_attribute) - if key in attributes: + try: + key = self.conf[conf_key] + except KeyError: + for key in default_attributes: + if key in attributes: + break # Found a matching default + else: + key = None + + if key is None: + return self.conf.get('attr_defaults', {}).get(conf_key) or None + try: + value = attributes[key] + except KeyError: + return self.conf.get('attr_defaults', {}).get(conf_key) or None + + if isinstance(value, list): try: - return attributes[key][0] + return value[0] except IndexError: - log.warning('[THIRD_PARTY_AUTH] SAML attribute value not found. ' - 'SamlAttribute: {attribute}'.format(attribute=key)) - return self.conf['attr_defaults'].get(conf_key) or None + log.warning( + '[THIRD_PARTY_AUTH] SAML attribute value not found. ' + 'The attribute %s was present but the list was empty.', + key + ) + else: + return value + return self.conf.get('attr_defaults', {}).get(conf_key) or None @property def saml_sp_configuration(self): diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 18059ac687..47fcfa5b0b 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -31,6 +31,7 @@ from openedx.features.enterprise_support.tests.factories import EnterpriseCustom from .base import IntegrationTestMixin from common.test.utils import assert_dict_contains_subset +from urllib.parse import urlparse, parse_qs, quote TESTSHIB_ENTITY_ID = "https://idp.testshib.org/idp/shibboleth" TESTSHIB_METADATA_URL = "https://mock.testshib.org/metadata/testshib-providers.xml" @@ -143,10 +144,20 @@ class SamlIntegrationTestUtilities: os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "testshib_saml_response.xml") ) + data = utils.prepare_saml_response_from_xml(saml_response_xml) + + # Extract RelayState from the redirect to IdP + parsed_url = urlparse(provider_redirect_url) + query_params = parse_qs(parsed_url.query) + relay_state = query_params.get('RelayState', [''])[0] + + if relay_state: + data += '&RelayState=' + quote(relay_state) # Append as string to the URL-encoded data + return self.client.post( # lint-amnesty, pylint: disable=no-member self.complete_url, # lint-amnesty, pylint: disable=no-member content_type="application/x-www-form-urlencoded", - data=utils.prepare_saml_response_from_xml(saml_response_xml), + data=data, ) @@ -189,45 +200,6 @@ class TestIndexExceptionTest(SamlIntegrationTestUtilities, IntegrationTestMixin, return response_data -@ddt.ddt -@utils.skip_unless_thirdpartyauth() -class TestKeyExceptionTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase): - """ - To test SAML error handling when presented with missing attributes - """ - - TOKEN_RESPONSE_DATA = { - "access_token": "access_token_value", - "expires_in": "expires_in_value", - } - USER_RESPONSE_DATA = { - "lastName": "lastName_value", - "id": "id_value", - "firstName": "firstName_value", - "idp_name": "testshib", - "attributes": {"name_id": "1"}, - "session_index": "1", - } - - def test_key_error_from_missing_saml_attributes(self): - """ - The `urn:oid:0.9.2342.19200300.100.1.1` attribute is missing, - should throw a specific exception NOT a Key Error - """ - self.provider = self._configure_testshib_provider() - request, strategy = self.get_request_and_strategy( - auth_entry=pipeline.AUTH_ENTRY_LOGIN, redirect_uri="social:complete" - ) - with self.assertRaises(IncorrectConfigurationException): - request.backend.auth_complete = MagicMock(return_value=self.fake_auth_complete(strategy)) - - def get_response_data(self): - """Gets dict (string -> object) of merged data about the user.""" - response_data = dict(self.TOKEN_RESPONSE_DATA) - response_data.update(self.USER_RESPONSE_DATA) - return response_data - - @ddt.ddt @utils.skip_unless_thirdpartyauth() class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase): @@ -415,7 +387,8 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin assert msg.startswith("SAML login %s") assert action_type == "response" assert idp_name == self.PROVIDER_IDP_SLUG - assert_dict_contains_subset(self, {"RelayState": idp_name}, response_data) + expected_relay_state = json.dumps({"idp": idp_name, "next": expected_next_url}) # Remove "auth_entry" + assert_dict_contains_subset(self, {"RelayState": expected_relay_state}, response_data) assert "SAMLResponse" in response_data assert next_url == expected_next_url assert "