diff --git a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py index 5cdea76fd2..e6996e7fa0 100644 --- a/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py +++ b/lms/djangoapps/program_enrollments/management/commands/migrate_saml_uids.py @@ -48,10 +48,36 @@ class Command(BaseCommand): uid_mappings = json.load(f) slug = options['saml_provider_slug'] - for pair in uid_mappings: - email = pair['email'] - uid = pair['student_key'] - user = User.objects.prefetch_related('social_auth').get(email=email) - auth = user.social_auth.filter(uid__startswith=slug)[0] + email_map = { m['email']: {'uid': m['student_key'], 'updated': False} for m in uid_mappings } + user_queryset = User.objects.prefetch_related('social_auth').filter(social_auth__uid__startswith=slug + ':') + users = [u for u in user_queryset] + + missed = 0 + updated = 0 + for user in users: + email = user.email + try: + info_for_email = email_map[email] + except KeyError: + missed += 1 + continue + info_for_email['updated'] = True + uid = info_for_email['uid'] + auth = user.social_auth.filter(uid__startswith=slug + ':')[0] + # print something about the ones who have more than one social_auth from gatech auth.uid = '{slug}:{uid}'.format(slug=slug, uid=uid) auth.save() + updated += 1 + not_previously_linked = reduce(lambda count, mapping: count + (not email_map[mapping['email']]['updated']), uid_mappings, 0) + log.info( + 'Number of users with {slug} UserSocialAuth records for which there was no mapping in the provided file: {missed}'.format( + slug=slug, + missed=missed + )) + log.info( + 'Number of users identified in the mapping file without {slug} UserSocialAuth records: {not_previously_linked}'.format( + slug=slug, + not_previously_linked=not_previously_linked + )) + + diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py index 809c7a24ab..2003677d32 100644 --- a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py @@ -39,8 +39,25 @@ class TestMigrateSamlUids(TestCase): super(TestMigrateSamlUids, cls).setUpClass() cls.command = migrate_saml_uids.Command() + def _format_email_uid_pair(self, email, uid): + return '{{"email":"{email}","student_key":"{new_urn}"}}'.format(email=email, new_urn=uid) + + def _format_single_email_uid_pair_json(self, email, uid): - return '[{{"email":"{email}","student_key":"{new_urn}"}}]'.format(email=email, new_urn=uid) + return '[{obj}]'.format( + obj=self._format_email_uid_pair(email, uid) + ) + + def _call_command(self, data): + with patch( + _COMMAND_PATH + '.open', + mock_open(read_data=data) + ) as _: + call_command( + self.command, + uid_mapping='./foo.json', + saml_provider_slug=self.provider_slug + ) def _format_slug_urn_pair(self, slug, urn): return '{slug}:{urn}'.format(slug=slug, urn=urn) @@ -50,16 +67,72 @@ class TestMigrateSamlUids(TestCase): auth = UserSocialAuthFactory.create(slug=self.provider_slug) email = auth.user.email old_uid = auth.uid - with patch( - _COMMAND_PATH + '.open', - mock_open(read_data=self._format_single_email_uid_pair_json(email, new_urn)) - ) as pat: - call_command( - self.command, - uid_mapping='./foo.json', - saml_provider_slug=self.provider_slug - ) + + self._call_command(self._format_single_email_uid_pair_json(email, new_urn)) auth.refresh_from_db() assert auth.uid == self._format_slug_urn_pair(self.provider_slug, new_urn) assert not auth.uid == old_uid + + def test_post_save_occurs(self): + """ + Test the signals downstream of this update are called with appropriate arguments + """ + auth = UserSocialAuthFactory.create(slug=self.provider_slug) + new_urn = '9001' + email = auth.user.email + + with patch('lms.djangoapps.program_enrollments.signals.matriculate_learner') as signal_handler_mock: + self._call_command(self._format_single_email_uid_pair_json(email, new_urn)) + assert signal_handler_mock.called + # first positional arg matches the user whose auth was updated + assert signal_handler_mock.call_args[0][0].id == auth.user.id + # second positional arg matches the urn we changed + assert signal_handler_mock.call_args[0][1] == self._format_slug_urn_pair(self.provider_slug, new_urn) + + def test_multiple_social_auth_records(self): + """ + Test we only alter one UserSocialAuth record if a learner has two + """ + auth1 = UserSocialAuthFactory.create(slug=self.provider_slug) + auth2 = UserSocialAuthFactory.create( + slug=self.provider_slug, + user=auth1.user + ) + new_urn = '9001' + email = auth1.user.email + + assert email == auth2.user.email + + self._call_command(self._format_single_email_uid_pair_json(email, new_urn)) + auths = UserSocialAuth.objects.filter( + user__email=email, + uid=self._format_slug_urn_pair(self.provider_slug, new_urn) + ) + assert auths.count() == 1 + + @patch(_COMMAND_PATH + '.log') + def test_learner_without_social_auth_records(self, mock_log): + user = UserFactory() + email = user.email + new_urn = '9001' + + mock_info = mock_log.info + + self._call_command(self._format_single_email_uid_pair_json(email, new_urn)) + mock_info.assert_any_call('Number of users identified in the mapping file without {slug} UserSocialAuth records: 1'.format( + slug=self.provider_slug + )) + + @patch(_COMMAND_PATH + '.log') + def test_learner_missed_by_mapping_file(self, mock_log): + auth = UserSocialAuthFactory() + email = auth.user.email + new_urn = '9001' + + mock_info = mock_log.info + + self._call_command(self._format_single_email_uid_pair_json('different' + email, new_urn)) + mock_info.assert_any_call('Number of users with {slug} UserSocialAuth records for which there was no mapping in the provided file: 1'.format( + slug=self.provider_slug + ))