diff --git a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py new file mode 100644 index 0000000000..5e2b9008b8 --- /dev/null +++ b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py @@ -0,0 +1,55 @@ +""" +Management command to re-write UIDs identifying learners in external organizations' systems + +Intented for use in production environments, to help support migration +of existing SSO learners into our most-recent program enrollment flow +without needing to manually re-link their account. +""" +from __future__ import absolute_import, unicode_literals + +import logging +from textwrap import dedent + +from django.contrib.auth import get_user_model +from django.core.management.base import BaseCommand + +log = logging.getLogger(__name__) + + +class Command(BaseCommand): + """ + Updates UserSocialAuth records to use UIDs provided in the supplied JSON file + + Example usage: + $ ./manage.py lms migrate_saml_uids.py \ + --uid-mapping=change@my.uid:4045A285AF596D8589C24841657CA3D8,me@too.uid:4045A285AF596D8589C24841657CA3D9 \ + --saml-provider-slug=default + """ + help = dedent(__doc__).strip() + + def add_arguments(self, parser): + parser.add_argument( + '--uid-mapping', + help='comma-separated list of email:uid mappings' + ) + parser.add_argument( + '--saml-provider-slug', + help='slug of SAMLProvider for which records should be updated' + ) + + def handle(self, *args, **options): + """ + Performs the re-writing + """ + User = get_user_model() + pairs = options['uid_mapping'].split(',') + slug = options['saml_provider_slug'] + + for pair in pairs: + uid_list = pair.split(':') + email = uid_list[0] + uid = uid_list[1] + user = User.objects.prefetch_related('social_auth').get(email=email) + auth = user.social_auth.filter(uid__startswith=slug)[0] + auth.uid = '{slug}:{uid}'.format(slug=slug, uid=uid) + auth.save() diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py new file mode 100644 index 0000000000..677cd0ef55 --- /dev/null +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py @@ -0,0 +1,59 @@ +""" +Tests for the migrate_saml_uids management command. +""" +from __future__ import absolute_import + +from django.core.management import call_command +from django.test import TestCase + +from factory import LazyAttributeSequence, SubFactory +from factory.django import DjangoModelFactory +from lms.djangoapps.program_enrollments.management.commands import migrate_saml_uids +from social_django.models import UserSocialAuth +from student.tests.factories import UserFactory + + +class UserSocialAuthFactory(DjangoModelFactory): + """ + Factory for UserSocialAuth records. + """ + class Meta(object): + model = UserSocialAuth + user = SubFactory(UserFactory) + uid = LazyAttributeSequence(lambda o, n: '%s:%d' % (o.slug, n)) + + class Params(object): + slug = 'gatech' + + +class TestMigrateSamlUids(TestCase): + """ + Test migrate_saml_uids command. + """ + provider_slug = 'gatech' + + @classmethod + def setUpClass(cls): + super(TestMigrateSamlUids, cls).setUpClass() + cls.command = migrate_saml_uids.Command() + + def _format_email_uid_pair(self, email, uid): + return '{email}:{uid}'.format(email=email, uid=uid) + + def _format_slug_urn_pair(self, slug, urn): + return '{slug}:{urn}'.format(slug=slug, urn=urn) + + def test_single_mapping(self): + new_urn = '9001' + auth = UserSocialAuthFactory.create(slug=self.provider_slug) + email = auth.user.email + old_uid = auth.uid + call_command( + self.command, + uid_mapping=self._format_email_uid_pair(email, new_urn), + saml_provider_slug=self.provider_slug + ) + + auth.refresh_from_db() + assert auth.uid == self._format_slug_urn_pair(self.provider_slug, new_urn) + assert not auth.uid == old_uid