diff --git a/common/djangoapps/entitlements/api/v1/permissions.py b/common/djangoapps/entitlements/api/v1/permissions.py new file mode 100644 index 0000000000..16361f3592 --- /dev/null +++ b/common/djangoapps/entitlements/api/v1/permissions.py @@ -0,0 +1,19 @@ +""" +This module provides a custom DRF Permission class for supporting SAFE_METHODS to Authenticated Users, but +requiring Superuser access for all other Request types on an API endpoint. +""" + +from rest_framework.permissions import BasePermission, SAFE_METHODS + + +class IsAdminOrAuthenticatedReadOnly(BasePermission): + """ + Method that will require staff access for all methods not + in the SAFE_METHODS list. For example GET requests will not + require a Staff or Admin user. + """ + def has_permission(self, request, view): + if request.method in SAFE_METHODS: + return request.user.is_authenticated + else: + return request.user.is_staff diff --git a/common/djangoapps/entitlements/api/v1/tests/test_views.py b/common/djangoapps/entitlements/api/v1/tests/test_views.py index ed5a4a2da3..3e791f1ca0 100644 --- a/common/djangoapps/entitlements/api/v1/tests/test_views.py +++ b/common/djangoapps/entitlements/api/v1/tests/test_views.py @@ -6,6 +6,7 @@ from django.conf import settings from django.core.urlresolvers import reverse from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory + from student.tests.factories import CourseEnrollmentFactory, UserFactory, TEST_PASSWORD # Entitlements is not in CMS' INSTALLED_APPS so these imports will error during test collection @@ -42,11 +43,11 @@ class EntitlementViewSetTest(ModuleStoreTestCase): response = self.client.get(self.entitlements_list_url) assert response.status_code == 401 - def test_staff_user_required(self): + def test_staff_user_not_required_for_get(self): not_staff_user = UserFactory() - self.client.login(username=not_staff_user.username, password=UserFactory._DEFAULT_PASSWORD) + self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) response = self.client.get(self.entitlements_list_url) - assert response.status_code == 403 + assert response.status_code == 200 def test_add_entitlement_with_missing_data(self): entitlement_data_missing_parts = self._get_data_set(self.user, str(uuid.uuid4())) @@ -60,6 +61,33 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ) assert response.status_code == 400 + def test_staff_user_required_for_post(self): + not_staff_user = UserFactory() + self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) + + course_uuid = uuid.uuid4() + entitlement_data = self._get_data_set(self.user, str(course_uuid)) + + response = self.client.post( + self.entitlements_list_url, + data=json.dumps(entitlement_data), + content_type='application/json', + ) + assert response.status_code == 403 + + def test_staff_user_required_for_delete(self): + not_staff_user = UserFactory() + self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) + + course_entitlement = CourseEntitlementFactory() + url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(course_entitlement.uuid)]) + + response = self.client.delete( + url, + content_type='application/json', + ) + assert response.status_code == 403 + def test_add_entitlement(self): course_uuid = uuid.uuid4() entitlement_data = self._get_data_set(self.user, str(course_uuid)) @@ -78,7 +106,21 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ) assert results == CourseEntitlementSerializer(course_entitlement).data - def test_get_entitlements(self): + def test_non_staff_get_select_entitlements(self): + not_staff_user = UserFactory() + self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) + CourseEntitlementFactory.create_batch(2) + entitlement = CourseEntitlementFactory.create(user=not_staff_user) + response = self.client.get( + self.entitlements_list_url, + content_type='application/json', + ) + assert response.status_code == 200 + + results = response.data.get('results', []) # pylint: disable=no-member + assert results == CourseEntitlementSerializer([entitlement], many=True).data + + def test_staff_get_all_entitlements(self): entitlements = CourseEntitlementFactory.create_batch(2) response = self.client.get( @@ -109,7 +151,6 @@ class EntitlementViewSetTest(ModuleStoreTestCase): entitlement = CourseEntitlementFactory() CourseEntitlementFactory.create_batch(2) - CourseEntitlementFactory() url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(entitlement.uuid)]) response = self.client.get( diff --git a/common/djangoapps/entitlements/api/v1/views.py b/common/djangoapps/entitlements/api/v1/views.py index c03c443f30..97bdf5f962 100644 --- a/common/djangoapps/entitlements/api/v1/views.py +++ b/common/djangoapps/entitlements/api/v1/views.py @@ -4,26 +4,32 @@ from django.utils import timezone from django_filters.rest_framework import DjangoFilterBackend from edx_rest_framework_extensions.authentication import JwtAuthentication from rest_framework import permissions, viewsets -from rest_framework.authentication import SessionAuthentication from entitlements.api.v1.filters import CourseEntitlementFilter -from entitlements.models import CourseEntitlement +from entitlements.api.v1.permissions import IsAdminOrAuthenticatedReadOnly from entitlements.api.v1.serializers import CourseEntitlementSerializer +from entitlements.models import CourseEntitlement +from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from student.models import CourseEnrollment log = logging.getLogger(__name__) class EntitlementViewSet(viewsets.ModelViewSet): - authentication_classes = (JwtAuthentication, SessionAuthentication,) - permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser,) - queryset = CourseEntitlement.objects.all().select_related('user') + authentication_classes = (JwtAuthentication, SessionAuthenticationCrossDomainCsrf,) + permission_classes = (permissions.IsAuthenticated, IsAdminOrAuthenticatedReadOnly,) lookup_value_regex = '[0-9a-f-]+' lookup_field = 'uuid' serializer_class = CourseEntitlementSerializer filter_backends = (DjangoFilterBackend,) filter_class = CourseEntitlementFilter + def get_queryset(self): + user = self.request.user + if user.is_staff: + return CourseEntitlement.objects.all().select_related('user') + return CourseEntitlement.objects.filter(user=user).select_related('user') + def perform_destroy(self, instance): """ This method is an override and is called by the DELETE method diff --git a/common/djangoapps/entitlements/tests/factories.py b/common/djangoapps/entitlements/tests/factories.py index f22073136e..6daa8ccb31 100644 --- a/common/djangoapps/entitlements/tests/factories.py +++ b/common/djangoapps/entitlements/tests/factories.py @@ -1,18 +1,20 @@ import string -import uuid +from uuid import uuid4 import factory from factory.fuzzy import FuzzyChoice, FuzzyText from entitlements.models import CourseEntitlement from student.tests.factories import UserFactory +from course_modes.helpers import CourseMode class CourseEntitlementFactory(factory.django.DjangoModelFactory): class Meta(object): model = CourseEntitlement - course_uuid = uuid.uuid4() - mode = FuzzyChoice(['verified', 'profesional']) + uuid = factory.LazyFunction(uuid4) + course_uuid = factory.LazyFunction(uuid4) + mode = FuzzyChoice([CourseMode.VERIFIED, CourseMode.PROFESSIONAL]) user = factory.SubFactory(UserFactory) order_number = FuzzyText(prefix='TEXTX', chars=string.digits)