feat: program_enrollments support for multiple SAML providers

This commit is contained in:
Zach Hancock
2021-07-09 15:57:42 -04:00
committed by Matt Hughes
parent b00f508b79
commit aa2bf9c063
7 changed files with 78 additions and 63 deletions

View File

@@ -27,7 +27,7 @@ from .reading import (
get_program_course_enrollment,
get_program_enrollment,
get_provider_slug,
get_saml_provider_for_organization,
get_saml_providers_for_organization,
get_users_by_external_keys,
get_users_by_external_keys_and_org_key,
is_course_staff_enrollment

View File

@@ -10,7 +10,6 @@ from organizations.models import Organization
from social_django.models import UserSocialAuth
from common.djangoapps.student.roles import CourseStaffRole
from common.djangoapps.third_party_auth.models import SAMLProviderConfig
from openedx.core.djangoapps.catalog.utils import get_programs
from ..constants import ProgramCourseEnrollmentRoles
@@ -18,7 +17,6 @@ from ..exceptions import (
BadOrganizationShortNameException,
ProgramDoesNotExistException,
ProgramHasNoAuthoringOrganizationException,
ProviderConfigurationException,
ProviderDoesNotExistException
)
from ..models import ProgramCourseEnrollment, ProgramEnrollment
@@ -404,18 +402,20 @@ def get_users_by_external_keys_and_org_key(external_user_keys, org_key):
Raises:
BadOrganizationShortNameException
ProviderDoesNotExistsException
ProviderConfigurationException
"""
saml_provider = get_saml_provider_by_org_key(org_key)
social_auth_uids = {
saml_provider.get_social_auth_uid(external_user_key)
for external_user_key in external_user_keys
}
social_auths = UserSocialAuth.objects.filter(uid__in=social_auth_uids)
found_users_by_external_keys = {
saml_provider.get_remote_id_from_social_auth(social_auth): social_auth.user
for social_auth in social_auths
}
saml_providers = get_saml_providers_by_org_key(org_key)
found_users_by_external_keys = dict()
for saml_provider in saml_providers:
social_auth_uids = {
saml_provider.get_social_auth_uid(external_user_key)
for external_user_key in external_user_keys
}
social_auths = UserSocialAuth.objects.filter(uid__in=social_auth_uids)
found_users_by_external_keys.update({
saml_provider.get_remote_id_from_social_auth(social_auth): social_auth.user
for social_auth in social_auths
})
# Default all external keys to None, because external keys
# without a User will not appear in `found_users_by_external_keys`.
users_by_external_keys = {key: None for key in external_user_keys}
@@ -444,7 +444,6 @@ def get_users_by_external_keys(program_uuid, external_user_keys):
ProgramHasNoAuthoringOrganizationException
BadOrganizationShortNameException
ProviderDoesNotExistsException
ProviderConfigurationException
"""
org_key = get_org_key_for_program(program_uuid)
return get_users_by_external_keys_and_org_key(external_user_keys, org_key)
@@ -477,14 +476,14 @@ def get_external_key_by_user_and_course(user, course_key):
return relevant_pce.program_enrollment.external_user_key
def get_saml_provider_by_org_key(org_key):
def get_saml_providers_by_org_key(org_key):
"""
Returns the SAML provider associated with the provided org_key
Returns a list of SAML providers associated with the provided org_key
Arguments:
org_key (str)
Returns: SAMLProvider
Returns: list[SAMLProvider]
Raises:
BadOrganizationShortNameException
@@ -493,7 +492,7 @@ def get_saml_provider_by_org_key(org_key):
organization = Organization.objects.get(short_name=org_key)
except Organization.DoesNotExist:
raise BadOrganizationShortNameException(org_key) # lint-amnesty, pylint: disable=raise-missing-from
return get_saml_provider_for_organization(organization)
return get_saml_providers_for_organization(organization)
def get_org_key_for_program(program_uuid):
@@ -520,26 +519,22 @@ def get_org_key_for_program(program_uuid):
return org_key
def get_saml_provider_for_organization(organization):
def get_saml_providers_for_organization(organization):
"""
Return currently configured SAML provider for the given Organization.
Return currently configured SAML provider(s) for the given Organization.
Arguments:
organization: Organization
Returns: SAMLProvider
Returns: list[SAMLProvider]
Raises:
ProviderDoesNotExistsException
ProviderConfigurationException
"""
try:
provider_config = organization.samlproviderconfig_set.current_set().get(enabled=True)
except SAMLProviderConfig.DoesNotExist:
provider_configs = organization.samlproviderconfig_set.current_set().filter(enabled=True)
if not provider_configs:
raise ProviderDoesNotExistException(organization) # lint-amnesty, pylint: disable=raise-missing-from
except SAMLProviderConfig.MultipleObjectsReturned:
raise ProviderConfigurationException(organization) # lint-amnesty, pylint: disable=raise-missing-from
return provider_config
return list(provider_configs)
def get_provider_slug(provider_config):

View File

@@ -22,7 +22,6 @@ from lms.djangoapps.program_enrollments.constants import ProgramEnrollmentStatus
from lms.djangoapps.program_enrollments.exceptions import (
OrganizationDoesNotExistException,
ProgramDoesNotExistException,
ProviderConfigurationException,
ProviderDoesNotExistException
)
from lms.djangoapps.program_enrollments.models import ProgramCourseEnrollment, ProgramEnrollment
@@ -568,10 +567,11 @@ class GetUsersByExternalKeysTests(CacheIsolationTestCase):
provider=provider.backend_name,
)
def test_happy_path(self):
def test_single_saml_provider(self):
"""
Test that get_users_by_external_keys returns the expected
mapping of external keys to users.
mapping of external keys to users when a single saml provider
is configured.
"""
organization = OrganizationFactory.create(short_name=self.organization_key)
provider = SAMLProviderConfigFactory.create(organization=organization)
@@ -588,6 +588,33 @@ class GetUsersByExternalKeysTests(CacheIsolationTestCase):
}
assert actual == expected
def test_multiple_saml_providers(self):
"""
Test that get_users_by_external_keys returns the expected
mapping of external keys to users when multiple saml providers
are configured.
"""
organization = OrganizationFactory.create(short_name=self.organization_key)
provider_1 = SAMLProviderConfigFactory.create(organization=organization)
provider_2 = SAMLProviderConfigFactory.create(
organization=organization,
slug='test-shib-2',
enabled=True
)
self.create_social_auth_entry(self.user_0, provider_1, 'ext-user-0')
self.create_social_auth_entry(self.user_1, provider_1, 'ext-user-1')
self.create_social_auth_entry(self.user_2, provider_2, 'ext-user-2')
requested_keys = {'ext-user-1', 'ext-user-2', 'ext-user-3'}
actual = get_users_by_external_keys(self.program_uuid, requested_keys)
# ext-user-0 not requested, ext-user-3 doesn't exist,
# ext-user-2 is authorized with secondary provider
expected = {
'ext-user-1': self.user_1,
'ext-user-2': self.user_2,
'ext-user-3': None,
}
assert actual == expected
def test_empty_request(self):
"""
Test that requesting no external keys does not cause an exception.
@@ -651,20 +678,6 @@ class GetUsersByExternalKeysTests(CacheIsolationTestCase):
)
get_users_by_external_keys(self.program_uuid, [])
def test_extra_saml_provider_enabled(self):
"""
If multiple enabled samlprovider records exist with the same organization
an exception is raised.
"""
organization = OrganizationFactory.create(short_name=self.organization_key)
SAMLProviderConfigFactory.create(organization=organization)
# create a second active config for the same organizationm, IS enabled
SAMLProviderConfigFactory.create(
organization=organization, slug='foox', enabled=True
)
with pytest.raises(ProviderConfigurationException):
get_users_by_external_keys(self.program_uuid, [])
@ddt.ddt
class IsCourseStaffEnrollmentTest(TestCase):

