diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/__init__.py b/common/djangoapps/third_party_auth/samlproviderconfig/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py b/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py new file mode 100644 index 0000000000..f9e7f6e618 --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderconfig/serializers.py @@ -0,0 +1,13 @@ +""" +Serializer for SAMLProviderConfig +""" + +from rest_framework import serializers + +from third_party_auth.models import SAMLProviderConfig + + +class SAMLProviderConfigSerializer(serializers.ModelSerializer): + class Meta: + model = SAMLProviderConfig + fields = '__all__' diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py new file mode 100644 index 0000000000..c025023dae --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderconfig/tests/test_samlproviderconfig.py @@ -0,0 +1,90 @@ +import unittest +import copy +from uuid import uuid4 +from django.urls import reverse +from django.contrib.sites.models import Site +from django.contrib.auth.models import User +from django.utils.http import urlencode +from rest_framework import status +from rest_framework.test import APITestCase + +from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomer +from enterprise.constants import ENTERPRISE_ADMIN_ROLE +from third_party_auth.tests.samlutils import set_jwt_cookie +from third_party_auth.models import SAMLProviderConfig +from third_party_auth.tests import testutil + +SINGLE_PROVIDER_CONFIG = { + 'entity_id': 'id', + 'metadata_source': 'http://test.url', + 'name': 'name-of-config', + 'enabled': 'true', + 'slug': 'test-slug' +} + +SINGLE_PROVIDER_CONFIG_2 = copy.copy(SINGLE_PROVIDER_CONFIG) +SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2' +SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2' + +ENTERPRISE_ID = str(uuid4()) + + +@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled') +class SAMLProviderConfigTests(APITestCase): + """ + API Tests for SAMLProviderConfig REST endpoints + The skip annotation above exists because we currently cannot run this test in + the cms mode in CI builds, where the third_party_auth application is not loaded + """ + @classmethod + def setUpTestData(cls): + super(SAMLProviderConfigTests, cls).setUpTestData() + cls.user = User.objects.create_user(username='testuser', password='testpwd') + cls.site, _ = Site.objects.get_or_create(domain='example.com') + cls.enterprise_customer = EnterpriseCustomer.objects.create( + uuid=ENTERPRISE_ID, + name='test-ep', + slug='test-ep', + site=cls.site) + cls.samlproviderconfig, _ = SAMLProviderConfig.objects.get_or_create( + entity_id=SINGLE_PROVIDER_CONFIG['entity_id'], + metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source'] + ) + cls.enterprisecustomeridp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create( + provider_id=cls.samlproviderconfig.id, + enterprise_customer_id=ENTERPRISE_ID + ) + + def setUp(self): + set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID)]) + self.client.force_authenticate(user=self.user) + + def test_get_one_config_by_enterprise_uuid_found(self): + """ + GET auth/saml/v0/providerconfig/?enterprise_customer_uuid=id=id + """ + urlbase = reverse('samlproviderconfig-list') + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} + url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) + response = self.client.get(url, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + print(response.data) + results = response.data['results'] + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['entity_id'], SINGLE_PROVIDER_CONFIG['entity_id']) + self.assertEqual(results[0]['metadata_source'], SINGLE_PROVIDER_CONFIG['metadata_source']) + self.assertEqual(SAMLProviderConfig.objects.count(), 1) + + def test_create_one_config(self): + """ + POST auth/saml/v0/providerconfig/?enterprise_customer_uuid=id -d data + """ + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} + url = '{}?{}'.format(reverse('samlproviderconfig-list'), urlencode(query_kwargs)) + data = SINGLE_PROVIDER_CONFIG_2 + orig_count = SAMLProviderConfig.objects.count() + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(SAMLProviderConfig.objects.count(), orig_count + 1) + providerconfig = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug']) + self.assertEqual(providerconfig.name, 'name-of-config-2') diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/urls.py b/common/djangoapps/third_party_auth/samlproviderconfig/urls.py new file mode 100644 index 0000000000..063bb5bfa3 --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderconfig/urls.py @@ -0,0 +1,11 @@ +""" + Viewset for auth/saml/v0/providerconfig/ +""" + +from rest_framework import routers + +from .views import SAMLProviderConfigViewSet + +samlproviderconfig_router = routers.DefaultRouter() +samlproviderconfig_router.register(r'providerconfig', SAMLProviderConfigViewSet, basename="samlproviderconfig") +urlpatterns = samlproviderconfig_router.urls diff --git a/common/djangoapps/third_party_auth/samlproviderconfig/views.py b/common/djangoapps/third_party_auth/samlproviderconfig/views.py new file mode 100644 index 0000000000..24e8c2a244 --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderconfig/views.py @@ -0,0 +1,57 @@ +""" + Viewset for auth/saml/v0/samlproviderconfig +""" + +from edx_rbac.mixins import PermissionRequiredMixin +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, viewsets +from rest_framework.authentication import SessionAuthentication + +from enterprise.models import EnterpriseCustomerIdentityProvider +from openedx.features.enterprise_support.utils import fetch_enterprise_customer_by_id + +from ..models import SAMLProviderConfig +from .serializers import SAMLProviderConfigSerializer + + +class SAMLProviderMixin(object): + authentication_classes = [JwtAuthentication, SessionAuthentication] + permission_classes = [permissions.IsAuthenticated] + serializer_class = SAMLProviderConfigSerializer + + +class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, viewsets.ModelViewSet): + """ + A View to handle SAMLProviderConfig CRUD + + Usage: + [HttpVerb] /auth/saml/v0/providerconfig/?enterprise-id=uuid + + permission_required refers to the Django permission name defined + in enterprise.rules. + The associated rule will allow edx-rbac to check if the EnterpriseCustomer + returned by the get_permission_object method here, can be + accessed by the user making this request (request.user) + Access is only allowed if the user has the system role + of 'ENTERPRISE_ADMIN' which is defined in enterprise.constants + """ + permission_required = 'enterprise.can_access_admin_dashboard' + + def get_queryset(self): + """ + Find and return the matching providerid for the given enterprise uuid + """ + enterprise_customer_idp = EnterpriseCustomerIdentityProvider.objects.get( + enterprise_customer__uuid=self.requested_enterprise_uuid + ) + return SAMLProviderConfig.objects.filter(pk=enterprise_customer_idp.provider_id) + + @property + def requested_enterprise_uuid(self): + return self.request.query_params.get('enterprise_customer_uuid') + + def get_permission_object(self): + """ + Retrive an EnterpriseCustomer to do auth against + """ + return fetch_enterprise_customer_by_id(self.requested_enterprise_uuid) diff --git a/common/djangoapps/third_party_auth/samlproviderdata/__init__.py b/common/djangoapps/third_party_auth/samlproviderdata/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/common/djangoapps/third_party_auth/samlproviderdata/serializers.py b/common/djangoapps/third_party_auth/samlproviderdata/serializers.py new file mode 100644 index 0000000000..14e8c83070 --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderdata/serializers.py @@ -0,0 +1,13 @@ +""" + Serializer for SAMLProviderData +""" + +from rest_framework import serializers + +from third_party_auth.models import SAMLProviderData + + +class SAMLProviderDataSerializer(serializers.ModelSerializer): + class Meta: + model = SAMLProviderData + fields = '__all__' diff --git a/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py new file mode 100644 index 0000000000..f5f33f1dff --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderdata/tests/test_samlproviderdata.py @@ -0,0 +1,102 @@ +import unittest +import copy +import pytz +from uuid import uuid4 +from datetime import datetime +from django.contrib.sites.models import Site +from django.contrib.auth.models import User +from django.urls import reverse +from django.utils.http import urlencode +from rest_framework import status +from rest_framework.test import APITestCase + +from enterprise.models import EnterpriseCustomer, EnterpriseCustomerIdentityProvider +from enterprise.constants import ENTERPRISE_ADMIN_ROLE + +from third_party_auth.tests import testutil +from third_party_auth.models import SAMLProviderData, SAMLProviderConfig +from third_party_auth.tests.samlutils import set_jwt_cookie + +SINGLE_PROVIDER_CONFIG = { + 'entity_id': 'http://entity-id-1', + 'metadata_source': 'http://test.url', + 'name': 'name-of-config', + 'enabled': 'true', + 'slug': 'test-slug' +} + +# entity_id here matches that of the providerconfig, intentionally +# that allows this data entity to be found +SINGLE_DATA_CONFIG = { + 'entity_id': 'http://entity-id-1', + 'sso_url': 'http://test.url', + 'public_key': 'a-key0Aid98', + 'fetched_at': datetime.now(pytz.UTC).replace(microsecond=0) +} + +SINGLE_DATA_CONFIG_2 = copy.copy(SINGLE_DATA_CONFIG) +SINGLE_DATA_CONFIG_2['entity_id'] = 'http://entity-id-2' +SINGLE_DATA_CONFIG_2['sso_url'] = 'http://test2.url' + +ENTERPRISE_ID = str(uuid4()) + + +@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled') +class SAMLProviderDataTests(APITestCase): + """ + API Tests for SAMLProviderConfig REST endpoints + """ + @classmethod + def setUpTestData(cls): + super(SAMLProviderDataTests, cls).setUpTestData() + cls.user = User.objects.create_user(username='testuser', password='testpwd') + cls.site, _ = Site.objects.get_or_create(domain='example.com') + cls.enterprise_customer = EnterpriseCustomer.objects.create( + uuid=ENTERPRISE_ID, + name='test-ep', + slug='test-ep', + site=cls.site) + cls.samlproviderconfig, _ = SAMLProviderConfig.objects.get_or_create( + entity_id=SINGLE_PROVIDER_CONFIG['entity_id'], + metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source'] + ) + # the entity_id here must match that of the samlproviderconfig + cls.samlproviderdata, _ = SAMLProviderData.objects.get_or_create( + entity_id=SINGLE_DATA_CONFIG['entity_id'], + sso_url=SINGLE_DATA_CONFIG['sso_url'], + fetched_at=SINGLE_DATA_CONFIG['fetched_at'] + ) + cls.enterprisecustomeridp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create( + provider_id=cls.samlproviderconfig.id, + enterprise_customer_id=ENTERPRISE_ID + ) + + def setUp(self): + # a cookie with roles: [{enterprise_admin_role: ent_id}] will be + # needed to rbac to authorize access for this view + set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID)]) + self.client.force_authenticate(user=self.user) + + def test_get_one_providedata_success(self): + # GET auth/saml/v0/providerdata/?enterprise_customer_uuid=id + urlbase = reverse('samlproviderdata-list') + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} + url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) + response = self.client.get(url, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_create_one_providerdata_success(self): + # POST auth/saml/v0/providerdata/?enterprise_customer_uuid -d data + urlbase = reverse('samlproviderdata-list') + query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID} + url = '{}?{}'.format(urlbase, urlencode(query_kwargs)) + fetched_at = '2009-01-10 00:12:12' + data = SINGLE_DATA_CONFIG_2 + orig_count = SAMLProviderData.objects.count() + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(SAMLProviderData.objects.count(), orig_count + 1) + self.assertEqual( + SAMLProviderData.objects.get(entity_id=SINGLE_DATA_CONFIG_2['entity_id']).sso_url, + SINGLE_DATA_CONFIG_2['sso_url'] + ) diff --git a/common/djangoapps/third_party_auth/samlproviderdata/urls.py b/common/djangoapps/third_party_auth/samlproviderdata/urls.py new file mode 100644 index 0000000000..2eaa3868fd --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderdata/urls.py @@ -0,0 +1,11 @@ +""" + url mappings for auth/saml/v0/providerdata/ +""" + +from rest_framework import routers + +from .views import SAMLProviderDataViewSet + +samlproviderdata_router = routers.DefaultRouter() +samlproviderdata_router.register(r'providerdata', SAMLProviderDataViewSet, basename="samlproviderdata") +urlpatterns = samlproviderdata_router.urls diff --git a/common/djangoapps/third_party_auth/samlproviderdata/views.py b/common/djangoapps/third_party_auth/samlproviderdata/views.py new file mode 100644 index 0000000000..9bd28675f9 --- /dev/null +++ b/common/djangoapps/third_party_auth/samlproviderdata/views.py @@ -0,0 +1,54 @@ +""" + Viewset for auth/saml/v0/samlproviderdata +""" + +from edx_rbac.mixins import PermissionRequiredMixin +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, viewsets +from rest_framework.authentication import SessionAuthentication + +from enterprise.models import EnterpriseCustomerIdentityProvider +from openedx.features.enterprise_support.utils import fetch_enterprise_customer_by_id + +from ..models import SAMLProviderConfig, SAMLProviderData +from .serializers import SAMLProviderDataSerializer + + +class SAMLProviderDataMixin(object): + authentication_classes = [JwtAuthentication, SessionAuthentication] + permission_classes = [permissions.IsAuthenticated] + serializer_class = SAMLProviderDataSerializer + + +class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, viewsets.ModelViewSet): + """ + A View to handle SAMLProviderData CRUD. + Uses the edx-rbac mixin PermissionRequiredMixin to apply enterprise authorization + + Usage: + [HttpVerb] /auth/saml/v0/providerdata/ + """ + permission_required = 'enterprise.can_access_admin_dashboard' + + def get_queryset(self): + """ + Find and return the matching providerid for the given enterprise uuid + Note: There is no direct association between samlproviderdata and enterprisecustomer. + So we make that association in code via samlproviderdata > samlproviderconfig ( via entity_id ) + then, we fetch enterprisecustomer via samlproviderconfig > enterprisecustomer ( via association table ) + """ + enterprise_customer_idp = EnterpriseCustomerIdentityProvider.objects.get( + enterprise_customer__uuid=self.requested_enterprise_uuid + ) + saml_provider = SAMLProviderConfig.objects.get(pk=enterprise_customer_idp.provider_id) + return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id) + + @property + def requested_enterprise_uuid(self): + return self.request.query_params.get('enterprise_customer_uuid') + + def get_permission_object(self): + """ + Retrive an EnterpriseCustomer to do auth against + """ + return fetch_enterprise_customer_by_id(self.requested_enterprise_uuid) diff --git a/common/djangoapps/third_party_auth/tests/samlutils.py b/common/djangoapps/third_party_auth/tests/samlutils.py new file mode 100644 index 0000000000..9b9fdbf110 --- /dev/null +++ b/common/djangoapps/third_party_auth/tests/samlutils.py @@ -0,0 +1,33 @@ +""" +Utility functions for use in SAMLProviderConfig, SAMLProviderData tests +""" + +from edx_rest_framework_extensions.auth.jwt.cookies import jwt_cookie_name +from edx_rest_framework_extensions.auth.jwt.tests.utils import ( + generate_jwt_token, + generate_unversioned_payload, +) + + +def _jwt_token_from_role_context_pairs(user, role_context_pairs): + """ + Generates a new JWT token with roles assigned from pairs of (role name, context). + """ + roles = [] + for role, context in role_context_pairs: + role_data = '{role}'.format(role=role) + if context is not None: + role_data += ':{context}'.format(context=context) + roles.append(role_data) + + payload = generate_unversioned_payload(user) + payload.update({'roles': roles}) + return generate_jwt_token(payload) + + +def set_jwt_cookie(client, user, role_context_pairs=None): + """ + Set jwt token in cookies + """ + jwt_token = _jwt_token_from_role_context_pairs(user, role_context_pairs or []) + client.cookies[jwt_cookie_name()] = jwt_token diff --git a/common/djangoapps/third_party_auth/urls.py b/common/djangoapps/third_party_auth/urls.py index 00175c18fe..fd743af060 100644 --- a/common/djangoapps/third_party_auth/urls.py +++ b/common/djangoapps/third_party_auth/urls.py @@ -1,6 +1,5 @@ """Url configuration for the auth module.""" - from django.conf.urls import include, url from .views import ( @@ -18,4 +17,6 @@ urlpatterns = [ url(r'^auth/login/(?Plti)/$', lti_login_and_complete_view), url(r'^auth/idp_redirect/(?P[\w-]+)', IdPRedirectView.as_view(), name="idp_redirect"), url(r'^auth/', include('social_django.urls', namespace='social')), + url(r'^auth/saml/v0/', include('third_party_auth.samlproviderconfig.urls')), + url(r'^auth/saml/v0/', include('third_party_auth.samlproviderdata.urls')) ] diff --git a/openedx/features/enterprise_support/utils.py b/openedx/features/enterprise_support/utils.py index b58d164925..36e24c4c1d 100644 --- a/openedx/features/enterprise_support/utils.py +++ b/openedx/features/enterprise_support/utils.py @@ -10,7 +10,7 @@ from django.conf import settings from django.urls import NoReverseMatch, reverse from django.utils.translation import ugettext as _ from edx_django_utils.cache import TieredCache, get_cache_key -from enterprise.models import EnterpriseCustomerUser +from enterprise.models import EnterpriseCustomerUser, EnterpriseCustomer from social_django.models import UserSocialAuth import third_party_auth @@ -342,3 +342,7 @@ def get_provider_login_url(request, provider_id, redirect_url=None): redirect_url=redirect_url if redirect_url else get_next_url_for_login_page(request) ) return provider_login_url + + +def fetch_enterprise_customer_by_id(enterprise_uuid): + return EnterpriseCustomer.objects.get(uuid=enterprise_uuid)