Files
edx-platform/common/djangoapps/third_party_auth/provider.py
2016-09-26 12:43:00 -04:00

107 lines
4.4 KiB
Python

"""
Third-party auth provider configuration API.
"""
from django.contrib.sites.models import Site
from openedx.core.djangoapps.theming.helpers import get_current_request
from .models import (
OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig,
_PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS,
)
class Registry(object):
"""
API for querying third-party auth ProviderConfig objects.
Providers must subclass ProviderConfig in order to be usable in the registry.
"""
@classmethod
def _enabled_providers(cls):
"""
Helper method that returns a generator used to iterate over all providers
of the current site.
"""
for backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled_for_current_site:
yield provider
if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_slug in idp_slugs:
provider = SAMLProviderConfig.current(idp_slug)
if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
yield provider
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key)
if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
yield provider
@classmethod
def enabled(cls):
"""Returns list of enabled providers."""
return sorted(cls._enabled_providers(), key=lambda provider: provider.name)
@classmethod
def displayed_for_login(cls):
"""Returns list of providers that can be used to initiate logins in the UI"""
return [provider for provider in cls.enabled() if provider.display_for_login]
@classmethod
def get(cls, provider_id):
"""Gets provider by provider_id string if enabled, else None."""
if '-' not in provider_id: # Check format - see models.py:ProviderConfig
raise ValueError("Invalid provider_id. Expect something like oa2-google")
try:
return next(provider for provider in cls._enabled_providers() if provider.provider_id == provider_id)
except StopIteration:
return None
@classmethod
def get_from_pipeline(cls, running_pipeline):
"""Gets the provider that is being used for the specified pipeline (or None).
Args:
running_pipeline: The python-social-auth pipeline being used to
authenticate a user.
Returns:
An instance of ProviderConfig or None.
"""
for enabled in cls._enabled_providers():
if enabled.is_active_for_pipeline(running_pipeline):
return enabled
@classmethod
def get_enabled_by_backend_name(cls, backend_name):
"""Generator returning all enabled providers that use the specified
backend on the current site.
Example:
>>> list(get_enabled_by_backend_name("tpa-saml"))
[<SAMLProviderConfig>, <SAMLProviderConfig>]
Args:
backend_name: The name of a python-social-auth backend used by
one or more providers.
Yields:
Instances of ProviderConfig.
"""
if backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled_for_current_site:
yield provider
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
Site.objects.get_current(get_current_request())):
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_name in idp_names:
provider = SAMLProviderConfig.current(idp_name)
if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider
elif backend_name in _LTI_BACKENDS:
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key)
if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider