Merge pull request #17873 from edx/bexline/refactor_tpa_models

ENT-943 Refactoring third_party_auth models
This commit is contained in:
Brittney Exline
2018-04-05 15:02:56 -06:00
committed by GitHub
13 changed files with 97 additions and 23 deletions

View File

@@ -37,7 +37,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin):
def get_list_display(self, request):
""" Don't show every single field in the admin change list """
return (
'name', 'enabled', 'provider_slug', 'site', 'backend_name', 'secondary', 'skip_registration_form',
'name', 'enabled', 'slug', 'site', 'backend_name', 'secondary', 'skip_registration_form',
'skip_email_verification', 'change_date', 'changed_by', 'edit_link',
)

View File

@@ -54,7 +54,7 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
testshib = self.configure_saml_provider(
name='TestShib',
enabled=True,
idp_slug=IDP_SLUG_TESTSHIB
slug=IDP_SLUG_TESTSHIB
)
# Create several users and link each user to Google and TestShib
@@ -75,7 +75,7 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
UserSocialAuth.objects.create(
user=user,
provider=testshib.backend_name,
uid='{}:remote_{}'.format(testshib.idp_slug, username),
uid='{}:remote_{}'.format(testshib.slug, username),
)
# Create another user not linked to any providers:
UserFactory.create(username=CARL_USERNAME, password=PASSWORD)
@@ -238,7 +238,7 @@ class UserMappingViewAPITests(TpaAPITestCase):
self._verify_response(response, expect_code, expect_data)
def test_user_mappings_only_return_requested_idp_mapping_by_provider_id(self):
testshib2 = self.configure_saml_provider(name='TestShib2', enabled=True, idp_slug='testshib2')
testshib2 = self.configure_saml_provider(name='TestShib2', enabled=True, slug='testshib2')
username = 'testshib2user'
user = UserFactory.create(
username=username,
@@ -249,7 +249,7 @@ class UserMappingViewAPITests(TpaAPITestCase):
UserSocialAuth.objects.create(
user=user,
provider=testshib2.backend_name,
uid='{}:{}'.format(testshib2.idp_slug, username),
uid='{}:{}'.format(testshib2.slug, username),
)
url = reverse('third_party_auth_user_mapping_api', kwargs={'provider_id': PROVIDER_ID_TESTSHIB})

View File

@@ -153,6 +153,7 @@ class TestSAMLCommand(TestCase):
"site__domain": "second.testserver.fake",
"site__name": "testserver.fake",
"idp_slug": "second-test-shib",
"slug": "second-test-shib",
"entity_id": "https://idp.testshib.org/idp/another-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/another-testshib-providers.xml",
}
@@ -168,6 +169,7 @@ class TestSAMLCommand(TestCase):
"site__domain": "third.testserver.fake",
"site__name": "testserver.fake",
"idp_slug": "third-test-shib",
"slug": "third-test-shib",
# Note: This entity id will not be present in returned response and will cause failed update.
"entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml",
@@ -189,6 +191,7 @@ class TestSAMLCommand(TestCase):
"site__domain": "fourth.testserver.fake",
"site__name": "testserver.fake",
"idp_slug": "fourth-test-shib",
"slug": "fourth-test-shib",
"automatic_refresh_enabled": False,
# Note: This invalid entity id will not be present in the refresh set
"entity_id": "https://idp.testshib.org/idp/fourth-shibboleth",
@@ -243,6 +246,7 @@ class TestSAMLCommand(TestCase):
saml_provider_config={
"site__domain": "third.testserver.fake",
"idp_slug": "third-test-shib",
"slug": "third-test-shib",
# Note: This entity id will not be present in returned response and will cause failed update.
"entity_id": "https://idp.testshib.org/idp/non-existent-shibboleth",
"metadata_source": "https://www.testshib.org/metadata/third/testshib-providers.xml",

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
"""
Custom migration script to add slug field to all ProviderConfig models.
"""
from __future__ import unicode_literals
from django.db import migrations, models
from django.utils.text import slugify
def fill_slug_field(apps, schema_editor):
"""
Fill in the slug field for each ProviderConfig class for backwards compatability.
"""
OAuth2ProviderConfig = apps.get_model('third_party_auth', 'OAuth2ProviderConfig')
SAMLProviderConfig = apps.get_model('third_party_auth', 'SAMLProviderConfig')
LTIProviderConfig = apps.get_model('third_party_auth', 'LTIProviderConfig')
for config in OAuth2ProviderConfig.objects.all():
config.slug = config.provider_slug
config.save()
for config in SAMLProviderConfig.objects.all():
config.slug = config.idp_slug
config.save()
for config in LTIProviderConfig.objects.all():
config.slug = slugify(config.lti_consumer_key)
config.save()
class Migration(migrations.Migration):
dependencies = [
('third_party_auth', '0018_auto_20180327_1631'),
]
operations = [
migrations.AddField(
model_name='ltiproviderconfig',
name='slug',
field=models.SlugField(default=b'default', help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"', max_length=30),
),
migrations.AddField(
model_name='oauth2providerconfig',
name='slug',
field=models.SlugField(default=b'default', help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"', max_length=30),
),
migrations.AddField(
model_name='samlproviderconfig',
name='slug',
field=models.SlugField(default=b'default', help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"', max_length=30),
),
migrations.RunPython(fill_slug_field, reverse_code=migrations.RunPython.noop),
]

View File

@@ -88,6 +88,8 @@ class ProviderConfig(ConfigurationModel):
"""
Abstract Base Class for configuring a third_party_auth provider
"""
KEY_FIELDS = ('slug',)
icon_class = models.CharField(
max_length=50,
blank=True,
@@ -109,6 +111,12 @@ class ProviderConfig(ConfigurationModel):
),
)
name = models.CharField(max_length=50, blank=False, help_text="Name of this provider (shown to users)")
slug = models.SlugField(
max_length=30, db_index=True, default='default',
help_text=(
'A short string uniquely identifying this provider. '
'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"'
))
secondary = models.BooleanField(
default=False,
help_text=_(
@@ -312,7 +320,6 @@ class OAuth2ProviderConfig(ProviderConfig):
Also works for OAuth1 providers.
"""
prefix = 'oa2'
KEY_FIELDS = ('provider_slug', ) # Backend name is unique
backend_name = models.CharField(
max_length=50, blank=False, db_index=True,
help_text=(
@@ -514,7 +521,6 @@ class SAMLProviderConfig(ProviderConfig):
Configuration Entry for a SAML/Shibboleth provider.
"""
prefix = 'saml'
KEY_FIELDS = ('idp_slug', )
backend_name = models.CharField(
max_length=50, default='tpa-saml', blank=False,
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend.")
@@ -603,26 +609,26 @@ class SAMLProviderConfig(ProviderConfig):
def get_url_params(self):
""" Get a dict of GET parameters to append to login links for this provider """
return {'idp': self.idp_slug}
return {'idp': self.slug}
def is_active_for_pipeline(self, pipeline):
""" Is this provider being used for the specified pipeline? """
return self.backend_name == pipeline['backend'] and self.idp_slug == pipeline['kwargs']['response']['idp_name']
return self.backend_name == pipeline['backend'] and self.slug == pipeline['kwargs']['response']['idp_name']
def match_social_auth(self, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
prefix = self.idp_slug + ":"
prefix = self.slug + ":"
return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix)
def get_remote_id_from_social_auth(self, social_auth):
""" Given a UserSocialAuth object, return the remote ID used by this provider. """
assert self.match_social_auth(social_auth)
# Remove the prefix from the UID
return social_auth.uid[len(self.idp_slug) + 1:]
return social_auth.uid[len(self.slug) + 1:]
def get_social_auth_uid(self, remote_id):
""" Get social auth uid from remote id by prepending idp_slug to the remote id """
return '{}:{}'.format(self.idp_slug, remote_id)
return '{}:{}'.format(self.slug, remote_id)
def get_config(self):
"""
@@ -648,7 +654,7 @@ class SAMLProviderConfig(ProviderConfig):
log.error(
'No SAMLProviderData found for provider "%s" with entity id "%s" and IdP slug "%s". '
'Run "manage.py saml pull" to fix or debug.',
self.name, self.entity_id, self.idp_slug
self.name, self.entity_id, self.slug
)
raise AuthNotConfigured(provider_name=self.name)
conf['x509cert'] = data.public_key
@@ -660,7 +666,7 @@ class SAMLProviderConfig(ProviderConfig):
SAMLConfiguration.current(self.site.id, 'default')
)
idp_class = get_saml_idp_class(self.identity_provider_type)
return idp_class(self.idp_slug, **conf)
return idp_class(self.slug, **conf)
class SAMLProviderData(models.Model):

View File

@@ -28,13 +28,13 @@ class Registry(object):
Helper method that returns a generator used to iterate over all providers
of the current site.
"""
oauth2_slugs = OAuth2ProviderConfig.key_values('provider_slug', flat=True)
oauth2_slugs = OAuth2ProviderConfig.key_values('slug', flat=True)
for oauth2_slug in oauth2_slugs:
provider = OAuth2ProviderConfig.current(oauth2_slug)
if provider.enabled_for_current_site and provider.backend_name in _PSA_OAUTH2_BACKENDS:
yield provider
if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request()), 'default'):
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
idp_slugs = SAMLProviderConfig.key_values('slug', flat=True)
for idp_slug in idp_slugs:
provider = SAMLProviderConfig.current(idp_slug)
if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
@@ -112,14 +112,14 @@ class Registry(object):
Instances of ProviderConfig.
"""
if backend_name in _PSA_OAUTH2_BACKENDS:
oauth2_slugs = OAuth2ProviderConfig.key_values('provider_slug', flat=True)
oauth2_slugs = OAuth2ProviderConfig.key_values('slug', flat=True)
for oauth2_slug in oauth2_slugs:
provider = OAuth2ProviderConfig.current(oauth2_slug)
if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
Site.objects.get_current(get_current_request()), 'default'):
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
idp_names = SAMLProviderConfig.key_values('slug', flat=True)
for idp_name in idp_names:
provider = SAMLProviderConfig.current(idp_name)
if provider.backend_name == backend_name and provider.enabled_for_current_site:

View File

@@ -47,7 +47,7 @@ def fetch_saml_metadata():
"""
# First make a list of all the metadata XML URLs:
saml_providers = SAMLProviderConfig.key_values('idp_slug', flat=True)
saml_providers = SAMLProviderConfig.key_values('slug', flat=True)
num_total = len(saml_providers)
num_skipped = 0
url_map = {}

View File

@@ -25,12 +25,13 @@ class SAMLProviderConfigFactory(DjangoModelFactory):
"""
class Meta(object):
model = SAMLProviderConfig
django_get_or_create = ('idp_slug', 'metadata_source', "entity_id")
django_get_or_create = ('slug', 'metadata_source', "entity_id")
site = SubFactory(SiteFactory)
enabled = True
idp_slug = "test-shib"
slug = "test-shib"
name = "TestShib College"
entity_id = "https://idp.testshib.org/idp/shibboleth"

View File

@@ -92,6 +92,7 @@ class SamlIntegrationTestUtilities(object):
kwargs.setdefault('enabled', True)
kwargs.setdefault('visible', True)
kwargs.setdefault('idp_slug', self.PROVIDER_IDP_SLUG)
kwargs.setdefault('slug', self.PROVIDER_IDP_SLUG)
kwargs.setdefault('entity_id', TESTSHIB_ENTITY_ID)
kwargs.setdefault('metadata_source', TESTSHIB_METADATA_URL)
kwargs.setdefault('icon_class', 'fa-university')
@@ -200,6 +201,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
kwargs.setdefault('enabled', True)
kwargs.setdefault('visible', True)
kwargs.setdefault('idp_slug', self.PROVIDER_IDP_SLUG)
kwargs.setdefault('slug', self.PROVIDER_IDP_SLUG)
kwargs.setdefault('entity_id', TESTSHIB_ENTITY_ID)
kwargs.setdefault('metadata_source', TESTSHIB_METADATA_URL_WITH_CACHE_DURATION)
kwargs.setdefault('icon_class', 'fa-university')

View File

@@ -59,6 +59,7 @@ class RegistryTest(testutil.TestCase):
enabled=True,
name="Disallowed",
idp_slug="test",
slug="test",
backend_name="disallowed"
)
self.assertEqual(len(provider.Registry.enabled()), 0)
@@ -150,12 +151,12 @@ class RegistryTest(testutil.TestCase):
google_provider = self.configure_google_provider(enabled=True)
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
def test_oauth2_provider_keyed_by_provider_slug(self):
def test_oauth2_provider_keyed_by_slug(self):
"""
Regression test to ensure that the Registry properly fetches OAuth2ProviderConfigs that have a provider_slug
Regression test to ensure that the Registry properly fetches OAuth2ProviderConfigs that have a slug
which doesn't match any of the possible backend_names.
"""
google_provider = self.configure_google_provider(enabled=True, provider_slug='custom_slug')
google_provider = self.configure_google_provider(enabled=True, slug='custom_slug')
self.assertIn(google_provider, provider.Registry._enabled_providers())
self.assertIn(google_provider, provider.Registry.get_enabled_by_backend_name('google-oauth2'))

View File

@@ -79,6 +79,7 @@ class ThirdPartyAuthTestMixin(object):
def configure_oauth_provider(**kwargs):
""" Update the settings for an OAuth2-based third party auth provider """
kwargs.setdefault('provider_slug', kwargs['backend_name'])
kwargs.setdefault('slug', kwargs['backend_name'])
obj = OAuth2ProviderConfig(**kwargs)
obj.save()
return obj

View File

@@ -11,6 +11,7 @@
"icon_image": null,
"backend_name": "google-oauth2",
"provider_slug": "google-oauth2",
"slug": "google-oauth2",
"key": "test",
"secret": "test",
"site": 2,
@@ -30,6 +31,7 @@
"icon_image": null,
"backend_name": "facebook",
"provider_slug": "facebook",
"slug": "facebook",
"key": "test",
"secret": "test",
"site": 2,
@@ -49,6 +51,7 @@
"icon_image": "test-icon.png",
"backend_name": "dummy",
"provider_slug": "dummy",
"slug": "dummy",
"key": "",
"secret": "",
"site": 2,

View File

@@ -500,6 +500,7 @@ class StudentAccountLoginAndRegistrationTest(ThirdPartyAuthTestMixin, UrlResetMi
kwargs.setdefault('enabled', True)
kwargs.setdefault('visible', True)
kwargs.setdefault('idp_slug', idp_slug)
kwargs.setdefault('slug', idp_slug)
kwargs.setdefault('entity_id', 'https://idp.testshib.org/idp/shibboleth')
kwargs.setdefault('metadata_source', 'https://mock.testshib.org/metadata/testshib-providers.xml')
kwargs.setdefault('icon_class', 'fa-university')