diff --git a/common/djangoapps/third_party_auth/api/utils.py b/common/djangoapps/third_party_auth/api/utils.py new file mode 100644 index 0000000000..f0dbd7b5e9 --- /dev/null +++ b/common/djangoapps/third_party_auth/api/utils.py @@ -0,0 +1,30 @@ +""" +Shareable utilities for third party auth api functions +""" + + +def filter_user_social_auth_queryset_by_provider(query_set, provider): + """ + Filter a query set by the given TPA provider + + Params: + query_set: QuerySet[UserSocialAuth] + provider: common.djangoapps.third_party_auth.models.ProviderConfig + Returns: + QuerySet[UserSocialAuth] + """ + # Note: When using multi-IdP backend, the provider column isn't + # enough to identify a specific backend + filtered_query_set = query_set.filter(provider=provider.backend_name) + + # Test if the current provider has a slug which it appends to + # uids; these can be used to identify the backend more + # specifically than the provider's backend + fake_uid = 'uid' + uid = provider.get_social_auth_uid(fake_uid) + if uid != fake_uid: + # if yes, we add a filter for the slug on uid column + # carve off the fake_uid from the end, so we get just the prepended slug + filtered_query_set = filtered_query_set.filter(uid__startswith=uid[:-len(fake_uid)]) + + return filtered_query_set diff --git a/common/djangoapps/third_party_auth/api/views.py b/common/djangoapps/third_party_auth/api/views.py index a4f6b16233..b9d4705f0b 100644 --- a/common/djangoapps/third_party_auth/api/views.py +++ b/common/djangoapps/third_party_auth/api/views.py @@ -27,6 +27,7 @@ from third_party_auth import pipeline from third_party_auth.api import serializers from third_party_auth.api.permissions import TPA_PERMISSIONS from third_party_auth.provider import Registry +from common.djangoapps.third_party_auth.api.utils import filter_user_social_auth_queryset_by_provider class ProviderBaseThrottle(throttling.UserRateThrottle): @@ -349,16 +350,10 @@ class UserMappingView(ListAPIView): if not self.provider: raise Http404 - query_set = UserSocialAuth.objects.select_related('user').filter(provider=self.provider.backend_name) - - # build our query filters - # When using multi-IdP backend, we only retrieve the ones that are for current IdP. - # test if the current provider has a slug - uid = self.provider.get_social_auth_uid('uid') - if uid != 'uid': - # if yes, we add a filter for the slug on uid column - query_set = query_set.filter(uid__startswith=uid[:-3]) - + query_set = filter_user_social_auth_queryset_by_provider( + UserSocialAuth.objects.select_related('user'), + self.provider, + ) query = Q() usernames = self.request.query_params.getlist('username', None) diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py index 6a97c091b2..2e83acf967 100644 --- a/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py +++ b/lms/djangoapps/program_enrollments/management/commands/tests/test_migrate_saml_uids.py @@ -3,33 +3,22 @@ Tests for the migrate_saml_uids management command. """ +from mock import mock_open, patch import six + from django.core.management import call_command from django.test import TestCase from factory import LazyAttributeSequence, SubFactory from factory.django import DjangoModelFactory -from mock import mock_open, patch from social_django.models import UserSocialAuth from lms.djangoapps.program_enrollments.management.commands import migrate_saml_uids +from lms.djangoapps.program_enrollments.management.commands.tests.utils import UserSocialAuthFactory from student.tests.factories import UserFactory _COMMAND_PATH = 'lms.djangoapps.program_enrollments.management.commands.migrate_saml_uids' -class UserSocialAuthFactory(DjangoModelFactory): - """ - Factory for UserSocialAuth records. - """ - class Meta(object): - model = UserSocialAuth - user = SubFactory(UserFactory) - uid = LazyAttributeSequence(lambda o, n: '%s:%d' % (o.slug, n)) - - class Params(object): - slug = 'gatech' - - class TestMigrateSamlUids(TestCase): """ Test migrate_saml_uids command. diff --git a/lms/djangoapps/program_enrollments/management/commands/tests/utils.py b/lms/djangoapps/program_enrollments/management/commands/tests/utils.py new file mode 100644 index 0000000000..911098bd0d --- /dev/null +++ b/lms/djangoapps/program_enrollments/management/commands/tests/utils.py @@ -0,0 +1,21 @@ +""" +Sharable utilities for testing program enrollments +""" + +from factory import LazyAttributeSequence, SubFactory +from factory.django import DjangoModelFactory +from social_django.models import UserSocialAuth +from student.tests.factories import UserFactory + + +class UserSocialAuthFactory(DjangoModelFactory): + """ + Factory for UserSocialAuth records. + """ + class Meta(object): + model = UserSocialAuth + user = SubFactory(UserFactory) + uid = LazyAttributeSequence(lambda o, n: '%s:%d' % (o.slug, n)) + + class Params(object): + slug = 'gatech' diff --git a/lms/djangoapps/verify_student/management/commands/backfill_sso_verifications_for_old_account_links.py b/lms/djangoapps/verify_student/management/commands/backfill_sso_verifications_for_old_account_links.py new file mode 100644 index 0000000000..71492302ce --- /dev/null +++ b/lms/djangoapps/verify_student/management/commands/backfill_sso_verifications_for_old_account_links.py @@ -0,0 +1,67 @@ +""" +Management command to backfill verification records for preexisting account links + +Meant to facilitate the alteration of a particular +third_party_auth_samlproviderconfig to flip on the +enable_sso_id_verification bit, which would ordinarily leave any +preexisting account links without the corresponding resultant ID +verification record. + +This also manually triggers the same signal which is sent on creation +of SSO IDV records. +""" + +from django.core.management.base import BaseCommand, CommandError + +from social_django.models import UserSocialAuth + +from common.djangoapps.third_party_auth.api.utils import filter_user_social_auth_queryset_by_provider +from lms.djangoapps.verify_student.models import SSOVerification +from third_party_auth.provider import Registry + + +class Command(BaseCommand): + """ + Management command to backfill verification records for preexisting account links + + Meant to facilitate the alteration of a particular + third_party_auth_samlproviderconfig to flip on the + enable_sso_id_verification bit, which would ordinarily leave any + preexisting account links without the corresponding resultant ID + verification record. + + Example usage: + $ ./manage.py lms backfill_sso_verifications_for_old_account_links --provider-slug=saml-gatech + """ + help = 'Backfills SSO verification records for the given SAML provider slug' + + def add_arguments(self, parser): + parser.add_argument( + '--provider-slug', + required=True, + ) + + def filter_user_social_auth_queryset_by_ssoverification_existence(self, query_set): + return query_set.filter(user__ssoverification__isnull=True) + + def handle(self, *args, **options): + provider_slug = options.get('provider_slug', None) + + try: + provider = Registry.get(provider_slug) + except ValueError as e: + raise CommandError('provider slug {slug} does not exist'.format(slug=provider_slug)) + + query_set = UserSocialAuth.objects.select_related('user__profile') + query_set = filter_user_social_auth_queryset_by_provider(query_set, provider) + query_set = self.filter_user_social_auth_queryset_by_ssoverification_existence(query_set) + for user_social_auth in query_set: + verification = SSOVerification.objects.create( + user=user_social_auth.user, + status="approved", + name=user_social_auth.user.profile.name, + identity_provider_type=provider.full_class_name, + identity_provider_slug=provider.slug, + ) + # Send a signal so users who have already passed their courses receive credit + verification.send_approval_signal(provider.slug) diff --git a/lms/djangoapps/verify_student/management/commands/tests/test_backfill_sso_verifications_for_old_account_links.py b/lms/djangoapps/verify_student/management/commands/tests/test_backfill_sso_verifications_for_old_account_links.py new file mode 100644 index 0000000000..666b9e6e1c --- /dev/null +++ b/lms/djangoapps/verify_student/management/commands/tests/test_backfill_sso_verifications_for_old_account_links.py @@ -0,0 +1,74 @@ +""" +Tests for management command backfill_sso_verifications_for_old_account_links +""" + +from mock import patch + +from django.core.management import call_command +from django.core.management.base import CommandError + +from lms.djangoapps.program_enrollments.management.commands.tests.utils import UserSocialAuthFactory +from lms.djangoapps.verify_student.models import SSOVerification +from lms.djangoapps.verify_student.tests.factories import SSOVerificationFactory +from third_party_auth.tests.testutil import TestCase + + +class TestBackfillSSOVerificationsCommand(TestCase): + """ + Tests for management command for backfilling SSO verification records + """ + slug = 'test' + + def setUp(self): + super(TestBackfillSSOVerificationsCommand, self).setUp() + self.enable_saml() + self.provider = self.configure_saml_provider( + name="Test", + slug=self.slug, + enabled=True, + enable_sso_id_verification=True, + ) + self.user_social_auth1 = UserSocialAuthFactory(slug=self.slug, provider=self.provider.backend_name) + self.user_social_auth1.save() + self.user1 = self.user_social_auth1.user + + def test_fails_without_required_param(self): + with self.assertRaises(CommandError): + call_command('backfill_sso_verifications_for_old_account_links') + + def test_fails_without_named_provider_config(self): + with self.assertRaises(CommandError): + call_command('backfill_sso_verifications_for_old_account_links', '--provider-slug', 'gatech') + + def test_sso_updated_single_user(self): + self.assertTrue(SSOVerification.objects.count() == 0) + call_command('backfill_sso_verifications_for_old_account_links', '--provider-slug', self.provider.provider_id) + self.assertTrue(SSOVerification.objects.count() > 0) + self.assertEqual(SSOVerification.objects.get().user.id, self.user1.id) + + def test_performance(self): + # TODO + #self.assertNumQueries(1) + call_command('backfill_sso_verifications_for_old_account_links', '--provider-slug', self.provider.provider_id) + #self.assertNumQueries(100) + + def test_signal_called(self): + with patch('openedx.core.djangoapps.signals.signals.LEARNER_NOW_VERIFIED.send_robust') as mock_signal: + call_command('backfill_sso_verifications_for_old_account_links', '--provider-slug', self.provider.provider_id) + self.assertEqual(mock_signal.call_count, 1) + + def test_fine_with_multiple_verification_records(self): + """ + Testing there are no issues with excluding learners with multiple sso verifications + """ + SSOVerificationFactory( + status='approved', + user=self.user1, + ) + SSOVerificationFactory( + status='approved', + user=self.user1, + ) + self.assertEqual(SSOVerification.objects.count(), 2) + call_command('backfill_sso_verifications_for_old_account_links', '--provider-slug', self.provider.provider_id) + self.assertEqual(SSOVerification.objects.count(), 2) diff --git a/lms/djangoapps/verify_student/tests/factories.py b/lms/djangoapps/verify_student/tests/factories.py index b84f48be9d..71a44deb72 100644 --- a/lms/djangoapps/verify_student/tests/factories.py +++ b/lms/djangoapps/verify_student/tests/factories.py @@ -9,7 +9,7 @@ from django.conf import settings from django.utils.timezone import now from factory.django import DjangoModelFactory -from lms.djangoapps.verify_student.models import SoftwareSecurePhotoVerification +from lms.djangoapps.verify_student.models import SSOVerification, SoftwareSecurePhotoVerification class SoftwareSecurePhotoVerificationFactory(DjangoModelFactory): @@ -22,3 +22,8 @@ class SoftwareSecurePhotoVerificationFactory(DjangoModelFactory): status = 'approved' if hasattr(settings, 'VERIFY_STUDENT'): expiry_date = now() + timedelta(days=settings.VERIFY_STUDENT["DAYS_GOOD_FOR"]) + + +class SSOVerificationFactory(DjangoModelFactory): + class Meta(): + model = SSOVerification