From 474ce9bea66cbd9502de454483c7e36bc8111d7a Mon Sep 17 00:00:00 2001 From: Matt Hughes Date: Wed, 13 Nov 2019 11:51:20 -0500 Subject: [PATCH] Management command for rewriting social auth uids This management command will help us support migrating an SSO provider from one configured user identifier to another, which may be necessary when previous choices of UID aren't as stable as needed. JIRA:EDUCATOR-4701 --- .../management/commands/migrate_saml_uids.py | 55 +++++++++++++++++ .../commands/tests/test_migrate_saml_uids.py | 59 +++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py create mode 100644 lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py diff --git a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py new file mode 100644 index 0000000000..5e2b9008b8 --- /dev/null +++ b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py @@ -0,0 +1,55 @@ +""" +Management command to re-write UIDs identifying learners in external organizations' systems + +Intented for use in production environments, to help support migration +of existing SSO learners into our most-recent program enrollment flow +without needing to manually re-link their account. +""" +from __future__ import absolute_import, unicode_literals + +import logging +from textwrap import dedent + +from django.contrib.auth import get_user_model +from django.core.management.base import BaseCommand + +log = logging.getLogger(__name__) + + +class Command(BaseCommand): + """ + Updates UserSocialAuth records to use UIDs provided in the supplied JSON file + + Example usage: + $ ./manage.py lms migrate_saml_uids.py \ + --uid-mapping=change@my.uid:4045A285AF596D8589C24841657CA3D8,me@too.uid:4045A285AF596D8589C24841657CA3D9 \ + --saml-provider-slug=default + """ + help = dedent(__doc__).strip() + + def add_arguments(self, parser): + parser.add_argument( + '--uid-mapping', + help='comma-separated list of email:uid mappings' + ) + parser.add_argument( + '--saml-provider-slug', + help='slug of SAMLProvider for which records should be updated' + ) + + def handle(self, *args, **options): + """ + Performs the re-writing + """ + User = get_user_model() + pairs = options['uid_mapping'].split(',') + slug = options['saml_provider_slug'] + + for pair in pairs: + uid_list = pair.split(':') + email = uid_list[0] + uid = uid_list[1] + user = User.objects.prefetch_related('social_auth').get(email=email) + auth = user.social_auth.filter(uid__startswith=slug)[0] + auth.uid = '{slug}:{uid}'.format(slug=slug, uid=uid) + auth.save() diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py new file mode 100644 index 0000000000..677cd0ef55 --- /dev/null +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py @@ -0,0 +1,59 @@ +""" +Tests for the migrate_saml_uids management command. +""" +from __future__ import absolute_import + +from django.core.management import call_command +from django.test import TestCase + +from factory import LazyAttributeSequence, SubFactory +from factory.django import DjangoModelFactory +from lms.djangoapps.program_enrollments.management.commands import migrate_saml_uids +from social_django.models import UserSocialAuth +from student.tests.factories import UserFactory + + +class UserSocialAuthFactory(DjangoModelFactory): + """ + Factory for UserSocialAuth records. + """ + class Meta(object): + model = UserSocialAuth + user = SubFactory(UserFactory) + uid = LazyAttributeSequence(lambda o, n: '%s:%d' % (o.slug, n)) + + class Params(object): + slug = 'gatech' + + +class TestMigrateSamlUids(TestCase): + """ + Test migrate_saml_uids command. + """ + provider_slug = 'gatech' + + @classmethod + def setUpClass(cls): + super(TestMigrateSamlUids, cls).setUpClass() + cls.command = migrate_saml_uids.Command() + + def _format_email_uid_pair(self, email, uid): + return '{email}:{uid}'.format(email=email, uid=uid) + + def _format_slug_urn_pair(self, slug, urn): + return '{slug}:{urn}'.format(slug=slug, urn=urn) + + def test_single_mapping(self): + new_urn = '9001' + auth = UserSocialAuthFactory.create(slug=self.provider_slug) + email = auth.user.email + old_uid = auth.uid + call_command( + self.command, + uid_mapping=self._format_email_uid_pair(email, new_urn), + saml_provider_slug=self.provider_slug + ) + + auth.refresh_from_db() + assert auth.uid == self._format_slug_urn_pair(self.provider_slug, new_urn) + assert not auth.uid == old_uid