diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py index ec8c253556..3a73f70d9f 100644 --- a/common/djangoapps/third_party_auth/pipeline.py +++ b/common/djangoapps/third_party_auth/pipeline.py @@ -480,7 +480,7 @@ def parse_query_params(strategy, response, *args, **kwargs): """Reads whitelisted query params, transforms them into pipeline args.""" # If auth_entry is not in the session, we got here by a non-standard workflow. # We simply assume 'login' in that case. - auth_entry = strategy.request.session.get(AUTH_ENTRY_KEY, AUTH_ENTRY_LOGIN) + auth_entry = strategy.request.session.get(AUTH_ENTRY_KEY) or AUTH_ENTRY_LOGIN if auth_entry not in _AUTH_ENTRY_CHOICES: raise AuthEntryError(strategy.request.backend, 'auth_entry invalid') diff --git a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py index a96fa8636b..e65932afb0 100644 --- a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py +++ b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py @@ -594,3 +594,64 @@ class SetIDVerificationStatusTestCase(TestCase): # Ensure a verification signal was sent assert mock_signal.call_count == 1 + + +class ParseQueryParamsPipelineTestCase(TestCase): + """Tests to ensure reading queryparams from the auth/login URL works as expected.""" + + def setUp(self): + super().setUp() + self.strategy = mock.MagicMock() + self.response = mock.MagicMock() + + def test_login_url_with_auth_entry(self): + """ + Parsing query params with auth entry results in dictionary with the auth entry. + """ + expected_query_params = { + "auth_entry": "login", + } + self.strategy.request.session = expected_query_params + + query_params = pipeline.parse_query_params(self.strategy, self.response) + + self.assertDictEqual(expected_query_params, query_params) + + def test_login_url_with_auth_entry_none(self): + """ + Parsing query params with auth entry equals to None results in dictionary with default auth entry. + """ + expected_query_params = { + "auth_entry": "login", + } + self.strategy.request.session = { + "auth_entry": None, + } + + query_params = pipeline.parse_query_params(self.strategy, self.response) + + self.assertDictEqual(expected_query_params, query_params) + + def test_login_url_without_auth_entry(self): + """ + Parsing query params without auth entry results in dictionary with default auth entry. + """ + expected_query_params = { + "auth_entry": "login", + } + self.strategy.request.session = {} + + query_params = pipeline.parse_query_params(self.strategy, self.response) + + self.assertDictEqual(expected_query_params, query_params) + + def test_login_url_invalid_auth_entry(self): + """ + Parsing query params with invalid auth entry results in AuthEntryError. + """ + self.strategy.request.session = { + "auth_entry": "not-valid", + } + + with self.assertRaises(pipeline.AuthEntryError): + pipeline.parse_query_params(self.strategy, self.response)