Asynchronous metadata fetching using celery beat - PR 8518
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -91,6 +91,9 @@ logs
|
||||
chromedriver.log
|
||||
ghostdriver.log
|
||||
|
||||
### Celery artifacts ###
|
||||
celerybeat-schedule
|
||||
|
||||
### Unknown artifacts
|
||||
database.sqlite
|
||||
courseware/static/js/mathjax/*
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.contrib import admin
|
||||
|
||||
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
|
||||
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData
|
||||
from .tasks import fetch_saml_metadata
|
||||
|
||||
admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin)
|
||||
|
||||
@@ -29,6 +30,17 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
|
||||
has_data.short_description = u'Metadata Ready'
|
||||
has_data.boolean = True
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
"""
|
||||
Post save: Queue an asynchronous metadata fetch to update SAMLProviderData.
|
||||
We only want to do this for manual edits done using the admin interface.
|
||||
|
||||
Note: This only works if the celery worker and the app worker are using the
|
||||
same 'configuration' cache.
|
||||
"""
|
||||
super(SAMLProviderConfigAdmin, self).save_model(request, obj, form, change)
|
||||
fetch_saml_metadata.apply_async((), countdown=2)
|
||||
|
||||
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
|
||||
|
||||
|
||||
@@ -54,7 +66,7 @@ admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin)
|
||||
|
||||
|
||||
class SAMLProviderDataAdmin(admin.ModelAdmin):
|
||||
""" Django Admin class for SAMLProviderData """
|
||||
""" Django Admin class for SAMLProviderData (Read Only) """
|
||||
list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url')
|
||||
readonly_fields = ('is_valid', )
|
||||
|
||||
|
||||
@@ -2,20 +2,10 @@
|
||||
"""
|
||||
Management commands for third_party_auth
|
||||
"""
|
||||
import datetime
|
||||
import dateutil.parser
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from lxml import etree
|
||||
import requests
|
||||
from onelogin.saml2.utils import OneLogin_Saml2_Utils
|
||||
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
|
||||
|
||||
#pylint: disable=superfluous-parens,no-member
|
||||
|
||||
|
||||
class MetadataParseError(Exception):
|
||||
""" An error occurred while parsing the SAML metadata from an IdP """
|
||||
pass
|
||||
import logging
|
||||
from third_party_auth.models import SAMLConfiguration
|
||||
from third_party_auth.tasks import fetch_saml_metadata
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -27,120 +17,21 @@ class Command(BaseCommand):
|
||||
raise CommandError("saml requires one argument: pull")
|
||||
|
||||
if not SAMLConfiguration.is_enabled():
|
||||
self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n")
|
||||
raise CommandError("SAML support is disabled via SAMLConfiguration.")
|
||||
|
||||
subcommand = args[0]
|
||||
|
||||
if subcommand == "pull":
|
||||
self.cmd_pull()
|
||||
log_handler = logging.StreamHandler(self.stdout)
|
||||
log_handler.setLevel(logging.DEBUG)
|
||||
log = logging.getLogger('third_party_auth.tasks')
|
||||
log.propagate = False
|
||||
log.addHandler(log_handler)
|
||||
num_changed, num_failed, num_total = fetch_saml_metadata()
|
||||
self.stdout.write(
|
||||
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
|
||||
num_changed=num_changed, num_failed=num_failed, num_total=num_total
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise CommandError("Unknown argment: {}".format(subcommand))
|
||||
|
||||
@staticmethod
|
||||
def tag_name(tag_name):
|
||||
""" Get the namespaced-qualified name for an XML tag """
|
||||
return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name
|
||||
|
||||
def cmd_pull(self):
|
||||
""" Fetch the metadata for each provider and update the DB """
|
||||
# First make a list of all the metadata XML URLs:
|
||||
url_map = {}
|
||||
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
|
||||
config = SAMLProviderConfig.current(idp_slug)
|
||||
if not config.enabled:
|
||||
continue
|
||||
url = config.metadata_source
|
||||
if url not in url_map:
|
||||
url_map[url] = []
|
||||
if config.entity_id not in url_map[url]:
|
||||
url_map[url].append(config.entity_id)
|
||||
# Now fetch the metadata:
|
||||
for url, entity_ids in url_map.items():
|
||||
try:
|
||||
self.stdout.write("\n→ Fetching {}\n".format(url))
|
||||
if not url.lower().startswith('https'):
|
||||
self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n")
|
||||
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
|
||||
response.raise_for_status() # May raise an HTTPError
|
||||
|
||||
try:
|
||||
parser = etree.XMLParser(remove_comments=True)
|
||||
xml = etree.fromstring(response.text, parser)
|
||||
except etree.XMLSyntaxError:
|
||||
raise
|
||||
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
|
||||
|
||||
for entity_id in entity_ids:
|
||||
self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id))
|
||||
public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id)
|
||||
self._update_data(entity_id, public_key, sso_url, expires_at)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self.stderr.write(u"→ ERROR: {}\n\n".format(err.message))
|
||||
|
||||
@classmethod
|
||||
def _parse_metadata_xml(cls, xml, entity_id):
|
||||
"""
|
||||
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
|
||||
(public_key, sso_url, expires_at) for the specified entityID.
|
||||
|
||||
Raises MetadataParseError if anything is wrong.
|
||||
"""
|
||||
if xml.tag == cls.tag_name('EntityDescriptor'):
|
||||
entity_desc = xml
|
||||
else:
|
||||
if xml.tag != cls.tag_name('EntitiesDescriptor'):
|
||||
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
|
||||
entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id))
|
||||
if not entity_desc:
|
||||
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
|
||||
|
||||
expires_at = None
|
||||
if "validUntil" in xml.attrib:
|
||||
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
|
||||
if "cacheDuration" in xml.attrib:
|
||||
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
|
||||
if expires_at is None or cache_expires < expires_at:
|
||||
expires_at = cache_expires
|
||||
|
||||
sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor"))
|
||||
if not sso_desc:
|
||||
raise MetadataParseError("IDPSSODescriptor missing")
|
||||
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
|
||||
raise MetadataParseError("This IdP does not support SAML 2.0")
|
||||
|
||||
# Now we just need to get the public_key and sso_url
|
||||
public_key = sso_desc.findtext("./{}//{}".format(
|
||||
cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
|
||||
))
|
||||
if not public_key:
|
||||
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
|
||||
public_key = public_key.replace(" ", "")
|
||||
binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService")))
|
||||
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
|
||||
try:
|
||||
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
|
||||
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
|
||||
except KeyError:
|
||||
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
|
||||
return public_key, sso_url, expires_at
|
||||
|
||||
def _update_data(self, entity_id, public_key, sso_url, expires_at):
|
||||
"""
|
||||
Update/Create the SAMLProviderData for the given entity ID.
|
||||
"""
|
||||
data_obj = SAMLProviderData.current(entity_id)
|
||||
fetched_at = datetime.datetime.now()
|
||||
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
|
||||
data_obj.expires_at = expires_at
|
||||
data_obj.fetched_at = fetched_at
|
||||
data_obj.save()
|
||||
self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n")
|
||||
else:
|
||||
SAMLProviderData.objects.create(
|
||||
entity_id=entity_id,
|
||||
fetched_at=fetched_at,
|
||||
expires_at=expires_at,
|
||||
sso_url=sso_url,
|
||||
public_key=public_key,
|
||||
)
|
||||
self.stdout.write("→ Created new record for SAMLProviderData\n")
|
||||
|
||||
157
common/djangoapps/third_party_auth/tasks.py
Normal file
157
common/djangoapps/third_party_auth/tasks.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Code to manage fetching and storing the metadata of IdPs.
|
||||
"""
|
||||
#pylint: disable=no-member
|
||||
from celery.task import task # pylint: disable=import-error,no-name-in-module
|
||||
import datetime
|
||||
import dateutil.parser
|
||||
import logging
|
||||
from lxml import etree
|
||||
import requests
|
||||
from onelogin.saml2.utils import OneLogin_Saml2_Utils
|
||||
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' # The SAML Metadata XML namespace
|
||||
|
||||
|
||||
class MetadataParseError(Exception):
|
||||
""" An error occurred while parsing the SAML metadata from an IdP """
|
||||
pass
|
||||
|
||||
|
||||
@task(name='third_party_auth.fetch_saml_metadata')
|
||||
def fetch_saml_metadata():
|
||||
"""
|
||||
Fetch and store/update the metadata of all IdPs
|
||||
|
||||
This task should be run on a daily basis.
|
||||
It's OK to run this whether or not SAML is enabled.
|
||||
|
||||
Return value:
|
||||
tuple(num_changed, num_failed, num_total)
|
||||
num_changed: Number of providers that are either new or whose metadata has changed
|
||||
num_failed: Number of providers that could not be updated
|
||||
num_total: Total number of providers whose metadata was fetched
|
||||
"""
|
||||
if not SAMLConfiguration.is_enabled():
|
||||
return (0, 0, 0) # Nothing to do until SAML is enabled.
|
||||
|
||||
num_changed, num_failed = 0, 0
|
||||
|
||||
# First make a list of all the metadata XML URLs:
|
||||
url_map = {}
|
||||
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
|
||||
config = SAMLProviderConfig.current(idp_slug)
|
||||
if not config.enabled:
|
||||
continue
|
||||
url = config.metadata_source
|
||||
if url not in url_map:
|
||||
url_map[url] = []
|
||||
if config.entity_id not in url_map[url]:
|
||||
url_map[url].append(config.entity_id)
|
||||
# Now fetch the metadata:
|
||||
for url, entity_ids in url_map.items():
|
||||
try:
|
||||
log.info("Fetching %s", url)
|
||||
if not url.lower().startswith('https'):
|
||||
log.warning("This SAML metadata URL is not secure! It should use HTTPS. (%s)", url)
|
||||
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
|
||||
response.raise_for_status() # May raise an HTTPError
|
||||
|
||||
try:
|
||||
parser = etree.XMLParser(remove_comments=True)
|
||||
xml = etree.fromstring(response.text, parser)
|
||||
except etree.XMLSyntaxError:
|
||||
raise
|
||||
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
|
||||
|
||||
for entity_id in entity_ids:
|
||||
log.info(u"Processing IdP with entityID %s", entity_id)
|
||||
public_key, sso_url, expires_at = _parse_metadata_xml(xml, entity_id)
|
||||
changed = _update_data(entity_id, public_key, sso_url, expires_at)
|
||||
if changed:
|
||||
log.info(u"→ Created new record for SAMLProviderData")
|
||||
num_changed += 1
|
||||
else:
|
||||
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
log.exception(err.message)
|
||||
num_failed += 1
|
||||
return (num_changed, num_failed, len(url_map))
|
||||
|
||||
|
||||
def _parse_metadata_xml(xml, entity_id):
|
||||
"""
|
||||
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
|
||||
(public_key, sso_url, expires_at) for the specified entityID.
|
||||
|
||||
Raises MetadataParseError if anything is wrong.
|
||||
"""
|
||||
if xml.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor'):
|
||||
entity_desc = xml
|
||||
else:
|
||||
if xml.tag != etree.QName(SAML_XML_NS, 'EntitiesDescriptor'):
|
||||
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
|
||||
entity_desc = xml.find(
|
||||
".//{}[@entityID='{}']".format(etree.QName(SAML_XML_NS, 'EntityDescriptor'), entity_id)
|
||||
)
|
||||
if not entity_desc:
|
||||
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
|
||||
|
||||
expires_at = None
|
||||
if "validUntil" in xml.attrib:
|
||||
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
|
||||
if "cacheDuration" in xml.attrib:
|
||||
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
|
||||
if expires_at is None or cache_expires < expires_at:
|
||||
expires_at = cache_expires
|
||||
|
||||
sso_desc = entity_desc.find(etree.QName(SAML_XML_NS, "IDPSSODescriptor"))
|
||||
if not sso_desc:
|
||||
raise MetadataParseError("IDPSSODescriptor missing")
|
||||
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
|
||||
raise MetadataParseError("This IdP does not support SAML 2.0")
|
||||
|
||||
# Now we just need to get the public_key and sso_url
|
||||
public_key = sso_desc.findtext("./{}//{}".format(
|
||||
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
|
||||
))
|
||||
if not public_key:
|
||||
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
|
||||
public_key = public_key.replace(" ", "")
|
||||
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
|
||||
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
|
||||
try:
|
||||
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
|
||||
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
|
||||
except KeyError:
|
||||
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
|
||||
return public_key, sso_url, expires_at
|
||||
|
||||
|
||||
def _update_data(entity_id, public_key, sso_url, expires_at):
|
||||
"""
|
||||
Update/Create the SAMLProviderData for the given entity ID.
|
||||
Return value:
|
||||
False if nothing has changed and existing data's "fetched at" timestamp is just updated.
|
||||
True if a new record was created. (Either this is a new provider or something changed.)
|
||||
"""
|
||||
data_obj = SAMLProviderData.current(entity_id)
|
||||
fetched_at = datetime.datetime.now()
|
||||
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
|
||||
data_obj.expires_at = expires_at
|
||||
data_obj.fetched_at = fetched_at
|
||||
data_obj.save()
|
||||
return False
|
||||
else:
|
||||
SAMLProviderData.objects.create(
|
||||
entity_id=entity_id,
|
||||
fetched_at=fetched_at,
|
||||
expires_at=expires_at,
|
||||
sso_url=sso_url,
|
||||
public_key=public_key,
|
||||
)
|
||||
return True
|
||||
@@ -1,12 +1,11 @@
|
||||
"""
|
||||
Third_party_auth integration tests using a mock version of the TestShib provider
|
||||
"""
|
||||
from django.core.management import call_command
|
||||
from django.core.urlresolvers import reverse
|
||||
import httpretty
|
||||
from mock import patch
|
||||
import StringIO
|
||||
from student.tests.factories import UserFactory
|
||||
from third_party_auth.tasks import fetch_saml_metadata
|
||||
from third_party_auth.tests import testutil
|
||||
import unittest
|
||||
|
||||
@@ -209,15 +208,11 @@ class TestShibIntegrationTest(testutil.SAMLTestCase):
|
||||
self.configure_saml_provider(**kwargs)
|
||||
|
||||
if fetch_metadata:
|
||||
stdout = StringIO.StringIO()
|
||||
stderr = StringIO.StringIO()
|
||||
self.assertTrue(httpretty.is_enabled())
|
||||
call_command('saml', 'pull', stdout=stdout, stderr=stderr)
|
||||
stdout = stdout.getvalue().decode('utf-8')
|
||||
stderr = stderr.getvalue().decode('utf-8')
|
||||
self.assertEqual(stderr, '')
|
||||
self.assertIn(u'Fetching {}'.format(TESTSHIB_METADATA_URL), stdout)
|
||||
self.assertIn(u'Created new record for SAMLProviderData', stdout)
|
||||
num_changed, num_failed, num_total = fetch_saml_metadata()
|
||||
self.assertEqual(num_failed, 0)
|
||||
self.assertEqual(num_changed, 1)
|
||||
self.assertEqual(num_total, 1)
|
||||
|
||||
def _fake_testshib_login_and_return(self):
|
||||
""" Mocked: the user logs in to TestShib and then gets redirected back """
|
||||
|
||||
@@ -8,7 +8,7 @@ import unittest
|
||||
from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase
|
||||
|
||||
# Define some XML namespaces:
|
||||
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata'
|
||||
from third_party_auth.tasks import SAML_XML_NS
|
||||
XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#'
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ Common traits:
|
||||
# and throws spurious errors. Therefore, we disable invalid-name checking.
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from .common import *
|
||||
@@ -107,6 +108,7 @@ CELERY_QUEUES = {
|
||||
if os.environ.get('QUEUE') == 'high_mem':
|
||||
CELERYD_MAX_TASKS_PER_CHILD = 1
|
||||
|
||||
CELERYBEAT_SCHEDULE = {} # For scheduling tasks, entries can be added to this dict
|
||||
|
||||
########################## NON-SECURE ENV CONFIG ##############################
|
||||
# Things like server locations, ports, etc.
|
||||
@@ -552,6 +554,12 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
|
||||
# third_party_auth config moved to ConfigurationModels. This is for data migration only:
|
||||
THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None)
|
||||
|
||||
if ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24) is not None:
|
||||
CELERYBEAT_SCHEDULE['refresh-saml-metadata'] = {
|
||||
'task': 'third_party_auth.fetch_saml_metadata',
|
||||
'schedule': datetime.timedelta(hours=ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24)),
|
||||
}
|
||||
|
||||
##### OAUTH2 Provider ##############
|
||||
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
|
||||
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
|
||||
|
||||
@@ -109,7 +109,7 @@ def celery(options):
|
||||
Runs Celery workers.
|
||||
"""
|
||||
settings = getattr(options, 'settings', 'dev_with_worker')
|
||||
run_process(django_cmd('lms', settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.'))
|
||||
run_process(django_cmd('lms', settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.'))
|
||||
|
||||
|
||||
@task
|
||||
@@ -142,7 +142,7 @@ def run_all_servers(options):
|
||||
run_multi_processes([
|
||||
django_cmd('lms', settings_lms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['lms'])),
|
||||
django_cmd('studio', settings_cms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['studio'])),
|
||||
django_cmd('lms', worker_settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.')
|
||||
django_cmd('lms', worker_settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.')
|
||||
])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user