View File

@@ -53,15 +53,3 @@ class ProviderDoesNotExistException(Exception):
return 'Unable to find organization for short_name {}'.format(
self.organization.id
)
class ProviderConfigurationException(Exception):
def __init__(self, organization):
self.organization = organization
def __str__(self):
return (
'Multiple active SAML configurations found for organization={}. '
'Expected one.'
).format(self.organization.short_name)

View File

@@ -2412,6 +2412,25 @@ class EnrollmentDataResetViewTests(ProgramCacheMixin, APITestCase):
mock.call(self.reset_enrollments_cmd, ','.join(programs), force=True),
])
@override_settings(FEATURES=FEATURES_WITH_ENABLED)
@patch_call_command
def test_reset_with_multiple_idp(self, mock_call_command):
programs = [str(uuid4()), str(uuid4())]
self.set_org_in_catalog_cache(self.organization, programs)
provider_2 = SAMLProviderConfigFactory(
organization=self.organization,
slug='test-shib-2',
enabled=True,
)
response = self.request(self.organization.short_name)
assert response.status_code == status.HTTP_200_OK
mock_call_command.assert_has_calls([
mock.call(self.reset_users_cmd, self.provider.slug, force=True),
mock.call(self.reset_users_cmd, provider_2.slug, force=True),
mock.call(self.reset_enrollments_cmd, ','.join(programs), force=True),
])
@override_settings(FEATURES=FEATURES_WITH_ENABLED)
@patch_call_command
def test_reset_without_idp(self, mock_call_command):

View File

@@ -24,7 +24,7 @@ from lms.djangoapps.program_enrollments.api import (
fetch_program_enrollments,
fetch_program_enrollments_by_student,
get_provider_slug,
get_saml_provider_for_organization,
get_saml_providers_for_organization,
iter_program_course_grades,
write_program_course_enrollments,
write_program_enrollments
@@ -993,11 +993,13 @@ class EnrollmentDataResetView(APIView):
except Organization.DoesNotExist:
return Response(f'organization {org_key} not found', status.HTTP_404_NOT_FOUND)
providers = []
try:
provider = get_saml_provider_for_organization(organization)
providers = get_saml_providers_for_organization(organization)
except ProviderDoesNotExistException:
pass
else:
for provider in providers:
idp_slug = get_provider_slug(provider)
call_command('remove_social_auth_users', idp_slug, force=True)

View File

@@ -21,7 +21,6 @@ from lms.djangoapps.program_enrollments.api import (
)
from lms.djangoapps.program_enrollments.exceptions import (
BadOrganizationShortNameException,
ProviderConfigurationException,
ProviderDoesNotExistException
)
from lms.djangoapps.support.decorators import require_support_permission
@@ -228,7 +227,6 @@ class ProgramEnrollmentsInspectorView(View):
found_user = users_by_key.get(external_user_key)
except (
BadOrganizationShortNameException,
ProviderConfigurationException,
ProviderDoesNotExistException
):
# We cannot identify edX user from external_user_key and org_key pair