Implement Site settings for Third Party Auth providers
This commit is contained in:
@@ -42,7 +42,7 @@ class ConfigurationModelManager(models.Manager):
|
||||
assert self.model.KEY_FIELDS != (), "Just use model.current() if there are no KEY_FIELDS"
|
||||
return self.get_queryset().extra( # pylint: disable=no-member
|
||||
where=["{table_name}.id IN ({subquery})".format(
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
|
||||
subquery=self._current_ids_subquery(),
|
||||
)],
|
||||
select={'is_active': 1}, # This annotation is used by the admin changelist. sqlite requires '1', not 'True'
|
||||
@@ -57,15 +57,15 @@ class ConfigurationModelManager(models.Manager):
|
||||
subquery = self._current_ids_subquery()
|
||||
return self.get_queryset().extra( # pylint: disable=no-member
|
||||
select={'is_active': "{table_name}.id IN ({subquery})".format(
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
|
||||
subquery=subquery,
|
||||
)}
|
||||
)
|
||||
else:
|
||||
return self.get_queryset().extra( # pylint: disable=no-member
|
||||
select={'is_active': "{table_name}.id = {pk}".format(
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access
|
||||
pk=self.model.current().pk,
|
||||
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
|
||||
pk=self.model.current().pk, # pylint: disable=no-member
|
||||
)}
|
||||
)
|
||||
|
||||
@@ -145,9 +145,15 @@ class ConfigurationModel(models.Model):
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def is_enabled(cls):
|
||||
"""Returns True if this feature is configured as enabled, else False."""
|
||||
return cls.current().enabled
|
||||
def is_enabled(cls, *key_fields):
|
||||
"""
|
||||
Returns True if this feature is configured as enabled, else False.
|
||||
|
||||
Arguments:
|
||||
key_fields: The positional arguments are the KEY_FIELDS used to identify the
|
||||
configuration to be checked.
|
||||
"""
|
||||
return cls.current(*key_fields).enabled
|
||||
|
||||
@classmethod
|
||||
def key_values_cache_key_name(cls, *key_fields):
|
||||
|
||||
@@ -33,7 +33,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin):
|
||||
def get_list_display(self, request):
|
||||
""" Don't show every single field in the admin change list """
|
||||
return (
|
||||
'name', 'enabled', 'backend_name', 'secondary', 'skip_registration_form',
|
||||
'name', 'enabled', 'site', 'backend_name', 'secondary', 'skip_registration_form',
|
||||
'skip_email_verification', 'change_date', 'changed_by', 'edit_link',
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
|
||||
def get_list_display(self, request):
|
||||
""" Don't show every single field in the admin change list """
|
||||
return (
|
||||
'name', 'enabled', 'backend_name', 'entity_id', 'metadata_source',
|
||||
'name', 'enabled', 'site', 'backend_name', 'entity_id', 'metadata_source',
|
||||
'has_data', 'mode', 'change_date', 'changed_by', 'edit_link',
|
||||
)
|
||||
|
||||
@@ -86,13 +86,13 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
|
||||
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
|
||||
|
||||
|
||||
class SAMLConfigurationAdmin(ConfigurationModelAdmin):
|
||||
class SAMLConfigurationAdmin(KeyedConfigurationModelAdmin):
|
||||
""" Django Admin class for SAMLConfiguration """
|
||||
def get_list_display(self, request):
|
||||
""" Shorten the public/private keys in the change view """
|
||||
return (
|
||||
'change_date', 'changed_by', 'enabled', 'entity_id',
|
||||
'org_info_str', 'key_summary',
|
||||
'site', 'change_date', 'changed_by', 'enabled', 'entity_id',
|
||||
'org_info_str', 'key_summary', 'edit_link',
|
||||
)
|
||||
|
||||
def key_summary(self, inst):
|
||||
@@ -136,6 +136,7 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin):
|
||||
return (
|
||||
'name',
|
||||
'enabled',
|
||||
'site',
|
||||
'lti_consumer_key',
|
||||
'lti_max_timestamp_age',
|
||||
'change_date',
|
||||
|
||||
@@ -198,9 +198,9 @@ class LTIAuthBackend(BaseAuth):
|
||||
"""
|
||||
from .models import LTIProviderConfig
|
||||
provider_config = LTIProviderConfig.current(lti_consumer_key)
|
||||
if provider_config and provider_config.enabled:
|
||||
if provider_config and provider_config.enabled_for_current_site:
|
||||
return (
|
||||
provider_config.enabled,
|
||||
provider_config.enabled_for_current_site,
|
||||
provider_config.get_lti_consumer_secret(),
|
||||
provider_config.lti_max_timestamp_age,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def fill_oauth2_slug(apps, schema_editor):
|
||||
"""
|
||||
Fill in the provider_slug to be the same as backend_name for backwards compatability.
|
||||
"""
|
||||
OAuth2ProviderConfig = apps.get_model('third_party_auth', 'OAuth2ProviderConfig')
|
||||
for config in OAuth2ProviderConfig.objects.all():
|
||||
config.provider_slug = config.backend_name
|
||||
config.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('sites', '0001_initial'),
|
||||
('third_party_auth', '0004_add_visible_field'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='oauth2providerconfig',
|
||||
name='provider_slug',
|
||||
field=models.SlugField(
|
||||
default='temp',
|
||||
help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"',
|
||||
max_length=30
|
||||
),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RunPython(fill_oauth2_slug, reverse_code=migrations.RunPython.noop),
|
||||
migrations.AddField(
|
||||
model_name='ltiproviderconfig',
|
||||
name='site',
|
||||
field=models.ForeignKey(
|
||||
related_name='ltiproviderconfigs',
|
||||
default=settings.SITE_ID,
|
||||
to='sites.Site',
|
||||
help_text='The Site that this provider configuration belongs to.'
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='oauth2providerconfig',
|
||||
name='site',
|
||||
field=models.ForeignKey(
|
||||
related_name='oauth2providerconfigs',
|
||||
default=settings.SITE_ID,
|
||||
to='sites.Site',
|
||||
help_text='The Site that this provider configuration belongs to.'
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='samlproviderconfig',
|
||||
name='site',
|
||||
field=models.ForeignKey(
|
||||
related_name='samlproviderconfigs',
|
||||
default=settings.SITE_ID,
|
||||
to='sites.Site',
|
||||
help_text='The Site that this provider configuration belongs to.'
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='samlconfiguration',
|
||||
name='site',
|
||||
field=models.ForeignKey(
|
||||
related_name='samlconfigurations',
|
||||
default=settings.SITE_ID,
|
||||
to='sites.Site',
|
||||
help_text='The Site that this SAML configuration belongs to.'
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -7,6 +7,7 @@ from __future__ import absolute_import
|
||||
|
||||
from config_models.models import ConfigurationModel, cache
|
||||
from django.conf import settings
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
@@ -22,6 +23,7 @@ from .lti import LTIAuthBackend, LTI_PARAMS_KEY
|
||||
from social.exceptions import SocialAuthBaseException
|
||||
from social.utils import module_member
|
||||
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
|
||||
from openedx.core.djangoapps.theming.helpers import get_current_request
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,6 +108,14 @@ class ProviderConfig(ConfigurationModel):
|
||||
'in a separate list of "Institution" login providers.'
|
||||
),
|
||||
)
|
||||
site = models.ForeignKey(
|
||||
Site,
|
||||
default=settings.SITE_ID,
|
||||
related_name='%(class)ss',
|
||||
help_text=_(
|
||||
'The Site that this provider configuration belongs to.'
|
||||
),
|
||||
)
|
||||
skip_registration_form = models.BooleanField(
|
||||
default=False,
|
||||
help_text=_(
|
||||
@@ -226,7 +236,14 @@ class ProviderConfig(ConfigurationModel):
|
||||
Determines whether the provider ought to be shown as an option with
|
||||
which to authenticate on the login screen, registration screen, and elsewhere.
|
||||
"""
|
||||
return bool(self.enabled and self.accepts_logins and self.visible)
|
||||
return bool(self.enabled_for_current_site and self.accepts_logins and self.visible)
|
||||
|
||||
@property
|
||||
def enabled_for_current_site(self):
|
||||
"""
|
||||
Determines if the provider is able to be used with the current site.
|
||||
"""
|
||||
return self.enabled and self.site == Site.objects.get_current(get_current_request())
|
||||
|
||||
|
||||
class OAuth2ProviderConfig(ProviderConfig):
|
||||
@@ -235,7 +252,7 @@ class OAuth2ProviderConfig(ProviderConfig):
|
||||
Also works for OAuth1 providers.
|
||||
"""
|
||||
prefix = 'oa2'
|
||||
KEY_FIELDS = ('backend_name', ) # Backend name is unique
|
||||
KEY_FIELDS = ('provider_slug', ) # Backend name is unique
|
||||
backend_name = models.CharField(
|
||||
max_length=50, blank=False, db_index=True,
|
||||
help_text=(
|
||||
@@ -244,6 +261,12 @@ class OAuth2ProviderConfig(ProviderConfig):
|
||||
# To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS
|
||||
)
|
||||
)
|
||||
provider_slug = models.SlugField(
|
||||
max_length=30, db_index=True,
|
||||
help_text=(
|
||||
'A short string uniquely identifying this provider. '
|
||||
'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"'
|
||||
))
|
||||
key = models.TextField(blank=True, verbose_name="Client ID")
|
||||
secret = models.TextField(
|
||||
blank=True,
|
||||
@@ -406,6 +429,15 @@ class SAMLConfiguration(ConfigurationModel):
|
||||
Service Provider and allow users to authenticate via third party SAML
|
||||
Identity Providers (IdPs)
|
||||
"""
|
||||
KEY_FIELDS = ('site_id', )
|
||||
site = models.ForeignKey(
|
||||
Site,
|
||||
default=settings.SITE_ID,
|
||||
related_name='%(class)ss',
|
||||
help_text=_(
|
||||
'The Site that this SAML configuration belongs to.'
|
||||
),
|
||||
)
|
||||
private_key = models.TextField(
|
||||
help_text=(
|
||||
'To generate a key pair as two files, run '
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""
|
||||
Third-party auth provider configuration API.
|
||||
"""
|
||||
from django.contrib.sites.models import Site
|
||||
from openedx.core.djangoapps.theming.helpers import get_current_request
|
||||
|
||||
from .models import (
|
||||
OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig,
|
||||
_PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS,
|
||||
@@ -15,20 +18,23 @@ class Registry(object):
|
||||
"""
|
||||
@classmethod
|
||||
def _enabled_providers(cls):
|
||||
""" Helper method to iterate over all providers """
|
||||
"""
|
||||
Helper method that returns a generator used to iterate over all providers
|
||||
of the current site.
|
||||
"""
|
||||
for backend_name in _PSA_OAUTH2_BACKENDS:
|
||||
provider = OAuth2ProviderConfig.current(backend_name)
|
||||
if provider.enabled:
|
||||
if provider.enabled_for_current_site:
|
||||
yield provider
|
||||
if SAMLConfiguration.is_enabled():
|
||||
if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
|
||||
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
|
||||
for idp_slug in idp_slugs:
|
||||
provider = SAMLProviderConfig.current(idp_slug)
|
||||
if provider.enabled and provider.backend_name in _PSA_SAML_BACKENDS:
|
||||
if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
|
||||
yield provider
|
||||
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
|
||||
provider = LTIProviderConfig.current(consumer_key)
|
||||
if provider.enabled and provider.backend_name in _LTI_BACKENDS:
|
||||
if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
|
||||
yield provider
|
||||
|
||||
@classmethod
|
||||
@@ -69,7 +75,7 @@ class Registry(object):
|
||||
@classmethod
|
||||
def get_enabled_by_backend_name(cls, backend_name):
|
||||
"""Generator returning all enabled providers that use the specified
|
||||
backend.
|
||||
backend on the current site.
|
||||
|
||||
Example:
|
||||
>>> list(get_enabled_by_backend_name("tpa-saml"))
|
||||
@@ -84,16 +90,17 @@ class Registry(object):
|
||||
"""
|
||||
if backend_name in _PSA_OAUTH2_BACKENDS:
|
||||
provider = OAuth2ProviderConfig.current(backend_name)
|
||||
if provider.enabled:
|
||||
if provider.enabled_for_current_site:
|
||||
yield provider
|
||||
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled():
|
||||
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
|
||||
Site.objects.get_current(get_current_request())):
|
||||
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
|
||||
for idp_name in idp_names:
|
||||
provider = SAMLProviderConfig.current(idp_name)
|
||||
if provider.backend_name == backend_name and provider.enabled:
|
||||
if provider.backend_name == backend_name and provider.enabled_for_current_site:
|
||||
yield provider
|
||||
elif backend_name in _LTI_BACKENDS:
|
||||
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
|
||||
provider = LTIProviderConfig.current(consumer_key)
|
||||
if provider.backend_name == backend_name and provider.enabled:
|
||||
if provider.backend_name == backend_name and provider.enabled_for_current_site:
|
||||
yield provider
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
Slightly customized python-social-auth backend for SAML 2.0 support
|
||||
"""
|
||||
import logging
|
||||
from django.contrib.sites.models import Site
|
||||
from django.http import Http404
|
||||
from django.utils.functional import cached_property
|
||||
from openedx.core.djangoapps.theming.helpers import get_current_request
|
||||
from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT
|
||||
from social.exceptions import AuthForbidden, AuthMissingParameter
|
||||
|
||||
@@ -41,10 +43,12 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
|
||||
if not self._config.enabled:
|
||||
log.error('SAML authentication is not enabled')
|
||||
raise Http404
|
||||
|
||||
# TODO: remove this check once the fix is merged upstream:
|
||||
# https://github.com/omab/python-social-auth/pull/821
|
||||
if 'idp' not in self.strategy.request_data():
|
||||
raise AuthMissingParameter(self, 'idp')
|
||||
|
||||
return super(SAMLAuthBackend, self).auth_url()
|
||||
|
||||
def _check_entitlements(self, idp, attributes):
|
||||
@@ -93,4 +97,4 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
|
||||
@cached_property
|
||||
def _config(self):
|
||||
from .models import SAMLConfiguration
|
||||
return SAMLConfiguration.current()
|
||||
return SAMLConfiguration.current(Site.objects.get_current(get_current_request()))
|
||||
|
||||
@@ -26,7 +26,7 @@ class ConfigurationModelStrategy(DjangoStrategy):
|
||||
"""
|
||||
if isinstance(backend, OAuthAuth):
|
||||
provider_config = OAuth2ProviderConfig.current(backend.name)
|
||||
if not provider_config.enabled:
|
||||
if not provider_config.enabled_for_current_site:
|
||||
raise Exception("Can't fetch setting of a disabled backend/provider.")
|
||||
try:
|
||||
return provider_config.get_setting(name)
|
||||
|
||||
@@ -36,16 +36,13 @@ def fetch_saml_metadata():
|
||||
num_failed: Number of providers that could not be updated
|
||||
num_total: Total number of providers whose metadata was fetched
|
||||
"""
|
||||
if not SAMLConfiguration.is_enabled():
|
||||
return (0, 0, 0) # Nothing to do until SAML is enabled.
|
||||
|
||||
num_changed, num_failed = 0, 0
|
||||
|
||||
# First make a list of all the metadata XML URLs:
|
||||
url_map = {}
|
||||
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
|
||||
config = SAMLProviderConfig.current(idp_slug)
|
||||
if not config.enabled:
|
||||
if not config.enabled or not SAMLConfiguration.is_enabled(config.site):
|
||||
continue
|
||||
url = config.metadata_source
|
||||
if url not in url_map:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Unit tests for provider.py."""
|
||||
|
||||
from django.contrib.sites.models import Site
|
||||
from mock import Mock, patch
|
||||
from openedx.core.djangoapps.site_configuration.tests.test_util import with_site_configuration
|
||||
from third_party_auth import provider
|
||||
from third_party_auth.tests import testutil
|
||||
import unittest
|
||||
|
||||
SITE_DOMAIN_A = 'professionalx.example.com'
|
||||
SITE_DOMAIN_B = 'somethingelse.example.com'
|
||||
|
||||
|
||||
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
|
||||
class RegistryTest(testutil.TestCase):
|
||||
@@ -84,6 +89,22 @@ class RegistryTest(testutil.TestCase):
|
||||
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
|
||||
self.assertIn(normal_provider.provider_id, provider_ids)
|
||||
|
||||
def test_provider_enabled_for_current_site(self):
|
||||
"""
|
||||
Verify that enabled_for_current_site returns True when the provider matches the current site.
|
||||
"""
|
||||
prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current())
|
||||
self.assertEqual(prov.enabled_for_current_site, True)
|
||||
|
||||
@with_site_configuration(SITE_DOMAIN_A)
|
||||
def test_provider_disabled_for_mismatching_site(self):
|
||||
"""
|
||||
Verify that enabled_for_current_site returns False when the provider is configured for a different site.
|
||||
"""
|
||||
site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0]
|
||||
prov = self.configure_google_provider(visible=True, enabled=True, site=site_b)
|
||||
self.assertEqual(prov.enabled_for_current_site, False)
|
||||
|
||||
def test_get_returns_enabled_provider(self):
|
||||
google_provider = self.configure_google_provider(enabled=True)
|
||||
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
|
||||
|
||||
@@ -7,6 +7,7 @@ Used by Django and non-Django tests; must not have Django deps.
|
||||
from contextlib import contextmanager
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.sites.models import Site
|
||||
from provider.oauth2.models import Client as OAuth2Client
|
||||
from provider import constants
|
||||
import django.test
|
||||
@@ -76,13 +77,17 @@ class ThirdPartyAuthTestMixin(object):
|
||||
@staticmethod
|
||||
def configure_oauth_provider(**kwargs):
|
||||
""" Update the settings for an OAuth2-based third party auth provider """
|
||||
kwargs.setdefault('provider_slug', kwargs['backend_name'])
|
||||
obj = OAuth2ProviderConfig(**kwargs)
|
||||
obj.save()
|
||||
return obj
|
||||
|
||||
def configure_saml_provider(self, **kwargs):
|
||||
""" Update the settings for a SAML-based third party auth provider """
|
||||
self.assertTrue(SAMLConfiguration.is_enabled(), "SAML Provider Configuration only works if SAML is enabled.")
|
||||
self.assertTrue(
|
||||
SAMLConfiguration.is_enabled(Site.objects.get_current()),
|
||||
"SAML Provider Configuration only works if SAML is enabled."
|
||||
)
|
||||
obj = SAMLProviderConfig(**kwargs)
|
||||
obj.save()
|
||||
return obj
|
||||
|
||||
@@ -36,7 +36,7 @@ def saml_metadata_view(request):
|
||||
Get the Service Provider metadata for this edx-platform instance.
|
||||
You must send this XML to any Shibboleth Identity Provider that you wish to use.
|
||||
"""
|
||||
if not SAMLConfiguration.is_enabled():
|
||||
if not SAMLConfiguration.is_enabled(request.site):
|
||||
raise Http404
|
||||
complete_url = reverse('social:complete', args=("tpa-saml", ))
|
||||
if settings.APPEND_SLASH and not complete_url.endswith('/'):
|
||||
|
||||
@@ -10,8 +10,10 @@
|
||||
"icon_class": "fa-google-plus",
|
||||
"icon_image": null,
|
||||
"backend_name": "google-oauth2",
|
||||
"provider_slug": "google-oauth2",
|
||||
"key": "test",
|
||||
"secret": "test",
|
||||
"site": 2,
|
||||
"other_settings": "{}",
|
||||
"visible": true
|
||||
}
|
||||
@@ -27,8 +29,10 @@
|
||||
"icon_class": "fa-facebook",
|
||||
"icon_image": null,
|
||||
"backend_name": "facebook",
|
||||
"provider_slug": "facebook",
|
||||
"key": "test",
|
||||
"secret": "test",
|
||||
"site": 2,
|
||||
"other_settings": "{}",
|
||||
"visible": true
|
||||
}
|
||||
@@ -44,8 +48,10 @@
|
||||
"icon_class": "",
|
||||
"icon_image": "test-icon.png",
|
||||
"backend_name": "dummy",
|
||||
"provider_slug": "dummy",
|
||||
"key": "",
|
||||
"secret": "",
|
||||
"site": 2,
|
||||
"other_settings": "{}",
|
||||
"visible": true
|
||||
}
|
||||
|
||||
@@ -37,7 +37,8 @@ def with_site_configuration(domain="test.localhost", configuration=None):
|
||||
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
|
||||
return_value=site_configuration):
|
||||
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
|
||||
return func(*args, **kwargs)
|
||||
with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
|
||||
return func(*args, **kwargs)
|
||||
return _decorated
|
||||
return _decorator
|
||||
|
||||
@@ -63,4 +64,5 @@ def with_site_configuration_context(domain="test.localhost", configuration=None)
|
||||
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
|
||||
return_value=site_configuration):
|
||||
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
|
||||
yield
|
||||
with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user