From 2fd6add524ad472f1a8a761ed174289f9e018cc6 Mon Sep 17 00:00:00 2001 From: Ben Patterson Date: Mon, 28 Sep 2015 17:12:43 -0400 Subject: [PATCH] Revert "Merge DRF 3.1 in to master" --- cms/envs/common.py | 5 +- .../cors_csrf/tests/test_authentication.py | 4 +- common/djangoapps/enrollment/data.py | 29 +- common/djangoapps/enrollment/serializers.py | 77 ++++-- .../djangoapps/enrollment/tests/test_views.py | 2 +- common/djangoapps/request_cache/__init__.py | 43 --- common/djangoapps/request_cache/tests.py | 20 -- .../test/acceptance/tests/lms/test_teams.py | 2 +- lms/djangoapps/commerce/api/v1/serializers.py | 40 +-- lms/djangoapps/commerce/api/v1/views.py | 14 +- .../course_structure_api/v0/serializers.py | 10 +- .../course_structure_api/v0/tests.py | 83 ++---- .../course_structure_api/v0/views.py | 7 +- lms/djangoapps/courseware/grades.py | 11 +- lms/djangoapps/discussion_api/api.py | 4 +- lms/djangoapps/discussion_api/pagination.py | 37 +-- lms/djangoapps/discussion_api/serializers.py | 144 ++++------ .../discussion_api/tests/test_api.py | 7 +- .../discussion_api/tests/test_serializers.py | 13 +- .../discussion_api/tests/test_views.py | 6 +- lms/djangoapps/discussion_api/views.py | 3 +- .../django_comment_client/base/tests.py | 3 +- .../django_comment_client/forum/tests.py | 3 - .../django_comment_client/tests/group_id.py | 2 - .../social_facebook/courses/views.py | 5 +- .../social_facebook/friends/views.py | 2 +- .../social_facebook/groups/views.py | 6 +- .../preferences/serializers.py | 2 +- .../social_facebook/preferences/views.py | 4 +- .../mobile_api/social_facebook/utils.py | 2 +- .../mobile_api/users/serializers.py | 8 +- lms/djangoapps/mobile_api/users/views.py | 8 - lms/djangoapps/notifier_api/serializers.py | 6 +- lms/djangoapps/notifier_api/views.py | 24 +- lms/djangoapps/teams/models.py | 1 + lms/djangoapps/teams/search_indexes.py | 11 +- lms/djangoapps/teams/serializers.py | 94 +++---- lms/djangoapps/teams/tests/factories.py | 7 - .../teams/tests/test_serializers.py | 87 +++--- lms/djangoapps/teams/tests/test_views.py | 62 +---- lms/djangoapps/teams/views.py | 178 ++++-------- lms/envs/common.py | 14 - lms/startup.py | 9 +- .../content/course_structures/api/v0/api.py | 2 +- .../course_structures/api/v0/serializers.py | 43 +-- openedx/core/djangoapps/credit/serializers.py | 21 -- .../djangoapps/credit/tests/test_views.py | 5 +- openedx/core/djangoapps/credit/views.py | 30 +- .../profile_images/tests/test_views.py | 38 --- .../core/djangoapps/user_api/accounts/api.py | 2 +- .../user_api/accounts/serializers.py | 56 ++-- .../user_api/accounts/tests/test_api.py | 5 +- .../user_api/accounts/tests/test_views.py | 21 +- .../djangoapps/user_api/preferences/api.py | 17 +- .../user_api/preferences/tests/test_api.py | 2 +- .../core/djangoapps/user_api/serializers.py | 17 +- openedx/core/lib/api/authentication.py | 9 +- openedx/core/lib/api/fields.py | 36 ++- openedx/core/lib/api/mixins.py | 33 --- openedx/core/lib/api/paginators.py | 25 -- openedx/core/lib/api/serializers.py | 35 ++- .../core/lib/api/tests/test_authentication.py | 256 +++--------------- openedx/core/lib/api/view_utils.py | 21 -- requirements/edx/base.txt | 2 +- requirements/edx/github.txt | 7 +- 65 files changed, 500 insertions(+), 1282 deletions(-) delete mode 100644 common/djangoapps/request_cache/tests.py delete mode 100644 openedx/core/lib/api/mixins.py diff --git a/cms/envs/common.py b/cms/envs/common.py index e2ae72f2dd..b1002e1222 100644 --- a/cms/envs/common.py +++ b/cms/envs/common.py @@ -44,10 +44,7 @@ from lms.envs.common import ( PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES, # The following setting is included as it is used to check whether to # display credit eligibility table on the CMS or not. - ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY, - - # Django REST framework configuration - REST_FRAMEWORK, + ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY ) from path import Path as path from warnings import simplefilter diff --git a/common/djangoapps/cors_csrf/tests/test_authentication.py b/common/djangoapps/cors_csrf/tests/test_authentication.py index ea4587021a..7f2d78fd8a 100644 --- a/common/djangoapps/cors_csrf/tests/test_authentication.py +++ b/common/djangoapps/cors_csrf/tests/test_authentication.py @@ -6,7 +6,7 @@ from django.test.utils import override_settings from django.test.client import RequestFactory from django.conf import settings -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import AuthenticationFailed from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf @@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase): def test_perform_csrf_referer_check(self): request = self._fake_request() - with self.assertRaisesRegexp(PermissionDenied, 'CSRF'): + with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'): self.auth.enforce_csrf(request) @patch.dict(settings.FEATURES, { diff --git a/common/djangoapps/enrollment/data.py b/common/djangoapps/enrollment/data.py index 68880e1a3e..33f78e3e4a 100644 --- a/common/djangoapps/enrollment/data.py +++ b/common/djangoapps/enrollment/data.py @@ -11,7 +11,7 @@ from enrollment.errors import ( CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError, CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute ) -from enrollment.serializers import CourseEnrollmentSerializer, CourseSerializer +from enrollment.serializers import CourseEnrollmentSerializer, CourseField from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from student.models import ( CourseEnrollment, NonExistentCourseError, EnrollmentClosedError, @@ -35,30 +35,9 @@ def get_course_enrollments(user_id): """ qset = CourseEnrollment.objects.filter( - user__username=user_id, - is_active=True + user__username=user_id, is_active=True ).order_by('created') - - enrollments = CourseEnrollmentSerializer(qset, many=True).data - - # Find deleted courses and filter them out of the results - deleted = [] - valid = [] - for enrollment in enrollments: - if enrollment.get("course_details") is not None: - valid.append(enrollment) - else: - deleted.append(enrollment) - - if deleted: - log.warning( - ( - u"Course enrollments for user %s reference " - u"courses that do not exist (this can occur if a course is deleted)." - ), user_id, - ) - - return valid + return CourseEnrollmentSerializer(qset).data def get_course_enrollment(username, course_id): @@ -292,4 +271,4 @@ def get_course_enrollment_info(course_id, include_expired=False): log.warning(msg) raise CourseNotFoundError(msg) else: - return CourseSerializer(course, include_expired=include_expired).data + return CourseField().to_native(course, include_expired=include_expired) diff --git a/common/djangoapps/enrollment/serializers.py b/common/djangoapps/enrollment/serializers.py index 994f97e5fc..ca5c4bfd93 100644 --- a/common/djangoapps/enrollment/serializers.py +++ b/common/djangoapps/enrollment/serializers.py @@ -30,36 +30,32 @@ class StringListField(serializers.CharField): return [int(item) for item in items] -class CourseSerializer(serializers.Serializer): # pylint: disable=abstract-method - """ - Serialize a course descriptor and related information. +class CourseField(serializers.RelatedField): + """Read-Only representation of course enrollment information. + + Aggregates course information from the CourseDescriptor as well as the Course Modes configured + for enrolling in the course. + """ - course_id = serializers.CharField(source="id") - enrollment_start = serializers.DateTimeField(format=None) - enrollment_end = serializers.DateTimeField(format=None) - course_start = serializers.DateTimeField(source="start", format=None) - course_end = serializers.DateTimeField(source="end", format=None) - invite_only = serializers.BooleanField(source="invitation_only") - course_modes = serializers.SerializerMethodField() + def to_native(self, course, **kwargs): + course_modes = ModeSerializer( + CourseMode.modes_for_course( + course.id, + include_expired=kwargs.get('include_expired', False), + only_selectable=False + ) + ).data - def __init__(self, *args, **kwargs): - self.include_expired = kwargs.pop("include_expired", False) - super(CourseSerializer, self).__init__(*args, **kwargs) - - def get_course_modes(self, obj): - """ - Retrieve course modes associated with the course. - """ - course_modes = CourseMode.modes_for_course( - obj.id, - include_expired=self.include_expired, - only_selectable=False - ) - return [ - ModeSerializer(mode).data - for mode in course_modes - ] + return { + 'course_id': unicode(course.id), + 'enrollment_start': course.enrollment_start, + 'enrollment_end': course.enrollment_end, + 'course_start': course.start, + 'course_end': course.end, + 'invite_only': course.invitation_only, + 'course_modes': course_modes, + } class CourseEnrollmentSerializer(serializers.ModelSerializer): @@ -69,9 +65,34 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): the Course Descriptor and course modes, to give a complete representation of course enrollment. """ - course_details = CourseSerializer(source="course_overview") + course_details = serializers.SerializerMethodField('get_course_details') user = serializers.SerializerMethodField('get_username') + @property + def data(self): + serialized_data = super(CourseEnrollmentSerializer, self).data + + # filter the results with empty courses 'course_details' + if isinstance(serialized_data, dict): + if serialized_data.get('course_details') is None: + return None + + return serialized_data + + return [enrollment for enrollment in serialized_data if enrollment.get('course_details')] + + def get_course_details(self, model): + if model.course is None: + msg = u"Course '{0}' does not exist (maybe deleted), in which User (user_id: '{1}') is enrolled.".format( + model.course_id, + model.user.id + ) + log.warning(msg) + return None + + field = CourseField() + return field.to_native(model.course) + def get_username(self, model): """Retrieves the username from the associated model.""" return model.username diff --git a/common/djangoapps/enrollment/tests/test_views.py b/common/djangoapps/enrollment/tests/test_views.py index 20052bff34..bd5a3ffbc2 100644 --- a/common/djangoapps/enrollment/tests/test_views.py +++ b/common/djangoapps/enrollment/tests/test_views.py @@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase): @cross_domain_config def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument resp = self._cross_domain_post('invalid_csrf_token') - self.assertEqual(resp.status_code, 403) + self.assertEqual(resp.status_code, 401) def _get_csrf_cookie(self): """Retrieve the cross-domain CSRF cookie. """ diff --git a/common/djangoapps/request_cache/__init__.py b/common/djangoapps/request_cache/__init__.py index aedc785cad..7d8bf77ef6 100644 --- a/common/djangoapps/request_cache/__init__.py +++ b/common/djangoapps/request_cache/__init__.py @@ -5,18 +5,10 @@ This module requires that :class:`request_cache.middleware.RequestCache` is installed in order to clear the cache after each request. """ -import logging -from urlparse import urlparse - -from django.conf import settings -from django.test.client import RequestFactory from request_cache import middleware -log = logging.getLogger(__name__) - - def get_cache(name): """ Return the request cache named ``name``. @@ -34,38 +26,3 @@ def get_request(): Return the current request. """ return middleware.RequestCache.get_current_request() - - -def get_request_or_stub(): - """ - Return the current request or a stub request. - - If called outside the context of a request, construct a fake - request that can be used to build an absolute URI. - - This is useful in cases where we need to pass in a request object - but don't have an active request (for example, in test cases). - """ - request = get_request() - - if request is None: - log.warning( - "Could not retrieve the current request. " - "A stub request will be created instead using settings.SITE_NAME. " - "This should be used *only* in test cases, never in production!" - ) - - # The settings SITE_NAME may contain a port number, so we need to - # parse the full URL. - full_url = "http://{site_name}".format(site_name=settings.SITE_NAME) - parsed_url = urlparse(full_url) - - # Construct the fake request. This can be used to construct absolute - # URIs to other paths. - return RequestFactory( - SERVER_NAME=parsed_url.hostname, - SERVER_PORT=parsed_url.port or 80, - ).get("/") - - else: - return request diff --git a/common/djangoapps/request_cache/tests.py b/common/djangoapps/request_cache/tests.py deleted file mode 100644 index 3883fe873b..0000000000 --- a/common/djangoapps/request_cache/tests.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Tests for the request cache. -""" -from django.conf import settings -from django.test import TestCase - -from request_cache import get_request_or_stub - - -class TestRequestCache(TestCase): - """ - Tests for the request cache. - """ - - def test_get_request_or_stub(self): - # Outside the context of the request, we should still get a request - # that allows us to build an absolute URI. - stub = get_request_or_stub() - expected_url = "http://{site_name}/foobar".format(site_name=settings.SITE_NAME) - self.assertEqual(stub.build_absolute_uri("foobar"), expected_url) diff --git a/common/test/acceptance/tests/lms/test_teams.py b/common/test/acceptance/tests/lms/test_teams.py index 33e4c0f0a7..40d6dc963d 100644 --- a/common/test/acceptance/tests/lms/test_teams.py +++ b/common/test/acceptance/tests/lms/test_teams.py @@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase): ) create_team_page.submit_form() team_page = TeamPage(self.browser, self.course_id) - self.assertTrue(team_page.is_browser_on_page()) + self.assertTrue(team_page.is_browser_on_page) team_page.click_all_topics() self.assertTrue(self.topics_page.is_browser_on_page()) self.topics_page.wait_for_ajax() diff --git a/lms/djangoapps/commerce/api/v1/serializers.py b/lms/djangoapps/commerce/api/v1/serializers.py index 775867fe85..a904fa7238 100644 --- a/lms/djangoapps/commerce/api/v1/serializers.py +++ b/lms/djangoapps/commerce/api/v1/serializers.py @@ -18,12 +18,7 @@ class CourseModeSerializer(serializers.ModelSerializer): """ CourseMode serializer. """ name = serializers.CharField(source='mode_slug') price = serializers.IntegerField(source='min_price') - expires = serializers.DateTimeField( - source='expiration_datetime', - required=False, - allow_null=True, - format=None - ) + expires = serializers.DateTimeField(source='expiration_datetime', required=False, blank=True) def get_identity(self, data): try: @@ -61,8 +56,8 @@ class CourseSerializer(serializers.Serializer): """ Course serializer. """ id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name name = serializers.CharField(read_only=True) - verification_deadline = serializers.DateTimeField(format=None, allow_null=True, required=False) - modes = CourseModeSerializer(many=True) + verification_deadline = serializers.DateTimeField(blank=True) + modes = CourseModeSerializer(many=True, allow_add_remove=True) def validate(self, attrs): """ Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """ @@ -73,7 +68,7 @@ class CourseSerializer(serializers.Serializer): # Find the earliest upgrade deadline for mode in attrs['modes']: - expires = mode.get("expiration_datetime") + expires = mode.expiration_datetime if expires: # If we don't already have an upgrade_deadline value, use datetime.max so that we can actually # complete the comparison. @@ -87,28 +82,9 @@ class CourseSerializer(serializers.Serializer): return attrs - def create(self, validated_data): - """Create course modes for a course. """ - course = Course( - validated_data["id"], - self._new_course_mode_models(validated_data["modes"]), - verification_deadline=validated_data["verification_deadline"] - ) - course.save() - return course + def restore_object(self, attrs, instance=None): + if instance is None: + return Course(attrs['id'], attrs['modes'], attrs['verification_deadline']) - def update(self, instance, validated_data): - """Update course modes for an existing course. """ - validated_data["modes"] = self._new_course_mode_models(validated_data["modes"]) - - instance.update(validated_data) - instance.save() + instance.update(attrs) return instance - - @staticmethod - def _new_course_mode_models(modes_data): - """Convert validated course mode data to CourseMode objects. """ - return [ - CourseMode(**modes_dict) - for modes_dict in modes_data - ] diff --git a/lms/djangoapps/commerce/api/v1/views.py b/lms/djangoapps/commerce/api/v1/views.py index 1c1b32ccad..34b58c7915 100644 --- a/lms/djangoapps/commerce/api/v1/views.py +++ b/lms/djangoapps/commerce/api/v1/views.py @@ -2,8 +2,7 @@ import logging from django.http import Http404 -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView from rest_framework.permissions import IsAuthenticated @@ -11,7 +10,6 @@ from commerce.api.v1.models import Course from commerce.api.v1.permissions import ApiKeyOrModelPermission from commerce.api.v1.serializers import CourseSerializer from course_modes.models import CourseMode -from openedx.core.lib.api.mixins import PutAsCreateMixin log = logging.getLogger(__name__) @@ -21,13 +19,12 @@ class CourseListView(ListAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication,) permission_classes = (IsAuthenticated,) serializer_class = CourseSerializer - pagination_class = None def get_queryset(self): - return list(Course.iterator()) + return Course.iterator() -class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView): +class CourseRetrieveUpdateView(RetrieveUpdateAPIView): """ Retrieve, update, or create courses/modes. """ lookup_field = 'id' lookup_url_kwarg = 'course_id' @@ -36,11 +33,6 @@ class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView): permission_classes = (ApiKeyOrModelPermission,) serializer_class = CourseSerializer - # Django Rest Framework v3 requires that we provide a queryset. - # Note that we're overriding `get_object()` below to return a `Course` - # rather than a CourseMode, so this isn't really used. - queryset = CourseMode.objects.all() - def get_object(self, queryset=None): course_id = self.kwargs.get(self.lookup_url_kwarg) course = Course.get(course_id) diff --git a/lms/djangoapps/course_structure_api/v0/serializers.py b/lms/djangoapps/course_structure_api/v0/serializers.py index bd079544e3..ad95b8e985 100644 --- a/lms/djangoapps/course_structure_api/v0/serializers.py +++ b/lms/djangoapps/course_structure_api/v0/serializers.py @@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer): id = serializers.CharField() # pylint: disable=invalid-name name = serializers.CharField(source='display_name') category = serializers.CharField() - org = serializers.SerializerMethodField() - run = serializers.SerializerMethodField() - course = serializers.SerializerMethodField() - uri = serializers.SerializerMethodField() - image_url = serializers.SerializerMethodField() + org = serializers.SerializerMethodField('get_org') + run = serializers.SerializerMethodField('get_run') + course = serializers.SerializerMethodField('get_course') + uri = serializers.SerializerMethodField('get_uri') + image_url = serializers.SerializerMethodField('get_image_url') start = serializers.DateTimeField() end = serializers.DateTimeField() diff --git a/lms/djangoapps/course_structure_api/v0/tests.py b/lms/djangoapps/course_structure_api/v0/tests.py index 08224e521a..c74f76f601 100644 --- a/lms/djangoapps/course_structure_api/v0/tests.py +++ b/lms/djangoapps/course_structure_api/v0/tests.py @@ -36,23 +36,6 @@ class CourseViewTestsMixin(object): """ view = None - raw_grader = [ - { - "min_count": 24, - "weight": 0.2, - "type": "Homework", - "drop_count": 0, - "short_label": "HW" - }, - { - "min_count": 4, - "weight": 0.8, - "type": "Exam", - "drop_count": 0, - "short_label": "Exam" - } - ] - def setUp(self): super(CourseViewTestsMixin, self).setUp() self.create_user_and_access_token() @@ -68,7 +51,22 @@ class CourseViewTestsMixin(object): @classmethod def create_course_data(cls): cls.invalid_course_id = 'foo/bar/baz' - cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=cls.raw_grader) + cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=[ + { + "min_count": 24, + "weight": 0.2, + "type": "Homework", + "drop_count": 0, + "short_label": "HW" + }, + { + "min_count": 4, + "weight": 0.8, + "type": "Exam", + "drop_count": 0, + "short_label": "Exam" + } + ]) cls.course_id = unicode(cls.course.id) with cls.store.bulk_operations(cls.course.id, emit_signals=False): cls.sequential = ItemFactory.create( @@ -410,55 +408,6 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar self.assertListEqual(response.data, expected) -class CourseGradingPolicyMissingFieldsTests(CourseDetailTestMixin, CourseViewTestsMixin, SharedModuleStoreTestCase): - view = 'course_structure_api:v0:grading_policy' - - # Update the raw grader to have missing keys - raw_grader = [ - { - "min_count": 24, - "weight": 0.2, - "type": "Homework", - "drop_count": 0, - "short_label": "HW" - }, - { - # Deleted "min_count" key - "weight": 0.8, - "type": "Exam", - "drop_count": 0, - "short_label": "Exam" - } - ] - - @classmethod - def setUpClass(cls): - super(CourseGradingPolicyMissingFieldsTests, cls).setUpClass() - cls.create_course_data() - - def test_get(self): - """ - The view should return grading policy for a course. - """ - response = super(CourseGradingPolicyMissingFieldsTests, self).test_get() - - expected = [ - { - "count": 24, - "weight": 0.2, - "assignment_type": "Homework", - "dropped": 0 - }, - { - "count": None, - "weight": 0.8, - "assignment_type": "Exam", - "dropped": 0 - } - ] - self.assertListEqual(response.data, expected) - - ##################################################################################### # # The following Mixins/Classes collectively test the CourseBlocksAndNavigation view. diff --git a/lms/djangoapps/course_structure_api/v0/views.py b/lms/djangoapps/course_structure_api/v0/views.py index 02986c3934..347ba45c24 100644 --- a/lms/djangoapps/course_structure_api/v0/views.py +++ b/lms/djangoapps/course_structure_api/v0/views.py @@ -6,8 +6,7 @@ import logging from django.conf import settings from django.http import Http404 -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.exceptions import AuthenticationFailed, ParseError from rest_framework.generics import RetrieveAPIView, ListAPIView from rest_framework.permissions import IsAuthenticated @@ -22,6 +21,7 @@ from courseware.access import has_access from courseware.model_data import FieldDataCache from courseware.module_render import get_module_for_descriptor from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes +from openedx.core.lib.api.serializers import PaginationSerializer from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors from student.roles import CourseInstructorRole, CourseStaffRole from util.module_utils import get_dynamic_descriptor_children @@ -157,6 +157,9 @@ class CourseList(CourseViewMixin, ListAPIView): * end: The course end date. If course end date is not specified, the value is null. """ + paginate_by = 10 + paginate_by_param = 'page_size' + pagination_serializer_class = PaginationSerializer serializer_class = serializers.CourseSerializer def get_queryset(self): diff --git a/lms/djangoapps/courseware/grades.py b/lms/djangoapps/courseware/grades.py index 008e0644e4..5b1a9741d5 100644 --- a/lms/djangoapps/courseware/grades.py +++ b/lms/djangoapps/courseware/grades.py @@ -25,6 +25,7 @@ from xmodule.modulestore.django import modulestore from xmodule.modulestore.exceptions import ItemNotFoundError from .models import StudentModule from .module_render import get_module_for_descriptor +from submissions import api as sub_api # installed from the edx-submissions repository from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey from openedx.core.djangoapps.signals.signals import GRADES_UPDATED @@ -348,13 +349,8 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c # Dict of item_ids -> (earned, possible) point tuples. This *only* grabs # scores that were registered with the submissions API, which for the moment # means only openassessment (edx-ora2) - # We need to import this here to avoid a circular dependency of the form: - # XBlock --> submissions --> Django Rest Framework error strings --> - # Django translation --> ... --> courseware --> submissions - from submissions import api as sub_api # installed from the edx-submissions repository submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) max_scores_cache = MaxScoresCache.create_for_course(course) - # For the moment, we have to get scorable_locations from field_data_cache # and not from scores_client, because scores_client is ignorant of things # in the submissions API. As a further refactoring step, submissions should @@ -569,12 +565,7 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl course_module = getattr(course_module, '_x_module', course_module) - # We need to import this here to avoid a circular dependency of the form: - # XBlock --> submissions --> Django Rest Framework error strings --> - # Django translation --> ... --> courseware --> submissions - from submissions import api as sub_api # installed from the edx-submissions repository submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) - max_scores_cache = MaxScoresCache.create_for_course(course) # For the moment, we have to get scorable_locations from field_data_cache # and not from scores_client, because scores_client is ignorant of things diff --git a/lms/djangoapps/discussion_api/api.py b/lms/djangoapps/discussion_api/api.py index 86d4d4ca8b..d8153ff486 100644 --- a/lms/djangoapps/discussion_api/api.py +++ b/lms/djangoapps/discussion_api/api.py @@ -560,7 +560,7 @@ def create_thread(request, thread_data): if not (serializer.is_valid() and actions_form.is_valid()): raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) serializer.save() - cc_thread = serializer.instance + cc_thread = serializer.object thread_created.send(sender=None, user=user, post=cc_thread) api_thread = serializer.data _do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context) @@ -606,7 +606,7 @@ def create_comment(request, comment_data): if not (serializer.is_valid() and actions_form.is_valid()): raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) serializer.save() - cc_comment = serializer.instance + cc_comment = serializer.object comment_created.send(sender=None, user=request.user, post=cc_comment) api_comment = serializer.data _do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context) diff --git a/lms/djangoapps/discussion_api/pagination.py b/lms/djangoapps/discussion_api/pagination.py index e818c5da19..5f14b430e8 100644 --- a/lms/djangoapps/discussion_api/pagination.py +++ b/lms/djangoapps/discussion_api/pagination.py @@ -1,7 +1,16 @@ """ Discussion API pagination support """ -from rest_framework.utils.urls import replace_query_param +from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField + + +class _PaginationSerializer(BasePaginationSerializer): + """ + A pagination serializer without the count field, because the Comments + Service does not return result counts + """ + next = NextPageField(source="*") + previous = PreviousPageField(source="*") class _Page(object): @@ -43,25 +52,7 @@ def get_paginated_data(request, results, page_num, per_page): previous: The URL for the previous page results: The results on this page """ - # Note: Previous versions of this function used Django Rest Framework's - # paginated serializer. With the upgrade to DRF 3.1, paginated serializers - # have been removed. We *could* use DRF's paginator classes, but there are - # some slight differences between how DRF does pagination and how we're doing - # pagination here. (For example, we respond with a next_url param even if - # there is only one result on the current page.) To maintain backwards - # compatability, we simulate the behavior that DRF used to provide. - page = _Page(results, page_num, per_page) - next_url, previous_url = None, None - base_url = request.build_absolute_uri() - - if page.has_next(): - next_url = replace_query_param(base_url, "page", page.next_page_number()) - - if page.has_previous(): - previous_url = replace_query_param(base_url, "page", page.previous_page_number()) - - return { - "next": next_url, - "previous": previous_url, - "results": results, - } + return _PaginationSerializer( + instance=_Page(results, page_num, per_page), + context={"request": request} + ).data diff --git a/lms/djangoapps/discussion_api/serializers.py b/lms/djangoapps/discussion_api/serializers.py index e65716ea25..56dc9495bb 100644 --- a/lms/djangoapps/discussion_api/serializers.py +++ b/lms/djangoapps/discussion_api/serializers.py @@ -28,6 +28,7 @@ from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.user import User as CommentClientUser from lms.lib.comment_client.utils import CommentClientRequestError from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names +from openedx.core.lib.api.fields import NonEmptyCharField def get_context(course, request, thread=None): @@ -65,43 +66,36 @@ def get_context(course, request, thread=None): } -def validate_not_blank(value): - """ - Validate that a value is not an empty string or whitespace. - - Raises: ValidationError - """ - if not value.strip(): - raise ValidationError("This field may not be blank.") - - class _ContentSerializer(serializers.Serializer): """A base class for thread and comment serializers.""" - id = serializers.CharField(read_only=True) # pylint: disable=invalid-name - author = serializers.SerializerMethodField() - author_label = serializers.SerializerMethodField() + id_ = serializers.CharField(read_only=True) + author = serializers.SerializerMethodField("get_author") + author_label = serializers.SerializerMethodField("get_author_label") created_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True) - raw_body = serializers.CharField(source="body", validators=[validate_not_blank]) - rendered_body = serializers.SerializerMethodField() - abuse_flagged = serializers.SerializerMethodField() - voted = serializers.SerializerMethodField() - vote_count = serializers.SerializerMethodField() - editable_fields = serializers.SerializerMethodField() + raw_body = NonEmptyCharField(source="body") + rendered_body = serializers.SerializerMethodField("get_rendered_body") + abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") + voted = serializers.SerializerMethodField("get_voted") + vote_count = serializers.SerializerMethodField("get_vote_count") + editable_fields = serializers.SerializerMethodField("get_editable_fields") non_updatable_fields = set() def __init__(self, *args, **kwargs): super(_ContentSerializer, self).__init__(*args, **kwargs) + # id is an invalid class attribute name, so we must declare a different + # name above and modify it here + self.fields["id"] = self.fields.pop("id_") for field in self.non_updatable_fields: setattr(self, "validate_{}".format(field), self._validate_non_updatable) - def _validate_non_updatable(self, value): + def _validate_non_updatable(self, attrs, _source): """Ensure that a field is not edited in an update operation.""" - if self.instance: + if self.object: raise ValidationError("This field is not allowed in an update.") - return value + return attrs def _is_user_privileged(self, user_id): """ @@ -137,11 +131,7 @@ class _ContentSerializer(serializers.Serializer): def get_author_label(self, obj): """Returns the role label for the content author.""" - if self._is_anonymous(obj) or obj["user_id"] is None: - return None - else: - user_id = int(obj["user_id"]) - return self._get_user_label(user_id) + return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"])) def get_rendered_body(self, obj): """Returns the rendered body content.""" @@ -152,7 +142,7 @@ class _ContentSerializer(serializers.Serializer): Returns a boolean indicating whether the requester has flagged the content as abusive. """ - return self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", []) + return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] def get_voted(self, obj): """ @@ -163,7 +153,7 @@ class _ContentSerializer(serializers.Serializer): def get_vote_count(self, obj): """Returns the number of votes for the content.""" - return obj.get("votes", {}).get("up_count", 0) + return obj["votes"]["up_count"] def get_editable_fields(self, obj): """Return the list of the fields the requester can edit""" @@ -179,28 +169,28 @@ class ThreadSerializer(_ContentSerializer): at introspection and Thread's __getattr__. """ course_id = serializers.CharField() - topic_id = serializers.CharField(source="commentable_id", validators=[validate_not_blank]) - group_id = serializers.IntegerField(required=False, allow_null=True) - group_name = serializers.SerializerMethodField() - type = serializers.ChoiceField( + topic_id = NonEmptyCharField(source="commentable_id") + group_id = serializers.IntegerField(required=False) + group_name = serializers.SerializerMethodField("get_group_name") + type_ = serializers.ChoiceField( source="thread_type", choices=[(val, val) for val in ["discussion", "question"]] ) - title = serializers.CharField(validators=[validate_not_blank]) - pinned = serializers.SerializerMethodField(read_only=True) + title = NonEmptyCharField() + pinned = serializers.BooleanField(read_only=True) closed = serializers.BooleanField(read_only=True) - following = serializers.SerializerMethodField() + following = serializers.SerializerMethodField("get_following") comment_count = serializers.IntegerField(source="comments_count", read_only=True) unread_comment_count = serializers.IntegerField(source="unread_comments_count", read_only=True) - comment_list_url = serializers.SerializerMethodField() - endorsed_comment_list_url = serializers.SerializerMethodField() - non_endorsed_comment_list_url = serializers.SerializerMethodField() + comment_list_url = serializers.SerializerMethodField("get_comment_list_url") + endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url") + non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url") read = serializers.BooleanField(read_only=True) has_endorsed = serializers.BooleanField(read_only=True, source="endorsed") response_count = serializers.IntegerField(source="resp_total", read_only=True) non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS - + # TODO: https://openedx.atlassian.net/browse/MA-1359 def __init__(self, *args, **kwargs): remove_fields = kwargs.pop('remove_fields', None) @@ -212,13 +202,6 @@ class ThreadSerializer(_ContentSerializer): # not have the pinned field set if self.object and self.object.get("pinned") is None: self.object["pinned"] = False - - def get_pinned(self, obj): - """ - Compensate for the fact that some threads in the comments service do - not have the pinned field set. - """ - return bool(obj["pinned"]) if remove_fields: # for multiple fields in a list @@ -262,16 +245,13 @@ class ThreadSerializer(_ContentSerializer): """Returns the URL to retrieve the thread's non-endorsed comments.""" return self.get_comment_list_url(obj, endorsed=False) - def create(self, validated_data): - thread = Thread(user_id=self.context["cc_requester"]["id"], **validated_data) - thread.save() - return thread - - def update(self, instance, validated_data): - for key, val in validated_data.items(): - instance[key] = val - instance.save() - return instance + def restore_object(self, attrs, instance=None): + if instance: + for key, val in attrs.items(): + instance[key] = val + return instance + else: + return Thread(user_id=self.context["cc_requester"]["id"], **attrs) class CommentSerializer(_ContentSerializer): @@ -283,12 +263,12 @@ class CommentSerializer(_ContentSerializer): at introspection and Comment's __getattr__. """ thread_id = serializers.CharField() - parent_id = serializers.CharField(required=False, allow_null=True) + parent_id = serializers.CharField(required=False) endorsed = serializers.BooleanField(required=False) - endorsed_by = serializers.SerializerMethodField() - endorsed_by_label = serializers.SerializerMethodField() - endorsed_at = serializers.SerializerMethodField() - children = serializers.SerializerMethodField() + endorsed_by = serializers.SerializerMethodField("get_endorsed_by") + endorsed_by_label = serializers.SerializerMethodField("get_endorsed_by_label") + endorsed_at = serializers.SerializerMethodField("get_endorsed_at") + children = serializers.SerializerMethodField("get_children") non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS @@ -331,17 +311,6 @@ class CommentSerializer(_ContentSerializer): for child in obj.get("children", []) ] - def to_representation(self, data): - data = super(CommentSerializer, self).to_representation(data) - - # Django Rest Framework v3 no longer includes None values - # in the representation. To maintain the previous behavior, - # we do this manually instead. - if 'parent_id' not in data: - data["parent_id"] = None - - return data - def validate(self, attrs): """ Ensure that parent_id identifies a comment that is actually in the @@ -363,23 +332,18 @@ class CommentSerializer(_ContentSerializer): raise ValidationError({"parent_id": ["Comment level is too deep."]}) return attrs - def create(self, validated_data): - comment = Comment( + def restore_object(self, attrs, instance=None): + if instance: + for key, val in attrs.items(): + instance[key] = val + # TODO: The comments service doesn't populate the endorsement + # field on comment creation, so we only provide + # endorsement_user_id on update + if key == "endorsed": + instance["endorsement_user_id"] = self.context["cc_requester"]["id"] + return instance + return Comment( course_id=self.context["thread"]["course_id"], user_id=self.context["cc_requester"]["id"], - **validated_data + **attrs ) - comment.save() - return comment - - def update(self, instance, validated_data): - for key, val in validated_data.items(): - instance[key] = val - # TODO: The comments service doesn't populate the endorsement - # field on comment creation, so we only provide - # endorsement_user_id on update - if key == "endorsed": - instance["endorsement_user_id"] = self.context["cc_requester"]["id"] - - instance.save() - return instance diff --git a/lms/djangoapps/discussion_api/tests/test_api.py b/lms/djangoapps/discussion_api/tests/test_api.py index e61a1202b6..a2df7abd36 100644 --- a/lms/djangoapps/discussion_api/tests/test_api.py +++ b/lms/djangoapps/discussion_api/tests/test_api.py @@ -1497,9 +1497,8 @@ class CreateThreadTest( self.assertEqual(actual_post_data["group_id"], [str(cohort.id)]) else: self.assertNotIn("group_id", actual_post_data) - except ValidationError as ex: - if not expected_error: - self.fail("Unexpected validation error: {}".format(ex)) + except ValidationError: + self.assertTrue(expected_error) def test_following(self): self.register_post_thread_response({"id": "test_id"}) @@ -2240,7 +2239,7 @@ class UpdateThreadTest( update_thread(self.request, "test_thread", {"raw_body": ""}) self.assertEqual( assertion.exception.message_dict, - {"raw_body": ["This field may not be blank."]} + {"raw_body": ["This field is required."]} ) diff --git a/lms/djangoapps/discussion_api/tests/test_serializers.py b/lms/djangoapps/discussion_api/tests/test_serializers.py index 365c7fbdca..985c94fcc5 100644 --- a/lms/djangoapps/discussion_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion_api/tests/test_serializers.py @@ -523,10 +523,9 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi data = self.minimal_data.copy() data.update({field: value for field in ["topic_id", "title", "raw_body"]}) serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request)) - self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]} + {field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} ) def test_create_type(self): @@ -593,10 +592,9 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi partial=True, context=get_context(self.course, self.request) ) - self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]} + {field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} ) def test_update_course_id(self): @@ -606,7 +604,6 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi partial=True, context=get_context(self.course, self.request) ) - self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, {"course_id": ["This field is not allowed in an update."]} @@ -772,7 +769,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul data["parent_id"] = None serializer = CommentSerializer(data=data, context=context) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {"non_field_errors": ["Comment level is too deep."]}) + self.assertEqual(serializer.errors, {"parent_id": ["Comment level is too deep."]}) def test_create_missing_field(self): for field in self.minimal_data: @@ -858,10 +855,9 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul partial=True, context=get_context(self.course, self.request) ) - self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, - {"raw_body": ["This field may not be blank."]} + {"raw_body": ["This field is required."]} ) @ddt.data("thread_id", "parent_id") @@ -872,7 +868,6 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul partial=True, context=get_context(self.course, self.request) ) - self.assertFalse(serializer.is_valid()) self.assertEqual( serializer.errors, {field: ["This field is not allowed in an update."]} diff --git a/lms/djangoapps/discussion_api/tests/test_views.py b/lms/djangoapps/discussion_api/tests/test_views.py index 0809df43e2..ce1df0a85a 100644 --- a/lms/djangoapps/discussion_api/tests/test_views.py +++ b/lms/djangoapps/discussion_api/tests/test_views.py @@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest content_type="application/json" ) expected_response_data = { - "field_errors": {"title": {"developer_message": "This field may not be blank."}} + "field_errors": {"title": {"developer_message": "This field is required."}} } self.assertEqual(response.status_code, 400) response_data = json.loads(response.content) @@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): 200, { "results": expected_comments, - "next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format( + "next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format( self.thread_id ), "previous": None, @@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes content_type="application/json" ) expected_response_data = { - "field_errors": {"raw_body": {"developer_message": "This field may not be blank."}} + "field_errors": {"raw_body": {"developer_message": "This field is required."}} } self.assertEqual(response.status_code, 400) response_data = json.loads(response.content) diff --git a/lms/djangoapps/discussion_api/views.py b/lms/djangoapps/discussion_api/views.py index 116fc32301..2072e7b5db 100644 --- a/lms/djangoapps/discussion_api/views.py +++ b/lms/djangoapps/discussion_api/views.py @@ -3,8 +3,7 @@ Discussion API views """ from django.core.exceptions import ValidationError -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView diff --git a/lms/djangoapps/django_comment_client/base/tests.py b/lms/djangoapps/django_comment_client/base/tests.py index 3e0dfd8129..568b866ab4 100644 --- a/lms/djangoapps/django_comment_client/base/tests.py +++ b/lms/djangoapps/django_comment_client/base/tests.py @@ -32,6 +32,8 @@ from xmodule.modulestore.tests.factories import check_mongo_calls from xmodule.modulestore.django import modulestore from xmodule.modulestore import ModuleStoreEnum +from teams.tests.factories import CourseTeamFactory + log = logging.getLogger(__name__) @@ -1288,7 +1290,6 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe topic_id='topic_id', discussion_topic_id=self.team_commentable_id ) - self.team.add_user(self.student_in_team) # Dummy commentable ID not linked to a team diff --git a/lms/djangoapps/django_comment_client/forum/tests.py b/lms/djangoapps/django_comment_client/forum/tests.py index eabeba998a..54b5bb3114 100644 --- a/lms/djangoapps/django_comment_client/forum/tests.py +++ b/lms/djangoapps/django_comment_client/forum/tests.py @@ -717,7 +717,6 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase): topic_id='topic_id', discussion_topic_id=self.discussion_topic_id ) - self.team.add_user(self.user) # pylint: disable=no-member def test_context_can_be_standalone(self, mock_request): @@ -1094,9 +1093,7 @@ class InlineDiscussionTestCase(ModuleStoreTestCase): course_id=self.course.id, discussion_topic_id=self.discussion1.discussion_id ) - team.add_user(self.student) # pylint: disable=no-member - response = self.send_request(mock_request) self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE) self.verify_response(response) diff --git a/lms/djangoapps/django_comment_client/tests/group_id.py b/lms/djangoapps/django_comment_client/tests/group_id.py index 374dc534cd..2773b5ee99 100644 --- a/lms/djangoapps/django_comment_client/tests/group_id.py +++ b/lms/djangoapps/django_comment_client/tests/group_id.py @@ -149,8 +149,6 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin): def test_team_discussion_id_not_cohorted(self, mock_request): team = CourseTeamFactory(course_id=self.course.id) - team.add_user(self.student) # pylint: disable=no-member self.call_view(mock_request, team.discussion_topic_id, self.student, None) - self._assert_comments_service_called_without_group_id(mock_request) diff --git a/lms/djangoapps/mobile_api/social_facebook/courses/views.py b/lms/djangoapps/mobile_api/social_facebook/courses/views.py index b4f1a23aeb..584e879db4 100644 --- a/lms/djangoapps/mobile_api/social_facebook/courses/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/courses/views.py @@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView): serializer_class = serializers.CoursesWithFriendsSerializer def list(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.GET) + serializer = self.get_serializer(data=request.GET, files=request.FILES) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -61,5 +61,4 @@ class CoursesWithFriends(generics.ListAPIView): and is_mobile_available_for_user(self.request.user, enrollment.course) ] - serializer = CourseEnrollmentSerializer(courses, context={'request': request}, many=True) - return Response(serializer.data) + return Response(CourseEnrollmentSerializer(courses, context={'request': request}).data) diff --git a/lms/djangoapps/mobile_api/social_facebook/friends/views.py b/lms/djangoapps/mobile_api/social_facebook/friends/views.py index 7e1904491d..07a8c862fe 100644 --- a/lms/djangoapps/mobile_api/social_facebook/friends/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/friends/views.py @@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView): serializer_class = serializers.FriendsInCourseSerializer def list(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.GET) + serializer = self.get_serializer(data=request.GET, files=request.FILES) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/lms/djangoapps/mobile_api/social_facebook/groups/views.py b/lms/djangoapps/mobile_api/social_facebook/groups/views.py index 620a482776..463e466aad 100644 --- a/lms/djangoapps/mobile_api/social_facebook/groups/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/groups/views.py @@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin): serializer_class = serializers.GroupSerializer def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA) + serializer = self.get_serializer(data=request.DATA, files=request.FILES) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) try: @@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin): serializer_class = serializers.GroupsMembersSerializer def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA) + serializer = self.get_serializer(data=request.DATA, files=request.FILES) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) graph = facebook_graph_api() url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members" - member_ids = serializer.data['member_ids'].split(',') + member_ids = serializer.object['member_ids'].split(',') response = {} for member_id in member_ids: try: diff --git a/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py b/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py index 1d51f87eda..12eb03f0cb 100644 --- a/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py +++ b/lms/djangoapps/mobile_api/social_facebook/preferences/serializers.py @@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer): """ Serializes user social settings """ - share_with_facebook_friends = serializers.BooleanField(required=True) + share_with_facebook_friends = serializers.BooleanField(required=True, default=False) diff --git a/lms/djangoapps/mobile_api/social_facebook/preferences/views.py b/lms/djangoapps/mobile_api/social_facebook/preferences/views.py index 7495ba30f2..c48901e67d 100644 --- a/lms/djangoapps/mobile_api/social_facebook/preferences/views.py +++ b/lms/djangoapps/mobile_api/social_facebook/preferences/views.py @@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView): serializer_class = serializers.UserSharingSerializar def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA) + serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): - value = serializer.data['share_with_facebook_friends'] + value = serializer.object['share_with_facebook_friends'] set_user_preference(request.user, "share_with_facebook_friends", value) return self.get(request, *args, **kwargs) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/lms/djangoapps/mobile_api/social_facebook/utils.py b/lms/djangoapps/mobile_api/social_facebook/utils.py index 654dd97edd..daa379edc6 100644 --- a/lms/djangoapps/mobile_api/social_facebook/utils.py +++ b/lms/djangoapps/mobile_api/social_facebook/utils.py @@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer): the error message. """ try: - graph = facebook.GraphAPI(serializer.data['oauth_token']) + graph = facebook.GraphAPI(serializer.object['oauth_token']) friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends") return get_pagination(friends) except facebook.GraphAPIError, ex: diff --git a/lms/djangoapps/mobile_api/users/serializers.py b/lms/djangoapps/mobile_api/users/serializers.py index cb2d0a2782..2be28e1ae3 100644 --- a/lms/djangoapps/mobile_api/users/serializers.py +++ b/lms/djangoapps/mobile_api/users/serializers.py @@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE class CourseOverviewField(serializers.RelatedField): """Custom field to wrap a CourseDescriptor object. Read-only.""" - def to_representation(self, course_overview): + def to_native(self, course_overview): course_id = unicode(course_overview.id) request = self.context.get('request', None) if request: @@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): """ Serializes CourseEnrollment models """ - course = CourseOverviewField(source="course_overview", read_only=True) - certificate = serializers.SerializerMethodField() + course = CourseOverviewField(source="course_overview") + certificate = serializers.SerializerMethodField('get_certificate') def get_certificate(self, model): """Returns the information about the user's certificate in the course.""" @@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer): """ Serializes User models """ - name = serializers.ReadOnlyField(source='profile.name') + name = serializers.Field(source='profile.name') course_enrollments = serializers.HyperlinkedIdentityField( view_name='courseenrollment-detail', lookup_field='username' diff --git a/lms/djangoapps/mobile_api/users/views.py b/lms/djangoapps/mobile_api/users/views.py index 95a5e6eb59..c73084195b 100644 --- a/lms/djangoapps/mobile_api/users/views.py +++ b/lms/djangoapps/mobile_api/users/views.py @@ -251,14 +251,6 @@ class UserCourseEnrollmentsList(generics.ListAPIView): serializer_class = CourseEnrollmentSerializer lookup_field = 'username' - # In Django Rest Framework v3, there is a default pagination - # class that transmutes the response data into a dictionary - # with pagination information. The original response data (a list) - # is stored in a "results" value of the dictionary. - # For backwards compatibility with the existing API, we disable - # the default behavior by setting the pagination_class to None. - pagination_class = None - def get_queryset(self): enrollments = self.queryset.filter( user__username=self.kwargs['username'], diff --git a/lms/djangoapps/notifier_api/serializers.py b/lms/djangoapps/notifier_api/serializers.py index 719cc9c6f1..57832475d6 100644 --- a/lms/djangoapps/notifier_api/serializers.py +++ b/lms/djangoapps/notifier_api/serializers.py @@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer): * course_groups * roles__permissions """ - name = serializers.SerializerMethodField() - preferences = serializers.SerializerMethodField() - course_info = serializers.SerializerMethodField() + name = serializers.SerializerMethodField("get_name") + preferences = serializers.SerializerMethodField("get_preferences") + course_info = serializers.SerializerMethodField("get_course_info") def get_name(self, user): return user.profile.name diff --git a/lms/djangoapps/notifier_api/views.py b/lms/djangoapps/notifier_api/views.py index 44e1c78037..b8e05ebb16 100644 --- a/lms/djangoapps/notifier_api/views.py +++ b/lms/djangoapps/notifier_api/views.py @@ -1,32 +1,11 @@ from django.contrib.auth.models import User from rest_framework.viewsets import ReadOnlyModelViewSet -from rest_framework.response import Response -from rest_framework import pagination from notification_prefs import NOTIFICATION_PREF_KEY from notifier_api.serializers import NotifierUserSerializer from openedx.core.lib.api.permissions import ApiKeyHeaderPermission -class NotifierPaginator(pagination.PageNumberPagination): - """ - Paginator for the notifier API. - """ - page_size = 10 - page_size_query_param = "page_size" - - def get_paginated_response(self, data): - """ - Construct a response with pagination information. - """ - return Response({ - 'next': self.get_next_link(), - 'previous': self.get_previous_link(), - 'count': self.page.paginator.count, - 'results': data - }) - - class NotifierUsersViewSet(ReadOnlyModelViewSet): """ An endpoint that the notifier can use to retrieve users who have enabled @@ -35,7 +14,8 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet): """ permission_classes = (ApiKeyHeaderPermission,) serializer_class = NotifierUserSerializer - pagination_class = NotifierPaginator + paginate_by = 10 + paginate_by_param = "page_size" # See NotifierUserSerializer for notes about related tables queryset = User.objects.filter( diff --git a/lms/djangoapps/teams/models.py b/lms/djangoapps/teams/models.py index 18d38f4492..cbcb822554 100644 --- a/lms/djangoapps/teams/models.py +++ b/lms/djangoapps/teams/models.py @@ -3,6 +3,7 @@ from datetime import datetime from uuid import uuid4 import pytz +from datetime import datetime from model_utils import FieldTracker from django.core.exceptions import ObjectDoesNotExist diff --git a/lms/djangoapps/teams/search_indexes.py b/lms/djangoapps/teams/search_indexes.py index 6dfc070e81..588a2c7892 100644 --- a/lms/djangoapps/teams/search_indexes.py +++ b/lms/djangoapps/teams/search_indexes.py @@ -10,7 +10,6 @@ from django.utils import translation from functools import wraps from search.search_engine_base import SearchEngine -from request_cache import get_request_or_stub from .errors import ElasticSearchConnectionError from .serializers import CourseTeamSerializer, CourseTeam @@ -48,15 +47,7 @@ class CourseTeamIndexer(object): Returns serialized object with additional search fields. """ - # Django Rest Framework v3.1 requires that we pass the request to the serializer - # so it can construct hyperlinks. To avoid changing the interface of this object, - # we retrieve the request from the request cache. - context = { - "request": get_request_or_stub() - } - - serialized_course_team = CourseTeamSerializer(self.course_team, context=context).data - + serialized_course_team = CourseTeamSerializer(self.course_team).data # Save the primary key so we can load the full objects easily after we search serialized_course_team['pk'] = self.course_team.pk # Don't save the membership relations in elasticsearch diff --git a/lms/djangoapps/teams/serializers.py b/lms/djangoapps/teams/serializers.py index 8befd944bc..f51ad9e61c 100644 --- a/lms/djangoapps/teams/serializers.py +++ b/lms/djangoapps/teams/serializers.py @@ -4,44 +4,15 @@ from django.contrib.auth.models import User from django.db.models import Count from django.conf import settings -from django_countries import countries from rest_framework import serializers -from openedx.core.lib.api.serializers import CollapsedReferenceSerializer +from openedx.core.lib.api.serializers import CollapsedReferenceSerializer, PaginationSerializer from openedx.core.lib.api.fields import ExpandableField from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer from .models import CourseTeam, CourseTeamMembership -class CountryField(serializers.Field): - """ - Field to serialize a country code. - """ - - COUNTRY_CODES = dict(countries).keys() - - def to_representation(self, obj): - """ - Represent the country as a 2-character unicode identifier. - """ - return unicode(obj) - - def to_internal_value(self, data): - """ - Check that the code is a valid country code. - - We leave the data in its original format so that the Django model's - CountryField can convert it to the internal representation used - by the django-countries library. - """ - if data and data not in self.COUNTRY_CODES: - raise serializers.ValidationError( - u"{code} is not a valid country code".format(code=data) - ) - return data - - class UserMembershipSerializer(serializers.ModelSerializer): """Serializes CourseTeamMemberships with only user and date_joined @@ -72,7 +43,6 @@ class CourseTeamSerializer(serializers.ModelSerializer): """Serializes a CourseTeam with membership information.""" id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name membership = UserMembershipSerializer(many=True, read_only=True) - country = CountryField() class Meta(object): """Defines meta information for the ModelSerializer.""" @@ -96,8 +66,6 @@ class CourseTeamSerializer(serializers.ModelSerializer): class CourseTeamCreationSerializer(serializers.ModelSerializer): """Deserializes a CourseTeam for creation.""" - country = CountryField(required=False) - class Meta(object): """Defines meta information for the ModelSerializer.""" model = CourseTeam @@ -110,17 +78,16 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer): "language", ) - def create(self, validated_data): - team = CourseTeam.create( - name=validated_data.get("name", ''), - course_id=validated_data.get("course_id"), - description=validated_data.get("description", ''), - topic_id=validated_data.get("topic_id", ''), - country=validated_data.get("country", ''), - language=validated_data.get("language", ''), + def restore_object(self, attrs, instance=None): + """Restores a CourseTeam instance from the given attrs.""" + return CourseTeam.create( + name=attrs.get("name", ''), + course_id=attrs.get("course_id"), + description=attrs.get("description", ''), + topic_id=attrs.get("topic_id", ''), + country=attrs.get("country", ''), + language=attrs.get("language", ''), ) - team.save() - return team class CourseTeamSerializerWithoutMembership(CourseTeamSerializer): @@ -167,6 +134,13 @@ class MembershipSerializer(serializers.ModelSerializer): read_only_fields = ("date_joined", "last_activity_at") +class PaginatedMembershipSerializer(PaginationSerializer): + """Serializes team memberships with support for pagination.""" + class Meta(object): + """Defines meta information for the PaginatedMembershipSerializer.""" + object_serializer_class = MembershipSerializer + + class BaseTopicSerializer(serializers.Serializer): """Serializes a topic without team_count.""" description = serializers.CharField() @@ -181,7 +155,7 @@ class TopicSerializer(BaseTopicSerializer): model to get the count. Requires that `context` is provided with a valid course_id in order to filter teams within the course. """ - team_count = serializers.SerializerMethodField() + team_count = serializers.SerializerMethodField('get_team_count') def get_team_count(self, topic): """Get the number of teams associated with this topic""" @@ -192,25 +166,31 @@ class TopicSerializer(BaseTopicSerializer): return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count() -class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method +class PaginatedTopicSerializer(PaginationSerializer): """ - List serializer for efficiently serializing a set of topics. + Serializes a set of topics, adding the team_count field to each topic individually, if team_count + is not already present in the topic data. Requires that `context` is provided with a valid course_id in + order to filter teams within the course. """ - - def to_representation(self, obj): - """Adds team_count to each topic. """ - data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj) - add_team_count(data, self.context["course_id"]) - return data + class Meta(object): + """Defines meta information for the PaginatedTopicSerializer.""" + object_serializer_class = TopicSerializer -class BulkTeamCountTopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method +class BulkTeamCountPaginatedTopicSerializer(PaginationSerializer): """ - Serializes a set of topics, adding the team_count field to each topic as a bulk operation. - Requires that `context` is provided with a valid course_id in order to filter teams within the course. + Serializes a set of topics, adding the team_count field to each topic as a bulk operation per page + (only on the page being returned). Requires that `context` is provided with a valid course_id in + order to filter teams within the course. """ - class Meta: # pylint: disable=missing-docstring,old-style-class - list_serializer_class = BulkTeamCountTopicListSerializer + class Meta(object): + """Defines meta information for the BulkTeamCountPaginatedTopicSerializer.""" + object_serializer_class = BaseTopicSerializer + + def __init__(self, *args, **kwargs): + """Adds team_count to each topic on the current page.""" + super(BulkTeamCountPaginatedTopicSerializer, self).__init__(*args, **kwargs) + add_team_count(self.data['results'], self.context['course_id']) def add_team_count(topics, course_id): diff --git a/lms/djangoapps/teams/tests/factories.py b/lms/djangoapps/teams/tests/factories.py index 557e1a1c21..ee58e1ad69 100644 --- a/lms/djangoapps/teams/tests/factories.py +++ b/lms/djangoapps/teams/tests/factories.py @@ -34,10 +34,3 @@ class CourseTeamMembershipFactory(DjangoModelFactory): class Meta(object): # pylint: disable=missing-docstring model = CourseTeamMembership last_activity_at = LAST_ACTIVITY_AT - - @classmethod - def _create(cls, model_class, *args, **kwargs): - """Create the team membership. """ - obj = model_class(*args, **kwargs) - obj.save() - return obj diff --git a/lms/djangoapps/teams/tests/test_serializers.py b/lms/djangoapps/teams/tests/test_serializers.py index c30334ad2e..123b2f793b 100644 --- a/lms/djangoapps/teams/tests/test_serializers.py +++ b/lms/djangoapps/teams/tests/test_serializers.py @@ -11,9 +11,12 @@ from xmodule.modulestore.tests.factories import CourseFactory from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory from lms.djangoapps.teams.serializers import ( - BulkTeamCountTopicSerializer, + BaseTopicSerializer, + PaginatedTopicSerializer, + BulkTeamCountPaginatedTopicSerializer, TopicSerializer, MembershipSerializer, + add_team_count ) @@ -70,6 +73,21 @@ class MembershipSerializerTestCase(SerializerTestCase): self.assertNotIn('membership', data['team']) +class BaseTopicSerializerTestCase(SerializerTestCase): + """ + Tests for the `BaseTopicSerializer`, which should not serialize team count + data. + """ + def test_team_count_not_included(self): + """Verifies that the `BaseTopicSerializer` does not include team count""" + with self.assertNumQueries(0): + serializer = BaseTopicSerializer(self.course.teams_topics[0]) + self.assertEqual( + serializer.data, + {u'name': u'Tøpic', u'description': u'The bést topic!', u'id': u'0'} + ) + + class TopicSerializerTestCase(SerializerTestCase): """ Tests for the `TopicSerializer`, which should serialize team count data for @@ -119,7 +137,7 @@ class TopicSerializerTestCase(SerializerTestCase): ) -class BaseTopicSerializerTestCase(SerializerTestCase): +class BasePaginatedTopicSerializerTestCase(SerializerTestCase): """ Base class for testing the two paginated topic serializers. """ @@ -173,15 +191,13 @@ class BaseTopicSerializerTestCase(SerializerTestCase): self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0) -class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase): +class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): """ - Tests for the `BulkTeamCountTopicSerializer`, which should serialize team_count + Tests for the `BulkTeamCountPaginatedTopicSerializer`, which should serialize team_count data for many topics with constant time SQL queries. """ __test__ = True - serializer = BulkTeamCountTopicSerializer - - NUM_TOPICS = 6 + serializer = BulkTeamCountPaginatedTopicSerializer def test_topics_with_no_team_counts(self): """ @@ -206,13 +222,13 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase): one SQL query. """ teams_per_topic = 10 - topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic) - self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1) + topics = self.setup_topics(num_topics=self.PAGE_SIZE + 1, teams_per_topic=teams_per_topic) + self.assert_serializer_output(topics[:self.PAGE_SIZE], num_teams_per_topic=teams_per_topic, num_queries=1) def test_scoped_within_course(self): """Verify that team counts are scoped within a course.""" teams_per_topic = 10 - first_course_topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic) + first_course_topics = self.setup_topics(num_topics=self.PAGE_SIZE, teams_per_topic=teams_per_topic) duplicate_topic = first_course_topics[0] second_course = CourseFactory.create( teams_configuration={ @@ -223,44 +239,27 @@ class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase): CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id']) self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1) - def _merge_dicts(self, first, second): - """Convenience method to merge two dicts in a single expression""" - result = first.copy() - result.update(second) - return result - def setup_topics(self, num_topics=5, teams_per_topic=0): - """ - Helper method to set up topics on the course. Returns a list of - created topics. - """ - self.course.teams_configuration['topics'] = [] - topics = [ - {u'name': u'Tøpic {}'.format(i), u'description': u'The bést topic! {}'.format(i), u'id': unicode(i)} - for i in xrange(num_topics) - ] - for i in xrange(num_topics): - topic_id = unicode(i) - self.course.teams_configuration['topics'].append(topics[i]) - for _ in xrange(teams_per_topic): - CourseTeamFactory.create(course_id=self.course.id, topic_id=topic_id) - return topics +class PaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): + """ + Tests for the `PaginatedTopicSerializer`, which will add team_count information per topic if not present. + """ + __test__ = True + serializer = PaginatedTopicSerializer - def assert_serializer_output(self, topics, num_teams_per_topic, num_queries): + def test_topics_with_team_counts(self): """ - Verify that the serializer produced the expected topics. + Verify that we serialize topics with team_count, making one SQL query per topic. """ - with self.assertNumQueries(num_queries): - serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True) - self.assertEqual( - serializer.data, - [self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics] - ) + teams_per_topic = 2 + topics = self.setup_topics(teams_per_topic=teams_per_topic) + self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=5) - def test_no_topics(self): + def test_topics_with_team_counts_prepopulated(self): """ - Verify that we return no results and make no SQL queries for a page - with no topics. + Verify that if team_count is pre-populated, there are no additional SQL queries. """ - self.course.teams_configuration['topics'] = [] - self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0) + teams_per_topic = 8 + topics = self.setup_topics(teams_per_topic=teams_per_topic) + add_team_count(topics, self.course.id) + self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=0) diff --git a/lms/djangoapps/teams/tests/test_views.py b/lms/djangoapps/teams/tests/test_views.py index b84e63416f..cc7fcc23ed 100644 --- a/lms/djangoapps/teams/tests/test_views.py +++ b/lms/djangoapps/teams/tests/test_views.py @@ -1,31 +1,32 @@ # -*- coding: utf-8 -*- """Tests for the teams API at the HTTP request level.""" import json -from datetime import datetime - import pytz +from datetime import datetime from dateutil import parser import ddt from elasticsearch.exceptions import ConnectionError from mock import patch from search.search_engine_base import SearchEngine + from django.core.urlresolvers import reverse from django.conf import settings from django.db.models.signals import post_save from django.utils import translation from nose.plugins.attrib import attr from rest_framework.test import APITestCase, APIClient -from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory from courseware.tests.factories import StaffFactory from common.test.utils import skip_signal from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory from student.models import CourseEnrollment from util.testing import EventTestMixin +from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory from .factories import CourseTeamFactory, LAST_ACTIVITY_AT from ..models import CourseTeamMembership from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback + from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA from django_comment_common.utils import seed_permissions_roles @@ -35,23 +36,11 @@ class TestDashboard(SharedModuleStoreTestCase): """Tests for the Teams dashboard.""" test_password = "test" - NUM_TOPICS = 10 - @classmethod def setUpClass(cls): super(TestDashboard, cls).setUpClass() cls.course = CourseFactory.create( - teams_configuration={ - "max_team_size": 10, - "topics": [ - { - "name": "Topic {}".format(topic_id), - "id": topic_id, - "description": "Description for topic {}".format(topic_id) - } - for topic_id in range(cls.NUM_TOPICS) - ] - } + teams_configuration={"max_team_size": 10, "topics": [{"name": "foo", "id": 0, "description": "test topic"}]} ) def setUp(self): @@ -108,30 +97,6 @@ class TestDashboard(SharedModuleStoreTestCase): response = self.client.get(teams_url) self.assertEqual(404, response.status_code) - def test_query_counts(self): - # Enroll in the course and log in - CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) - self.client.login(username=self.user.username, password=self.test_password) - - # Check the query count on the dashboard With no teams - with self.assertNumQueries(15): - self.client.get(self.teams_url) - - # Create some teams - for topic_id in range(self.NUM_TOPICS): - team = CourseTeamFactory.create( - name=u"Team for topic {}".format(topic_id), - course_id=self.course.id, - topic_id=topic_id, - ) - - # Add the user to the last team - team.add_user(self.user) - - # Check the query count on the dashboard again - with self.assertNumQueries(19): - self.client.get(self.teams_url) - def test_bad_course_id(self): """ Verifies expected behavior when course_id does not reference an existing course or is invalid. @@ -287,9 +252,6 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): self.users[user], course.id, check_access=True ) - # Django Rest Framework v3 requires us to pass a request to serializers - # that have URL fields. Since we're invoking this code outside the context - # of a request, we need to simulate that there's a request. self.solar_team.add_user(self.users['student_enrolled']) self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team']) self.another_team.add_user(self.users['student_enrolled_both_courses_other_team']) @@ -349,17 +311,7 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): response = func(url, data=data, content_type=content_type) else: response = func(url, data=data) - - self.assertEqual( - expected_status, - response.status_code, - msg="Expected status {expected} but got {actual}: {content}".format( - expected=expected_status, - actual=response.status_code, - content=response.content, - ) - ) - + self.assertEqual(expected_status, response.status_code) if expected_status == 200: return json.loads(response.content) else: diff --git a/lms/djangoapps/teams/views.py b/lms/djangoapps/teams/views.py index 1b983e3e9c..6a6af4e5d9 100644 --- a/lms/djangoapps/teams/views.py +++ b/lms/djangoapps/teams/views.py @@ -5,14 +5,19 @@ import logging from django.shortcuts import get_object_or_404, render_to_response from django.http import Http404 from django.conf import settings +from django.core.paginator import Paginator +from django.views.generic.base import View from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.views import APIView -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework.authentication import ( + SessionAuthentication, + OAuth2Authentication +) from rest_framework import status from rest_framework import permissions +from django.db.models import Count from django.db.models.signals import post_save from django.dispatch import receiver from django.contrib.auth.models import User @@ -27,7 +32,8 @@ from openedx.core.lib.api.view_utils import ( build_api_error, ExpandableFieldViewMixin ) -from openedx.core.lib.api.paginators import paginate_search_results, DefaultPagination +from openedx.core.lib.api.serializers import PaginationSerializer +from openedx.core.lib.api.paginators import paginate_search_results from xmodule.modulestore.django import modulestore from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -43,8 +49,10 @@ from .serializers import ( CourseTeamSerializer, CourseTeamCreationSerializer, TopicSerializer, - BulkTeamCountTopicSerializer, + PaginatedTopicSerializer, + BulkTeamCountPaginatedTopicSerializer, MembershipSerializer, + PaginatedMembershipSerializer, add_team_count ) from .search_indexes import CourseTeamIndexer @@ -77,42 +85,7 @@ def team_post_save_callback(sender, instance, **kwargs): # pylint: disable=unus ) -class TeamAPIPagination(DefaultPagination): - """ - Pagination format used by the teams API. - """ - page_size_query_param = "page_size" - - def get_paginated_response(self, data): - """ - Annotate the response with pagination information. - """ - response = super(TeamAPIPagination, self).get_paginated_response(data) - - # Add the current page to the response. - # It may make sense to eventually move this field into the default - # implementation, but for now, teams is the only API that uses this. - response.data["current_page"] = self.page.number - - # This field can be derived from other fields in the response, - # so it may make sense to have the JavaScript client calculate it - # instead of including it in the response. - response.data["start"] = (self.page.number - 1) * self.get_page_size(self.request) - - return response - - -class TopicsPagination(TeamAPIPagination): - """Paginate topics. """ - page_size = TOPICS_PER_PAGE - - -class MembershipPagination(TeamAPIPagination): - """Paginate memberships. """ - page_size = TEAM_MEMBERSHIPS_PER_PAGE - - -class TeamsDashboardView(GenericAPIView): +class TeamsDashboardView(View): """ View methods related to the teams dashboard. """ @@ -134,38 +107,29 @@ class TeamsDashboardView(GenericAPIView): not has_access(request.user, 'staff', course, course.id): raise Http404 - user = request.user - # Even though sorting is done outside of the serializer, sort_order needs to be passed # to the serializer so that the paginated results indicate how they were sorted. sort_order = 'name' topics = get_alphabetical_topics(course) - - # Paginate and serialize topic data + topics_page = Paginator(topics, TOPICS_PER_PAGE).page(1) # BulkTeamCountPaginatedTopicSerializer will add team counts to the topics in a single # bulk operation per page. - topics_data = self._serialize_and_paginate( - TopicsPagination, - topics, - request, - BulkTeamCountTopicSerializer, - {'course_id': course.id}, + topics_serializer = BulkTeamCountPaginatedTopicSerializer( + instance=topics_page, + context={'course_id': course.id, 'sort_order': sort_order} ) - topics_data["sort_order"] = sort_order + user = request.user - # Paginate and serialize team membership data. - team_memberships = CourseTeamMembership.get_memberships(user.username, [course.id]) - memberships_data = self._serialize_and_paginate( - MembershipPagination, - team_memberships, - request, - MembershipSerializer, - {'expand': ('team', 'user',)} + team_memberships = CourseTeamMembership.get_memberships(request.user.username, [course.id]) + team_memberships_page = Paginator(team_memberships, TEAM_MEMBERSHIPS_PER_PAGE).page(1) + team_memberships_serializer = PaginatedMembershipSerializer( + instance=team_memberships_page, + context={'expand': ('team', 'user'), 'request': request}, ) context = { "course": course, - "topics": topics_data, + "topics": topics_serializer.data, # It is necessary to pass both privileged and staff because only privileged users can # administer discussion threads, but both privileged and staff users are allowed to create # multiple teams (since they are not automatically added to teams upon creation). @@ -173,7 +137,7 @@ class TeamsDashboardView(GenericAPIView): "username": user.username, "privileged": has_discussion_privileges(user, course_key), "staff": bool(has_access(user, 'staff', course_key)), - "team_memberships_data": memberships_data, + "team_memberships_data": team_memberships_serializer.data, }, "topic_url": reverse( 'topics_detail', kwargs={'topic_id': 'topic_id', 'course_id': str(course_id)}, request=request @@ -190,39 +154,6 @@ class TeamsDashboardView(GenericAPIView): } return render_to_response("teams/teams.html", context) - def _serialize_and_paginate(self, pagination_cls, queryset, request, serializer_cls, serializer_ctx): - """ - Serialize and paginate objects in a queryset. - - Arguments: - pagination_cls (pagination.Paginator class): Django Rest Framework Paginator subclass. - queryset (QuerySet): Django queryset to serialize/paginate. - serializer_cls (serializers.Serializer class): Django Rest Framework Serializer subclass. - serializer_ctx (dict): Context dictionary to pass to the serializer - - Returns: dict - - """ - # Django Rest Framework v3 requires that we pass the request - # into the serializer's context if the serialize contains - # hyperlink fields. - serializer_ctx["request"] = request - - # Instantiate the paginator and use it to paginate the queryset - paginator = pagination_cls() - page = paginator.paginate_queryset(queryset, request) - - # Serialize the page - serializer = serializer_cls(page, context=serializer_ctx, many=True) - - # Use the paginator to construct the response data - # This will use the pagination subclass for the view to add additional - # fields to the response. - # For example, if the input data is a list, the output data would - # be a dictionary with keys "count", "next", "previous", and "results" - # (where "results" is set to the value of the original list) - return paginator.get_paginated_response(serializer.data).data - def has_team_api_access(user, course_key, access_username=None): """Returns True if the user has access to the Team API for the course @@ -376,8 +307,11 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): # OAuth2Authentication must come first to return a 401 for unauthenticated users authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) + + paginate_by = 10 + paginate_by_param = 'page_size' + pagination_serializer_class = PaginationSerializer serializer_class = CourseTeamSerializer - pagination_class = TeamAPIPagination def get(self, request): """GET /api/team/v0/teams/""" @@ -443,18 +377,15 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): paginated_results = paginate_search_results( CourseTeam, search_results, - self.paginator.get_page_size(request), + self.get_paginate_by(), self.get_page() ) + serializer = self.get_pagination_serializer(paginated_results) emit_team_event('edx.team.searched', course_key, { "number_of_results": search_results['total'], "search_text": text_search, "topic_id": topic_id, }) - - page = self.paginate_queryset(paginated_results) - serializer = self.get_serializer(page, many=True) - order_by_input = None else: queryset = CourseTeam.objects.filter(**result_filter) order_by_input = request.QUERY_PARAMS.get('order_by', 'name') @@ -476,12 +407,10 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): }, status=status.HTTP_400_BAD_REQUEST) page = self.paginate_queryset(queryset) - serializer = self.get_serializer(page, many=True) + serializer = self.get_pagination_serializer(page) + serializer.context.update({'sort_order': order_by_input}) # pylint: disable=maybe-no-member - response = self.get_paginated_response(serializer.data) - if order_by_input is not None: - response.data['sort_order'] = order_by_input - return response + return Response(serializer.data) # pylint: disable=maybe-no-member def post(self, request): """POST /api/team/v0/teams/""" @@ -544,16 +473,14 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView): 'add_method': 'added_on_create' } ) - - data = CourseTeamSerializer(team, context={"request": request}).data - return Response(data) + return Response(CourseTeamSerializer(team).data) def get_page(self): """ Returns page number specified in args, params, or defaults to 1. """ # This code is taken from within the GenericAPIView#paginate_queryset method. # We need need access to the page outside of that method for our paginate_search_results method - page_kwarg = self.kwargs.get(self.paginator.page_query_param) - page_query_param = self.request.QUERY_PARAMS.get(self.paginator.page_query_param) + page_kwarg = self.kwargs.get(self.page_kwarg) + page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) return page_kwarg or page_query_param or 1 @@ -765,7 +692,9 @@ class TopicListView(GenericAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) - pagination_class = TopicsPagination + + paginate_by = TOPICS_PER_PAGE + paginate_by_param = 'page_size' def get(self, request): """GET /api/team/v0/topics/?course_id={course_id}""" @@ -811,20 +740,18 @@ class TopicListView(GenericAPIView): add_team_count(topics, course_id) topics.sort(key=lambda t: t['team_count'], reverse=True) page = self.paginate_queryset(topics) - serializer = TopicSerializer( - page, - context={'course_id': course_id}, - many=True, - ) + # Since team_count has already been added to all the topics, use PaginatedTopicSerializer. + # Even though sorting is done outside of the serializer, sort_order needs to be passed + # to the serializer so that the paginated results indicate how they were sorted. + serializer = PaginatedTopicSerializer(page, context={'course_id': course_id, 'sort_order': ordering}) else: page = self.paginate_queryset(topics) # Use the serializer that adds team_count in a bulk operation per page. - serializer = BulkTeamCountTopicSerializer(page, context={'course_id': course_id}, many=True) + serializer = BulkTeamCountPaginatedTopicSerializer( + page, context={'course_id': course_id, 'sort_order': ordering} + ) - response = self.get_paginated_response(serializer.data) - response.data['sort_order'] = ordering - - return response + return Response(serializer.data) def get_alphabetical_topics(course_module): @@ -1033,8 +960,13 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView): authentication_classes = (OAuth2Authentication, SessionAuthentication) permission_classes = (permissions.IsAuthenticated,) + serializer_class = MembershipSerializer + paginate_by = 10 + paginate_by_param = 'page_size' + pagination_serializer_class = PaginationSerializer + def get(self, request): """GET /api/team/v0/team_membership""" specified_username_or_team = False @@ -1091,8 +1023,8 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView): queryset = CourseTeamMembership.get_memberships(username, course_keys, team_id) page = self.paginate_queryset(queryset) - serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) + serializer = self.get_pagination_serializer(page) + return Response(serializer.data) # pylint: disable=maybe-no-member def post(self, request): """POST /api/team/v0/team_membership""" diff --git a/lms/envs/common.py b/lms/envs/common.py index 4ddb70cf3f..3040639818 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -1866,12 +1866,6 @@ INSTALLED_APPS = ( 'provider.oauth2', 'oauth2_provider', - # We don't use this directly (since we use OAuth2), but we need to install it anyway. - # When a user is deleted, Django queries all tables with a FK to the auth_user table, - # and since django-rest-framework-oauth imports this, it will try to access tables - # defined by oauth_provider. If those tables don't exist, an error can occur. - 'oauth_provider', - 'auth_exchange', # For the wiki @@ -1987,14 +1981,6 @@ INSTALLED_APPS = ( CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52 -######################### Django Rest Framework ######################## - -REST_FRAMEWORK = { - 'DEFAULT_PAGINATION_CLASS': 'openedx.core.lib.api.paginators.DefaultPagination', - 'PAGE_SIZE': 10, -} - - ######################### MARKETING SITE ############################### EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin' EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info' diff --git a/lms/startup.py b/lms/startup.py index 4db5f6c649..6869f90828 100644 --- a/lms/startup.py +++ b/lms/startup.py @@ -9,12 +9,16 @@ from django.conf import settings # Force settings to run so that the python path is modified settings.INSTALLED_APPS # pylint: disable=pointless-statement +from instructor.services import InstructorService + from openedx.core.lib.django_startup import autostartup import edxmako import logging from monkey_patch import django_utils_translation import analytics +from edx_proctoring.runtime import set_runtime_service +from openedx.core.djangoapps.credit.services import CreditService log = logging.getLogger(__name__) @@ -47,11 +51,6 @@ def run(): # right now edx_proctoring is dependent on the openedx.core.djangoapps.credit # as well as the instructor dashboard (for deleting student attempts) if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'): - # Import these here to avoid circular dependencies of the form: - # edx-platform app --> DRF --> django translation --> edx-platform app - from edx_proctoring.runtime import set_runtime_service - from instructor.services import InstructorService - from openedx.core.djangoapps.credit.services import CreditService set_runtime_service('credit', CreditService()) set_runtime_service('instructor', InstructorService()) diff --git a/openedx/core/djangoapps/content/course_structures/api/v0/api.py b/openedx/core/djangoapps/content/course_structures/api/v0/api.py index 016b613782..ac46603eff 100644 --- a/openedx/core/djangoapps/content/course_structures/api/v0/api.py +++ b/openedx/core/djangoapps/content/course_structures/api/v0/api.py @@ -124,4 +124,4 @@ def course_grading_policy(course_key): final grade. """ course = _retrieve_course(course_key) - return GradingPolicySerializer(course.raw_grader, many=True).data + return GradingPolicySerializer(course.raw_grader).data diff --git a/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py b/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py index 881be80437..f440e6cd4a 100644 --- a/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py +++ b/openedx/core/djangoapps/content/course_structures/api/v0/serializers.py @@ -1,8 +1,6 @@ """ API Serializers """ -from collections import defaultdict - from rest_framework import serializers @@ -13,58 +11,23 @@ class GradingPolicySerializer(serializers.Serializer): dropped = serializers.IntegerField(source='drop_count') weight = serializers.FloatField() - def to_representation(self, obj): - """ - Return a representation of the grading policy. - """ - # Backwards compatibility with the behavior of DRF v2. - # When the grader dictionary was missing keys, DRF v2 would default to None; - # DRF v3 unhelpfully raises an exception. - return dict( - super(GradingPolicySerializer, self).to_representation( - defaultdict(lambda: None, obj) - ) - ) - # pylint: disable=invalid-name class BlockSerializer(serializers.Serializer): """ Serializer for course structure block. """ id = serializers.CharField(source='usage_key') type = serializers.CharField(source='block_type') - parent = serializers.CharField(required=False) + parent = serializers.CharField(source='parent') display_name = serializers.CharField() graded = serializers.BooleanField(default=False) format = serializers.CharField() children = serializers.CharField() - def to_representation(self, obj): - """ - Return a representation of the block. - - NOTE: this method maintains backwards compatibility with the behavior - of Django Rest Framework v2. - """ - data = super(BlockSerializer, self).to_representation(obj) - - # Backwards compatibility with the behavior of DRF v2 - # Include a NULL value for "parent" in the representation - # (instead of excluding the key entirely) - if obj.get("parent") is None: - data["parent"] = None - - # Backwards compatibility with the behavior of DRF v2 - # Leave the children list as a list instead of serializing - # it to a string. - data["children"] = obj["children"] - - return data - class CourseStructureSerializer(serializers.Serializer): """ Serializer for course structure. """ - root = serializers.CharField() - blocks = serializers.SerializerMethodField() + root = serializers.CharField(source='root') + blocks = serializers.SerializerMethodField('get_blocks') def get_blocks(self, structure): """ Serialize the individual blocks. """ diff --git a/openedx/core/djangoapps/credit/serializers.py b/openedx/core/djangoapps/credit/serializers.py index 4e43a139fc..d808233848 100644 --- a/openedx/core/djangoapps/credit/serializers.py +++ b/openedx/core/djangoapps/credit/serializers.py @@ -2,33 +2,12 @@ from rest_framework import serializers -from opaque_keys.edx.keys import CourseKey -from opaque_keys import InvalidKeyError from openedx.core.djangoapps.credit.models import CreditCourse -class CourseKeyField(serializers.Field): - """ - Serializer field for a model CourseKey field. - """ - - def to_representation(self, data): - """Convert a course key to unicode. """ - return unicode(data) - - def to_internal_value(self, data): - """Convert unicode to a course key. """ - try: - return CourseKey.from_string(data) - except InvalidKeyError as ex: - raise serializers.ValidationError("Invalid course key: {msg}".format(msg=ex.msg)) - - class CreditCourseSerializer(serializers.ModelSerializer): """ CreditCourse Serializer """ - course_key = CourseKeyField() - class Meta(object): # pylint: disable=missing-docstring model = CreditCourse exclude = ('id',) diff --git a/openedx/core/djangoapps/credit/tests/test_views.py b/openedx/core/djangoapps/credit/tests/test_views.py index fc95c762b6..e0ae94152b 100644 --- a/openedx/core/djangoapps/credit/tests/test_views.py +++ b/openedx/core/djangoapps/credit/tests/test_views.py @@ -393,7 +393,10 @@ class CreditCourseViewSetTests(TestCase): # POSTs without a CSRF token should fail. response = client.post(self.path, data=json.dumps(data), content_type=JSON) - self.assertEqual(response.status_code, 403) + + # NOTE (CCB): Ordinarily we would expect a 403; however, since the CSRF validation and session authentication + # fail, DRF considers the request to be unauthenticated. + self.assertEqual(response.status_code, 401) self.assertIn('CSRF', response.content) # Retrieve a CSRF token diff --git a/openedx/core/djangoapps/credit/views.py b/openedx/core/djangoapps/credit/views.py index 21bec85ad4..5c179fa242 100644 --- a/openedx/core/djangoapps/credit/views.py +++ b/openedx/core/djangoapps/credit/views.py @@ -18,9 +18,8 @@ from django.views.decorators.http import require_POST, require_GET from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey import pytz -from rest_framework import viewsets, mixins, permissions -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework import viewsets, mixins, permissions, authentication + from util.json_request import JsonResponse from util.date_utils import from_timestamp from openedx.core.djangoapps.credit import api @@ -378,28 +377,17 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view lookup_value_regex = settings.COURSE_KEY_REGEX queryset = CreditCourse.objects.all() serializer_class = CreditCourseSerializer - authentication_classes = (OAuth2Authentication, SessionAuthentication,) + authentication_classes = (authentication.OAuth2Authentication, authentication.SessionAuthentication,) permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser) - # In Django Rest Framework v3, there is a default pagination - # class that transmutes the response data into a dictionary - # with pagination information. The original response data (a list) - # is stored in a "results" value of the dictionary. - # For backwards compatibility with the existing API, we disable - # the default behavior by setting the pagination_class to None. - pagination_class = None - # This CSRF exemption only applies when authenticating without SessionAuthentication. # SessionAuthentication will enforce CSRF protection. @method_decorator(csrf_exempt) def dispatch(self, request, *args, **kwargs): + # Convert the course ID/key from a string to an actual CourseKey object. + course_id = kwargs.get(self.lookup_field, None) + + if course_id: + kwargs[self.lookup_field] = CourseKey.from_string(course_id) + return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs) - - def get_object(self): - # Convert the serialized course key into a CourseKey instance - # so we can look up the object. - course_key = self.kwargs.get(self.lookup_field) - if course_key is not None: - self.kwargs[self.lookup_field] = CourseKey.from_string(course_key) - - return super(CreditCourseViewSet, self).get_object() diff --git a/openedx/core/djangoapps/profile_images/tests/test_views.py b/openedx/core/djangoapps/profile_images/tests/test_views.py index 4e93a1ad25..d1d2901b4c 100644 --- a/openedx/core/djangoapps/profile_images/tests/test_views.py +++ b/openedx/core/djangoapps/profile_images/tests/test_views.py @@ -30,40 +30,6 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC) TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC) -class PatchedClient(APIClient): - """ - Patch DRF's APIClient to avoid a unicode error on file upload. - - Famous last words: This is a *temporary* fix that we should be - able to remove once we upgrade Django past 1.4. - """ - - def request(self, *args, **kwargs): - """Construct an API request. """ - # DRF's default test client implementation uses `six.text_type()` - # to convert the CONTENT_TYPE to `unicode`. In Django 1.4, - # this causes a `UnicodeDecodeError` when Django parses a multipart - # upload. - # - # This is the DRF code we're working around: - # https://github.com/tomchristie/django-rest-framework/blob/3.1.3/rest_framework/compat.py#L227 - # - # ... and this is the Django code that raises the exception: - # - # https://github.com/django/django/blob/1.4.22/django/http/multipartparser.py#L435 - # - # Django unhelpfully swallows the exception, so to the application code - # it appears as though the user didn't send any file data. - # - # This appears to be an issue only with requests constructed in the test - # suite, not with the upload code used in production. - # - if isinstance(kwargs.get("CONTENT_TYPE"), basestring): - kwargs["CONTENT_TYPE"] = str(kwargs["CONTENT_TYPE"]) - - return super(PatchedClient, self).request(*args, **kwargs) - - class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase): """ Base class / shared infrastructure for tests of profile_image "upload" and @@ -145,10 +111,6 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase): """ _view_name = "profile_image_upload" - # Use the patched version of the API client to workaround a unicode issue - # with DRF 3.1 and Django 1.4. Remove this after we upgrade Django past 1.4! - client_class = PatchedClient - def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT): """ Make sure we emit a UserProfile event corresponding to the diff --git a/openedx/core/djangoapps/user_api/accounts/api.py b/openedx/core/djangoapps/user_api/accounts/api.py index e7e716a3b5..532c139a7a 100644 --- a/openedx/core/djangoapps/user_api/accounts/api.py +++ b/openedx/core/djangoapps/user_api/accounts/api.py @@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None): serializer.save() if "language_proficiencies" in update: - new_language_proficiencies = update["language_proficiencies"] + new_language_proficiencies = legacy_profile_serializer.data["language_proficiencies"] emit_setting_changed_event( user=existing_user, db_table=existing_user_profile.language_proficiencies.model._meta.db_table, diff --git a/openedx/core/djangoapps/user_api/accounts/serializers.py b/openedx/core/djangoapps/user_api/accounts/serializers.py index df4d37da2c..180f8ea208 100644 --- a/openedx/core/djangoapps/user_api/accounts/serializers.py +++ b/openedx/core/djangoapps/user_api/accounts/serializers.py @@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer): super(UserReadOnlySerializer, self).__init__(*args, **kwargs) - def to_representation(self, user): + def to_native(self, user): """ Overwrite to_native to handle custom logic since we are serializing two models as one here :param user: User object @@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea Class that serializes the portion of UserProfile model needed for account information. """ profile_image = serializers.SerializerMethodField("_get_profile_image") - requires_parental_consent = serializers.SerializerMethodField() - language_proficiencies = LanguageProficiencySerializer(many=True, required=False) + requires_parental_consent = serializers.SerializerMethodField("get_requires_parental_consent") + language_proficiencies = LanguageProficiencySerializer(many=True, allow_add_remove=True, required=False) class Meta(object): # pylint: disable=missing-docstring model = UserProfile @@ -165,21 +165,25 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea read_only_fields = () explicit_read_only_fields = ("profile_image", "requires_parental_consent") - def validate_name(self, new_name): + def validate_name(self, attrs, source): """ Enforce minimum length for name. """ - if len(new_name) < NAME_MIN_LENGTH: - raise serializers.ValidationError( - "The name field must be at least {} characters long.".format(NAME_MIN_LENGTH) - ) - return new_name + if source in attrs: + new_name = attrs[source].strip() + if len(new_name) < NAME_MIN_LENGTH: + raise serializers.ValidationError( + "The name field must be at least {} characters long.".format(NAME_MIN_LENGTH) + ) + attrs[source] = new_name - def validate_language_proficiencies(self, value): + return attrs + + def validate_language_proficiencies(self, attrs, source): """ Enforce all languages are unique. """ - language_proficiencies = [language for language in value] - unique_language_proficiencies = set(language["code"] for language in language_proficiencies) + language_proficiencies = [language for language in attrs.get(source, [])] + unique_language_proficiencies = set(language.code for language in language_proficiencies) if len(language_proficiencies) != len(unique_language_proficiencies): raise serializers.ValidationError("The language_proficiencies field must consist of unique languages") - return value + return attrs def transform_gender(self, user_profile, value): """ Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """ @@ -226,29 +230,3 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea call the method with a single argument, the user_profile object. """ return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user) - - def update(self, instance, validated_data): - """ - Update the profile, including nested fields. - """ - language_proficiencies = validated_data.pop("language_proficiencies", None) - - # Update all fields on the user profile that are writeable, - # except for "language_proficiencies", which we'll update separately - update_fields = set(self.get_writeable_fields()) - set(["language_proficiencies"]) - for field_name in update_fields: - default = getattr(instance, field_name) - field_value = validated_data.get(field_name, default) - setattr(instance, field_name, field_value) - - instance.save() - - # Now update the related language proficiency - if language_proficiencies is not None: - instance.language_proficiencies.all().delete() - instance.language_proficiencies.bulk_create([ - LanguageProficiency(user_profile=instance, code=language["code"]) - for language in language_proficiencies - ]) - - return instance diff --git a/openedx/core/djangoapps/user_api/accounts/tests/test_api.py b/openedx/core/djangoapps/user_api/accounts/tests/test_api.py index d3201eaa39..2cad671c49 100644 --- a/openedx/core/djangoapps/user_api/accounts/tests/test_api.py +++ b/openedx/core/djangoapps/user_api/accounts/tests/test_api.py @@ -164,10 +164,7 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): field_errors = context_manager.exception.field_errors self.assertEqual(3, len(field_errors)) self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"]) - self.assertIn( - "Value \'undecided\' is not valid for field \'gender\'", - field_errors["gender"]["developer_message"] - ) + self.assertIn("Select a valid choice", field_errors["gender"]["developer_message"]) self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"]) @patch('django.core.mail.send_mail') diff --git a/openedx/core/djangoapps/user_api/accounts/tests/test_views.py b/openedx/core/djangoapps/user_api/accounts/tests/test_views.py index f7dc0bd879..8d021d7dd2 100644 --- a/openedx/core/djangoapps/user_api/accounts/tests/test_views.py +++ b/openedx/core/djangoapps/user_api/accounts/tests/test_views.py @@ -359,19 +359,16 @@ class TestAccountAPI(UserAPITestCase): self.assertEqual(404, response.status_code) @ddt.data( - ("gender", "f", "not a gender", u'"not a gender" is not a valid choice.'), - ("level_of_education", "none", u"ȻħȺɍłɇs", u'"ȻħȺɍłɇs" is not a valid choice.'), - ("country", "GB", "XY", u'"XY" is not a valid choice.'), - ("year_of_birth", 2009, "not_an_int", u"A valid integer is required."), - ("name", "bob", "z" * 256, u"Ensure this field has no more than 255 characters."), + ("gender", "f", "not a gender", u"Select a valid choice. not a gender is not one of the available choices."), + ("level_of_education", "none", u"ȻħȺɍłɇs", u"Select a valid choice. ȻħȺɍłɇs is not one of the available choices."), + ("country", "GB", "XY", u"Select a valid choice. XY is not one of the available choices."), + ("year_of_birth", 2009, "not_an_int", u"Enter a whole number."), + ("name", "bob", "z" * 256, u"Ensure this value has at most 255 characters (it has 256)."), ("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."), ("goals", "Smell the roses"), ("mailing_address", "Sesame Street"), # Note that we store the raw data, so it is up to client to escape the HTML. - ( - "bio", u"Lacrosse-playing superhero 壓是進界推日不復女", - "z" * 3001, u"Ensure this field has no more than 3000 characters." - ), + ("bio", u"Lacrosse-playing superhero 壓是進界推日不復女", "z" * 3001, u"Ensure this value has at most 3000 characters (it has 3001)."), # Note that email is tested below, as it is not immediately updated. # Note that language_proficiencies is tested below as there are multiple error and success conditions. ) @@ -571,10 +568,10 @@ class TestAccountAPI(UserAPITestCase): self.assertItemsEqual(response.data["language_proficiencies"], proficiencies) @ddt.data( - (u"not_a_list", {u'non_field_errors': [u'Expected a list of items but got type "unicode".']}), - ([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data. Expected a dictionary, but got unicode.']}]), + (u"not_a_list", [{u'non_field_errors': [u'Expected a list of items.']}]), + ([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data']}]), ([{}], [{"code": [u"This field is required."]}]), - ([{u"code": u"invalid_language_code"}], [{'code': [u'"invalid_language_code" is not a valid choice.']}]), + ([{u"code": u"invalid_language_code"}], [{'code': [u'Select a valid choice. invalid_language_code is not one of the available choices.']}]), ([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']), ) @ddt.unpack diff --git a/openedx/core/djangoapps/user_api/preferences/api.py b/openedx/core/djangoapps/user_api/preferences/api.py index 58a3f915d3..8c2a3860ac 100644 --- a/openedx/core/djangoapps/user_api/preferences/api.py +++ b/openedx/core/djangoapps/user_api/preferences/api.py @@ -9,10 +9,9 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db import IntegrityError from django.utils.translation import ugettext as _ +from student.models import User, UserProfile from django.utils.translation import ugettext_noop -from student.models import User, UserProfile -from request_cache import get_request_or_stub from ..errors import ( UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized, PreferenceValidationError, PreferenceUpdateError @@ -69,17 +68,7 @@ def get_user_preferences(requesting_user, username=None): UserAPIInternalError: the operation failed due to an unexpected error. """ existing_user = _get_user(requesting_user, username, allow_staff=True) - - # Django Rest Framework V3 uses the current request to version - # hyperlinked URLS, so we need to retrieve the request and pass - # it in the serializer's context (otherwise we get an AssertionError). - # We're retrieving the request from the cache rather than passing it in - # as an argument because this is an implementation detail of how we're - # serializing data, which we want to encapsulate in the API call. - context = { - "request": get_request_or_stub() - } - user_serializer = UserSerializer(existing_user, context=context) + user_serializer = UserSerializer(existing_user) return user_serializer.data["preferences"] @@ -367,7 +356,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format( preference_key=preference_key, preference_value=preference_value, error=serializer.errors ) - if "key" in serializer.errors: + if serializer.errors["key"]: user_message = _(u"Invalid user preference key '{preference_key}'.").format( preference_key=preference_key ) diff --git a/openedx/core/djangoapps/user_api/preferences/tests/test_api.py b/openedx/core/djangoapps/user_api/preferences/tests/test_api.py index ef95c7b2db..bef0e79a67 100644 --- a/openedx/core/djangoapps/user_api/preferences/tests/test_api.py +++ b/openedx/core/djangoapps/user_api/preferences/tests/test_api.py @@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value): preference_key=preference_key, preference_value=preference_value, error={ - "key": [u"Ensure this field has no more than 255 characters."] + "key": [u"Ensure this value has at most 255 characters (it has 256)."] } ) diff --git a/openedx/core/djangoapps/user_api/serializers.py b/openedx/core/djangoapps/user_api/serializers.py index 2bf15b5ad8..e95dedee00 100644 --- a/openedx/core/djangoapps/user_api/serializers.py +++ b/openedx/core/djangoapps/user_api/serializers.py @@ -6,8 +6,8 @@ from .models import UserPreference class UserSerializer(serializers.HyperlinkedModelSerializer): - name = serializers.SerializerMethodField() - preferences = serializers.SerializerMethodField() + name = serializers.SerializerMethodField("get_name") + preferences = serializers.SerializerMethodField("get_preferences") def get_name(self, user): profile = UserProfile.objects.get(user=user) @@ -32,10 +32,9 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer): class RawUserPreferenceSerializer(serializers.ModelSerializer): + """Serializer that generates a raw representation of a user preference. """ - Serializer that generates a raw representation of a user preference. - """ - user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all()) + user = serializers.PrimaryKeyRelatedField() class Meta(object): # pylint: disable=missing-docstring model = UserPreference @@ -58,11 +57,3 @@ class ReadOnlyFieldsSerializerMixin(object): cls.Meta.read_only_fields tuple. """ return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '') - - @classmethod - def get_writeable_fields(cls): - """ - Return all fields on this serializer that are writeable. - """ - all_fields = getattr(cls.Meta, 'fields', tuple()) - return tuple(set(all_fields) - set(cls.get_read_only_fields())) diff --git a/openedx/core/lib/api/authentication.py b/openedx/core/lib/api/authentication.py index e6e50284f1..92f46861c4 100644 --- a/openedx/core/lib/api/authentication.py +++ b/openedx/core/lib/api/authentication.py @@ -1,11 +1,10 @@ """ Common Authentication Handlers used across projects. """ -from rest_framework.authentication import SessionAuthentication -from rest_framework_oauth.authentication import OAuth2Authentication +from rest_framework import authentication from rest_framework.exceptions import AuthenticationFailed -from rest_framework_oauth.compat import oauth2_provider, provider_now +from rest_framework.compat import oauth2_provider, provider_now -class SessionAuthenticationAllowInactiveUser(SessionAuthentication): +class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthentication): """Ensure that the user is logged in, but do not require the account to be active. We use this in the special case that a user has created an account, @@ -52,7 +51,7 @@ class SessionAuthenticationAllowInactiveUser(SessionAuthentication): return (user, None) -class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication): +class OAuth2AuthenticationAllowInactiveUser(authentication.OAuth2Authentication): """ This is a temporary workaround while the is_active field on the user is coupled with whether or not the user has verified ownership of their claimed email address. diff --git a/openedx/core/lib/api/fields.py b/openedx/core/lib/api/fields.py index 7fc9a836b4..0a1d609997 100644 --- a/openedx/core/lib/api/fields.py +++ b/openedx/core/lib/api/fields.py @@ -1,5 +1,7 @@ """Fields useful for edX API implementations.""" -from rest_framework.serializers import Field +from django.core.exceptions import ValidationError + +from rest_framework.serializers import CharField, Field class ExpandableField(Field): @@ -16,19 +18,25 @@ class ExpandableField(Field): self.expanded = kwargs.pop('expanded_serializer') super(ExpandableField, self).__init__(**kwargs) - def to_representation(self, obj): - """ - Return a representation of the field that is either expanded or collapsed. - """ - should_expand = self.field_name in self.context.get("expand", []) - field = self.expanded if should_expand else self.collapsed + def field_to_native(self, obj, field_name): + """Converts obj to a native representation, using the expanded serializer if the context requires it.""" + if 'expand' in self.context and field_name in self.context['expand']: + self.expanded.initialize(self, field_name) + return self.expanded.field_to_native(obj, field_name) + else: + self.collapsed.initialize(self, field_name) + return self.collapsed.field_to_native(obj, field_name) - # Avoid double-binding the field, otherwise we'll get - # an error about the source kwarg being redundant. - if field.source is None: - field.bind(self.field_name, self) - if should_expand: - self.expanded.context["expand"] = set(field.context.get("expand", [])) +class NonEmptyCharField(CharField): + """ + A field that enforces non-emptiness even for partial updates. - return field.to_representation(obj) + This is necessary because prior to version 3, DRF skips validation for empty + values. Thus, CharField's min_length and RegexField cannot be used to + enforce this constraint. + """ + def validate(self, value): + super(NonEmptyCharField, self).validate(value) + if not value.strip(): + raise ValidationError(self.error_messages["required"]) diff --git a/openedx/core/lib/api/mixins.py b/openedx/core/lib/api/mixins.py deleted file mode 100644 index 909fb8b765..0000000000 --- a/openedx/core/lib/api/mixins.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Django Rest Framework view mixins. -""" -from django.core.exceptions import ValidationError -from django.http import Http404 -from rest_framework import status -from rest_framework.mixins import CreateModelMixin -from rest_framework.response import Response - - -class PutAsCreateMixin(CreateModelMixin): - """ - Backwards compatibility with Django Rest Framework v2, which allowed - creation of a new resource using PUT. - """ - - def update(self, request, *args, **kwargs): - """ - Create/update course modes for a course. - """ - # First, try to update the existing instance - try: - try: - return super(PutAsCreateMixin, self).update(request, *args, **kwargs) - except Http404: - # If no instance exists yet, create it. - # This is backwards-compatible with the behavior of DRF v2. - return super(PutAsCreateMixin, self).create(request, *args, **kwargs) - - # Backwards compatibility with DRF v2 behavior, which would catch model-level - # validation errors and return a 400 - except ValidationError as err: - return Response(err.messages, status=status.HTTP_400_BAD_REQUEST) diff --git a/openedx/core/lib/api/paginators.py b/openedx/core/lib/api/paginators.py index 373845bc19..ae4b32d37e 100644 --- a/openedx/core/lib/api/paginators.py +++ b/openedx/core/lib/api/paginators.py @@ -3,31 +3,6 @@ from django.http import Http404 from django.core.paginator import Paginator, InvalidPage -from rest_framework.response import Response -from rest_framework import pagination - - -class DefaultPagination(pagination.PageNumberPagination): - """ - Default paginator for APIs in edx-platform. - - This is configured in settings to be automatically used - by any subclass of Django Rest Framework's generic API views. - """ - page_size_query_param = "page_size" - - def get_paginated_response(self, data): - """ - Annotate the response with pagination information. - """ - return Response({ - 'next': self.get_next_link(), - 'previous': self.get_previous_link(), - 'count': self.page.paginator.count, - 'num_pages': self.page.paginator.num_pages, - 'results': data - }) - def paginate_search_results(object_class, search_results, page_size, page): """ diff --git a/openedx/core/lib/api/serializers.py b/openedx/core/lib/api/serializers.py index b13bbc2d83..03ec001137 100644 --- a/openedx/core/lib/api/serializers.py +++ b/openedx/core/lib/api/serializers.py @@ -1,8 +1,32 @@ -""" -Serializers to be used in APIs. -""" +from rest_framework import pagination, serializers -from rest_framework import serializers + +class PaginationSerializer(pagination.PaginationSerializer): + """ + Custom PaginationSerializer for openedx. + + Adds the following fields: + - num_pages: total number of pages + - current_page: the current page being returned + - start: the index of the first page item within the overall collection + """ + start_page = 1 # django Paginator.page objects have 1-based indexes + num_pages = serializers.Field(source='paginator.num_pages') + current_page = serializers.SerializerMethodField('get_current_page') + start = serializers.SerializerMethodField('get_start') + sort_order = serializers.SerializerMethodField('get_sort_order') + + def get_current_page(self, page): + """Get the current page""" + return page.number + + def get_start(self, page): + """Get the index of the first page item within the overall collection""" + return (self.get_current_page(page) - self.start_page) * page.paginator.per_page + + def get_sort_order(self, page): # pylint: disable=unused-argument + """Get the order by which this collection was sorted""" + return self.context.get('sort_order') class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): @@ -30,10 +54,9 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs) - self.fields[id_source] = serializers.CharField(read_only=True) + self.fields[id_source] = serializers.CharField(read_only=True, source=id_source) self.fields['url'].view_name = view_name self.fields['url'].lookup_field = lookup_field - self.fields['url'].lookup_url_kwarg = lookup_field class Meta(object): """Defines meta information for the ModelSerializer. diff --git a/openedx/core/lib/api/tests/test_authentication.py b/openedx/core/lib/api/tests/test_authentication.py index 1fb5ff5b31..fe9a1dbd47 100644 --- a/openedx/core/lib/api/tests/test_authentication.py +++ b/openedx/core/lib/api/tests/test_authentication.py @@ -1,235 +1,63 @@ -""" -Tests for OAuth2. This module is copied from django-rest-framework-oauth (tests/test_authentication.py) -and updated to use our subclass of OAuth2Authentication. -""" - -from __future__ import unicode_literals -import datetime - -from django.conf.urls import patterns, url, include -from django.contrib.auth.models import User -from django.http import HttpResponse -from django.test import TestCase -from django.utils import unittest -from django.utils.http import urlencode - -from rest_framework import status -from rest_framework.permissions import IsAuthenticated -from rest_framework_oauth import permissions -from rest_framework_oauth.compat import oauth2_provider, oauth2_provider_scope -from rest_framework.test import APIRequestFactory, APIClient -from rest_framework.views import APIView +"""Tests for util.authentication module.""" +from mock import patch +from django.conf import settings +from rest_framework import permissions +from rest_framework.compat import patterns, url +from rest_framework.tests import test_authentication from provider import scope, constants +from unittest import skipUnless from ..authentication import OAuth2AuthenticationAllowInactiveUser -factory = APIRequestFactory() # pylint: disable=invalid-name - -class MockView(APIView): # pylint: disable=missing-docstring - permission_classes = (IsAuthenticated,) - - def get(self, request): # pylint: disable=missing-docstring,unused-argument - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - - def post(self, request): # pylint: disable=missing-docstring,unused-argument - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - - def put(self, request): # pylint: disable=missing-docstring,unused-argument - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - - -# This is the a change we've made from the django-rest-framework-oauth version -# of these tests. We're subclassing our custom OAuth2AuthenticationAllowInactiveUser -# instead of OAuth2Authentication. -class OAuth2AuthenticationDebug(OAuth2AuthenticationAllowInactiveUser): # pylint: disable=missing-docstring +class OAuth2AuthAllowInactiveUserDebug(OAuth2AuthenticationAllowInactiveUser): + """ + A debug class analogous to the OAuth2AuthenticationDebug class that tests + the OAuth2 flow with the access token sent in a query param.""" allow_query_params_token = True -urlpatterns = patterns( - '', - url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), - url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser])), - url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), - url( - r'^oauth2-with-scope-test/$', - MockView.as_view( - authentication_classes=[OAuth2AuthenticationAllowInactiveUser], - permission_classes=[permissions.TokenHasReadWriteScope] +# The following patch overrides the URL patterns for the MockView class used in +# rest_framework.tests.test_authentication so that the corresponding AllowInactiveUser +# classes are tested instead. +@skipUnless(settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'), 'OAuth2 not enabled') +@patch.object( + test_authentication, + 'urlpatterns', + patterns( + '', + url( + r'^oauth2-test/$', + test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser]) + ), + url( + r'^oauth2-test-debug/$', + test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthAllowInactiveUserDebug]) + ), + url( + r'^oauth2-with-scope-test/$', + test_authentication.MockView.as_view( + authentication_classes=[OAuth2AuthenticationAllowInactiveUser], + permission_classes=[permissions.TokenHasReadWriteScope] + ) ) - ), + ) ) - - -class OAuth2Tests(TestCase): - """OAuth 2.0 authentication""" - urls = 'openedx.core.lib.api.tests.test_authentication' - +class OAuth2AuthenticationAllowInactiveUserTestCase(test_authentication.OAuth2Tests): + """ + Tests the OAuth2AuthenticationAllowInactiveUser class by running all the existing tests in + OAuth2Tests but with the is_active flag on the user set to False. + """ def setUp(self): - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user(self.username, self.email, self.password) + super(OAuth2AuthenticationAllowInactiveUserTestCase, self).setUp() - self.CLIENT_ID = 'client_key' # pylint: disable=invalid-name - self.CLIENT_SECRET = 'client_secret' # pylint: disable=invalid-name - self.ACCESS_TOKEN = "access_token" # pylint: disable=invalid-name - self.REFRESH_TOKEN = "refresh_token" # pylint: disable=invalid-name - - self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( - client_id=self.CLIENT_ID, - client_secret=self.CLIENT_SECRET, - redirect_uri='', - client_type=0, - name='example', - user=None, - ) - - self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( - token=self.ACCESS_TOKEN, - client=self.oauth2_client, - user=self.user, - ) - self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( - user=self.user, - access_token=self.access_token, - client=self.oauth2_client - ) - - # This is the a change we've made from the django-rest-framework-oauth version - # of these tests. + # set the user's is_active flag to False. self.user.is_active = False self.user.save() - # This is the a change we've made from the django-rest-framework-oauth version - # of these tests. # Override the SCOPE_NAME_DICT setting for tests for oauth2-with-scope-test. This is # needed to support READ and WRITE scopes as they currently aren't supported by the # edx-auth2-provider, and their scope values collide with other scopes defined in the # edx-auth2-provider. scope.SCOPE_NAME_DICT = {'read': constants.READ, 'write': constants.WRITE} - - def _create_authorization_header(self, token=None): # pylint: disable=missing-docstring - return "Bearer {0}".format(token or self.access_token.token) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_type_failing(self): - """Ensure that a wrong token type lead to the correct HTTP error status code""" - auth = "Wrong token-type-obviously" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_format_failing(self): - """Ensure that a wrong token format lead to the correct HTTP error status code""" - auth = "Bearer wrong token format" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_failing(self): - """Ensure that a wrong token lead to the correct HTTP error status code""" - auth = "Bearer wrong-token" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_missing(self): - """Ensure that a missing token lead to the correct HTTP error status code""" - auth = "Bearer" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_passing_auth(self): - """Ensure GETing form over OAuth with correct client credentials succeed""" - auth = self._create_authorization_header() - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_passing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in form data succeed""" - response = self.csrf_client.post( - '/oauth2-test/', - data={'access_token': self.access_token.token} - ) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_passing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" - query = urlencode({'access_token': self.access_token.token}) - response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_failing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" - query = urlencode({'access_token': self.access_token.token}) - response = self.csrf_client.get('/oauth2-test/?%s' % query) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_passing_auth(self): - """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_token_removed_failing_auth(self): - """Ensure POSTing when there is no OAuth access token in db fails""" - self.access_token.delete() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_refresh_token_failing_auth(self): - """Ensure POSTing with refresh token instead of access token fails""" - auth = self._create_authorization_header(token=self.refresh_token.token) - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_expired_access_token_failing_auth(self): - """Ensure POSTing with expired access token fails with an 'Invalid token' error""" - self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late - self.access_token.save() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - self.assertIn('Invalid token', response.content) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_invalid_scope_failing_auth(self): - """Ensure POSTing with a readonly scope instead of a write scope fails""" - read_only_access_token = self.access_token - read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] - read_only_access_token.save() - auth = self._create_authorization_header(token=read_only_access_token.token) - response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_valid_scope_passing_auth(self): - """Ensure POSTing with a write scope succeed""" - read_write_access_token = self.access_token - read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] - read_write_access_token.save() - auth = self._create_authorization_header(token=read_write_access_token.token) - response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) diff --git a/openedx/core/lib/api/view_utils.py b/openedx/core/lib/api/view_utils.py index 77b0684e63..2a4aa6d372 100644 --- a/openedx/core/lib/api/view_utils.py +++ b/openedx/core/lib/api/view_utils.py @@ -9,7 +9,6 @@ from django.utils.translation import ugettext as _ from rest_framework import status, response from rest_framework.exceptions import APIException from rest_framework.permissions import IsAuthenticated -from rest_framework.request import clone_request from rest_framework.response import Response from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin from rest_framework.generics import GenericAPIView @@ -194,23 +193,3 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView) add_serializer_errors(serializer, patch, field_errors) return field_errors - - def get_object_or_none(self): - """ - Retrieve an object or return None if the object can't be found. - - NOTE: This replaces functionality that was removed in Django Rest Framework v3.1. - """ - try: - return self.get_object() - except Http404: - if self.request.method == 'PUT': - # For PUT-as-create operation, we need to ensure that we have - # relevant permissions, as if this was a POST request. This - # will either raise a PermissionDenied exception, or simply - # return None. - self.check_permissions(clone_request(self.request, 'POST')) - else: - # PATCH requests where the object does not exist should still - # return a 404 response. - raise diff --git a/requirements/edx/base.txt b/requirements/edx/base.txt index c0463faf5e..22f63923cd 100644 --- a/requirements/edx/base.txt +++ b/requirements/edx/base.txt @@ -28,7 +28,7 @@ django-ses==0.7.0 django-simple-history==1.6.3 django-storages==1.1.5 django-method-override==0.1.0 -djangorestframework>=3.1,<3.2 +djangorestframework==2.3.14 django==1.4.22 elasticsearch==0.4.5 facebook-sdk==0.4.0 diff --git a/requirements/edx/github.txt b/requirements/edx/github.txt index a40bb836e9..82f1e313d7 100644 --- a/requirements/edx/github.txt +++ b/requirements/edx/github.txt @@ -12,7 +12,6 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df -e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline -e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki -e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider --e git+https://github.com/edx/django-rest-framework-oauth.git@f0b503fda8c254a38f97fef802ded4f5fe367f7a#egg=djangorestframework-oauth -e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6 -e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev @@ -41,13 +40,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx -e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking -e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash -e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock --e git+https://github.com/edx/edx-ora2.git@release-2015-09-16T15.28#egg=edx-ora2 --e git+https://github.com/edx/edx-submissions.git@0.1.0#egg=edx-submissions +-e git+https://github.com/edx/edx-ora2.git@release-2015-08-25T16.16#egg=edx-ora2 +-e git+https://github.com/edx/edx-submissions.git@9538ee8a971d04dc1cb05e88f6aa0c36b224455c#egg=edx-submissions -e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3 git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3 git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7 --e git+https://github.com/edx/edx-val.git@0.0.6#egg=edx-val +-e git+https://github.com/edx/edx-val.git@v0.0.5#egg=edx-val -e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock -e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search -e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones