diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index f31d48f004..1b67fe5efd 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -10,10 +10,11 @@ import requests from django.contrib.sites.models import Site from django.http import Http404 from django.utils.functional import cached_property +from django.utils.datastructures import MultiValueDictKeyError from django_countries import countries from onelogin.saml2.settings import OneLogin_Saml2_Settings from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider -from social_core.exceptions import AuthForbidden +from social_core.exceptions import AuthForbidden, AuthMissingParameter from openedx.core.djangoapps.theming.helpers import get_current_request from common.djangoapps.third_party_auth.exceptions import IncorrectConfigurationException @@ -84,6 +85,17 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method else: return super().generate_saml_config() + def auth_complete(self, *args, **kwargs): + """ + Handle exceptions that happen during SAML authentication + """ + try: + return super().auth_complete(*args, **kwargs) + # We are seeing errors of MultiValueDictKeyError looking for the parameter 'RelayState'. + # We would like to have a more specific error to handle for observability purposes. + except MultiValueDictKeyError as e: + raise AuthMissingParameter(self.name, e.args[0]) from e + def get_user_id(self, details, response): """ Calling the parent function and handling the exception properly. diff --git a/common/djangoapps/third_party_auth/tests/test_saml.py b/common/djangoapps/third_party_auth/tests/test_saml.py index 9bb49aa52a..6b966a3e6e 100644 --- a/common/djangoapps/third_party_auth/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/tests/test_saml.py @@ -5,7 +5,10 @@ Unit tests for third_party_auth SAML auth providers from unittest import mock -from common.djangoapps.third_party_auth.saml import EdXSAMLIdentityProvider, get_saml_idp_class +from django.utils.datastructures import MultiValueDictKeyError +from social_core.exceptions import AuthMissingParameter + +from common.djangoapps.third_party_auth.saml import EdXSAMLIdentityProvider, get_saml_idp_class, SAMLAuthBackend from common.djangoapps.third_party_auth.tests.data.saml_identity_provider_mock_data import ( expected_user_details, mock_attributes, @@ -32,3 +35,16 @@ class TestEdXSAMLIdentityProvider(SAMLTestCase): """ test get_attr and get_user_details of EdXSAMLIdentityProvider""" edx_saml_identity_provider = EdXSAMLIdentityProvider('demo', **mock_conf) assert edx_saml_identity_provider.get_user_details(mock_attributes) == expected_user_details + + +class TestSAMLAuthBackend(SAMLTestCase): + """ Tests for the SAML backend. """ + + @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete') + def test_saml_auth_complete(self, super_auth_complete): + super_auth_complete.side_effect = MultiValueDictKeyError('RelayState') + backend = SAMLAuthBackend() + with self.assertRaises(AuthMissingParameter) as cm: + backend.auth_complete() + + assert cm.exception.parameter == 'RelayState'