diff --git a/cms/envs/common.py b/cms/envs/common.py index b1002e1222..e2ae72f2dd 100644 --- a/cms/envs/common.py +++ b/cms/envs/common.py @@ -44,7 +44,10 @@ from lms.envs.common import ( PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES, # The following setting is included as it is used to check whether to # display credit eligibility table on the CMS or not. - ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY + ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY, + + # Django REST framework configuration + REST_FRAMEWORK, ) from path import Path as path from warnings import simplefilter diff --git a/common/djangoapps/cors_csrf/tests/test_authentication.py b/common/djangoapps/cors_csrf/tests/test_authentication.py index 7f2d78fd8a..ea4587021a 100644 --- a/common/djangoapps/cors_csrf/tests/test_authentication.py +++ b/common/djangoapps/cors_csrf/tests/test_authentication.py @@ -6,7 +6,7 @@ from django.test.utils import override_settings from django.test.client import RequestFactory from django.conf import settings -from rest_framework.exceptions import AuthenticationFailed +from rest_framework.exceptions import PermissionDenied from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf @@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase): def test_perform_csrf_referer_check(self): request = self._fake_request() - with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'): + with self.assertRaisesRegexp(PermissionDenied, 'CSRF'): self.auth.enforce_csrf(request) @patch.dict(settings.FEATURES, { diff --git a/common/djangoapps/enrollment/data.py b/common/djangoapps/enrollment/data.py index 33f78e3e4a..68880e1a3e 100644 --- a/common/djangoapps/enrollment/data.py +++ b/common/djangoapps/enrollment/data.py @@ -11,7 +11,7 @@ from enrollment.errors import ( CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError, CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute ) -from enrollment.serializers import CourseEnrollmentSerializer, CourseField +from enrollment.serializers import CourseEnrollmentSerializer, CourseSerializer from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from student.models import ( CourseEnrollment, NonExistentCourseError, EnrollmentClosedError, @@ -35,9 +35,30 @@ def get_course_enrollments(user_id): """ qset = CourseEnrollment.objects.filter( - user__username=user_id, is_active=True + user__username=user_id, + is_active=True ).order_by('created') - return CourseEnrollmentSerializer(qset).data + + enrollments = CourseEnrollmentSerializer(qset, many=True).data + + # Find deleted courses and filter them out of the results + deleted = [] + valid = [] + for enrollment in enrollments: + if enrollment.get("course_details") is not None: + valid.append(enrollment) + else: + deleted.append(enrollment) + + if deleted: + log.warning( + ( + u"Course enrollments for user %s reference " + u"courses that do not exist (this can occur if a course is deleted)." + ), user_id, + ) + + return valid def get_course_enrollment(username, course_id): @@ -271,4 +292,4 @@ def get_course_enrollment_info(course_id, include_expired=False): log.warning(msg) raise CourseNotFoundError(msg) else: - return CourseField().to_native(course, include_expired=include_expired) + return CourseSerializer(course, include_expired=include_expired).data diff --git a/common/djangoapps/enrollment/serializers.py b/common/djangoapps/enrollment/serializers.py index ca5c4bfd93..994f97e5fc 100644 --- a/common/djangoapps/enrollment/serializers.py +++ b/common/djangoapps/enrollment/serializers.py @@ -30,32 +30,36 @@ class StringListField(serializers.CharField): return [int(item) for item in items] -class CourseField(serializers.RelatedField): - """Read-Only representation of course enrollment information. - - Aggregates course information from the CourseDescriptor as well as the Course Modes configured - for enrolling in the course. - +class CourseSerializer(serializers.Serializer): # pylint: disable=abstract-method + """ + Serialize a course descriptor and related information. """ - def to_native(self, course, **kwargs): - course_modes = ModeSerializer( - CourseMode.modes_for_course( - course.id, - include_expired=kwargs.get('include_expired', False), - only_selectable=False - ) - ).data + course_id = serializers.CharField(source="id") + enrollment_start = serializers.DateTimeField(format=None) + enrollment_end = serializers.DateTimeField(format=None) + course_start = serializers.DateTimeField(source="start", format=None) + course_end = serializers.DateTimeField(source="end", format=None) + invite_only = serializers.BooleanField(source="invitation_only") + course_modes = serializers.SerializerMethodField() - return { - 'course_id': unicode(course.id), - 'enrollment_start': course.enrollment_start, - 'enrollment_end': course.enrollment_end, - 'course_start': course.start, - 'course_end': course.end, - 'invite_only': course.invitation_only, - 'course_modes': course_modes, - } + def __init__(self, *args, **kwargs): + self.include_expired = kwargs.pop("include_expired", False) + super(CourseSerializer, self).__init__(*args, **kwargs) + + def get_course_modes(self, obj): + """ + Retrieve course modes associated with the course. + """ + course_modes = CourseMode.modes_for_course( + obj.id, + include_expired=self.include_expired, + only_selectable=False + ) + return [ + ModeSerializer(mode).data + for mode in course_modes + ] class CourseEnrollmentSerializer(serializers.ModelSerializer): @@ -65,34 +69,9 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): the Course Descriptor and course modes, to give a complete representation of course enrollment. """ - course_details = serializers.SerializerMethodField('get_course_details') + course_details = CourseSerializer(source="course_overview") user = serializers.SerializerMethodField('get_username') - @property - def data(self): - serialized_data = super(CourseEnrollmentSerializer, self).data - - # filter the results with empty courses 'course_details' - if isinstance(serialized_data, dict): - if serialized_data.get('course_details') is None: - return None - - return serialized_data - - return [enrollment for enrollment in serialized_data if enrollment.get('course_details')] - - def get_course_details(self, model): - if model.course is None: - msg = u"Course '{0}' does not exist (maybe deleted), in which User (user_id: '{1}') is enrolled.".format( - model.course_id, - model.user.id - ) - log.warning(msg) - return None - - field = CourseField() - return field.to_native(model.course) - def get_username(self, model): """Retrieves the username from the associated model.""" return model.username diff --git a/common/djangoapps/enrollment/tests/test_views.py b/common/djangoapps/enrollment/tests/test_views.py index bd5a3ffbc2..20052bff34 100644 --- a/common/djangoapps/enrollment/tests/test_views.py +++ b/common/djangoapps/enrollment/tests/test_views.py @@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase): @cross_domain_config def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument resp = self._cross_domain_post('invalid_csrf_token') - self.assertEqual(resp.status_code, 401) + self.assertEqual(resp.status_code, 403) def _get_csrf_cookie(self): """Retrieve the cross-domain CSRF cookie. """ diff --git a/common/djangoapps/request_cache/__init__.py b/common/djangoapps/request_cache/__init__.py index 7d8bf77ef6..aedc785cad 100644 --- a/common/djangoapps/request_cache/__init__.py +++ b/common/djangoapps/request_cache/__init__.py @@ -5,10 +5,18 @@ This module requires that :class:`request_cache.middleware.RequestCache` is installed in order to clear the cache after each request. """ +import logging +from urlparse import urlparse + +from django.conf import settings +from django.test.client import RequestFactory from request_cache import middleware +log = logging.getLogger(__name__) + + def get_cache(name): """ Return the request cache named ``name``. @@ -26,3 +34,38 @@ def get_request(): Return the current request. """ return middleware.RequestCache.get_current_request() + + +def get_request_or_stub(): + """ + Return the current request or a stub request. + + If called outside the context of a request, construct a fake + request that can be used to build an absolute URI. + + This is useful in cases where we need to pass in a request object + but don't have an active request (for example, in test cases). + """ + request = get_request() + + if request is None: + log.warning( + "Could not retrieve the current request. " + "A stub request will be created instead using settings.SITE_NAME. " + "This should be used *only* in test cases, never in production!" + ) + + # The settings SITE_NAME may contain a port number, so we need to + # parse the full URL. + full_url = "http://{site_name}".format(site_name=settings.SITE_NAME) + parsed_url = urlparse(full_url) + + # Construct the fake request. This can be used to construct absolute + # URIs to other paths. + return RequestFactory( + SERVER_NAME=parsed_url.hostname, + SERVER_PORT=parsed_url.port or 80, + ).get("/") + + else: + return request diff --git a/common/djangoapps/request_cache/tests.py b/common/djangoapps/request_cache/tests.py new file mode 100644 index 0000000000..3883fe873b --- /dev/null +++ b/common/djangoapps/request_cache/tests.py @@ -0,0 +1,20 @@ +""" +Tests for the request cache. +""" +from django.conf import settings +from django.test import TestCase + +from request_cache import get_request_or_stub + + +class TestRequestCache(TestCase): + """ + Tests for the request cache. + """ + + def test_get_request_or_stub(self): + # Outside the context of the request, we should still get a request + # that allows us to build an absolute URI. + stub = get_request_or_stub() + expected_url = "http://{site_name}/foobar".format(site_name=settings.SITE_NAME) + self.assertEqual(stub.build_absolute_uri("foobar"), expected_url) diff --git a/common/test/acceptance/tests/lms/test_teams.py b/common/test/acceptance/tests/lms/test_teams.py index 40d6dc963d..33e4c0f0a7 100644 --- a/common/test/acceptance/tests/lms/test_teams.py +++ b/common/test/acceptance/tests/lms/test_teams.py @@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase): ) create_team_page.submit_form() team_page = TeamPage(self.browser, self.course_id) - self.assertTrue(team_page.is_browser_on_page) + self.assertTrue(team_page.is_browser_on_page()) team_page.click_all_topics() self.assertTrue(self.topics_page.is_browser_on_page()) self.topics_page.wait_for_ajax() diff --git a/lms/djangoapps/commerce/api/v1/serializers.py b/lms/djangoapps/commerce/api/v1/serializers.py index a904fa7238..775867fe85 100644 --- a/lms/djangoapps/commerce/api/v1/serializers.py +++ b/lms/djangoapps/commerce/api/v1/serializers.py @@ -18,7 +18,12 @@ class CourseModeSerializer(serializers.ModelSerializer): """ CourseMode serializer. """ name = serializers.CharField(source='mode_slug') price = serializers.IntegerField(source='min_price') - expires = serializers.DateTimeField(source='expiration_datetime', required=False, blank=True) + expires = serializers.DateTimeField( + source='expiration_datetime', + required=False, + allow_null=True, + format=None + ) def get_identity(self, data): try: @@ -56,8 +61,8 @@ class CourseSerializer(serializers.Serializer): """ Course serializer. """ id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name name = serializers.CharField(read_only=True) - verification_deadline = serializers.DateTimeField(blank=True) - modes = CourseModeSerializer(many=True, allow_add_remove=True) + verification_deadline = serializers.DateTimeField(format=None, allow_null=True, required=False) + modes = CourseModeSerializer(many=True) def validate(self, attrs): """ Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """ @@ -68,7 +73,7 @@ class CourseSerializer(serializers.Serializer): # Find the earliest upgrade deadline for mode in attrs['modes']: - expires = mode.expiration_datetime + expires = mode.get("expiration_datetime") if expires: # If we don't already have an upgrade_deadline value, use datetime.max so that we can actually # complete the comparison. @@ -82,9 +87,28 @@ class CourseSerializer(serializers.Serializer): return attrs - def restore_object(self, attrs, instance=None): - if instance is None: - return Course(attrs['id'], attrs['modes'], attrs['verification_deadline']) + def create(self, validated_data): + """Create course modes for a course. """ + course = Course( + validated_data["id"], + self._new_course_mode_models(validated_data["modes"]), + verification_deadline=validated_data["verification_deadline"] + ) + course.save() + return course - instance.update(attrs) + def update(self, instance, validated_data): + """Update course modes for an existing course. """ + validated_data["modes"] = self._new_course_mode_models(validated_data["modes"]) + + instance.update(validated_data) + instance.save() return instance + + @staticmethod + def _new_course_mode_models(modes_data): + """Convert validated course mode data to CourseMode objects. """ + return [ + CourseMode(**modes_dict) + for modes_dict in modes_data + ] diff --git a/lms/djangoapps/commerce/api/v1/views.py b/lms/djangoapps/commerce/api/v1/views.py index 34b58c7915..1c1b32ccad 100644 --- a/lms/djangoapps/commerce/api/v1/views.py +++ b/lms/djangoapps/commerce/api/v1/views.py @@ -2,7 +2,8 @@ import logging from django.http import Http404 -from rest_framework.authentication import OAuth2Authentication, SessionAuthentication +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView from rest_framework.permissions import IsAuthenticated @@ -10,6 +11,7 @@ from commerce.api.v1.models import Course from commerce.api.v1.permissions import ApiKeyOrModelPermission from commerce.api.v1.serializers import CourseSerializer from course_modes.models import CourseMode +from openedx.core.lib.api.mixins import PutAsCreateMixin log = logging.getLogger(__name__) @@ -19,12 +21,13 @@ class CourseListView(ListAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication,) permission_classes = (IsAuthenticated,) serializer_class = CourseSerializer + pagination_class = None def get_queryset(self): - return Course.iterator() + return list(Course.iterator()) -class CourseRetrieveUpdateView(RetrieveUpdateAPIView): +class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView): """ Retrieve, update, or create courses/modes. """ lookup_field = 'id' lookup_url_kwarg = 'course_id' @@ -33,6 +36,11 @@ class CourseRetrieveUpdateView(RetrieveUpdateAPIView): permission_classes = (ApiKeyOrModelPermission,) serializer_class = CourseSerializer + # Django Rest Framework v3 requires that we provide a queryset. + # Note that we're overriding `get_object()` below to return a `Course` + # rather than a CourseMode, so this isn't really used. + queryset = CourseMode.objects.all() + def get_object(self, queryset=None): course_id = self.kwargs.get(self.lookup_url_kwarg) course = Course.get(course_id) diff --git a/lms/djangoapps/course_structure_api/v0/serializers.py b/lms/djangoapps/course_structure_api/v0/serializers.py index ad95b8e985..bd079544e3 100644 --- a/lms/djangoapps/course_structure_api/v0/serializers.py +++ b/lms/djangoapps/course_structure_api/v0/serializers.py @@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer): id = serializers.CharField() # pylint: disable=invalid-name name = serializers.CharField(source='display_name') category = serializers.CharField() - org = serializers.SerializerMethodField('get_org') - run = serializers.SerializerMethodField('get_run') - course = serializers.SerializerMethodField('get_course') - uri = serializers.SerializerMethodField('get_uri') - image_url = serializers.SerializerMethodField('get_image_url') + org = serializers.SerializerMethodField() + run = serializers.SerializerMethodField() + course = serializers.SerializerMethodField() + uri = serializers.SerializerMethodField() + image_url = serializers.SerializerMethodField() start = serializers.DateTimeField() end = serializers.DateTimeField() diff --git a/lms/djangoapps/course_structure_api/v0/tests.py b/lms/djangoapps/course_structure_api/v0/tests.py index c74f76f601..08224e521a 100644 --- a/lms/djangoapps/course_structure_api/v0/tests.py +++ b/lms/djangoapps/course_structure_api/v0/tests.py @@ -36,6 +36,23 @@ class CourseViewTestsMixin(object): """ view = None + raw_grader = [ + { + "min_count": 24, + "weight": 0.2, + "type": "Homework", + "drop_count": 0, + "short_label": "HW" + }, + { + "min_count": 4, + "weight": 0.8, + "type": "Exam", + "drop_count": 0, + "short_label": "Exam" + } + ] + def setUp(self): super(CourseViewTestsMixin, self).setUp() self.create_user_and_access_token() @@ -51,22 +68,7 @@ class CourseViewTestsMixin(object): @classmethod def create_course_data(cls): cls.invalid_course_id = 'foo/bar/baz' - cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=[ - { - "min_count": 24, - "weight": 0.2, - "type": "Homework", - "drop_count": 0, - "short_label": "HW" - }, - { - "min_count": 4, - "weight": 0.8, - "type": "Exam", - "drop_count": 0, - "short_label": "Exam" - } - ]) + cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=cls.raw_grader) cls.course_id = unicode(cls.course.id) with cls.store.bulk_operations(cls.course.id, emit_signals=False): cls.sequential = ItemFactory.create( @@ -408,6 +410,55 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar self.assertListEqual(response.data, expected) +class CourseGradingPolicyMissingFieldsTests(CourseDetailTestMixin, CourseViewTestsMixin, SharedModuleStoreTestCase): + view = 'course_structure_api:v0:grading_policy' + + # Update the raw grader to have missing keys + raw_grader = [ + { + "min_count": 24, + "weight": 0.2, + "type": "Homework", + "drop_count": 0, + "short_label": "HW" + }, + { + # Deleted "min_count" key + "weight": 0.8, + "type": "Exam", + "drop_count": 0, + "short_label": "Exam" + } + ] + + @classmethod + def setUpClass(cls): + super(CourseGradingPolicyMissingFieldsTests, cls).setUpClass() + cls.create_course_data() + + def test_get(self): + """ + The view should return grading policy for a course. + """ + response = super(CourseGradingPolicyMissingFieldsTests, self).test_get() + + expected = [ + { + "count": 24, + "weight": 0.2, + "assignment_type": "Homework", + "dropped": 0 + }, + { + "count": None, + "weight": 0.8, + "assignment_type": "Exam", + "dropped": 0 + } + ] + self.assertListEqual(response.data, expected) + + ##################################################################################### # # The following Mixins/Classes collectively test the CourseBlocksAndNavigation view. diff --git a/lms/djangoapps/course_structure_api/v0/views.py b/lms/djangoapps/course_structure_api/v0/views.py index 347ba45c24..02986c3934 100644 --- a/lms/djangoapps/course_structure_api/v0/views.py +++ b/lms/djangoapps/course_structure_api/v0/views.py @@ -6,7 +6,8 @@ import logging from django.conf import settings from django.http import Http404 -from rest_framework.authentication import OAuth2Authentication, SessionAuthentication +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework.exceptions import AuthenticationFailed, ParseError from rest_framework.generics import RetrieveAPIView, ListAPIView from rest_framework.permissions import IsAuthenticated @@ -21,7 +22,6 @@ from courseware.access import has_access from courseware.model_data import FieldDataCache from courseware.module_render import get_module_for_descriptor from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes -from openedx.core.lib.api.serializers import PaginationSerializer from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors from student.roles import CourseInstructorRole, CourseStaffRole from util.module_utils import get_dynamic_descriptor_children @@ -157,9 +157,6 @@ class CourseList(CourseViewMixin, ListAPIView): * end: The course end date. If course end date is not specified, the value is null. """ - paginate_by = 10 - paginate_by_param = 'page_size' - pagination_serializer_class = PaginationSerializer serializer_class = serializers.CourseSerializer def get_queryset(self): diff --git a/lms/djangoapps/courseware/grades.py b/lms/djangoapps/courseware/grades.py index 5b1a9741d5..008e0644e4 100644 --- a/lms/djangoapps/courseware/grades.py +++ b/lms/djangoapps/courseware/grades.py @@ -25,7 +25,6 @@ from xmodule.modulestore.django import modulestore from xmodule.modulestore.exceptions import ItemNotFoundError from .models import StudentModule from .module_render import get_module_for_descriptor -from submissions import api as sub_api # installed from the edx-submissions repository from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey from openedx.core.djangoapps.signals.signals import GRADES_UPDATED @@ -349,8 +348,13 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c # Dict of item_ids -> (earned, possible) point tuples. This *only* grabs # scores that were registered with the submissions API, which for the moment # means only openassessment (edx-ora2) + # We need to import this here to avoid a circular dependency of the form: + # XBlock --> submissions --> Django Rest Framework error strings --> + # Django translation --> ... --> courseware --> submissions + from submissions import api as sub_api # installed from the edx-submissions repository submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) max_scores_cache = MaxScoresCache.create_for_course(course) + # For the moment, we have to get scorable_locations from field_data_cache # and not from scores_client, because scores_client is ignorant of things # in the submissions API. As a further refactoring step, submissions should @@ -565,7 +569,12 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl course_module = getattr(course_module, '_x_module', course_module) + # We need to import this here to avoid a circular dependency of the form: + # XBlock --> submissions --> Django Rest Framework error strings --> + # Django translation --> ... --> courseware --> submissions + from submissions import api as sub_api # installed from the edx-submissions repository submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) + max_scores_cache = MaxScoresCache.create_for_course(course) # For the moment, we have to get scorable_locations from field_data_cache # and not from scores_client, because scores_client is ignorant of things diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index d8153ff486..86d4d4ca8b 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -560,7 +560,7 @@ def create_thread(request, thread_data): if not (serializer.is_valid() and actions_form.is_valid()): raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) serializer.save() - cc_thread = serializer.object + cc_thread = serializer.instance thread_created.send(sender=None, user=user, post=cc_thread) api_thread = serializer.data _do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context) @@ -606,7 +606,7 @@ def create_comment(request, comment_data): if not (serializer.is_valid() and actions_form.is_valid()): raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) serializer.save() - cc_comment = serializer.object + cc_comment = serializer.instance comment_created.send(sender=None, user=request.user, post=cc_comment) api_comment = serializer.data _do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context) diff --git a/lms/djangoapps/discussion_api/pagination.py b/lms/djangoapps/discussion_api/pagination.py index 5f14b430e8..e818c5da19 100644 --- a/lms/djangoapps/discussion_api/pagination.py +++ b/lms/djangoapps/discussion_api/pagination.py @@ -1,16 +1,7 @@ """ Discussion API pagination support """ -from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField - - -class _PaginationSerializer(BasePaginationSerializer): - """ - A pagination serializer without the count field, because the Comments - Service does not return result counts - """ - next = NextPageField(source="*") - previous = PreviousPageField(source="*") +from rest_framework.utils.urls import replace_query_param class _Page(object): @@ -52,7 +43,25 @@ def get_paginated_data(request, results, page_num, per_page): previous: The URL for the previous page results: The results on this page """ - return _PaginationSerializer( - instance=_Page(results, page_num, per_page), - context={"request": request} - ).data + # Note: Previous versions of this function used Django Rest Framework's + # paginated serializer. With the upgrade to DRF 3.1, paginated serializers + # have been removed. We *could* use DRF's paginator classes, but there are + # some slight differences between how DRF does pagination and how we're doing + # pagination here. (For example, we respond with a next_url param even if + # there is only one result on the current page.) To maintain backwards + # compatability, we simulate the behavior that DRF used to provide. + page = _Page(results, page_num, per_page) + next_url, previous_url = None, None + base_url = request.build_absolute_uri() + + if page.has_next(): + next_url = replace_query_param(base_url, "page", page.next_page_number()) + + if page.has_previous(): + previous_url = replace_query_param(base_url, "page", page.previous_page_number()) + + return { + "next": next_url, + "previous": previous_url, + "results": results, + } diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py index 56dc9495bb..e65716ea25 100644 --- a/lms/djangoapps/discussion_api/serializers.py +++ b/lms/djangoapps/discussion_api/serializers.py @@ -28,7 +28,6 @@ from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User as CommentClientUser from lms.lib.comment_client.utils import CommentClientRequestError from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names -from openedx.core.lib.api.fields import NonEmptyCharField def get_context(course, request, thread=None): @@ -66,36 +65,43 @@ def get_context(course, request, thread=None): } +def validate_not_blank(value): + """ + Validate that a value is not an empty string or whitespace. + + Raises: ValidationError + """ + if not value.strip(): + raise ValidationError("This field may not be blank.") + + class _ContentSerializer(serializers.Serializer): """A base class for thread and comment serializers.""" - id_ = serializers.CharField(read_only=True) - author = serializers.SerializerMethodField("get_author") - author_label = serializers.SerializerMethodField("get_author_label") + id = serializers.CharField(read_only=True) # pylint: disable=invalid-name + author = serializers.SerializerMethodField() + author_label = serializers.SerializerMethodField() created_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True) - raw_body = NonEmptyCharField(source="body") - rendered_body = serializers.SerializerMethodField("get_rendered_body") - abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") - voted = serializers.SerializerMethodField("get_voted") - vote_count = serializers.SerializerMethodField("get_vote_count") - editable_fields = serializers.SerializerMethodField("get_editable_fields") + raw_body = serializers.CharField(source="body", validators=[validate_not_blank]) + rendered_body = serializers.SerializerMethodField() + abuse_flagged = serializers.SerializerMethodField() + voted = serializers.SerializerMethodField() + vote_count = serializers.SerializerMethodField() + editable_fields = serializers.SerializerMethodField() non_updatable_fields = set() def __init__(self, *args, **kwargs): super(_ContentSerializer, self).__init__(*args, **kwargs) - # id is an invalid class attribute name, so we must declare a different - # name above and modify it here - self.fields["id"] = self.fields.pop("id_") for field in self.non_updatable_fields: setattr(self, "validate_{}".format(field), self._validate_non_updatable) - def _validate_non_updatable(self, attrs, _source): + def _validate_non_updatable(self, value): """Ensure that a field is not edited in an update operation.""" - if self.object: + if self.instance: raise ValidationError("This field is not allowed in an update.") - return attrs + return value def _is_user_privileged(self, user_id): """ @@ -131,7 +137,11 @@ class _ContentSerializer(serializers.Serializer): def get_author_label(self, obj): """Returns the role label for the content author.""" - return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"])) + if self._is_anonymous(obj) or obj["user_id"] is None: + return None + else: + user_id = int(obj["user_id"]) + return self._get_user_label(user_id) def get_rendered_body(self, obj): """Returns the rendered body content.""" @@ -142,7 +152,7 @@ class _ContentSerializer(serializers.Serializer): Returns a boolean indicating whether the requester has flagged the content as abusive. """ - return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] + return self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", []) def get_voted(self, obj): """ @@ -153,7 +163,7 @@ class _ContentSerializer(serializers.Serializer): def get_vote_count(self, obj): """Returns the number of votes for the content.""" - return obj["votes"]["up_count"] + return obj.get("votes", {}).get("up_count", 0) def get_editable_fields(self, obj): """Return the list of the fields the requester can edit""" @@ -169,28 +179,28 @@ class ThreadSerializer(_ContentSerializer): at introspection and Thread's __getattr__. """ course_id = serializers.CharField() - topic_id = NonEmptyCharField(source="commentable_id") - group_id = serializers.IntegerField(required=False) - group_name = serializers.SerializerMethodField("get_group_name") - type_ = serializers.ChoiceField( + topic_id = serializers.CharField(source="commentable_id", validators=[validate_not_blank]) + group_id = serializers.IntegerField(required=False, allow_null=True) + group_name = serializers.SerializerMethodField() + type = serializers.ChoiceField( source="thread_type", choices=[(val, val) for val in ["discussion", "question"]] ) - title = NonEmptyCharField() - pinned = serializers.BooleanField(read_only=True) + title = serializers.CharField(validators=[validate_not_blank]) + pinned = serializers.SerializerMethodField(read_only=True) closed = serializers.BooleanField(read_only=True) - following = serializers.SerializerMethodField("get_following") + following = serializers.SerializerMethodField() comment_count = serializers.IntegerField(source="comments_count", read_only=True) unread_comment_count = serializers.IntegerField(source="unread_comments_count", read_only=True) - comment_list_url = serializers.SerializerMethodField("get_comment_list_url") - endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url") - non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url") + comment_list_url = serializers.SerializerMethodField() + endorsed_comment_list_url = serializers.SerializerMethodField() + non_endorsed_comment_list_url = serializers.SerializerMethodField() read = serializers.BooleanField(read_only=True) has_endorsed = serializers.BooleanField(read_only=True, source="endorsed") response_count = serializers.IntegerField(source="resp_total", read_only=True) non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS - + # TODO: https://openedx.atlassian.net/browse/MA-1359 def __init__(self, *args, **kwargs): remove_fields = kwargs.pop('remove_fields', None) @@ -202,6 +212,13 @@ class ThreadSerializer(_ContentSerializer): # not have the pinned field set if self.object and self.object.get("pinned") is None: self.object["pinned"] = False + + def get_pinned(self, obj): + """ + Compensate for the fact that some threads in the comments service do + not have the pinned field set. + """ + return bool(obj["pinned"]) if remove_fields: # for multiple fields in a list @@ -245,13 +262,16 @@ class ThreadSerializer(_ContentSerializer): """Returns the URL to retrieve the thread's non-endorsed comments.""" return self.get_comment_list_url(obj, endorsed=False) - def restore_object(self, attrs, instance=None): - if instance: - for key, val in attrs.items(): - instance[key] = val - return instance - else: - return Thread(user_id=self.context["cc_requester"]["id"], **attrs) + def create(self, validated_data): + thread = Thread(user_id=self.context["cc_requester"]["id"], **validated_data) + thread.save() + return thread + + def update(self, instance, validated_data): + for key, val in validated_data.items(): + instance[key] = val + instance.save() + return instance class CommentSerializer(_ContentSerializer): @@ -263,12 +283,12 @@ class CommentSerializer(_ContentSerializer): at introspection and Comment's __getattr__. """ thread_id = serializers.CharField() - parent_id = serializers.CharField(required=False) + parent_id = serializers.CharField(required=False, allow_null=True) endorsed = serializers.BooleanField(required=False) - endorsed_by = serializers.SerializerMethodField("get_endorsed_by") - endorsed_by_label = serializers.SerializerMethodField("get_endorsed_by_label") - endorsed_at = serializers.SerializerMethodField("get_endorsed_at") - children = serializers.SerializerMethodField("get_children") + endorsed_by = serializers.SerializerMethodField() + endorsed_by_label = serializers.SerializerMethodField() + endorsed_at = serializers.SerializerMethodField() + children = serializers.SerializerMethodField() non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS @@ -311,6 +331,17 @@ class CommentSerializer(_ContentSerializer): for child in obj.get("children", []) ] + def to_representation(self, data): + data = super(CommentSerializer, self).to_representation(data) + + # Django Rest Framework v3 no longer includes None values + # in the representation. To maintain the previous behavior, + # we do this manually instead. + if 'parent_id' not in data: + data["parent_id"] = None + + return data + def validate(self, attrs): """ Ensure that parent_id identifies a comment that is actually in the @@ -332,18 +363,23 @@ class CommentSerializer(_ContentSerializer): raise ValidationError({"parent_id": ["Comment level is too deep."]}) return attrs - def restore_object(self, attrs, instance=None): - if instance: - for key, val in attrs.items(): - instance[key] = val - # TODO: The comments service doesn't populate the endorsement - # field on comment creation, so we only provide - # endorsement_user_id on update - if key == "endorsed": - instance["endorsement_user_id"] = self.context["cc_requester"]["id"] - return instance - return Comment( + def create(self, validated_data): + comment = Comment( course_id=self.context["thread"]["course_id"], user_id=self.context["cc_requester"]["id"], - **attrs + **validated_data ) + comment.save() + return comment + + def update(self, instance, validated_data): + for key, val in validated_data.items(): + instance[key] = val + # TODO: The comments service doesn't populate the endorsement + # field on comment creation, so we only provide + # endorsement_user_id on update + if key == "endorsed": + instance["endorsement_user_id"] = self.context["cc_requester"]["id"] + + instance.save() + return instance diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index a2df7abd36..e61a1202b6 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -1497,8 +1497,9 @@ class CreateThreadTest( self.assertEqual(actual_post_data["group_id"], [str(cohort.id)]) else: self.assertNotIn("group_id", actual_post_data) - except ValidationError: - self.assertTrue(expected_error) + except ValidationError as ex: + if not expected_error: + self.fail("Unexpected validation error: {}".format(ex)) def test_following(self): self.register_post_thread_response({"id": "test_id"}) @@ -2239,7 +2240,7 @@ class UpdateThreadTest( update_thread(self.request, "test_thread", {"raw_body": ""}) self.assertEqual( assertion.exception.message_dict, - {"raw_body": ["This field is required."]} + {"raw_body": ["This field may not be blank."]} ) diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py index 985c94fcc5..365c7fbdca 100644 --- a/lms/djangoapps/discussion_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -523,9 +523,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi data = self.minimal_data.copy() data.update({field: value for field in ["topic_id", "title", "raw_body"]}) serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request)) + self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} + {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]} ) def test_create_type(self): @@ -592,9 +593,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi partial=True, context=get_context(self.course, self.request) ) + self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} + {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]} ) def test_update_course_id(self): @@ -604,6 +606,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi partial=True, context=get_context(self.course, self.request) ) + self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, {"course_id": ["This field is not allowed in an update."]} @@ -769,7 +772,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul data["parent_id"] = None serializer = CommentSerializer(data=data, context=context) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {"parent_id": ["Comment level is too deep."]}) + self.assertEqual(serializer.errors, {"non_field_errors": ["Comment level is too deep."]}) def test_create_missing_field(self): for field in self.minimal_data: @@ -855,9 +858,10 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul partial=True, context=get_context(self.course, self.request) ) + self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {"raw_body": ["This field is required."]} + {"raw_body": ["This field may not be blank."]} ) @ddt.data("thread_id", "parent_id") @@ -868,6 +872,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul partial=True, context=get_context(self.course, self.request) ) + self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, {field: ["This field is not allowed in an update."]} diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index ce1df0a85a..0809df43e2 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest content_type="application/json" ) expected_response_data = { - "field_errors": {"title": {"developer_message": "This field is required."}} + "field_errors": {"title": {"developer_message": "This field may not be blank."}} } self.assertEqual(response.status_code, 400) response_data = json.loads(response.content) @@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): 200, { "results": expected_comments, - "next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format( + "next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format( self.thread_id ), "previous": None, @@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes content_type="application/json" ) expected_response_data = { - "field_errors": {"raw_body": {"developer_message": "This field is required."}} + "field_errors": {"raw_body": {"developer_message": "This field may not be blank."}} } self.assertEqual(response.status_code, 400) response_data = json.loads(response.content) diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index 2072e7b5db..116fc32301 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -3,7 +3,8 @@ Discussion API views """ from django.core.exceptions import ValidationError -from rest_framework.authentication import OAuth2Authentication, SessionAuthentication +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView diff --git a/lms/djangoapps/django_comment_client/base/tests.py b/lms/djangoapps/django_comment_client/base/tests.py index 568b866ab4..3e0dfd8129 100644 --- a/lms/djangoapps/django_comment_client/base/tests.py +++ b/lms/djangoapps/django_comment_client/base/tests.py @@ -32,8 +32,6 @@ from xmodule.modulestore.tests.factories import check_mongo_calls from xmodule.modulestore.django import modulestore from xmodule.modulestore import ModuleStoreEnum -from teams.tests.factories import CourseTeamFactory - log = logging.getLogger(__name__) @@ -1290,6 +1288,7 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe topic_id='topic_id', discussion_topic_id=self.team_commentable_id ) + self.team.add_user(self.student_in_team) # Dummy commentable ID not linked to a team diff --git a/lms/djangoapps/django_comment_client/forum/tests.py b/lms/djangoapps/django_comment_client/forum/tests.py index 54b5bb3114..eabeba998a 100644 --- a/lms/djangoapps/django_comment_client/forum/tests.py +++ b/lms/djangoapps/django_comment_client/forum/tests.py @@ -717,6 +717,7 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase): topic_id='topic_id', discussion_topic_id=self.discussion_topic_id ) + self.team.add_user(self.user) # pylint: disable=no-member def test_context_can_be_standalone(self, mock_request): @@ -1093,7 +1094,9 @@ class InlineDiscussionTestCase(ModuleStoreTestCase): course_id=self.course.id, discussion_topic_id=self.discussion1.discussion_id ) + team.add_user(self.student) # pylint: disable=no-member + response = self.send_request(mock_request) self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE) self.verify_response(response) diff --git a/lms/djangoapps/django_comment_client/tests/group_id.py b/lms/djangoapps/django_comment_client/tests/group_id.py index 2773b5ee99..374dc534cd 100644 --- a/lms/djangoapps/django_comment_client/tests/group_id.py +++ b/lms/djangoapps/django_comment_client/tests/group_id.py @@ -149,6 +149,8 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin): def test_team_discussion_id_not_cohorted(self, mock_request): team = CourseTeamFactory(course_id=self.course.id) + team.add_user(self.student) # pylint: disable=no-member self.call_view(mock_request, team.discussion_topic_id, self.student, None) + self._assert_comments_service_called_without_group_id(mock_request) diff --git a/lms/djangoapps/mobile_api/social_facebook/courses/views.py b/lms/djangoapps/mobile_api/social_facebook/courses/views.py index 584e879db4..b4f1a23aeb 100644 --- a/lms/djangoapps/mobile_api/social_facebook/courses/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/courses/views.py @@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView): serializer_class = serializers.CoursesWithFriendsSerializer def list(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.GET, files=request.FILES) + serializer = self.get_serializer(data=request.GET) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -61,4 +61,5 @@ class CoursesWithFriends(generics.ListAPIView): and is_mobile_available_for_user(self.request.user, enrollment.course) ] - return Response(CourseEnrollmentSerializer(courses, context={'request': request}).data) + serializer = CourseEnrollmentSerializer(courses, context={'request': request}, many=True) + return Response(serializer.data) diff --git a/lms/djangoapps/mobile_api/social_facebook/friends/views.py b/lms/djangoapps/mobile_api/social_facebook/friends/views.py index 07a8c862fe..7e1904491d 100644 --- a/lms/djangoapps/mobile_api/social_facebook/friends/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/friends/views.py @@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView): serializer_class = serializers.FriendsInCourseSerializer def list(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.GET, files=request.FILES) + serializer = self.get_serializer(data=request.GET) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/lms/djangoapps/mobile_api/social_facebook/groups/views.py b/lms/djangoapps/mobile_api/social_facebook/groups/views.py index 463e466aad..620a482776 100644 --- a/lms/djangoapps/mobile_api/social_facebook/groups/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/groups/views.py @@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin): serializer_class = serializers.GroupSerializer def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA, files=request.FILES) + serializer = self.get_serializer(data=request.DATA) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) try: @@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin): serializer_class = serializers.GroupsMembersSerializer def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA, files=request.FILES) + serializer = self.get_serializer(data=request.DATA) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) graph = facebook_graph_api() url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members" - member_ids = serializer.object['member_ids'].split(',') + member_ids = serializer.data['member_ids'].split(',') response = {} for member_id in member_ids: try: diff --git a/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py b/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py index 12eb03f0cb..1d51f87eda 100644 --- a/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py +++ b/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py @@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer): """ Serializes user social settings """ - share_with_facebook_friends = serializers.BooleanField(required=True, default=False) + share_with_facebook_friends = serializers.BooleanField(required=True) diff --git a/lms/djangoapps/mobile_api/social_facebook/preferences/views.py b/lms/djangoapps/mobile_api/social_facebook/preferences/views.py index c48901e67d..7495ba30f2 100644 --- a/lms/djangoapps/mobile_api/social_facebook/preferences/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/preferences/views.py @@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView): serializer_class = serializers.UserSharingSerializar def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA, files=request.FILES) + serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): - value = serializer.object['share_with_facebook_friends'] + value = serializer.data['share_with_facebook_friends'] set_user_preference(request.user, "share_with_facebook_friends", value) return self.get(request, *args, **kwargs) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/lms/djangoapps/mobile_api/social_facebook/utils.py b/lms/djangoapps/mobile_api/social_facebook/utils.py index daa379edc6..654dd97edd 100644 --- a/lms/djangoapps/mobile_api/social_facebook/utils.py +++ b/lms/djangoapps/mobile_api/social_facebook/utils.py @@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer): the error message. """ try: - graph = facebook.GraphAPI(serializer.object['oauth_token']) + graph = facebook.GraphAPI(serializer.data['oauth_token']) friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends") return get_pagination(friends) except facebook.GraphAPIError, ex: diff --git a/lms/djangoapps/mobile_api/users/serializers.py b/lms/djangoapps/mobile_api/users/serializers.py index 2be28e1ae3..cb2d0a2782 100644 --- a/lms/djangoapps/mobile_api/users/serializers.py +++ b/lms/djangoapps/mobile_api/users/serializers.py @@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE class CourseOverviewField(serializers.RelatedField): """Custom field to wrap a CourseDescriptor object. Read-only.""" - def to_native(self, course_overview): + def to_representation(self, course_overview): course_id = unicode(course_overview.id) request = self.context.get('request', None) if request: @@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): """ Serializes CourseEnrollment models """ - course = CourseOverviewField(source="course_overview") - certificate = serializers.SerializerMethodField('get_certificate') + course = CourseOverviewField(source="course_overview", read_only=True) + certificate = serializers.SerializerMethodField() def get_certificate(self, model): """Returns the information about the user's certificate in the course.""" @@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer): """ Serializes User models """ - name = serializers.Field(source='profile.name') + name = serializers.ReadOnlyField(source='profile.name') course_enrollments = serializers.HyperlinkedIdentityField( view_name='courseenrollment-detail', lookup_field='username' diff --git a/lms/djangoapps/mobile_api/users/views.py b/lms/djangoapps/mobile_api/users/views.py index c73084195b..95a5e6eb59 100644 --- a/lms/djangoapps/mobile_api/users/views.py +++ b/lms/djangoapps/mobile_api/users/views.py @@ -251,6 +251,14 @@ class UserCourseEnrollmentsList(generics.ListAPIView): serializer_class = CourseEnrollmentSerializer lookup_field = 'username' + # In Django Rest Framework v3, there is a default pagination + # class that transmutes the response data into a dictionary + # with pagination information. The original response data (a list) + # is stored in a "results" value of the dictionary. + # For backwards compatibility with the existing API, we disable + # the default behavior by setting the pagination_class to None. + pagination_class = None + def get_queryset(self): enrollments = self.queryset.filter( user__username=self.kwargs['username'], diff --git a/lms/djangoapps/notifier_api/serializers.py b/lms/djangoapps/notifier_api/serializers.py index 57832475d6..719cc9c6f1 100644 --- a/lms/djangoapps/notifier_api/serializers.py +++ b/lms/djangoapps/notifier_api/serializers.py @@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer): * course_groups * roles__permissions """ - name = serializers.SerializerMethodField("get_name") - preferences = serializers.SerializerMethodField("get_preferences") - course_info = serializers.SerializerMethodField("get_course_info") + name = serializers.SerializerMethodField() + preferences = serializers.SerializerMethodField() + course_info = serializers.SerializerMethodField() def get_name(self, user): return user.profile.name diff --git a/lms/djangoapps/notifier_api/views.py b/lms/djangoapps/notifier_api/views.py index b8e05ebb16..44e1c78037 100644 --- a/lms/djangoapps/notifier_api/views.py +++ b/lms/djangoapps/notifier_api/views.py @@ -1,11 +1,32 @@ from django.contrib.auth.models import User from rest_framework.viewsets import ReadOnlyModelViewSet +from rest_framework.response import Response +from rest_framework import pagination from notification_prefs import NOTIFICATION_PREF_KEY from notifier_api.serializers import NotifierUserSerializer from openedx.core.lib.api.permissions import ApiKeyHeaderPermission +class NotifierPaginator(pagination.PageNumberPagination): + """ + Paginator for the notifier API. + """ + page_size = 10 + page_size_query_param = "page_size" + + def get_paginated_response(self, data): + """ + Construct a response with pagination information. + """ + return Response({ + 'next': self.get_next_link(), + 'previous': self.get_previous_link(), + 'count': self.page.paginator.count, + 'results': data + }) + + class NotifierUsersViewSet(ReadOnlyModelViewSet): """ An endpoint that the notifier can use to retrieve users who have enabled @@ -14,8 +35,7 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet): """ permission_classes = (ApiKeyHeaderPermission,) serializer_class = NotifierUserSerializer - paginate_by = 10 - paginate_by_param = "page_size" + pagination_class = NotifierPaginator # See NotifierUserSerializer for notes about related tables queryset = User.objects.filter( diff --git a/lms/djangoapps/teams/models.py b/lms/djangoapps/teams/models.py index cbcb822554..18d38f4492 100644 --- a/lms/djangoapps/teams/models.py +++ b/lms/djangoapps/teams/models.py @@ -3,7 +3,6 @@ from datetime import datetime from uuid import uuid4 import pytz -from datetime import datetime from model_utils import FieldTracker from django.core.exceptions import ObjectDoesNotExist diff --git a/lms/djangoapps/teams/search_indexes.py b/lms/djangoapps/teams/search_indexes.py index 588a2c7892..6dfc070e81 100644 --- a/lms/djangoapps/teams/search_indexes.py +++ b/lms/djangoapps/teams/search_indexes.py @@ -10,6 +10,7 @@ from django.utils import translation from functools import wraps from search.search_engine_base import SearchEngine +from request_cache import get_request_or_stub from .errors import ElasticSearchConnectionError from .serializers import CourseTeamSerializer, CourseTeam @@ -47,7 +48,15 @@ class CourseTeamIndexer(object): Returns serialized object with additional search fields. """ - serialized_course_team = CourseTeamSerializer(self.course_team).data + # Django Rest Framework v3.1 requires that we pass the request to the serializer + # so it can construct hyperlinks. To avoid changing the interface of this object, + # we retrieve the request from the request cache. + context = { + "request": get_request_or_stub() + } + + serialized_course_team = CourseTeamSerializer(self.course_team, context=context).data + # Save the primary key so we can load the full objects easily after we search serialized_course_team['pk'] = self.course_team.pk # Don't save the membership relations in elasticsearch diff --git a/lms/djangoapps/teams/serializers.py b/lms/djangoapps/teams/serializers.py index f51ad9e61c..8befd944bc 100644 --- a/lms/djangoapps/teams/serializers.py +++ b/lms/djangoapps/teams/serializers.py @@ -4,15 +4,44 @@ from django.contrib.auth.models import User from django.db.models import Count from django.conf import settings +from django_countries import countries from rest_framework import serializers -from openedx.core.lib.api.serializers import CollapsedReferenceSerializer, PaginationSerializer +from openedx.core.lib.api.serializers import CollapsedReferenceSerializer from openedx.core.lib.api.fields import ExpandableField from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer from .models import CourseTeam, CourseTeamMembership +class CountryField(serializers.Field): + """ + Field to serialize a country code. + """ + + COUNTRY_CODES = dict(countries).keys() + + def to_representation(self, obj): + """ + Represent the country as a 2-character unicode identifier. + """ + return unicode(obj) + + def to_internal_value(self, data): + """ + Check that the code is a valid country code. + + We leave the data in its original format so that the Django model's + CountryField can convert it to the internal representation used + by the django-countries library. + """ + if data and data not in self.COUNTRY_CODES: + raise serializers.ValidationError( + u"{code} is not a valid country code".format(code=data) + ) + return data + + class UserMembershipSerializer(serializers.ModelSerializer): """Serializes CourseTeamMemberships with only user and date_joined @@ -43,6 +72,7 @@ class CourseTeamSerializer(serializers.ModelSerializer): """Serializes a CourseTeam with membership information.""" id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name membership = UserMembershipSerializer(many=True, read_only=True) + country = CountryField() class Meta(object): """Defines meta information for the ModelSerializer.""" @@ -66,6 +96,8 @@ class CourseTeamSerializer(serializers.ModelSerializer): class CourseTeamCreationSerializer(serializers.ModelSerializer): """Deserializes a CourseTeam for creation.""" + country = CountryField(required=False) + class Meta(object): """Defines meta information for the ModelSerializer.""" model = CourseTeam @@ -78,16 +110,17 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer): "language", ) - def restore_object(self, attrs, instance=None): - """Restores a CourseTeam instance from the given attrs.""" - return CourseTeam.create( - name=attrs.get("name", ''), - course_id=attrs.get("course_id"), - description=attrs.get("description", ''), - topic_id=attrs.get("topic_id", ''), - country=attrs.get("country", ''), - language=attrs.get("language", ''), + def create(self, validated_data): + team = CourseTeam.create( + name=validated_data.get("name", ''), + course_id=validated_data.get("course_id"), + description=validated_data.get("description", ''), + topic_id=validated_data.get("topic_id", ''), + country=validated_data.get("country", ''), + language=validated_data.get("language", ''), ) + team.save() + return team class CourseTeamSerializerWithoutMembership(CourseTeamSerializer): @@ -134,13 +167,6 @@ class MembershipSerializer(serializers.ModelSerializer): read_only_fields = ("date_joined", "last_activity_at") -class PaginatedMembershipSerializer(PaginationSerializer): - """Serializes team memberships with support for pagination.""" - class Meta(object): - """Defines meta information for the PaginatedMembershipSerializer.""" - object_serializer_class = MembershipSerializer - - class BaseTopicSerializer(serializers.Serializer): """Serializes a topic without team_count.""" description = serializers.CharField() @@ -155,7 +181,7 @@ class TopicSerializer(BaseTopicSerializer): model to get the count. Requires that `context` is provided with a valid course_id in order to filter teams within the course. """ - team_count = serializers.SerializerMethodField('get_team_count') + team_count = serializers.SerializerMethodField() def get_team_count(self, topic): """Get the number of teams associated with this topic""" @@ -166,31 +192,25 @@ class TopicSerializer(BaseTopicSerializer): return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count() -class PaginatedTopicSerializer(PaginationSerializer): +class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method """ - Serializes a set of topics, adding the team_count field to each topic individually, if team_count - is not already present in the topic data. Requires that `context` is provided with a valid course_id in - order to filter teams within the course. + List serializer for efficiently serializing a set of topics. """ - class Meta(object): - """Defines meta information for the PaginatedTopicSerializer.""" - object_serializer_class = TopicSerializer + + def to_representation(self, obj): + """Adds team_count to each topic. """ + data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj) + add_team_count(data, self.context["course_id"]) + return data -class BulkTeamCountPaginatedTopicSerializer(PaginationSerializer): +class BulkTeamCountTopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method """ - Serializes a set of topics, adding the team_count field to each topic as a bulk operation per page - (only on the page being returned). Requires that `context` is provided with a valid course_id in - order to filter teams within the course. + Serializes a set of topics, adding the team_count field to each topic as a bulk operation. + Requires that `context` is provided with a valid course_id in order to filter teams within the course. """ - class Meta(object): - """Defines meta information for the BulkTeamCountPaginatedTopicSerializer.""" - object_serializer_class = BaseTopicSerializer - - def __init__(self, *args, **kwargs): - """Adds team_count to each topic on the current page.""" - super(BulkTeamCountPaginatedTopicSerializer, self).__init__(*args, **kwargs) - add_team_count(self.data['results'], self.context['course_id']) + class Meta: # pylint: disable=missing-docstring,old-style-class + list_serializer_class = BulkTeamCountTopicListSerializer def add_team_count(topics, course_id): diff --git a/lms/djangoapps/teams/tests/factories.py b/lms/djangoapps/teams/tests/factories.py index ee58e1ad69..557e1a1c21 100644 --- a/lms/djangoapps/teams/tests/factories.py +++ b/lms/djangoapps/teams/tests/factories.py @@ -34,3 +34,10 @@ class CourseTeamMembershipFactory(DjangoModelFactory): class Meta(object): # pylint: disable=missing-docstring model = CourseTeamMembership last_activity_at = LAST_ACTIVITY_AT + + @classmethod + def _create(cls, model_class, *args, **kwargs): + """Create the team membership. """ + obj = model_class(*args, **kwargs) + obj.save() + return obj diff --git a/lms/djangoapps/teams/tests/test_serializers.py b/lms/djangoapps/teams/tests/test_serializers.py index 123b2f793b..c30334ad2e 100644 --- a/lms/djangoapps/teams/tests/test_serializers.py +++ b/lms/djangoapps/teams/tests/test_serializers.py @@ -11,12 +11,9 @@ from xmodule.modulestore.tests.factories import CourseFactory from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory from lms.djangoapps.teams.serializers import ( - BaseTopicSerializer, - PaginatedTopicSerializer, - BulkTeamCountPaginatedTopicSerializer, + BulkTeamCountTopicSerializer, TopicSerializer, MembershipSerializer, - add_team_count ) @@ -73,21 +70,6 @@ class MembershipSerializerTestCase(SerializerTestCase): self.assertNotIn('membership', data['team']) -class BaseTopicSerializerTestCase(SerializerTestCase): - """ - Tests for the `BaseTopicSerializer`, which should not serialize team count - data. - """ - def test_team_count_not_included(self): - """Verifies that the `BaseTopicSerializer` does not include team count""" - with self.assertNumQueries(0): - serializer = BaseTopicSerializer(self.course.teams_topics[0]) - self.assertEqual( - serializer.data, - {u'name': u'Tøpic', u'description': u'The bést topic!', u'id': u'0'} - ) - - class TopicSerializerTestCase(SerializerTestCase): """ Tests for the `TopicSerializer`, which should serialize team count data for @@ -137,7 +119,7 @@ class TopicSerializerTestCase(SerializerTestCase): ) -class BasePaginatedTopicSerializerTestCase(SerializerTestCase): +class BaseTopicSerializerTestCase(SerializerTestCase): """ Base class for testing the two paginated topic serializers. """ @@ -191,13 +173,15 @@ class BasePaginatedTopicSerializerTestCase(SerializerTestCase): self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0) -class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): +class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase): """ - Tests for the `BulkTeamCountPaginatedTopicSerializer`, which should serialize team_count + Tests for the `BulkTeamCountTopicSerializer`, which should serialize team_count data for many topics with constant time SQL queries. """ __test__ = True - serializer = BulkTeamCountPaginatedTopicSerializer + serializer = BulkTeamCountTopicSerializer + + NUM_TOPICS = 6 def test_topics_with_no_team_counts(self): """ @@ -222,13 +206,13 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer one SQL query. """ teams_per_topic = 10 - topics = self.setup_topics(num_topics=self.PAGE_SIZE + 1, teams_per_topic=teams_per_topic) - self.assert_serializer_output(topics[:self.PAGE_SIZE], num_teams_per_topic=teams_per_topic, num_queries=1) + topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic) + self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1) def test_scoped_within_course(self): """Verify that team counts are scoped within a course.""" teams_per_topic = 10 - first_course_topics = self.setup_topics(num_topics=self.PAGE_SIZE, teams_per_topic=teams_per_topic) + first_course_topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic) duplicate_topic = first_course_topics[0] second_course = CourseFactory.create( teams_configuration={ @@ -239,27 +223,44 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id']) self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1) + def _merge_dicts(self, first, second): + """Convenience method to merge two dicts in a single expression""" + result = first.copy() + result.update(second) + return result -class PaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): - """ - Tests for the `PaginatedTopicSerializer`, which will add team_count information per topic if not present. - """ - __test__ = True - serializer = PaginatedTopicSerializer + def setup_topics(self, num_topics=5, teams_per_topic=0): + """ + Helper method to set up topics on the course. Returns a list of + created topics. + """ + self.course.teams_configuration['topics'] = [] + topics = [ + {u'name': u'Tøpic {}'.format(i), u'description': u'The bést topic! {}'.format(i), u'id': unicode(i)} + for i in xrange(num_topics) + ] + for i in xrange(num_topics): + topic_id = unicode(i) + self.course.teams_configuration['topics'].append(topics[i]) + for _ in xrange(teams_per_topic): + CourseTeamFactory.create(course_id=self.course.id, topic_id=topic_id) + return topics - def test_topics_with_team_counts(self): + def assert_serializer_output(self, topics, num_teams_per_topic, num_queries): """ - Verify that we serialize topics with team_count, making one SQL query per topic. + Verify that the serializer produced the expected topics. """ - teams_per_topic = 2 - topics = self.setup_topics(teams_per_topic=teams_per_topic) - self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=5) + with self.assertNumQueries(num_queries): + serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True) + self.assertEqual( + serializer.data, + [self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics] + ) - def test_topics_with_team_counts_prepopulated(self): + def test_no_topics(self): """ - Verify that if team_count is pre-populated, there are no additional SQL queries. + Verify that we return no results and make no SQL queries for a page + with no topics. """ - teams_per_topic = 8 - topics = self.setup_topics(teams_per_topic=teams_per_topic) - add_team_count(topics, self.course.id) - self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=0) + self.course.teams_configuration['topics'] = [] + self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0) diff --git a/lms/djangoapps/teams/tests/test_views.py b/lms/djangoapps/teams/tests/test_views.py index cc7fcc23ed..b84e63416f 100644 --- a/lms/djangoapps/teams/tests/test_views.py +++ b/lms/djangoapps/teams/tests/test_views.py @@ -1,32 +1,31 @@ # -*- coding: utf-8 -*- """Tests for the teams API at the HTTP request level.""" import json -import pytz from datetime import datetime + +import pytz from dateutil import parser import ddt from elasticsearch.exceptions import ConnectionError from mock import patch from search.search_engine_base import SearchEngine - from django.core.urlresolvers import reverse from django.conf import settings from django.db.models.signals import post_save from django.utils import translation from nose.plugins.attrib import attr from rest_framework.test import APITestCase, APIClient +from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory from courseware.tests.factories import StaffFactory from common.test.utils import skip_signal from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory from student.models import CourseEnrollment from util.testing import EventTestMixin -from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory from .factories import CourseTeamFactory, LAST_ACTIVITY_AT from ..models import CourseTeamMembership from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback - from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA from django_comment_common.utils import seed_permissions_roles @@ -36,11 +35,23 @@ class TestDashboard(SharedModuleStoreTestCase): """Tests for the Teams dashboard.""" test_password = "test" + NUM_TOPICS = 10 + @classmethod def setUpClass(cls): super(TestDashboard, cls).setUpClass() cls.course = CourseFactory.create( - teams_configuration={"max_team_size": 10, "topics": [{"name": "foo", "id": 0, "description": "test topic"}]} + teams_configuration={ + "max_team_size": 10, + "topics": [ + { + "name": "Topic {}".format(topic_id), + "id": topic_id, + "description": "Description for topic {}".format(topic_id) + } + for topic_id in range(cls.NUM_TOPICS) + ] + } ) def setUp(self): @@ -97,6 +108,30 @@ class TestDashboard(SharedModuleStoreTestCase): response = self.client.get(teams_url) self.assertEqual(404, response.status_code) + def test_query_counts(self): + # Enroll in the course and log in + CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) + self.client.login(username=self.user.username, password=self.test_password) + + # Check the query count on the dashboard With no teams + with self.assertNumQueries(15): + self.client.get(self.teams_url) + + # Create some teams + for topic_id in range(self.NUM_TOPICS): + team = CourseTeamFactory.create( + name=u"Team for topic {}".format(topic_id), + course_id=self.course.id, + topic_id=topic_id, + ) + + # Add the user to the last team + team.add_user(self.user) + + # Check the query count on the dashboard again + with self.assertNumQueries(19): + self.client.get(self.teams_url) + def test_bad_course_id(self): """ Verifies expected behavior when course_id does not reference an existing course or is invalid. @@ -252,6 +287,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): self.users[user], course.id, check_access=True ) + # Django Rest Framework v3 requires us to pass a request to serializers + # that have URL fields. Since we're invoking this code outside the context + # of a request, we need to simulate that there's a request. self.solar_team.add_user(self.users['student_enrolled']) self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team']) self.another_team.add_user(self.users['student_enrolled_both_courses_other_team']) @@ -311,7 +349,17 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): response = func(url, data=data, content_type=content_type) else: response = func(url, data=data) - self.assertEqual(expected_status, response.status_code) + + self.assertEqual( + expected_status, + response.status_code, + msg="Expected status {expected} but got {actual}: {content}".format( + expected=expected_status, + actual=response.status_code, + content=response.content, + ) + ) + if expected_status == 200: return json.loads(response.content) else: diff --git a/lms/djangoapps/teams/views.py b/lms/djangoapps/teams/views.py index 6a6af4e5d9..1b983e3e9c 100644 --- a/lms/djangoapps/teams/views.py +++ b/lms/djangoapps/teams/views.py @@ -5,19 +5,14 @@ import logging from django.shortcuts import get_object_or_404, render_to_response from django.http import Http404 from django.conf import settings -from django.core.paginator import Paginator -from django.views.generic.base import View from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.views import APIView -from rest_framework.authentication import ( - SessionAuthentication, - OAuth2Authentication -) +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework import status from rest_framework import permissions -from django.db.models import Count from django.db.models.signals import post_save from django.dispatch import receiver from django.contrib.auth.models import User @@ -32,8 +27,7 @@ from openedx.core.lib.api.view_utils import ( build_api_error, ExpandableFieldViewMixin ) -from openedx.core.lib.api.serializers import PaginationSerializer -from openedx.core.lib.api.paginators import paginate_search_results +from openedx.core.lib.api.paginators import paginate_search_results, DefaultPagination from xmodule.modulestore.django import modulestore from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -49,10 +43,8 @@ from .serializers import ( CourseTeamSerializer, CourseTeamCreationSerializer, TopicSerializer, - PaginatedTopicSerializer, - BulkTeamCountPaginatedTopicSerializer, + BulkTeamCountTopicSerializer, MembershipSerializer, - PaginatedMembershipSerializer, add_team_count ) from .search_indexes import CourseTeamIndexer @@ -85,7 +77,42 @@ def team_post_save_callback(sender, instance, **kwargs): # pylint: disable=unus ) -class TeamsDashboardView(View): +class TeamAPIPagination(DefaultPagination): + """ + Pagination format used by the teams API. + """ + page_size_query_param = "page_size" + + def get_paginated_response(self, data): + """ + Annotate the response with pagination information. + """ + response = super(TeamAPIPagination, self).get_paginated_response(data) + + # Add the current page to the response. + # It may make sense to eventually move this field into the default + # implementation, but for now, teams is the only API that uses this. + response.data["current_page"] = self.page.number + + # This field can be derived from other fields in the response, + # so it may make sense to have the JavaScript client calculate it + # instead of including it in the response. + response.data["start"] = (self.page.number - 1) * self.get_page_size(self.request) + + return response + + +class TopicsPagination(TeamAPIPagination): + """Paginate topics. """ + page_size = TOPICS_PER_PAGE + + +class MembershipPagination(TeamAPIPagination): + """Paginate memberships. """ + page_size = TEAM_MEMBERSHIPS_PER_PAGE + + +class TeamsDashboardView(GenericAPIView): """ View methods related to the teams dashboard. """ @@ -107,29 +134,38 @@ class TeamsDashboardView(View): not has_access(request.user, 'staff', course, course.id): raise Http404 + user = request.user + # Even though sorting is done outside of the serializer, sort_order needs to be passed # to the serializer so that the paginated results indicate how they were sorted. sort_order = 'name' topics = get_alphabetical_topics(course) - topics_page = Paginator(topics, TOPICS_PER_PAGE).page(1) + + # Paginate and serialize topic data # BulkTeamCountPaginatedTopicSerializer will add team counts to the topics in a single # bulk operation per page. - topics_serializer = BulkTeamCountPaginatedTopicSerializer( - instance=topics_page, - context={'course_id': course.id, 'sort_order': sort_order} + topics_data = self._serialize_and_paginate( + TopicsPagination, + topics, + request, + BulkTeamCountTopicSerializer, + {'course_id': course.id}, ) - user = request.user + topics_data["sort_order"] = sort_order - team_memberships = CourseTeamMembership.get_memberships(request.user.username, [course.id]) - team_memberships_page = Paginator(team_memberships, TEAM_MEMBERSHIPS_PER_PAGE).page(1) - team_memberships_serializer = PaginatedMembershipSerializer( - instance=team_memberships_page, - context={'expand': ('team', 'user'), 'request': request}, + # Paginate and serialize team membership data. + team_memberships = CourseTeamMembership.get_memberships(user.username, [course.id]) + memberships_data = self._serialize_and_paginate( + MembershipPagination, + team_memberships, + request, + MembershipSerializer, + {'expand': ('team', 'user',)} ) context = { "course": course, - "topics": topics_serializer.data, + "topics": topics_data, # It is necessary to pass both privileged and staff because only privileged users can # administer discussion threads, but both privileged and staff users are allowed to create # multiple teams (since they are not automatically added to teams upon creation). @@ -137,7 +173,7 @@ class TeamsDashboardView(View): "username": user.username, "privileged": has_discussion_privileges(user, course_key), "staff": bool(has_access(user, 'staff', course_key)), - "team_memberships_data": team_memberships_serializer.data, + "team_memberships_data": memberships_data, }, "topic_url": reverse( 'topics_detail', kwargs={'topic_id': 'topic_id', 'course_id': str(course_id)}, request=request @@ -154,6 +190,39 @@ class TeamsDashboardView(View): } return render_to_response("teams/teams.html", context) + def _serialize_and_paginate(self, pagination_cls, queryset, request, serializer_cls, serializer_ctx): + """ + Serialize and paginate objects in a queryset. + + Arguments: + pagination_cls (pagination.Paginator class): Django Rest Framework Paginator subclass. + queryset (QuerySet): Django queryset to serialize/paginate. + serializer_cls (serializers.Serializer class): Django Rest Framework Serializer subclass. + serializer_ctx (dict): Context dictionary to pass to the serializer + + Returns: dict + + """ + # Django Rest Framework v3 requires that we pass the request + # into the serializer's context if the serialize contains + # hyperlink fields. + serializer_ctx["request"] = request + + # Instantiate the paginator and use it to paginate the queryset + paginator = pagination_cls() + page = paginator.paginate_queryset(queryset, request) + + # Serialize the page + serializer = serializer_cls(page, context=serializer_ctx, many=True) + + # Use the paginator to construct the response data + # This will use the pagination subclass for the view to add additional + # fields to the response. + # For example, if the input data is a list, the output data would + # be a dictionary with keys "count", "next", "previous", and "results" + # (where "results" is set to the value of the original list) + return paginator.get_paginated_response(serializer.data).data + def has_team_api_access(user, course_key, access_username=None): """Returns True if the user has access to the Team API for the course @@ -307,11 +376,8 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): # OAuth2Authentication must come first to return a 401 for unauthenticated users authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) - - paginate_by = 10 - paginate_by_param = 'page_size' - pagination_serializer_class = PaginationSerializer serializer_class = CourseTeamSerializer + pagination_class = TeamAPIPagination def get(self, request): """GET /api/team/v0/teams/""" @@ -377,15 +443,18 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): paginated_results = paginate_search_results( CourseTeam, search_results, - self.get_paginate_by(), + self.paginator.get_page_size(request), self.get_page() ) - serializer = self.get_pagination_serializer(paginated_results) emit_team_event('edx.team.searched', course_key, { "number_of_results": search_results['total'], "search_text": text_search, "topic_id": topic_id, }) + + page = self.paginate_queryset(paginated_results) + serializer = self.get_serializer(page, many=True) + order_by_input = None else: queryset = CourseTeam.objects.filter(**result_filter) order_by_input = request.QUERY_PARAMS.get('order_by', 'name') @@ -407,10 +476,12 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): }, status=status.HTTP_400_BAD_REQUEST) page = self.paginate_queryset(queryset) - serializer = self.get_pagination_serializer(page) - serializer.context.update({'sort_order': order_by_input}) # pylint: disable=maybe-no-member + serializer = self.get_serializer(page, many=True) - return Response(serializer.data) # pylint: disable=maybe-no-member + response = self.get_paginated_response(serializer.data) + if order_by_input is not None: + response.data['sort_order'] = order_by_input + return response def post(self, request): """POST /api/team/v0/teams/""" @@ -473,14 +544,16 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): 'add_method': 'added_on_create' } ) - return Response(CourseTeamSerializer(team).data) + + data = CourseTeamSerializer(team, context={"request": request}).data + return Response(data) def get_page(self): """ Returns page number specified in args, params, or defaults to 1. """ # This code is taken from within the GenericAPIView#paginate_queryset method. # We need need access to the page outside of that method for our paginate_search_results method - page_kwarg = self.kwargs.get(self.page_kwarg) - page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page_kwarg = self.kwargs.get(self.paginator.page_query_param) + page_query_param = self.request.QUERY_PARAMS.get(self.paginator.page_query_param) return page_kwarg or page_query_param or 1 @@ -692,9 +765,7 @@ class TopicListView(GenericAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) - - paginate_by = TOPICS_PER_PAGE - paginate_by_param = 'page_size' + pagination_class = TopicsPagination def get(self, request): """GET /api/team/v0/topics/?course_id={course_id}""" @@ -740,18 +811,20 @@ class TopicListView(GenericAPIView): add_team_count(topics, course_id) topics.sort(key=lambda t: t['team_count'], reverse=True) page = self.paginate_queryset(topics) - # Since team_count has already been added to all the topics, use PaginatedTopicSerializer. - # Even though sorting is done outside of the serializer, sort_order needs to be passed - # to the serializer so that the paginated results indicate how they were sorted. - serializer = PaginatedTopicSerializer(page, context={'course_id': course_id, 'sort_order': ordering}) + serializer = TopicSerializer( + page, + context={'course_id': course_id}, + many=True, + ) else: page = self.paginate_queryset(topics) # Use the serializer that adds team_count in a bulk operation per page. - serializer = BulkTeamCountPaginatedTopicSerializer( - page, context={'course_id': course_id, 'sort_order': ordering} - ) + serializer = BulkTeamCountTopicSerializer(page, context={'course_id': course_id}, many=True) - return Response(serializer.data) + response = self.get_paginated_response(serializer.data) + response.data['sort_order'] = ordering + + return response def get_alphabetical_topics(course_module): @@ -960,13 +1033,8 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) - serializer_class = MembershipSerializer - paginate_by = 10 - paginate_by_param = 'page_size' - pagination_serializer_class = PaginationSerializer - def get(self, request): """GET /api/team/v0/team_membership""" specified_username_or_team = False @@ -1023,8 +1091,8 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView): queryset = CourseTeamMembership.get_memberships(username, course_keys, team_id) page = self.paginate_queryset(queryset) - serializer = self.get_pagination_serializer(page) - return Response(serializer.data) # pylint: disable=maybe-no-member + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) def post(self, request): """POST /api/team/v0/team_membership""" diff --git a/lms/envs/common.py b/lms/envs/common.py index 3040639818..4ddb70cf3f 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1866,6 +1866,12 @@ INSTALLED_APPS = ( 'provider.oauth2', 'oauth2_provider', + # We don't use this directly (since we use OAuth2), but we need to install it anyway. + # When a user is deleted, Django queries all tables with a FK to the auth_user table, + # and since django-rest-framework-oauth imports this, it will try to access tables + # defined by oauth_provider. If those tables don't exist, an error can occur. + 'oauth_provider', + 'auth_exchange', # For the wiki @@ -1981,6 +1987,14 @@ INSTALLED_APPS = ( CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52 +######################### Django Rest Framework ######################## + +REST_FRAMEWORK = { + 'DEFAULT_PAGINATION_CLASS': 'openedx.core.lib.api.paginators.DefaultPagination', + 'PAGE_SIZE': 10, +} + + ######################### MARKETING SITE ############################### EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin' EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info' diff --git a/lms/startup.py b/lms/startup.py index 6869f90828..4db5f6c649 100644 --- a/lms/startup.py +++ b/lms/startup.py @@ -9,16 +9,12 @@ from django.conf import settings # Force settings to run so that the python path is modified settings.INSTALLED_APPS # pylint: disable=pointless-statement -from instructor.services import InstructorService - from openedx.core.lib.django_startup import autostartup import edxmako import logging from monkey_patch import django_utils_translation import analytics -from edx_proctoring.runtime import set_runtime_service -from openedx.core.djangoapps.credit.services import CreditService log = logging.getLogger(__name__) @@ -51,6 +47,11 @@ def run(): # right now edx_proctoring is dependent on the openedx.core.djangoapps.credit # as well as the instructor dashboard (for deleting student attempts) if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'): + # Import these here to avoid circular dependencies of the form: + # edx-platform app --> DRF --> django translation --> edx-platform app + from edx_proctoring.runtime import set_runtime_service + from instructor.services import InstructorService + from openedx.core.djangoapps.credit.services import CreditService set_runtime_service('credit', CreditService()) set_runtime_service('instructor', InstructorService()) diff --git a/openedx/core/djangoapps/content/course_structures/api/v0/api.py b/openedx/core/djangoapps/content/course_structures/api/v0/api.py index ac46603eff..016b613782 100644 --- a/openedx/core/djangoapps/content/course_structures/api/v0/api.py +++ b/openedx/core/djangoapps/content/course_structures/api/v0/api.py @@ -124,4 +124,4 @@ def course_grading_policy(course_key): final grade. """ course = _retrieve_course(course_key) - return GradingPolicySerializer(course.raw_grader).data + return GradingPolicySerializer(course.raw_grader, many=True).data diff --git a/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py b/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py index f440e6cd4a..881be80437 100644 --- a/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py +++ b/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py @@ -1,6 +1,8 @@ """ API Serializers """ +from collections import defaultdict + from rest_framework import serializers @@ -11,23 +13,58 @@ class GradingPolicySerializer(serializers.Serializer): dropped = serializers.IntegerField(source='drop_count') weight = serializers.FloatField() + def to_representation(self, obj): + """ + Return a representation of the grading policy. + """ + # Backwards compatibility with the behavior of DRF v2. + # When the grader dictionary was missing keys, DRF v2 would default to None; + # DRF v3 unhelpfully raises an exception. + return dict( + super(GradingPolicySerializer, self).to_representation( + defaultdict(lambda: None, obj) + ) + ) + # pylint: disable=invalid-name class BlockSerializer(serializers.Serializer): """ Serializer for course structure block. """ id = serializers.CharField(source='usage_key') type = serializers.CharField(source='block_type') - parent = serializers.CharField(source='parent') + parent = serializers.CharField(required=False) display_name = serializers.CharField() graded = serializers.BooleanField(default=False) format = serializers.CharField() children = serializers.CharField() + def to_representation(self, obj): + """ + Return a representation of the block. + + NOTE: this method maintains backwards compatibility with the behavior + of Django Rest Framework v2. + """ + data = super(BlockSerializer, self).to_representation(obj) + + # Backwards compatibility with the behavior of DRF v2 + # Include a NULL value for "parent" in the representation + # (instead of excluding the key entirely) + if obj.get("parent") is None: + data["parent"] = None + + # Backwards compatibility with the behavior of DRF v2 + # Leave the children list as a list instead of serializing + # it to a string. + data["children"] = obj["children"] + + return data + class CourseStructureSerializer(serializers.Serializer): """ Serializer for course structure. """ - root = serializers.CharField(source='root') - blocks = serializers.SerializerMethodField('get_blocks') + root = serializers.CharField() + blocks = serializers.SerializerMethodField() def get_blocks(self, structure): """ Serialize the individual blocks. """ diff --git a/openedx/core/djangoapps/credit/serializers.py b/openedx/core/djangoapps/credit/serializers.py index d808233848..4e43a139fc 100644 --- a/openedx/core/djangoapps/credit/serializers.py +++ b/openedx/core/djangoapps/credit/serializers.py @@ -2,12 +2,33 @@ from rest_framework import serializers +from opaque_keys.edx.keys import CourseKey +from opaque_keys import InvalidKeyError from openedx.core.djangoapps.credit.models import CreditCourse +class CourseKeyField(serializers.Field): + """ + Serializer field for a model CourseKey field. + """ + + def to_representation(self, data): + """Convert a course key to unicode. """ + return unicode(data) + + def to_internal_value(self, data): + """Convert unicode to a course key. """ + try: + return CourseKey.from_string(data) + except InvalidKeyError as ex: + raise serializers.ValidationError("Invalid course key: {msg}".format(msg=ex.msg)) + + class CreditCourseSerializer(serializers.ModelSerializer): """ CreditCourse Serializer """ + course_key = CourseKeyField() + class Meta(object): # pylint: disable=missing-docstring model = CreditCourse exclude = ('id',) diff --git a/openedx/core/djangoapps/credit/tests/test_views.py b/openedx/core/djangoapps/credit/tests/test_views.py index e0ae94152b..fc95c762b6 100644 --- a/openedx/core/djangoapps/credit/tests/test_views.py +++ b/openedx/core/djangoapps/credit/tests/test_views.py @@ -393,10 +393,7 @@ class CreditCourseViewSetTests(TestCase): # POSTs without a CSRF token should fail. response = client.post(self.path, data=json.dumps(data), content_type=JSON) - - # NOTE (CCB): Ordinarily we would expect a 403; however, since the CSRF validation and session authentication - # fail, DRF considers the request to be unauthenticated. - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, 403) self.assertIn('CSRF', response.content) # Retrieve a CSRF token diff --git a/openedx/core/djangoapps/credit/views.py b/openedx/core/djangoapps/credit/views.py index 5c179fa242..21bec85ad4 100644 --- a/openedx/core/djangoapps/credit/views.py +++ b/openedx/core/djangoapps/credit/views.py @@ -18,8 +18,9 @@ from django.views.decorators.http import require_POST, require_GET from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey import pytz -from rest_framework import viewsets, mixins, permissions, authentication - +from rest_framework import viewsets, mixins, permissions +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from util.json_request import JsonResponse from util.date_utils import from_timestamp from openedx.core.djangoapps.credit import api @@ -377,17 +378,28 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view lookup_value_regex = settings.COURSE_KEY_REGEX queryset = CreditCourse.objects.all() serializer_class = CreditCourseSerializer - authentication_classes = (authentication.OAuth2Authentication, authentication.SessionAuthentication,) + authentication_classes = (OAuth2Authentication, SessionAuthentication,) permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser) + # In Django Rest Framework v3, there is a default pagination + # class that transmutes the response data into a dictionary + # with pagination information. The original response data (a list) + # is stored in a "results" value of the dictionary. + # For backwards compatibility with the existing API, we disable + # the default behavior by setting the pagination_class to None. + pagination_class = None + # This CSRF exemption only applies when authenticating without SessionAuthentication. # SessionAuthentication will enforce CSRF protection. @method_decorator(csrf_exempt) def dispatch(self, request, *args, **kwargs): - # Convert the course ID/key from a string to an actual CourseKey object. - course_id = kwargs.get(self.lookup_field, None) - - if course_id: - kwargs[self.lookup_field] = CourseKey.from_string(course_id) - return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs) + + def get_object(self): + # Convert the serialized course key into a CourseKey instance + # so we can look up the object. + course_key = self.kwargs.get(self.lookup_field) + if course_key is not None: + self.kwargs[self.lookup_field] = CourseKey.from_string(course_key) + + return super(CreditCourseViewSet, self).get_object() diff --git a/openedx/core/djangoapps/profile_images/tests/test_views.py b/openedx/core/djangoapps/profile_images/tests/test_views.py index d1d2901b4c..4e93a1ad25 100644 --- a/openedx/core/djangoapps/profile_images/tests/test_views.py +++ b/openedx/core/djangoapps/profile_images/tests/test_views.py @@ -30,6 +30,40 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC) TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC) +class PatchedClient(APIClient): + """ + Patch DRF's APIClient to avoid a unicode error on file upload. + + Famous last words: This is a *temporary* fix that we should be + able to remove once we upgrade Django past 1.4. + """ + + def request(self, *args, **kwargs): + """Construct an API request. """ + # DRF's default test client implementation uses `six.text_type()` + # to convert the CONTENT_TYPE to `unicode`. In Django 1.4, + # this causes a `UnicodeDecodeError` when Django parses a multipart + # upload. + # + # This is the DRF code we're working around: + # https://github.com/tomchristie/django-rest-framework/blob/3.1.3/rest_framework/compat.py#L227 + # + # ... and this is the Django code that raises the exception: + # + # https://github.com/django/django/blob/1.4.22/django/http/multipartparser.py#L435 + # + # Django unhelpfully swallows the exception, so to the application code + # it appears as though the user didn't send any file data. + # + # This appears to be an issue only with requests constructed in the test + # suite, not with the upload code used in production. + # + if isinstance(kwargs.get("CONTENT_TYPE"), basestring): + kwargs["CONTENT_TYPE"] = str(kwargs["CONTENT_TYPE"]) + + return super(PatchedClient, self).request(*args, **kwargs) + + class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase): """ Base class / shared infrastructure for tests of profile_image "upload" and @@ -111,6 +145,10 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase): """ _view_name = "profile_image_upload" + # Use the patched version of the API client to workaround a unicode issue + # with DRF 3.1 and Django 1.4. Remove this after we upgrade Django past 1.4! + client_class = PatchedClient + def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT): """ Make sure we emit a UserProfile event corresponding to the diff --git a/openedx/core/djangoapps/user_api/accounts/api.py b/openedx/core/djangoapps/user_api/accounts/api.py index 532c139a7a..e7e716a3b5 100644 --- a/openedx/core/djangoapps/user_api/accounts/api.py +++ b/openedx/core/djangoapps/user_api/accounts/api.py @@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None): serializer.save() if "language_proficiencies" in update: - new_language_proficiencies = legacy_profile_serializer.data["language_proficiencies"] + new_language_proficiencies = update["language_proficiencies"] emit_setting_changed_event( user=existing_user, db_table=existing_user_profile.language_proficiencies.model._meta.db_table, diff --git a/openedx/core/djangoapps/user_api/accounts/serializers.py b/openedx/core/djangoapps/user_api/accounts/serializers.py index 180f8ea208..df4d37da2c 100644 --- a/openedx/core/djangoapps/user_api/accounts/serializers.py +++ b/openedx/core/djangoapps/user_api/accounts/serializers.py @@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer): super(UserReadOnlySerializer, self).__init__(*args, **kwargs) - def to_native(self, user): + def to_representation(self, user): """ Overwrite to_native to handle custom logic since we are serializing two models as one here :param user: User object @@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea Class that serializes the portion of UserProfile model needed for account information. """ profile_image = serializers.SerializerMethodField("_get_profile_image") - requires_parental_consent = serializers.SerializerMethodField("get_requires_parental_consent") - language_proficiencies = LanguageProficiencySerializer(many=True, allow_add_remove=True, required=False) + requires_parental_consent = serializers.SerializerMethodField() + language_proficiencies = LanguageProficiencySerializer(many=True, required=False) class Meta(object): # pylint: disable=missing-docstring model = UserProfile @@ -165,25 +165,21 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea read_only_fields = () explicit_read_only_fields = ("profile_image", "requires_parental_consent") - def validate_name(self, attrs, source): + def validate_name(self, new_name): """ Enforce minimum length for name. """ - if source in attrs: - new_name = attrs[source].strip() - if len(new_name) < NAME_MIN_LENGTH: - raise serializers.ValidationError( - "The name field must be at least {} characters long.".format(NAME_MIN_LENGTH) - ) - attrs[source] = new_name + if len(new_name) < NAME_MIN_LENGTH: + raise serializers.ValidationError( + "The name field must be at least {} characters long.".format(NAME_MIN_LENGTH) + ) + return new_name - return attrs - - def validate_language_proficiencies(self, attrs, source): + def validate_language_proficiencies(self, value): """ Enforce all languages are unique. """ - language_proficiencies = [language for language in attrs.get(source, [])] - unique_language_proficiencies = set(language.code for language in language_proficiencies) + language_proficiencies = [language for language in value] + unique_language_proficiencies = set(language["code"] for language in language_proficiencies) if len(language_proficiencies) != len(unique_language_proficiencies): raise serializers.ValidationError("The language_proficiencies field must consist of unique languages") - return attrs + return value def transform_gender(self, user_profile, value): """ Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """ @@ -230,3 +226,29 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea call the method with a single argument, the user_profile object. """ return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user) + + def update(self, instance, validated_data): + """ + Update the profile, including nested fields. + """ + language_proficiencies = validated_data.pop("language_proficiencies", None) + + # Update all fields on the user profile that are writeable, + # except for "language_proficiencies", which we'll update separately + update_fields = set(self.get_writeable_fields()) - set(["language_proficiencies"]) + for field_name in update_fields: + default = getattr(instance, field_name) + field_value = validated_data.get(field_name, default) + setattr(instance, field_name, field_value) + + instance.save() + + # Now update the related language proficiency + if language_proficiencies is not None: + instance.language_proficiencies.all().delete() + instance.language_proficiencies.bulk_create([ + LanguageProficiency(user_profile=instance, code=language["code"]) + for language in language_proficiencies + ]) + + return instance diff --git a/openedx/core/djangoapps/user_api/accounts/tests/test_api.py b/openedx/core/djangoapps/user_api/accounts/tests/test_api.py index 2cad671c49..d3201eaa39 100644 --- a/openedx/core/djangoapps/user_api/accounts/tests/test_api.py +++ b/openedx/core/djangoapps/user_api/accounts/tests/test_api.py @@ -164,7 +164,10 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): field_errors = context_manager.exception.field_errors self.assertEqual(3, len(field_errors)) self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"]) - self.assertIn("Select a valid choice", field_errors["gender"]["developer_message"]) + self.assertIn( + "Value \'undecided\' is not valid for field \'gender\'", + field_errors["gender"]["developer_message"] + ) self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"]) @patch('django.core.mail.send_mail') diff --git a/openedx/core/djangoapps/user_api/accounts/tests/test_views.py b/openedx/core/djangoapps/user_api/accounts/tests/test_views.py index 8d021d7dd2..f7dc0bd879 100644 --- a/openedx/core/djangoapps/user_api/accounts/tests/test_views.py +++ b/openedx/core/djangoapps/user_api/accounts/tests/test_views.py @@ -359,16 +359,19 @@ class TestAccountAPI(UserAPITestCase): self.assertEqual(404, response.status_code) @ddt.data( - ("gender", "f", "not a gender", u"Select a valid choice. not a gender is not one of the available choices."), - ("level_of_education", "none", u"ȻħȺɍłɇs", u"Select a valid choice. ȻħȺɍłɇs is not one of the available choices."), - ("country", "GB", "XY", u"Select a valid choice. XY is not one of the available choices."), - ("year_of_birth", 2009, "not_an_int", u"Enter a whole number."), - ("name", "bob", "z" * 256, u"Ensure this value has at most 255 characters (it has 256)."), + ("gender", "f", "not a gender", u'"not a gender" is not a valid choice.'), + ("level_of_education", "none", u"ȻħȺɍłɇs", u'"ȻħȺɍłɇs" is not a valid choice.'), + ("country", "GB", "XY", u'"XY" is not a valid choice.'), + ("year_of_birth", 2009, "not_an_int", u"A valid integer is required."), + ("name", "bob", "z" * 256, u"Ensure this field has no more than 255 characters."), ("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."), ("goals", "Smell the roses"), ("mailing_address", "Sesame Street"), # Note that we store the raw data, so it is up to client to escape the HTML. - ("bio", u"Lacrosse-playing superhero 壓是進界推日不復女", "z" * 3001, u"Ensure this value has at most 3000 characters (it has 3001)."), + ( + "bio", u"Lacrosse-playing superhero 壓是進界推日不復女", + "z" * 3001, u"Ensure this field has no more than 3000 characters." + ), # Note that email is tested below, as it is not immediately updated. # Note that language_proficiencies is tested below as there are multiple error and success conditions. ) @@ -568,10 +571,10 @@ class TestAccountAPI(UserAPITestCase): self.assertItemsEqual(response.data["language_proficiencies"], proficiencies) @ddt.data( - (u"not_a_list", [{u'non_field_errors': [u'Expected a list of items.']}]), - ([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data']}]), + (u"not_a_list", {u'non_field_errors': [u'Expected a list of items but got type "unicode".']}), + ([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data. Expected a dictionary, but got unicode.']}]), ([{}], [{"code": [u"This field is required."]}]), - ([{u"code": u"invalid_language_code"}], [{'code': [u'Select a valid choice. invalid_language_code is not one of the available choices.']}]), + ([{u"code": u"invalid_language_code"}], [{'code': [u'"invalid_language_code" is not a valid choice.']}]), ([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']), ) @ddt.unpack diff --git a/openedx/core/djangoapps/user_api/preferences/api.py b/openedx/core/djangoapps/user_api/preferences/api.py index 8c2a3860ac..58a3f915d3 100644 --- a/openedx/core/djangoapps/user_api/preferences/api.py +++ b/openedx/core/djangoapps/user_api/preferences/api.py @@ -9,9 +9,10 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db import IntegrityError from django.utils.translation import ugettext as _ -from student.models import User, UserProfile from django.utils.translation import ugettext_noop +from student.models import User, UserProfile +from request_cache import get_request_or_stub from ..errors import ( UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized, PreferenceValidationError, PreferenceUpdateError @@ -68,7 +69,17 @@ def get_user_preferences(requesting_user, username=None): UserAPIInternalError: the operation failed due to an unexpected error. """ existing_user = _get_user(requesting_user, username, allow_staff=True) - user_serializer = UserSerializer(existing_user) + + # Django Rest Framework V3 uses the current request to version + # hyperlinked URLS, so we need to retrieve the request and pass + # it in the serializer's context (otherwise we get an AssertionError). + # We're retrieving the request from the cache rather than passing it in + # as an argument because this is an implementation detail of how we're + # serializing data, which we want to encapsulate in the API call. + context = { + "request": get_request_or_stub() + } + user_serializer = UserSerializer(existing_user, context=context) return user_serializer.data["preferences"] @@ -356,7 +367,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format( preference_key=preference_key, preference_value=preference_value, error=serializer.errors ) - if serializer.errors["key"]: + if "key" in serializer.errors: user_message = _(u"Invalid user preference key '{preference_key}'.").format( preference_key=preference_key ) diff --git a/openedx/core/djangoapps/user_api/preferences/tests/test_api.py b/openedx/core/djangoapps/user_api/preferences/tests/test_api.py index bef0e79a67..ef95c7b2db 100644 --- a/openedx/core/djangoapps/user_api/preferences/tests/test_api.py +++ b/openedx/core/djangoapps/user_api/preferences/tests/test_api.py @@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value): preference_key=preference_key, preference_value=preference_value, error={ - "key": [u"Ensure this value has at most 255 characters (it has 256)."] + "key": [u"Ensure this field has no more than 255 characters."] } ) diff --git a/openedx/core/djangoapps/user_api/serializers.py b/openedx/core/djangoapps/user_api/serializers.py index e95dedee00..2bf15b5ad8 100644 --- a/openedx/core/djangoapps/user_api/serializers.py +++ b/openedx/core/djangoapps/user_api/serializers.py @@ -6,8 +6,8 @@ from .models import UserPreference class UserSerializer(serializers.HyperlinkedModelSerializer): - name = serializers.SerializerMethodField("get_name") - preferences = serializers.SerializerMethodField("get_preferences") + name = serializers.SerializerMethodField() + preferences = serializers.SerializerMethodField() def get_name(self, user): profile = UserProfile.objects.get(user=user) @@ -32,9 +32,10 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer): class RawUserPreferenceSerializer(serializers.ModelSerializer): - """Serializer that generates a raw representation of a user preference. """ - user = serializers.PrimaryKeyRelatedField() + Serializer that generates a raw representation of a user preference. + """ + user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all()) class Meta(object): # pylint: disable=missing-docstring model = UserPreference @@ -57,3 +58,11 @@ class ReadOnlyFieldsSerializerMixin(object): cls.Meta.read_only_fields tuple. """ return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '') + + @classmethod + def get_writeable_fields(cls): + """ + Return all fields on this serializer that are writeable. + """ + all_fields = getattr(cls.Meta, 'fields', tuple()) + return tuple(set(all_fields) - set(cls.get_read_only_fields())) diff --git a/openedx/core/lib/api/authentication.py b/openedx/core/lib/api/authentication.py index 92f46861c4..e6e50284f1 100644 --- a/openedx/core/lib/api/authentication.py +++ b/openedx/core/lib/api/authentication.py @@ -1,10 +1,11 @@ """ Common Authentication Handlers used across projects. """ -from rest_framework import authentication +from rest_framework.authentication import SessionAuthentication +from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework.exceptions import AuthenticationFailed -from rest_framework.compat import oauth2_provider, provider_now +from rest_framework_oauth.compat import oauth2_provider, provider_now -class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthentication): +class SessionAuthenticationAllowInactiveUser(SessionAuthentication): """Ensure that the user is logged in, but do not require the account to be active. We use this in the special case that a user has created an account, @@ -51,7 +52,7 @@ class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthenticatio return (user, None) -class OAuth2AuthenticationAllowInactiveUser(authentication.OAuth2Authentication): +class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication): """ This is a temporary workaround while the is_active field on the user is coupled with whether or not the user has verified ownership of their claimed email address. diff --git a/openedx/core/lib/api/fields.py b/openedx/core/lib/api/fields.py index 0a1d609997..7fc9a836b4 100644 --- a/openedx/core/lib/api/fields.py +++ b/openedx/core/lib/api/fields.py @@ -1,7 +1,5 @@ """Fields useful for edX API implementations.""" -from django.core.exceptions import ValidationError - -from rest_framework.serializers import CharField, Field +from rest_framework.serializers import Field class ExpandableField(Field): @@ -18,25 +16,19 @@ class ExpandableField(Field): self.expanded = kwargs.pop('expanded_serializer') super(ExpandableField, self).__init__(**kwargs) - def field_to_native(self, obj, field_name): - """Converts obj to a native representation, using the expanded serializer if the context requires it.""" - if 'expand' in self.context and field_name in self.context['expand']: - self.expanded.initialize(self, field_name) - return self.expanded.field_to_native(obj, field_name) - else: - self.collapsed.initialize(self, field_name) - return self.collapsed.field_to_native(obj, field_name) + def to_representation(self, obj): + """ + Return a representation of the field that is either expanded or collapsed. + """ + should_expand = self.field_name in self.context.get("expand", []) + field = self.expanded if should_expand else self.collapsed + # Avoid double-binding the field, otherwise we'll get + # an error about the source kwarg being redundant. + if field.source is None: + field.bind(self.field_name, self) -class NonEmptyCharField(CharField): - """ - A field that enforces non-emptiness even for partial updates. + if should_expand: + self.expanded.context["expand"] = set(field.context.get("expand", [])) - This is necessary because prior to version 3, DRF skips validation for empty - values. Thus, CharField's min_length and RegexField cannot be used to - enforce this constraint. - """ - def validate(self, value): - super(NonEmptyCharField, self).validate(value) - if not value.strip(): - raise ValidationError(self.error_messages["required"]) + return field.to_representation(obj) diff --git a/openedx/core/lib/api/mixins.py b/openedx/core/lib/api/mixins.py new file mode 100644 index 0000000000..909fb8b765 --- /dev/null +++ b/openedx/core/lib/api/mixins.py @@ -0,0 +1,33 @@ +""" +Django Rest Framework view mixins. +""" +from django.core.exceptions import ValidationError +from django.http import Http404 +from rest_framework import status +from rest_framework.mixins import CreateModelMixin +from rest_framework.response import Response + + +class PutAsCreateMixin(CreateModelMixin): + """ + Backwards compatibility with Django Rest Framework v2, which allowed + creation of a new resource using PUT. + """ + + def update(self, request, *args, **kwargs): + """ + Create/update course modes for a course. + """ + # First, try to update the existing instance + try: + try: + return super(PutAsCreateMixin, self).update(request, *args, **kwargs) + except Http404: + # If no instance exists yet, create it. + # This is backwards-compatible with the behavior of DRF v2. + return super(PutAsCreateMixin, self).create(request, *args, **kwargs) + + # Backwards compatibility with DRF v2 behavior, which would catch model-level + # validation errors and return a 400 + except ValidationError as err: + return Response(err.messages, status=status.HTTP_400_BAD_REQUEST) diff --git a/openedx/core/lib/api/paginators.py b/openedx/core/lib/api/paginators.py index ae4b32d37e..373845bc19 100644 --- a/openedx/core/lib/api/paginators.py +++ b/openedx/core/lib/api/paginators.py @@ -3,6 +3,31 @@ from django.http import Http404 from django.core.paginator import Paginator, InvalidPage +from rest_framework.response import Response +from rest_framework import pagination + + +class DefaultPagination(pagination.PageNumberPagination): + """ + Default paginator for APIs in edx-platform. + + This is configured in settings to be automatically used + by any subclass of Django Rest Framework's generic API views. + """ + page_size_query_param = "page_size" + + def get_paginated_response(self, data): + """ + Annotate the response with pagination information. + """ + return Response({ + 'next': self.get_next_link(), + 'previous': self.get_previous_link(), + 'count': self.page.paginator.count, + 'num_pages': self.page.paginator.num_pages, + 'results': data + }) + def paginate_search_results(object_class, search_results, page_size, page): """ diff --git a/openedx/core/lib/api/serializers.py b/openedx/core/lib/api/serializers.py index 03ec001137..b13bbc2d83 100644 --- a/openedx/core/lib/api/serializers.py +++ b/openedx/core/lib/api/serializers.py @@ -1,32 +1,8 @@ -from rest_framework import pagination, serializers +""" +Serializers to be used in APIs. +""" - -class PaginationSerializer(pagination.PaginationSerializer): - """ - Custom PaginationSerializer for openedx. - - Adds the following fields: - - num_pages: total number of pages - - current_page: the current page being returned - - start: the index of the first page item within the overall collection - """ - start_page = 1 # django Paginator.page objects have 1-based indexes - num_pages = serializers.Field(source='paginator.num_pages') - current_page = serializers.SerializerMethodField('get_current_page') - start = serializers.SerializerMethodField('get_start') - sort_order = serializers.SerializerMethodField('get_sort_order') - - def get_current_page(self, page): - """Get the current page""" - return page.number - - def get_start(self, page): - """Get the index of the first page item within the overall collection""" - return (self.get_current_page(page) - self.start_page) * page.paginator.per_page - - def get_sort_order(self, page): # pylint: disable=unused-argument - """Get the order by which this collection was sorted""" - return self.context.get('sort_order') +from rest_framework import serializers class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): @@ -54,9 +30,10 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs) - self.fields[id_source] = serializers.CharField(read_only=True, source=id_source) + self.fields[id_source] = serializers.CharField(read_only=True) self.fields['url'].view_name = view_name self.fields['url'].lookup_field = lookup_field + self.fields['url'].lookup_url_kwarg = lookup_field class Meta(object): """Defines meta information for the ModelSerializer. diff --git a/openedx/core/lib/api/tests/test_authentication.py b/openedx/core/lib/api/tests/test_authentication.py index fe9a1dbd47..1fb5ff5b31 100644 --- a/openedx/core/lib/api/tests/test_authentication.py +++ b/openedx/core/lib/api/tests/test_authentication.py @@ -1,63 +1,235 @@ -"""Tests for util.authentication module.""" +""" +Tests for OAuth2. This module is copied from django-rest-framework-oauth (tests/test_authentication.py) +and updated to use our subclass of OAuth2Authentication. +""" + +from __future__ import unicode_literals +import datetime + +from django.conf.urls import patterns, url, include +from django.contrib.auth.models import User +from django.http import HttpResponse +from django.test import TestCase +from django.utils import unittest +from django.utils.http import urlencode + +from rest_framework import status +from rest_framework.permissions import IsAuthenticated +from rest_framework_oauth import permissions +from rest_framework_oauth.compat import oauth2_provider, oauth2_provider_scope +from rest_framework.test import APIRequestFactory, APIClient +from rest_framework.views import APIView -from mock import patch -from django.conf import settings -from rest_framework import permissions -from rest_framework.compat import patterns, url -from rest_framework.tests import test_authentication from provider import scope, constants -from unittest import skipUnless from ..authentication import OAuth2AuthenticationAllowInactiveUser +factory = APIRequestFactory() # pylint: disable=invalid-name -class OAuth2AuthAllowInactiveUserDebug(OAuth2AuthenticationAllowInactiveUser): - """ - A debug class analogous to the OAuth2AuthenticationDebug class that tests - the OAuth2 flow with the access token sent in a query param.""" + +class MockView(APIView): # pylint: disable=missing-docstring + permission_classes = (IsAuthenticated,) + + def get(self, request): # pylint: disable=missing-docstring,unused-argument + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + def post(self, request): # pylint: disable=missing-docstring,unused-argument + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + def put(self, request): # pylint: disable=missing-docstring,unused-argument + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + +# This is the a change we've made from the django-rest-framework-oauth version +# of these tests. We're subclassing our custom OAuth2AuthenticationAllowInactiveUser +# instead of OAuth2Authentication. +class OAuth2AuthenticationDebug(OAuth2AuthenticationAllowInactiveUser): # pylint: disable=missing-docstring allow_query_params_token = True -# The following patch overrides the URL patterns for the MockView class used in -# rest_framework.tests.test_authentication so that the corresponding AllowInactiveUser -# classes are tested instead. -@skipUnless(settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'), 'OAuth2 not enabled') -@patch.object( - test_authentication, - 'urlpatterns', - patterns( - '', - url( - r'^oauth2-test/$', - test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser]) - ), - url( - r'^oauth2-test-debug/$', - test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthAllowInactiveUserDebug]) - ), - url( - r'^oauth2-with-scope-test/$', - test_authentication.MockView.as_view( - authentication_classes=[OAuth2AuthenticationAllowInactiveUser], - permission_classes=[permissions.TokenHasReadWriteScope] - ) +urlpatterns = patterns( + '', + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser])), + url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), + url( + r'^oauth2-with-scope-test/$', + MockView.as_view( + authentication_classes=[OAuth2AuthenticationAllowInactiveUser], + permission_classes=[permissions.TokenHasReadWriteScope] ) - ) + ), ) -class OAuth2AuthenticationAllowInactiveUserTestCase(test_authentication.OAuth2Tests): - """ - Tests the OAuth2AuthenticationAllowInactiveUser class by running all the existing tests in - OAuth2Tests but with the is_active flag on the user set to False. - """ - def setUp(self): - super(OAuth2AuthenticationAllowInactiveUserTestCase, self).setUp() - # set the user's is_active flag to False. + +class OAuth2Tests(TestCase): + """OAuth 2.0 authentication""" + urls = 'openedx.core.lib.api.tests.test_authentication' + + def setUp(self): + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CLIENT_ID = 'client_key' # pylint: disable=invalid-name + self.CLIENT_SECRET = 'client_secret' # pylint: disable=invalid-name + self.ACCESS_TOKEN = "access_token" # pylint: disable=invalid-name + self.REFRESH_TOKEN = "refresh_token" # pylint: disable=invalid-name + + self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + redirect_uri='', + client_type=0, + name='example', + user=None, + ) + + self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( + token=self.ACCESS_TOKEN, + client=self.oauth2_client, + user=self.user, + ) + self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( + user=self.user, + access_token=self.access_token, + client=self.oauth2_client + ) + + # This is the a change we've made from the django-rest-framework-oauth version + # of these tests. self.user.is_active = False self.user.save() + # This is the a change we've made from the django-rest-framework-oauth version + # of these tests. # Override the SCOPE_NAME_DICT setting for tests for oauth2-with-scope-test. This is # needed to support READ and WRITE scopes as they currently aren't supported by the # edx-auth2-provider, and their scope values collide with other scopes defined in the # edx-auth2-provider. scope.SCOPE_NAME_DICT = {'read': constants.READ, 'write': constants.WRITE} + + def _create_authorization_header(self, token=None): # pylint: disable=missing-docstring + return "Bearer {0}".format(token or self.access_token.token) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_type_failing(self): + """Ensure that a wrong token type lead to the correct HTTP error status code""" + auth = "Wrong token-type-obviously" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_format_failing(self): + """Ensure that a wrong token format lead to the correct HTTP error status code""" + auth = "Bearer wrong token format" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_failing(self): + """Ensure that a wrong token lead to the correct HTTP error status code""" + auth = "Bearer wrong-token" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_missing(self): + """Ensure that a missing token lead to the correct HTTP error status code""" + auth = "Bearer" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth(self): + """Ensure GETing form over OAuth with correct client credentials succeed""" + auth = self._create_authorization_header() + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in form data succeed""" + response = self.csrf_client.post( + '/oauth2-test/', + data={'access_token': self.access_token.token} + ) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_failing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test/?%s' % query) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_token_removed_failing_auth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.access_token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_refresh_token_failing_auth(self): + """Ensure POSTing with refresh token instead of access token fails""" + auth = self._create_authorization_header(token=self.refresh_token.token) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_expired_access_token_failing_auth(self): + """Ensure POSTing with expired access token fails with an 'Invalid token' error""" + self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late + self.access_token.save() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + self.assertIn('Invalid token', response.content) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_invalid_scope_failing_auth(self): + """Ensure POSTing with a readonly scope instead of a write scope fails""" + read_only_access_token = self.access_token + read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] + read_only_access_token.save() + auth = self._create_authorization_header(token=read_only_access_token.token) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_valid_scope_passing_auth(self): + """Ensure POSTing with a write scope succeed""" + read_write_access_token = self.access_token + read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] + read_write_access_token.save() + auth = self._create_authorization_header(token=read_write_access_token.token) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) diff --git a/openedx/core/lib/api/view_utils.py b/openedx/core/lib/api/view_utils.py index 2a4aa6d372..77b0684e63 100644 --- a/openedx/core/lib/api/view_utils.py +++ b/openedx/core/lib/api/view_utils.py @@ -9,6 +9,7 @@ from django.utils.translation import ugettext as _ from rest_framework import status, response from rest_framework.exceptions import APIException from rest_framework.permissions import IsAuthenticated +from rest_framework.request import clone_request from rest_framework.response import Response from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin from rest_framework.generics import GenericAPIView @@ -193,3 +194,23 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView) add_serializer_errors(serializer, patch, field_errors) return field_errors + + def get_object_or_none(self): + """ + Retrieve an object or return None if the object can't be found. + + NOTE: This replaces functionality that was removed in Django Rest Framework v3.1. + """ + try: + return self.get_object() + except Http404: + if self.request.method == 'PUT': + # For PUT-as-create operation, we need to ensure that we have + # relevant permissions, as if this was a POST request. This + # will either raise a PermissionDenied exception, or simply + # return None. + self.check_permissions(clone_request(self.request, 'POST')) + else: + # PATCH requests where the object does not exist should still + # return a 404 response. + raise diff --git a/requirements/edx/base.txt b/requirements/edx/base.txt index 22f63923cd..c0463faf5e 100644 --- a/requirements/edx/base.txt +++ b/requirements/edx/base.txt @@ -28,7 +28,7 @@ django-ses==0.7.0 django-simple-history==1.6.3 django-storages==1.1.5 django-method-override==0.1.0 -djangorestframework==2.3.14 +djangorestframework>=3.1,<3.2 django==1.4.22 elasticsearch==0.4.5 facebook-sdk==0.4.0 diff --git a/requirements/edx/github.txt b/requirements/edx/github.txt index 82f1e313d7..a40bb836e9 100644 --- a/requirements/edx/github.txt +++ b/requirements/edx/github.txt @@ -12,6 +12,7 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df -e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline -e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki -e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider +-e git+https://github.com/edx/django-rest-framework-oauth.git@f0b503fda8c254a38f97fef802ded4f5fe367f7a#egg=djangorestframework-oauth -e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6 -e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev @@ -40,13 +41,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx -e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking -e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash -e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock --e git+https://github.com/edx/edx-ora2.git@release-2015-08-25T16.16#egg=edx-ora2 --e git+https://github.com/edx/edx-submissions.git@9538ee8a971d04dc1cb05e88f6aa0c36b224455c#egg=edx-submissions +-e git+https://github.com/edx/edx-ora2.git@release-2015-09-16T15.28#egg=edx-ora2 +-e git+https://github.com/edx/edx-submissions.git@0.1.0#egg=edx-submissions -e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3 git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3 git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7 --e git+https://github.com/edx/edx-val.git@v0.0.5#egg=edx-val +-e git+https://github.com/edx/edx-val.git@0.0.6#egg=edx-val -e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock -e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search -e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones