diff --git a/common/djangoapps/config_models/models.py b/common/djangoapps/config_models/models.py index 1eec190771..a95ebbb5bb 100644 --- a/common/djangoapps/config_models/models.py +++ b/common/djangoapps/config_models/models.py @@ -42,7 +42,7 @@ class ConfigurationModelManager(models.Manager): assert self.model.KEY_FIELDS != (), "Just use model.current() if there are no KEY_FIELDS" return self.get_queryset().extra( # pylint: disable=no-member where=["{table_name}.id IN ({subquery})".format( - table_name=self.model._meta.db_table, # pylint: disable=protected-access + table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member subquery=self._current_ids_subquery(), )], select={'is_active': 1}, # This annotation is used by the admin changelist. sqlite requires '1', not 'True' @@ -57,15 +57,15 @@ class ConfigurationModelManager(models.Manager): subquery = self._current_ids_subquery() return self.get_queryset().extra( # pylint: disable=no-member select={'is_active': "{table_name}.id IN ({subquery})".format( - table_name=self.model._meta.db_table, # pylint: disable=protected-access + table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member subquery=subquery, )} ) else: return self.get_queryset().extra( # pylint: disable=no-member select={'is_active': "{table_name}.id = {pk}".format( - table_name=self.model._meta.db_table, # pylint: disable=protected-access - pk=self.model.current().pk, + table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member + pk=self.model.current().pk, # pylint: disable=no-member )} ) @@ -145,9 +145,15 @@ class ConfigurationModel(models.Model): return current @classmethod - def is_enabled(cls): - """Returns True if this feature is configured as enabled, else False.""" - return cls.current().enabled + def is_enabled(cls, *key_fields): + """ + Returns True if this feature is configured as enabled, else False. + + Arguments: + key_fields: The positional arguments are the KEY_FIELDS used to identify the + configuration to be checked. + """ + return cls.current(*key_fields).enabled @classmethod def key_values_cache_key_name(cls, *key_fields): diff --git a/common/djangoapps/third_party_auth/admin.py b/common/djangoapps/third_party_auth/admin.py index 47b65b6d67..8ba67183c3 100644 --- a/common/djangoapps/third_party_auth/admin.py +++ b/common/djangoapps/third_party_auth/admin.py @@ -33,7 +33,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin): def get_list_display(self, request): """ Don't show every single field in the admin change list """ return ( - 'name', 'enabled', 'backend_name', 'secondary', 'skip_registration_form', + 'name', 'enabled', 'site', 'backend_name', 'secondary', 'skip_registration_form', 'skip_email_verification', 'change_date', 'changed_by', 'edit_link', ) @@ -52,7 +52,7 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): def get_list_display(self, request): """ Don't show every single field in the admin change list """ return ( - 'name', 'enabled', 'backend_name', 'entity_id', 'metadata_source', + 'name', 'enabled', 'site', 'backend_name', 'entity_id', 'metadata_source', 'has_data', 'mode', 'change_date', 'changed_by', 'edit_link', ) @@ -86,13 +86,13 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin) -class SAMLConfigurationAdmin(ConfigurationModelAdmin): +class SAMLConfigurationAdmin(KeyedConfigurationModelAdmin): """ Django Admin class for SAMLConfiguration """ def get_list_display(self, request): """ Shorten the public/private keys in the change view """ return ( - 'change_date', 'changed_by', 'enabled', 'entity_id', - 'org_info_str', 'key_summary', + 'site', 'change_date', 'changed_by', 'enabled', 'entity_id', + 'org_info_str', 'key_summary', 'edit_link', ) def key_summary(self, inst): @@ -136,6 +136,7 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin): return ( 'name', 'enabled', + 'site', 'lti_consumer_key', 'lti_max_timestamp_age', 'change_date', diff --git a/common/djangoapps/third_party_auth/lti.py b/common/djangoapps/third_party_auth/lti.py index 841ed12496..53c4088408 100644 --- a/common/djangoapps/third_party_auth/lti.py +++ b/common/djangoapps/third_party_auth/lti.py @@ -198,9 +198,9 @@ class LTIAuthBackend(BaseAuth): """ from .models import LTIProviderConfig provider_config = LTIProviderConfig.current(lti_consumer_key) - if provider_config and provider_config.enabled: + if provider_config and provider_config.enabled_for_current_site: return ( - provider_config.enabled, + provider_config.enabled_for_current_site, provider_config.get_lti_consumer_secret(), provider_config.lti_max_timestamp_age, ) diff --git a/common/djangoapps/third_party_auth/migrations/0005_add_site_field.py b/common/djangoapps/third_party_auth/migrations/0005_add_site_field.py new file mode 100644 index 0000000000..b02d39568f --- /dev/null +++ b/common/djangoapps/third_party_auth/migrations/0005_add_site_field.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.conf import settings +from django.db import migrations, models + + +def fill_oauth2_slug(apps, schema_editor): + """ + Fill in the provider_slug to be the same as backend_name for backwards compatability. + """ + OAuth2ProviderConfig = apps.get_model('third_party_auth', 'OAuth2ProviderConfig') + for config in OAuth2ProviderConfig.objects.all(): + config.provider_slug = config.backend_name + config.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('sites', '0001_initial'), + ('third_party_auth', '0004_add_visible_field'), + ] + + operations = [ + migrations.AddField( + model_name='oauth2providerconfig', + name='provider_slug', + field=models.SlugField( + default='temp', + help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"', + max_length=30 + ), + preserve_default=False, + ), + migrations.RunPython(fill_oauth2_slug, reverse_code=migrations.RunPython.noop), + migrations.AddField( + model_name='ltiproviderconfig', + name='site', + field=models.ForeignKey( + related_name='ltiproviderconfigs', + default=settings.SITE_ID, + to='sites.Site', + help_text='The Site that this provider configuration belongs to.' + ), + ), + migrations.AddField( + model_name='oauth2providerconfig', + name='site', + field=models.ForeignKey( + related_name='oauth2providerconfigs', + default=settings.SITE_ID, + to='sites.Site', + help_text='The Site that this provider configuration belongs to.' + ), + ), + migrations.AddField( + model_name='samlproviderconfig', + name='site', + field=models.ForeignKey( + related_name='samlproviderconfigs', + default=settings.SITE_ID, + to='sites.Site', + help_text='The Site that this provider configuration belongs to.' + ), + ), + migrations.AddField( + model_name='samlconfiguration', + name='site', + field=models.ForeignKey( + related_name='samlconfigurations', + default=settings.SITE_ID, + to='sites.Site', + help_text='The Site that this SAML configuration belongs to.' + ), + ), + ] diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index 4ab0357d95..7dd7d3feb1 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -7,6 +7,7 @@ from __future__ import absolute_import from config_models.models import ConfigurationModel, cache from django.conf import settings +from django.contrib.sites.models import Site from django.core.exceptions import ValidationError from django.db import models from django.utils import timezone @@ -22,6 +23,7 @@ from .lti import LTIAuthBackend, LTI_PARAMS_KEY from social.exceptions import SocialAuthBaseException from social.utils import module_member from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers +from openedx.core.djangoapps.theming.helpers import get_current_request log = logging.getLogger(__name__) @@ -106,6 +108,14 @@ class ProviderConfig(ConfigurationModel): 'in a separate list of "Institution" login providers.' ), ) + site = models.ForeignKey( + Site, + default=settings.SITE_ID, + related_name='%(class)ss', + help_text=_( + 'The Site that this provider configuration belongs to.' + ), + ) skip_registration_form = models.BooleanField( default=False, help_text=_( @@ -226,7 +236,14 @@ class ProviderConfig(ConfigurationModel): Determines whether the provider ought to be shown as an option with which to authenticate on the login screen, registration screen, and elsewhere. """ - return bool(self.enabled and self.accepts_logins and self.visible) + return bool(self.enabled_for_current_site and self.accepts_logins and self.visible) + + @property + def enabled_for_current_site(self): + """ + Determines if the provider is able to be used with the current site. + """ + return self.enabled and self.site == Site.objects.get_current(get_current_request()) class OAuth2ProviderConfig(ProviderConfig): @@ -235,7 +252,7 @@ class OAuth2ProviderConfig(ProviderConfig): Also works for OAuth1 providers. """ prefix = 'oa2' - KEY_FIELDS = ('backend_name', ) # Backend name is unique + KEY_FIELDS = ('provider_slug', ) # Backend name is unique backend_name = models.CharField( max_length=50, blank=False, db_index=True, help_text=( @@ -244,6 +261,12 @@ class OAuth2ProviderConfig(ProviderConfig): # To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS ) ) + provider_slug = models.SlugField( + max_length=30, db_index=True, + help_text=( + 'A short string uniquely identifying this provider. ' + 'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"' + )) key = models.TextField(blank=True, verbose_name="Client ID") secret = models.TextField( blank=True, @@ -406,6 +429,15 @@ class SAMLConfiguration(ConfigurationModel): Service Provider and allow users to authenticate via third party SAML Identity Providers (IdPs) """ + KEY_FIELDS = ('site_id', ) + site = models.ForeignKey( + Site, + default=settings.SITE_ID, + related_name='%(class)ss', + help_text=_( + 'The Site that this SAML configuration belongs to.' + ), + ) private_key = models.TextField( help_text=( 'To generate a key pair as two files, run ' diff --git a/common/djangoapps/third_party_auth/provider.py b/common/djangoapps/third_party_auth/provider.py index 426ac7fd32..fe72bc6561 100644 --- a/common/djangoapps/third_party_auth/provider.py +++ b/common/djangoapps/third_party_auth/provider.py @@ -1,6 +1,9 @@ """ Third-party auth provider configuration API. """ +from django.contrib.sites.models import Site +from openedx.core.djangoapps.theming.helpers import get_current_request + from .models import ( OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig, _PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS, @@ -15,20 +18,23 @@ class Registry(object): """ @classmethod def _enabled_providers(cls): - """ Helper method to iterate over all providers """ + """ + Helper method that returns a generator used to iterate over all providers + of the current site. + """ for backend_name in _PSA_OAUTH2_BACKENDS: provider = OAuth2ProviderConfig.current(backend_name) - if provider.enabled: + if provider.enabled_for_current_site: yield provider - if SAMLConfiguration.is_enabled(): + if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())): idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True) for idp_slug in idp_slugs: provider = SAMLProviderConfig.current(idp_slug) - if provider.enabled and provider.backend_name in _PSA_SAML_BACKENDS: + if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS: yield provider for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True): provider = LTIProviderConfig.current(consumer_key) - if provider.enabled and provider.backend_name in _LTI_BACKENDS: + if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS: yield provider @classmethod @@ -69,7 +75,7 @@ class Registry(object): @classmethod def get_enabled_by_backend_name(cls, backend_name): """Generator returning all enabled providers that use the specified - backend. + backend on the current site. Example: >>> list(get_enabled_by_backend_name("tpa-saml")) @@ -84,16 +90,17 @@ class Registry(object): """ if backend_name in _PSA_OAUTH2_BACKENDS: provider = OAuth2ProviderConfig.current(backend_name) - if provider.enabled: + if provider.enabled_for_current_site: yield provider - elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(): + elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled( + Site.objects.get_current(get_current_request())): idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True) for idp_name in idp_names: provider = SAMLProviderConfig.current(idp_name) - if provider.backend_name == backend_name and provider.enabled: + if provider.backend_name == backend_name and provider.enabled_for_current_site: yield provider elif backend_name in _LTI_BACKENDS: for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True): provider = LTIProviderConfig.current(consumer_key) - if provider.backend_name == backend_name and provider.enabled: + if provider.backend_name == backend_name and provider.enabled_for_current_site: yield provider diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index a86db3ecbd..7a90a32028 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -2,8 +2,10 @@ Slightly customized python-social-auth backend for SAML 2.0 support """ import logging +from django.contrib.sites.models import Site from django.http import Http404 from django.utils.functional import cached_property +from openedx.core.djangoapps.theming.helpers import get_current_request from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT from social.exceptions import AuthForbidden, AuthMissingParameter @@ -41,10 +43,12 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method if not self._config.enabled: log.error('SAML authentication is not enabled') raise Http404 + # TODO: remove this check once the fix is merged upstream: # https://github.com/omab/python-social-auth/pull/821 if 'idp' not in self.strategy.request_data(): raise AuthMissingParameter(self, 'idp') + return super(SAMLAuthBackend, self).auth_url() def _check_entitlements(self, idp, attributes): @@ -93,4 +97,4 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method @cached_property def _config(self): from .models import SAMLConfiguration - return SAMLConfiguration.current() + return SAMLConfiguration.current(Site.objects.get_current(get_current_request())) diff --git a/common/djangoapps/third_party_auth/strategy.py b/common/djangoapps/third_party_auth/strategy.py index f542254234..a0ced2f88b 100644 --- a/common/djangoapps/third_party_auth/strategy.py +++ b/common/djangoapps/third_party_auth/strategy.py @@ -26,7 +26,7 @@ class ConfigurationModelStrategy(DjangoStrategy): """ if isinstance(backend, OAuthAuth): provider_config = OAuth2ProviderConfig.current(backend.name) - if not provider_config.enabled: + if not provider_config.enabled_for_current_site: raise Exception("Can't fetch setting of a disabled backend/provider.") try: return provider_config.get_setting(name) diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py index 71373b2f75..0c64859f2b 100644 --- a/common/djangoapps/third_party_auth/tasks.py +++ b/common/djangoapps/third_party_auth/tasks.py @@ -36,16 +36,13 @@ def fetch_saml_metadata(): num_failed: Number of providers that could not be updated num_total: Total number of providers whose metadata was fetched """ - if not SAMLConfiguration.is_enabled(): - return (0, 0, 0) # Nothing to do until SAML is enabled. - num_changed, num_failed = 0, 0 # First make a list of all the metadata XML URLs: url_map = {} for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True): config = SAMLProviderConfig.current(idp_slug) - if not config.enabled: + if not config.enabled or not SAMLConfiguration.is_enabled(config.site): continue url = config.metadata_source if url not in url_map: diff --git a/common/djangoapps/third_party_auth/tests/test_provider.py b/common/djangoapps/third_party_auth/tests/test_provider.py index f55b0f0f06..a9a7ec1c6b 100644 --- a/common/djangoapps/third_party_auth/tests/test_provider.py +++ b/common/djangoapps/third_party_auth/tests/test_provider.py @@ -1,10 +1,15 @@ """Unit tests for provider.py.""" +from django.contrib.sites.models import Site from mock import Mock, patch +from openedx.core.djangoapps.site_configuration.tests.test_util import with_site_configuration from third_party_auth import provider from third_party_auth.tests import testutil import unittest +SITE_DOMAIN_A = 'professionalx.example.com' +SITE_DOMAIN_B = 'somethingelse.example.com' + @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') class RegistryTest(testutil.TestCase): @@ -84,6 +89,22 @@ class RegistryTest(testutil.TestCase): self.assertNotIn(no_log_in_provider.provider_id, provider_ids) self.assertIn(normal_provider.provider_id, provider_ids) + def test_provider_enabled_for_current_site(self): + """ + Verify that enabled_for_current_site returns True when the provider matches the current site. + """ + prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current()) + self.assertEqual(prov.enabled_for_current_site, True) + + @with_site_configuration(SITE_DOMAIN_A) + def test_provider_disabled_for_mismatching_site(self): + """ + Verify that enabled_for_current_site returns False when the provider is configured for a different site. + """ + site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0] + prov = self.configure_google_provider(visible=True, enabled=True, site=site_b) + self.assertEqual(prov.enabled_for_current_site, False) + def test_get_returns_enabled_provider(self): google_provider = self.configure_google_provider(enabled=True) self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id) diff --git a/common/djangoapps/third_party_auth/tests/testutil.py b/common/djangoapps/third_party_auth/tests/testutil.py index fa9ef60bf1..18d4f66a7f 100644 --- a/common/djangoapps/third_party_auth/tests/testutil.py +++ b/common/djangoapps/third_party_auth/tests/testutil.py @@ -7,6 +7,7 @@ Used by Django and non-Django tests; must not have Django deps. from contextlib import contextmanager from django.conf import settings from django.contrib.auth.models import User +from django.contrib.sites.models import Site from provider.oauth2.models import Client as OAuth2Client from provider import constants import django.test @@ -76,13 +77,17 @@ class ThirdPartyAuthTestMixin(object): @staticmethod def configure_oauth_provider(**kwargs): """ Update the settings for an OAuth2-based third party auth provider """ + kwargs.setdefault('provider_slug', kwargs['backend_name']) obj = OAuth2ProviderConfig(**kwargs) obj.save() return obj def configure_saml_provider(self, **kwargs): """ Update the settings for a SAML-based third party auth provider """ - self.assertTrue(SAMLConfiguration.is_enabled(), "SAML Provider Configuration only works if SAML is enabled.") + self.assertTrue( + SAMLConfiguration.is_enabled(Site.objects.get_current()), + "SAML Provider Configuration only works if SAML is enabled." + ) obj = SAMLProviderConfig(**kwargs) obj.save() return obj diff --git a/common/djangoapps/third_party_auth/views.py b/common/djangoapps/third_party_auth/views.py index 56d34dd178..caf0ab2b5e 100644 --- a/common/djangoapps/third_party_auth/views.py +++ b/common/djangoapps/third_party_auth/views.py @@ -36,7 +36,7 @@ def saml_metadata_view(request): Get the Service Provider metadata for this edx-platform instance. You must send this XML to any Shibboleth Identity Provider that you wish to use. """ - if not SAMLConfiguration.is_enabled(): + if not SAMLConfiguration.is_enabled(request.site): raise Http404 complete_url = reverse('social:complete', args=("tpa-saml", )) if settings.APPEND_SLASH and not complete_url.endswith('/'): diff --git a/common/test/db_fixtures/third_party_auth.json b/common/test/db_fixtures/third_party_auth.json index c414b0e53a..8bfc8855ac 100644 --- a/common/test/db_fixtures/third_party_auth.json +++ b/common/test/db_fixtures/third_party_auth.json @@ -10,8 +10,10 @@ "icon_class": "fa-google-plus", "icon_image": null, "backend_name": "google-oauth2", + "provider_slug": "google-oauth2", "key": "test", "secret": "test", + "site": 2, "other_settings": "{}", "visible": true } @@ -27,8 +29,10 @@ "icon_class": "fa-facebook", "icon_image": null, "backend_name": "facebook", + "provider_slug": "facebook", "key": "test", "secret": "test", + "site": 2, "other_settings": "{}", "visible": true } @@ -44,8 +48,10 @@ "icon_class": "", "icon_image": "test-icon.png", "backend_name": "dummy", + "provider_slug": "dummy", "key": "", "secret": "", + "site": 2, "other_settings": "{}", "visible": true } diff --git a/openedx/core/djangoapps/site_configuration/tests/test_util.py b/openedx/core/djangoapps/site_configuration/tests/test_util.py index ace9bccaf5..2ddb1e55fe 100644 --- a/openedx/core/djangoapps/site_configuration/tests/test_util.py +++ b/openedx/core/djangoapps/site_configuration/tests/test_util.py @@ -37,7 +37,8 @@ def with_site_configuration(domain="test.localhost", configuration=None): with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration', return_value=site_configuration): with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site): - return func(*args, **kwargs) + with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site): + return func(*args, **kwargs) return _decorated return _decorator @@ -63,4 +64,5 @@ def with_site_configuration_context(domain="test.localhost", configuration=None) with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration', return_value=site_configuration): with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site): - yield + with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site): + yield