Management command for rewriting social auth uids
This management command will help us support migrating an SSO provider from one configured user identifier to another, which may be necessary when previous choices of UID aren't as stable as needed. JIRA:EDUCATOR-4701
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Management command to re-write UIDs identifying learners in external organizations' systems
|
||||
|
||||
Intented for use in production environments, to help support migration
|
||||
of existing SSO learners into our most-recent program enrollment flow
|
||||
without needing to manually re-link their account.
|
||||
"""
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""
|
||||
Updates UserSocialAuth records to use UIDs provided in the supplied JSON file
|
||||
|
||||
Example usage:
|
||||
$ ./manage.py lms migrate_saml_uids.py \
|
||||
--uid-mapping=change@my.uid:4045A285AF596D8589C24841657CA3D8,me@too.uid:4045A285AF596D8589C24841657CA3D9 \
|
||||
--saml-provider-slug=default
|
||||
"""
|
||||
help = dedent(__doc__).strip()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--uid-mapping',
|
||||
help='comma-separated list of email:uid mappings'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--saml-provider-slug',
|
||||
help='slug of SAMLProvider for which records should be updated'
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
"""
|
||||
Performs the re-writing
|
||||
"""
|
||||
User = get_user_model()
|
||||
pairs = options['uid_mapping'].split(',')
|
||||
slug = options['saml_provider_slug']
|
||||
|
||||
for pair in pairs:
|
||||
uid_list = pair.split(':')
|
||||
email = uid_list[0]
|
||||
uid = uid_list[1]
|
||||
user = User.objects.prefetch_related('social_auth').get(email=email)
|
||||
auth = user.social_auth.filter(uid__startswith=slug)[0]
|
||||
auth.uid = '{slug}:{uid}'.format(slug=slug, uid=uid)
|
||||
auth.save()
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Tests for the migrate_saml_uids management command.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
|
||||
from factory import LazyAttributeSequence, SubFactory
|
||||
from factory.django import DjangoModelFactory
|
||||
from lms.djangoapps.program_enrollments.management.commands import migrate_saml_uids
|
||||
from social_django.models import UserSocialAuth
|
||||
from student.tests.factories import UserFactory
|
||||
|
||||
|
||||
class UserSocialAuthFactory(DjangoModelFactory):
|
||||
"""
|
||||
Factory for UserSocialAuth records.
|
||||
"""
|
||||
class Meta(object):
|
||||
model = UserSocialAuth
|
||||
user = SubFactory(UserFactory)
|
||||
uid = LazyAttributeSequence(lambda o, n: '%s:%d' % (o.slug, n))
|
||||
|
||||
class Params(object):
|
||||
slug = 'gatech'
|
||||
|
||||
|
||||
class TestMigrateSamlUids(TestCase):
|
||||
"""
|
||||
Test migrate_saml_uids command.
|
||||
"""
|
||||
provider_slug = 'gatech'
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestMigrateSamlUids, cls).setUpClass()
|
||||
cls.command = migrate_saml_uids.Command()
|
||||
|
||||
def _format_email_uid_pair(self, email, uid):
|
||||
return '{email}:{uid}'.format(email=email, uid=uid)
|
||||
|
||||
def _format_slug_urn_pair(self, slug, urn):
|
||||
return '{slug}:{urn}'.format(slug=slug, urn=urn)
|
||||
|
||||
def test_single_mapping(self):
|
||||
new_urn = '9001'
|
||||
auth = UserSocialAuthFactory.create(slug=self.provider_slug)
|
||||
email = auth.user.email
|
||||
old_uid = auth.uid
|
||||
call_command(
|
||||
self.command,
|
||||
uid_mapping=self._format_email_uid_pair(email, new_urn),
|
||||
saml_provider_slug=self.provider_slug
|
||||
)
|
||||
|
||||
auth.refresh_from_db()
|
||||
assert auth.uid == self._format_slug_urn_pair(self.provider_slug, new_urn)
|
||||
assert not auth.uid == old_uid
|
||||
Reference in New Issue
Block a user