diff --git a/common/djangoapps/student/views.py b/common/djangoapps/student/views.py index bd2e9bc58d..19f576b56a 100644 --- a/common/djangoapps/student/views.py +++ b/common/djangoapps/student/views.py @@ -424,7 +424,7 @@ def register_user(request, extra_context=None): # selected provider. if third_party_auth.is_enabled() and pipeline.running(request): running_pipeline = pipeline.get(request) - current_provider = provider.Registry.get_by_backend_name(running_pipeline.get('backend')) + current_provider = provider.Registry.get_from_pipeline(running_pipeline) overrides = current_provider.get_register_form_data(running_pipeline.get('kwargs')) overrides['running_pipeline'] = running_pipeline overrides['selected_provider'] = current_provider.NAME @@ -952,10 +952,11 @@ def login_user(request, error=""): # pylint: disable-msg=too-many-statements,un running_pipeline = pipeline.get(request) username = running_pipeline['kwargs'].get('username') backend_name = running_pipeline['backend'] - requested_provider = provider.Registry.get_by_backend_name(backend_name) + third_party_uid = running_pipeline['kwargs']['uid'] + requested_provider = provider.Registry.get_from_pipeline(running_pipeline) try: - user = pipeline.get_authenticated_user(username, backend_name) + user = pipeline.get_authenticated_user(requested_provider, username, third_party_uid) third_party_auth_successful = True except User.DoesNotExist: AUDIT_LOG.warning( @@ -1509,7 +1510,7 @@ def create_account_with_params(request, params): provider_name = None if third_party_auth.is_enabled() and pipeline.running(request): running_pipeline = pipeline.get(request) - current_provider = provider.Registry.get_by_backend_name(running_pipeline.get('backend')) + current_provider = provider.Registry.get_from_pipeline(running_pipeline) provider_name = current_provider.NAME analytics.track( diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py index 7c0ad27c08..e5c489a6fe 100644 --- a/common/djangoapps/third_party_auth/pipeline.py +++ b/common/djangoapps/third_party_auth/pipeline.py @@ -196,9 +196,11 @@ class ProviderUserState(object): lms/templates/dashboard.html. """ - def __init__(self, enabled_provider, user, state): + def __init__(self, enabled_provider, user, association_id=None): + # UserSocialAuth row ID + self.association_id = association_id # Boolean. Whether the user has an account associated with the provider - self.has_account = state + self.has_account = association_id is not None # provider.BaseProvider child. Callers must verify that the provider is # enabled. self.provider = enabled_provider @@ -215,7 +217,7 @@ def get(request): return request.session.get('partial_pipeline') -def get_authenticated_user(username, backend_name): +def get_authenticated_user(auth_provider, username, uid): """Gets a saved user authenticated by a particular backend. Between pipeline steps User objects are not saved. We need to reconstitute @@ -224,26 +226,26 @@ def get_authenticated_user(username, backend_name): authenticate(). Args: + auth_provider: the third_party_auth provider in use for the current pipeline. username: string. Username of user to get. - backend_name: string. The name of the third-party auth backend from - the running pipeline. + uid: string. The user ID according to the third party. Returns: User if user is found and has a social auth from the passed - backend_name. + provider. Raises: User.DoesNotExist: if no user matching user is found, or the matching user has no social auth associated with the given backend. AssertionError: if the user is not authenticated. """ - user = models.DjangoStorage.user.user_model().objects.get(username=username) - match = models.DjangoStorage.user.get_social_auth_for_user(user, provider=backend_name) + match = models.DjangoStorage.user.get_social_auth(provider=auth_provider.BACKEND_CLASS.name, uid=uid) - if not match: + if not match or match.user.username != username: raise User.DoesNotExist - user.backend = provider.Registry.get_by_backend_name(backend_name).get_authentication_backend() + user = match.user + user.backend = auth_provider.get_authentication_backend() return user @@ -257,10 +259,12 @@ def _get_enabled_provider_by_name(provider_name): return enabled_provider -def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None): +def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None, + extra_params=None, url_params=None): """Creates a URL to hook into social auth endpoints.""" - kwargs = {'backend': backend_name} - url = reverse(view_name, kwargs=kwargs) + url_params = url_params or {} + url_params['backend'] = backend_name + url = reverse(view_name, kwargs=url_params) query_params = OrderedDict() if auth_entry: @@ -269,6 +273,9 @@ def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None): if redirect_url: query_params[AUTH_REDIRECT_KEY] = redirect_url + if extra_params: + query_params.update(extra_params) + return u"{url}?{params}".format( url=url, params=urllib.urlencode(query_params) @@ -288,29 +295,32 @@ def get_complete_url(backend_name): Raises: ValueError: if no provider is enabled with the given backend_name. """ - enabled_provider = provider.Registry.get_by_backend_name(backend_name) - - if not enabled_provider: + if not any(provider.Registry.get_enabled_by_backend_name(backend_name)): raise ValueError('Provider with backend %s not enabled' % backend_name) return _get_url('social:complete', backend_name) -def get_disconnect_url(provider_name): +def get_disconnect_url(provider_name, association_id): """Gets URL for the endpoint that starts the disconnect pipeline. Args: provider_name: string. Name of the provider.BaseProvider child you want to disconnect from. + association_id: int. Optional ID of a specific row in the UserSocialAuth + table to disconnect (useful if multiple providers use a common backend) Returns: String. URL that starts the disconnection pipeline. Raises: - ValueError: if no provider is enabled with the given backend_name. + ValueError: if no provider is enabled with the given name. """ - enabled_provider = _get_enabled_provider_by_name(provider_name) - return _get_url('social:disconnect', enabled_provider.BACKEND_CLASS.name) + backend_name = _get_enabled_provider_by_name(provider_name).BACKEND_CLASS.name + if association_id: + return _get_url('social:disconnect_individual', backend_name, url_params={'association_id': association_id}) + else: + return _get_url('social:disconnect', backend_name) def get_login_url(provider_name, auth_entry, redirect_url=None): @@ -340,6 +350,7 @@ def get_login_url(provider_name, auth_entry, redirect_url=None): enabled_provider.BACKEND_CLASS.name, auth_entry=auth_entry, redirect_url=redirect_url, + extra_params=enabled_provider.get_url_params(), ) @@ -355,7 +366,7 @@ def get_duplicate_provider(messages): unfortunately not in a reusable constant. Returns: - provider.BaseProvider child instance. The provider of the duplicate + string name of the python-social-auth backend that has the duplicate account, or None if there is no duplicate (and hence no error). """ social_auth_messages = [m for m in messages if m.message.endswith('is already in use.')] @@ -364,7 +375,8 @@ def get_duplicate_provider(messages): return assert len(social_auth_messages) == 1 - return provider.Registry.get_by_backend_name(social_auth_messages[0].extra_tags.split()[1]) + backend_name = social_auth_messages[0].extra_tags.split()[1] + return backend_name def get_provider_user_states(user): @@ -378,13 +390,16 @@ def get_provider_user_states(user): each enabled provider. """ states = [] - found_user_backends = [ - social_auth.provider for social_auth in models.DjangoStorage.user.get_social_auth_for_user(user) - ] + found_user_auths = list(models.DjangoStorage.user.get_social_auth_for_user(user)) for enabled_provider in provider.Registry.enabled(): + association_id = None + for auth in found_user_auths: + if enabled_provider.match_social_auth(auth): + association_id = auth.id + break states.append( - ProviderUserState(enabled_provider, user, enabled_provider.BACKEND_CLASS.name in found_user_backends) + ProviderUserState(enabled_provider, user, association_id) ) return states diff --git a/common/djangoapps/third_party_auth/provider.py b/common/djangoapps/third_party_auth/provider.py index 9f0809d42a..155748fa28 100644 --- a/common/djangoapps/third_party_auth/provider.py +++ b/common/djangoapps/third_party_auth/provider.py @@ -5,6 +5,8 @@ invoke the Django armature. """ from social.backends import google, linkedin, facebook +from social.backends.saml import OID_EDU_PERSON_PRINCIPAL_NAME +from .saml import SAMLAuthBackend _DEFAULT_ICON_CLASS = 'fa-signin' @@ -109,6 +111,21 @@ class BaseProvider(object): for key, value in cls.SETTINGS.iteritems(): setattr(settings, key, value) + @classmethod + def get_url_params(cls): + """ Get a dict of GET parameters to append to login links for this provider """ + return {} + + @classmethod + def is_active_for_pipeline(cls, pipeline): + """ Is this provider being used for the specified pipeline? """ + return cls.BACKEND_CLASS.name == pipeline['backend'] + + @classmethod + def match_social_auth(cls, social_auth): + """ Is this provider being used for this UserSocialAuth entry? """ + return cls.BACKEND_CLASS.name == social_auth.provider + class GoogleOauth2(BaseProvider): """Provider for Google's Oauth2 auth system.""" @@ -146,6 +163,78 @@ class FacebookOauth2(BaseProvider): } +class SAMLProviderMixin(object): + """ Base class for SAML/Shibboleth providers """ + BACKEND_CLASS = SAMLAuthBackend + ICON_CLASS = 'fa-university' + + @classmethod + def get_url_params(cls): + """ Get a dict of GET parameters to append to login links for this provider """ + return {'idp': cls.IDP["id"]} + + @classmethod + def is_active_for_pipeline(cls, pipeline): + """ Is this provider being used for the specified pipeline? """ + if cls.BACKEND_CLASS.name == pipeline['backend']: + idp_name = pipeline['kwargs']['response']['idp_name'] + return cls.IDP["id"] == idp_name + return False + + @classmethod + def match_social_auth(cls, social_auth): + """ Is this provider being used for this UserSocialAuth entry? """ + prefix = cls.IDP["id"] + ":" + return cls.BACKEND_CLASS.name == social_auth.provider and social_auth.uid.startswith(prefix) + + +class TestShibAProvider(SAMLProviderMixin, BaseProvider): + """ Provider for testshib.org public Shibboleth test server. """ + NAME = 'TestShib A' + IDP = { + "id": "testshiba", # Required slug + "entity_id": "https://idp.testshib.org/idp/shibboleth", + "url": "https://idp.testshib.org/idp/profile/SAML2/Redirect/SSO", + "attr_email": OID_EDU_PERSON_PRINCIPAL_NAME, + "x509cert": """ + MIIEDjCCAvagAwIBAgIBADANBgkqhkiG9w0BAQUFADBnMQswCQYDVQQGEwJVUzEV + MBMGA1UECBMMUGVubnN5bHZhbmlhMRMwEQYDVQQHEwpQaXR0c2J1cmdoMREwDwYD + VQQKEwhUZXN0U2hpYjEZMBcGA1UEAxMQaWRwLnRlc3RzaGliLm9yZzAeFw0wNjA4 + MzAyMTEyMjVaFw0xNjA4MjcyMTEyMjVaMGcxCzAJBgNVBAYTAlVTMRUwEwYDVQQI + EwxQZW5uc3lsdmFuaWExEzARBgNVBAcTClBpdHRzYnVyZ2gxETAPBgNVBAoTCFRl + c3RTaGliMRkwFwYDVQQDExBpZHAudGVzdHNoaWIub3JnMIIBIjANBgkqhkiG9w0B + AQEFAAOCAQ8AMIIBCgKCAQEArYkCGuTmJp9eAOSGHwRJo1SNatB5ZOKqDM9ysg7C + yVTDClcpu93gSP10nH4gkCZOlnESNgttg0r+MqL8tfJC6ybddEFB3YBo8PZajKSe + 3OQ01Ow3yT4I+Wdg1tsTpSge9gEz7SrC07EkYmHuPtd71CHiUaCWDv+xVfUQX0aT + NPFmDixzUjoYzbGDrtAyCqA8f9CN2txIfJnpHE6q6CmKcoLADS4UrNPlhHSzd614 + kR/JYiks0K4kbRqCQF0Dv0P5Di+rEfefC6glV8ysC8dB5/9nb0yh/ojRuJGmgMWH + gWk6h0ihjihqiu4jACovUZ7vVOCgSE5Ipn7OIwqd93zp2wIDAQABo4HEMIHBMB0G + A1UdDgQWBBSsBQ869nh83KqZr5jArr4/7b+QazCBkQYDVR0jBIGJMIGGgBSsBQ86 + 9nh83KqZr5jArr4/7b+Qa6FrpGkwZzELMAkGA1UEBhMCVVMxFTATBgNVBAgTDFBl + bm5zeWx2YW5pYTETMBEGA1UEBxMKUGl0dHNidXJnaDERMA8GA1UEChMIVGVzdFNo + aWIxGTAXBgNVBAMTEGlkcC50ZXN0c2hpYi5vcmeCAQAwDAYDVR0TBAUwAwEB/zAN + BgkqhkiG9w0BAQUFAAOCAQEAjR29PhrCbk8qLN5MFfSVk98t3CT9jHZoYxd8QMRL + I4j7iYQxXiGJTT1FXs1nd4Rha9un+LqTfeMMYqISdDDI6tv8iNpkOAvZZUosVkUo + 93pv1T0RPz35hcHHYq2yee59HJOco2bFlcsH8JBXRSRrJ3Q7Eut+z9uo80JdGNJ4 + /SJy5UorZ8KazGj16lfJhOBXldgrhppQBb0Nq6HKHguqmwRfJ+WkxemZXzhediAj + Geka8nz8JjwxpUjAiSWYKLtJhGEaTqCYxCCX2Dw+dOTqUzHOZ7WKv4JXPK5G/Uhr + 8K/qhmFT2nIQi538n6rVYLeWj8Bbnl+ev0peYzxFyF5sQA== + """ + } + + +class TestShibBProvider(SAMLProviderMixin, BaseProvider): + """ Provider for testshib.org public Shibboleth test server. """ + NAME = 'TestShib B' + IDP = { + "id": "testshibB", # Required slug + "entity_id": "https://idp.testshib.org/idp/shibboleth", + "url": "https://IDP.TESTSHIB.ORG/idp/profile/SAML2/Redirect/SSO", + "attr_email": OID_EDU_PERSON_PRINCIPAL_NAME, + "x509cert": TestShibAProvider.IDP["x509cert"], + } + + class Registry(object): """Singleton registry of third-party auth providers. @@ -211,13 +300,39 @@ class Registry(object): return cls._ENABLED.get(provider_name) @classmethod - def get_by_backend_name(cls, backend_name): - """Gets provider (or None) by backend name. + def get_from_pipeline(cls, running_pipeline): + """Gets the provider that is being used for the specified pipeline (or None). Args: - backend_name: string. The python-social-auth - backends.base.BaseAuth.name (for example, 'google-oauth2') to - try and get a provider for. + running_pipeline: The python-social-auth pipeline being used to + authenticate a user. + + Returns: + A provider class (a subclass of BaseProvider) or None. + + Raises: + RuntimeError: if the registry has not been configured. + """ + cls._check_configured() + for enabled in cls._ENABLED.values(): + if enabled.is_active_for_pipeline(running_pipeline): + return enabled + + @classmethod + def get_enabled_by_backend_name(cls, backend_name): + """Generator returning all enabled providers that use the specified + backend. + + Example: + >>> list(get_enabled_by_backend_name("tpa-saml")) + [TestShibAProvider, TestShibBProvider] + + Args: + backend_name: The name of a python-social-auth backend used by + one or more providers. + + Yields: + Provider classes (subclasses of BaseProvider). Raises: RuntimeError: if the registry has not been configured. @@ -225,7 +340,7 @@ class Registry(object): cls._check_configured() for enabled in cls._ENABLED.values(): if enabled.BACKEND_CLASS.name == backend_name: - return enabled + yield enabled @classmethod def _reset(cls): diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py new file mode 100644 index 0000000000..78106f7080 --- /dev/null +++ b/common/djangoapps/third_party_auth/saml.py @@ -0,0 +1,21 @@ +""" +Slightly customized python-social-auth backend for SAML 2.0 support +""" + +from social.backends.saml import SAMLIdentityProvider, SAMLAuth + + +class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method + """ + Customized version of SAMLAuth that gets the list of IdPs from third_party_auth's list of + enabled providers. + """ + name = "tpa-saml" + + def get_idp(self, idp_name): + """ Given the name of an IdP, get a SAMLIdentityProvider instance """ + from .provider import Registry # Import here to avoid circular import + for provider in Registry.enabled(): + if issubclass(provider.BACKEND_CLASS, SAMLAuth) and provider.IDP["id"] == idp_name: + return SAMLIdentityProvider(idp_name, **provider.IDP) + raise KeyError("SAML IdP {} not found.".format(idp_name)) diff --git a/common/djangoapps/third_party_auth/tests/specs/base.py b/common/djangoapps/third_party_auth/tests/specs/base.py index 25e060c099..ea90c8d659 100644 --- a/common/djangoapps/third_party_auth/tests/specs/base.py +++ b/common/djangoapps/third_party_auth/tests/specs/base.py @@ -115,12 +115,12 @@ class IntegrationTest(testutil.TestCase, test.TestCase): """Asserts the user's account settings page context is in the expected state. If duplicate is True, we expect context['duplicate_provider'] to contain - the duplicate provider object. If linked is passed, we conditionally + the duplicate provider backend name. If linked is passed, we conditionally check that the provider is included in context['auth']['providers'] and its connected state is correct. """ if duplicate: - self.assertEqual(context['duplicate_provider'].NAME, self.PROVIDER_CLASS.NAME) + self.assertEqual(context['duplicate_provider'], self.PROVIDER_CLASS.BACKEND_CLASS.name) else: self.assertIsNone(context['duplicate_provider']) diff --git a/common/djangoapps/third_party_auth/tests/test_pipeline.py b/common/djangoapps/third_party_auth/tests/test_pipeline.py index 66c11d9043..462f24e4b2 100644 --- a/common/djangoapps/third_party_auth/tests/test_pipeline.py +++ b/common/djangoapps/third_party_auth/tests/test_pipeline.py @@ -38,5 +38,5 @@ class ProviderUserStateTestCase(testutil.TestCase): """Tests ProviderUserState behavior.""" def test_get_unlink_form_name(self): - state = pipeline.ProviderUserState(provider.GoogleOauth2, object(), False) + state = pipeline.ProviderUserState(provider.GoogleOauth2, object(), 1000) self.assertEqual(provider.GoogleOauth2.NAME + '_unlink_form', state.get_unlink_form_name()) diff --git a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py index 8d1f3b7019..e6181fac61 100644 --- a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py +++ b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py @@ -41,16 +41,16 @@ class GetAuthenticatedUserTestCase(TestCase): def test_raises_does_not_exist_if_user_missing(self): with self.assertRaises(models.User.DoesNotExist): - pipeline.get_authenticated_user('new_' + self.user.username, 'backend') + pipeline.get_authenticated_user(self.enabled_provider, 'new_' + self.user.username, 'user@example.com') def test_raises_does_not_exist_if_user_found_but_no_association(self): backend_name = 'backend' self.assertIsNotNone(self.get_by_username(self.user.username)) - self.assertIsNone(provider.Registry.get_by_backend_name(backend_name)) + self.assertFalse(any(provider.Registry.get_enabled_by_backend_name(backend_name))) with self.assertRaises(models.User.DoesNotExist): - pipeline.get_authenticated_user(self.user.username, 'backend') + pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'user@example.com') def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self): self.assertIsNotNone(self.get_by_username(self.user.username)) @@ -58,11 +58,11 @@ class GetAuthenticatedUserTestCase(TestCase): self.user, 'uid', 'other_' + self.enabled_provider.BACKEND_CLASS.name) with self.assertRaises(models.User.DoesNotExist): - pipeline.get_authenticated_user(self.user.username, self.enabled_provider.BACKEND_CLASS.name) + pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid') def test_returns_user_with_is_authenticated_and_backend_set_if_match(self): social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.BACKEND_CLASS.name) - user = pipeline.get_authenticated_user(self.user.username, self.enabled_provider.BACKEND_CLASS.name) + user = pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid') self.assertEqual(self.user, user) self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend) @@ -93,8 +93,9 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): def test_states_for_enabled_providers_user_has_accounts_associated_with(self): provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME]) - social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', provider.GoogleOauth2.BACKEND_CLASS.name) - social_models.DjangoStorage.user.create_social_auth( + user_social_auth_google = social_models.DjangoStorage.user.create_social_auth( + self.user, 'uid', provider.GoogleOauth2.BACKEND_CLASS.name) + user_social_auth_linkedin = social_models.DjangoStorage.user.create_social_auth( self.user, 'uid', provider.LinkedInOauth2.BACKEND_CLASS.name) states = pipeline.get_provider_user_states(self.user) @@ -106,10 +107,12 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): self.assertTrue(google_state.has_account) self.assertEqual(provider.GoogleOauth2, google_state.provider) self.assertEqual(self.user, google_state.user) + self.assertEqual(user_social_auth_google.id, google_state.association_id) self.assertTrue(linkedin_state.has_account) self.assertEqual(provider.LinkedInOauth2, linkedin_state.provider) self.assertEqual(self.user, linkedin_state.user) + self.assertEqual(user_social_auth_linkedin.id, linkedin_state.association_id) def test_states_for_enabled_providers_user_has_no_account_associated_with(self): provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME]) @@ -155,13 +158,16 @@ class UrlFormationTestCase(TestCase): self.assertIsNone(provider.Registry.get(provider_name)) with self.assertRaises(ValueError): - pipeline.get_disconnect_url(provider_name) + pipeline.get_disconnect_url(provider_name, 1000) def test_disconnect_url_returns_expected_format(self): - disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.NAME) - - self.assertTrue(disconnect_url.startswith('/auth/disconnect')) - self.assertIn(self.enabled_provider.BACKEND_CLASS.name, disconnect_url) + disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.NAME, 1000) + disconnect_url = disconnect_url.rstrip('?') + self.assertEqual( + disconnect_url, + '/auth/disconnect/{backend}/{association_id}/'.format( + backend=self.enabled_provider.BACKEND_CLASS.name, association_id=1000) + ) def test_login_url_raises_value_error_if_provider_not_enabled(self): provider_name = 'not_enabled' diff --git a/common/djangoapps/third_party_auth/tests/test_provider.py b/common/djangoapps/third_party_auth/tests/test_provider.py index 20120d7329..a1de2943bd 100644 --- a/common/djangoapps/third_party_auth/tests/test_provider.py +++ b/common/djangoapps/third_party_auth/tests/test_provider.py @@ -1,5 +1,6 @@ """Unit tests for provider.py.""" +from mock import Mock from third_party_auth import provider from third_party_auth.tests import testutil @@ -67,16 +68,22 @@ class RegistryTest(testutil.TestCase): provider.Registry.configure_once([]) self.assertIsNone(provider.Registry.get(provider.LinkedInOauth2.NAME)) - def test_get_by_backend_name_raises_runtime_error_if_not_configured(self): - with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'): - provider.Registry.get_by_backend_name('') - - def test_get_by_backend_name_returns_enabled_provider(self): - provider.Registry.configure_once([provider.GoogleOauth2.NAME]) - self.assertIs( - provider.GoogleOauth2, - provider.Registry.get_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name)) - - def test_get_by_backend_name_returns_none_if_provider_not_enabled(self): + def test_get_from_pipeline_returns_none_if_provider_not_enabled(self): provider.Registry.configure_once([]) - self.assertIsNone(provider.Registry.get_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name)) + self.assertIsNone(provider.Registry.get_from_pipeline(Mock())) + + def test_get_enabled_by_backend_name_raises_runtime_error_if_not_configured(self): + with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'): + provider.Registry.get_enabled_by_backend_name('').next() + + def test_get_enabled_by_backend_name_returns_enabled_provider(self): + provider.Registry.configure_once([provider.GoogleOauth2.NAME]) + found = list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name)) + self.assertEqual(found, [provider.GoogleOauth2]) + + def test_get_enabled_by_backend_name_returns_none_if_provider_not_enabled(self): + provider.Registry.configure_once([]) + self.assertEqual( + [], + list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name)) + ) diff --git a/common/djangoapps/third_party_auth/urls.py b/common/djangoapps/third_party_auth/urls.py index b020e775b5..5d366b2da3 100644 --- a/common/djangoapps/third_party_auth/urls.py +++ b/common/djangoapps/third_party_auth/urls.py @@ -2,10 +2,11 @@ from django.conf.urls import include, patterns, url -from .views import inactive_user_view +from .views import inactive_user_view, saml_metadata_view urlpatterns = patterns( '', url(r'^auth/inactive', inactive_user_view), + url(r'^auth/saml/metadata.xml', saml_metadata_view), url(r'^auth/', include('social.apps.django_app.urls', namespace='social')), ) diff --git a/common/djangoapps/third_party_auth/views.py b/common/djangoapps/third_party_auth/views.py index 5ae69db526..8f0c6bc3ba 100644 --- a/common/djangoapps/third_party_auth/views.py +++ b/common/djangoapps/third_party_auth/views.py @@ -1,7 +1,11 @@ """ Extra views required for SSO """ +from django.conf import settings +from django.core.urlresolvers import reverse +from django.http import HttpResponse, HttpResponseServerError from django.shortcuts import redirect +from social.apps.django_app.utils import load_strategy, load_backend def inactive_user_view(request): @@ -13,3 +17,19 @@ def inactive_user_view(request): # in a course. Otherwise, just redirect them to the dashboard, which displays a message # about activating their account. return redirect(request.GET.get('next', 'dashboard')) + + +def saml_metadata_view(request): + """ + Get the Service Provider metadata for this edx-platform instance. + You must send this XML to any Shibboleth Identity Provider that you wish to use. + """ + complete_url = reverse('social:complete', args=("tpa-saml", )) + if settings.APPEND_SLASH and not complete_url.endswith('/'): + complete_url = complete_url + '/' # Required for consistency + saml_backend = load_backend(load_strategy(request), "tpa-saml", redirect_uri=complete_url) + metadata, errors = saml_backend.generate_metadata_xml() + + if not errors: + return HttpResponse(content=metadata, content_type='text/xml') + return HttpResponseServerError(content=', '.join(errors)) diff --git a/common/lib/safe_lxml/safe_lxml/etree.py b/common/lib/safe_lxml/safe_lxml/etree.py index 83052b22b6..97bc0b7547 100644 --- a/common/lib/safe_lxml/safe_lxml/etree.py +++ b/common/lib/safe_lxml/safe_lxml/etree.py @@ -9,7 +9,7 @@ For processing xml always prefer this over using lxml.etree directly. from lxml.etree import * # pylint: disable=wildcard-import, unused-wildcard-import from lxml.etree import XMLParser as _XMLParser -from lxml.etree import _ElementTree # pylint: disable=unused-import +from lxml.etree import _Element, _ElementTree # pylint: disable=unused-import, no-name-in-module # This should be imported after lxml.etree so that it overrides the following attributes. from defusedxml.lxml import parse, fromstring, XML diff --git a/common/lib/xmodule/xmodule/x_module.py b/common/lib/xmodule/xmodule/x_module.py index 5f6ed8f8a0..4cb4ccfd40 100644 --- a/common/lib/xmodule/xmodule/x_module.py +++ b/common/lib/xmodule/xmodule/x_module.py @@ -1754,6 +1754,7 @@ class CombinedSystem(object): integrate it into a larger whole. """ + context = context or {} if view_name in PREVIEW_VIEWS: block = self._get_student_block(block) diff --git a/lms/djangoapps/student_account/test/test_views.py b/lms/djangoapps/student_account/test/test_views.py index 5c46647dfa..80cebeae70 100644 --- a/lms/djangoapps/student_account/test/test_views.py +++ b/lms/djangoapps/student_account/test/test_views.py @@ -432,7 +432,7 @@ class AccountSettingsViewTest(TestCase): context['user_preferences_api_url'], reverse('preferences_api', kwargs={'username': self.user.username}) ) - self.assertEqual(context['duplicate_provider'].BACKEND_CLASS.name, 'facebook') + self.assertEqual(context['duplicate_provider'], 'facebook') self.assertEqual(context['auth']['providers'][0]['name'], 'Facebook') self.assertEqual(context['auth']['providers'][1]['name'], 'Google') diff --git a/lms/djangoapps/student_account/views.py b/lms/djangoapps/student_account/views.py index dc178285ff..284472c263 100644 --- a/lms/djangoapps/student_account/views.py +++ b/lms/djangoapps/student_account/views.py @@ -189,9 +189,7 @@ def _third_party_auth_context(request, redirect_to): running_pipeline = pipeline.get(request) if running_pipeline is not None: - current_provider = third_party_auth.provider.Registry.get_by_backend_name( - running_pipeline.get('backend') - ) + current_provider = third_party_auth.provider.Registry.get_from_pipeline(running_pipeline) context["currentProvider"] = current_provider.NAME context["finishAuthUrl"] = pipeline.get_complete_url(current_provider.BACKEND_CLASS.name) @@ -382,7 +380,7 @@ def account_settings_context(request): ), # If the user is connected, sending a POST request to this url removes the connection # information for this provider from their edX account. - 'disconnect_url': pipeline.get_disconnect_url(state.provider.NAME), + 'disconnect_url': pipeline.get_disconnect_url(state.provider.NAME, state.association_id), } for state in auth_states] return context diff --git a/lms/envs/aws.py b/lms/envs/aws.py index a4ae2dfb9b..07dd1474a6 100644 --- a/lms/envs/aws.py +++ b/lms/envs/aws.py @@ -541,6 +541,25 @@ THIRD_PARTY_AUTH = AUTH_TOKENS.get('THIRD_PARTY_AUTH', THIRD_PARTY_AUTH) # The reduced session expiry time during the third party login pipeline. (Value in seconds) SOCIAL_AUTH_PIPELINE_TIMEOUT = ENV_TOKENS.get('SOCIAL_AUTH_PIPELINE_TIMEOUT', 600) +##### SAML configuration for third_party_auth ##### + +if 'SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID' in ENV_TOKENS: + SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID') + SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT', 'unspecified') + SOCIAL_AUTH_TPA_SAML_SP_EXTRA = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_EXTRA', {}) + SOCIAL_AUTH_TPA_SAML_ORG_INFO = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_ORG_INFO') + SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT = ENV_TOKENS.get( + 'SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT', + {"givenName": "Technical Support", "emailAddress": TECH_SUPPORT_EMAIL} + ) + SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT = ENV_TOKENS.get( + 'SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT', + {"givenName": "Support", "emailAddress": TECH_SUPPORT_EMAIL} + ) + SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG', {}) + SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT') + SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY') + ##### OAUTH2 Provider ############## if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER'] diff --git a/lms/templates/dashboard/_dashboard_third_party_error.html b/lms/templates/dashboard/_dashboard_third_party_error.html index 99ba0ae4fb..a7958b9481 100644 --- a/lms/templates/dashboard/_dashboard_third_party_error.html +++ b/lms/templates/dashboard/_dashboard_third_party_error.html @@ -5,7 +5,7 @@
${_("The {provider_name} account you selected is already linked to another {platform_name} account.").format(provider_name='{duplicate_provider}'.format(duplicate_provider=duplicate_provider.NAME), platform_name=platform_name)}
+${_("The {provider_name} account you selected is already linked to another {platform_name} account.").format(provider_name=duplicate_provider, platform_name=platform_name)}