From fe4a2d0c4fd5ba6c4e0c18687a2d0b5e3229e87a Mon Sep 17 00:00:00 2001 From: Matt Hughes Date: Thu, 14 Nov 2019 17:18:25 -0500 Subject: [PATCH] Moved to taking a file rather than long blob-of-text argument --- .../management/commands/migrate_saml_uids.py | 16 ++++++++------- .../commands/tests/test_migrate_saml_uids.py | 20 ++++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py index 5e2b9008b8..5cdea76fd2 100644 --- a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py +++ b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py @@ -7,7 +7,9 @@ without needing to manually re-link their account. """ from __future__ import absolute_import, unicode_literals +import json import logging +from io import open from textwrap import dedent from django.contrib.auth import get_user_model @@ -22,7 +24,7 @@ class Command(BaseCommand): Example usage: $ ./manage.py lms migrate_saml_uids.py \ - --uid-mapping=change@my.uid:4045A285AF596D8589C24841657CA3D8,me@too.uid:4045A285AF596D8589C24841657CA3D9 \ + --uid-mapping=path/to/file.json --saml-provider-slug=default """ help = dedent(__doc__).strip() @@ -30,7 +32,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( '--uid-mapping', - help='comma-separated list of email:uid mappings' + help='path to utf-8-encoded json file containing an array of objects with keys email and student_key' ) parser.add_argument( '--saml-provider-slug', @@ -42,13 +44,13 @@ class Command(BaseCommand): Performs the re-writing """ User = get_user_model() - pairs = options['uid_mapping'].split(',') + with open(options['uid_mapping'], 'r', encoding='utf-8') as f: + uid_mappings = json.load(f) slug = options['saml_provider_slug'] - for pair in pairs: - uid_list = pair.split(':') - email = uid_list[0] - uid = uid_list[1] + for pair in uid_mappings: + email = pair['email'] + uid = pair['student_key'] user = User.objects.prefetch_related('social_auth').get(email=email) auth = user.social_auth.filter(uid__startswith=slug)[0] auth.uid = '{slug}:{uid}'.format(slug=slug, uid=uid) diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py index 677cd0ef55..809c7a24ab 100644 --- a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py @@ -9,9 +9,11 @@ from django.test import TestCase from factory import LazyAttributeSequence, SubFactory from factory.django import DjangoModelFactory from lms.djangoapps.program_enrollments.management.commands import migrate_saml_uids +from mock import mock_open, patch from social_django.models import UserSocialAuth from student.tests.factories import UserFactory +_COMMAND_PATH = 'lms.djangoapps.program_enrollments.management.commands.migrate_saml_uids' class UserSocialAuthFactory(DjangoModelFactory): """ @@ -37,8 +39,8 @@ class TestMigrateSamlUids(TestCase): super(TestMigrateSamlUids, cls).setUpClass() cls.command = migrate_saml_uids.Command() - def _format_email_uid_pair(self, email, uid): - return '{email}:{uid}'.format(email=email, uid=uid) + def _format_single_email_uid_pair_json(self, email, uid): + return '[{{"email":"{email}","student_key":"{new_urn}"}}]'.format(email=email, new_urn=uid) def _format_slug_urn_pair(self, slug, urn): return '{slug}:{urn}'.format(slug=slug, urn=urn) @@ -48,11 +50,15 @@ class TestMigrateSamlUids(TestCase): auth = UserSocialAuthFactory.create(slug=self.provider_slug) email = auth.user.email old_uid = auth.uid - call_command( - self.command, - uid_mapping=self._format_email_uid_pair(email, new_urn), - saml_provider_slug=self.provider_slug - ) + with patch( + _COMMAND_PATH + '.open', + mock_open(read_data=self._format_single_email_uid_pair_json(email, new_urn)) + ) as pat: + call_command( + self.command, + uid_mapping='./foo.json', + saml_provider_slug=self.provider_slug + ) auth.refresh_from_db() assert auth.uid == self._format_slug_urn_pair(self.provider_slug, new_urn)