Merge pull request #9929 from edx/bbeggs/merge-DRF-3.1
Merge DRF 3.1 in to master
This commit is contained in:
@@ -44,7 +44,10 @@ from lms.envs.common import (
|
||||
PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES,
|
||||
# The following setting is included as it is used to check whether to
|
||||
# display credit eligibility table on the CMS or not.
|
||||
ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY
|
||||
ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY,
|
||||
|
||||
# Django REST framework configuration
|
||||
REST_FRAMEWORK,
|
||||
)
|
||||
from path import Path as path
|
||||
from warnings import simplefilter
|
||||
|
||||
@@ -6,7 +6,7 @@ from django.test.utils import override_settings
|
||||
from django.test.client import RequestFactory
|
||||
from django.conf import settings
|
||||
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
|
||||
|
||||
@@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase):
|
||||
|
||||
def test_perform_csrf_referer_check(self):
|
||||
request = self._fake_request()
|
||||
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
|
||||
with self.assertRaisesRegexp(PermissionDenied, 'CSRF'):
|
||||
self.auth.enforce_csrf(request)
|
||||
|
||||
@patch.dict(settings.FEATURES, {
|
||||
|
||||
@@ -11,7 +11,7 @@ from enrollment.errors import (
|
||||
CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError,
|
||||
CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute
|
||||
)
|
||||
from enrollment.serializers import CourseEnrollmentSerializer, CourseField
|
||||
from enrollment.serializers import CourseEnrollmentSerializer, CourseSerializer
|
||||
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
|
||||
from student.models import (
|
||||
CourseEnrollment, NonExistentCourseError, EnrollmentClosedError,
|
||||
@@ -35,9 +35,30 @@ def get_course_enrollments(user_id):
|
||||
|
||||
"""
|
||||
qset = CourseEnrollment.objects.filter(
|
||||
user__username=user_id, is_active=True
|
||||
user__username=user_id,
|
||||
is_active=True
|
||||
).order_by('created')
|
||||
return CourseEnrollmentSerializer(qset).data
|
||||
|
||||
enrollments = CourseEnrollmentSerializer(qset, many=True).data
|
||||
|
||||
# Find deleted courses and filter them out of the results
|
||||
deleted = []
|
||||
valid = []
|
||||
for enrollment in enrollments:
|
||||
if enrollment.get("course_details") is not None:
|
||||
valid.append(enrollment)
|
||||
else:
|
||||
deleted.append(enrollment)
|
||||
|
||||
if deleted:
|
||||
log.warning(
|
||||
(
|
||||
u"Course enrollments for user %s reference "
|
||||
u"courses that do not exist (this can occur if a course is deleted)."
|
||||
), user_id,
|
||||
)
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
def get_course_enrollment(username, course_id):
|
||||
@@ -271,4 +292,4 @@ def get_course_enrollment_info(course_id, include_expired=False):
|
||||
log.warning(msg)
|
||||
raise CourseNotFoundError(msg)
|
||||
else:
|
||||
return CourseField().to_native(course, include_expired=include_expired)
|
||||
return CourseSerializer(course, include_expired=include_expired).data
|
||||
|
||||
@@ -30,32 +30,36 @@ class StringListField(serializers.CharField):
|
||||
return [int(item) for item in items]
|
||||
|
||||
|
||||
class CourseField(serializers.RelatedField):
|
||||
"""Read-Only representation of course enrollment information.
|
||||
|
||||
Aggregates course information from the CourseDescriptor as well as the Course Modes configured
|
||||
for enrolling in the course.
|
||||
|
||||
class CourseSerializer(serializers.Serializer): # pylint: disable=abstract-method
|
||||
"""
|
||||
Serialize a course descriptor and related information.
|
||||
"""
|
||||
|
||||
def to_native(self, course, **kwargs):
|
||||
course_modes = ModeSerializer(
|
||||
CourseMode.modes_for_course(
|
||||
course.id,
|
||||
include_expired=kwargs.get('include_expired', False),
|
||||
only_selectable=False
|
||||
)
|
||||
).data
|
||||
course_id = serializers.CharField(source="id")
|
||||
enrollment_start = serializers.DateTimeField(format=None)
|
||||
enrollment_end = serializers.DateTimeField(format=None)
|
||||
course_start = serializers.DateTimeField(source="start", format=None)
|
||||
course_end = serializers.DateTimeField(source="end", format=None)
|
||||
invite_only = serializers.BooleanField(source="invitation_only")
|
||||
course_modes = serializers.SerializerMethodField()
|
||||
|
||||
return {
|
||||
'course_id': unicode(course.id),
|
||||
'enrollment_start': course.enrollment_start,
|
||||
'enrollment_end': course.enrollment_end,
|
||||
'course_start': course.start,
|
||||
'course_end': course.end,
|
||||
'invite_only': course.invitation_only,
|
||||
'course_modes': course_modes,
|
||||
}
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.include_expired = kwargs.pop("include_expired", False)
|
||||
super(CourseSerializer, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_course_modes(self, obj):
|
||||
"""
|
||||
Retrieve course modes associated with the course.
|
||||
"""
|
||||
course_modes = CourseMode.modes_for_course(
|
||||
obj.id,
|
||||
include_expired=self.include_expired,
|
||||
only_selectable=False
|
||||
)
|
||||
return [
|
||||
ModeSerializer(mode).data
|
||||
for mode in course_modes
|
||||
]
|
||||
|
||||
|
||||
class CourseEnrollmentSerializer(serializers.ModelSerializer):
|
||||
@@ -65,34 +69,9 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
|
||||
the Course Descriptor and course modes, to give a complete representation of course enrollment.
|
||||
|
||||
"""
|
||||
course_details = serializers.SerializerMethodField('get_course_details')
|
||||
course_details = CourseSerializer(source="course_overview")
|
||||
user = serializers.SerializerMethodField('get_username')
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
serialized_data = super(CourseEnrollmentSerializer, self).data
|
||||
|
||||
# filter the results with empty courses 'course_details'
|
||||
if isinstance(serialized_data, dict):
|
||||
if serialized_data.get('course_details') is None:
|
||||
return None
|
||||
|
||||
return serialized_data
|
||||
|
||||
return [enrollment for enrollment in serialized_data if enrollment.get('course_details')]
|
||||
|
||||
def get_course_details(self, model):
|
||||
if model.course is None:
|
||||
msg = u"Course '{0}' does not exist (maybe deleted), in which User (user_id: '{1}') is enrolled.".format(
|
||||
model.course_id,
|
||||
model.user.id
|
||||
)
|
||||
log.warning(msg)
|
||||
return None
|
||||
|
||||
field = CourseField()
|
||||
return field.to_native(model.course)
|
||||
|
||||
def get_username(self, model):
|
||||
"""Retrieves the username from the associated model."""
|
||||
return model.username
|
||||
|
||||
@@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase):
|
||||
@cross_domain_config
|
||||
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
|
||||
resp = self._cross_domain_post('invalid_csrf_token')
|
||||
self.assertEqual(resp.status_code, 401)
|
||||
self.assertEqual(resp.status_code, 403)
|
||||
|
||||
def _get_csrf_cookie(self):
|
||||
"""Retrieve the cross-domain CSRF cookie. """
|
||||
|
||||
@@ -5,10 +5,18 @@ This module requires that :class:`request_cache.middleware.RequestCache`
|
||||
is installed in order to clear the cache after each request.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urlparse import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.test.client import RequestFactory
|
||||
|
||||
from request_cache import middleware
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache(name):
|
||||
"""
|
||||
Return the request cache named ``name``.
|
||||
@@ -26,3 +34,38 @@ def get_request():
|
||||
Return the current request.
|
||||
"""
|
||||
return middleware.RequestCache.get_current_request()
|
||||
|
||||
|
||||
def get_request_or_stub():
|
||||
"""
|
||||
Return the current request or a stub request.
|
||||
|
||||
If called outside the context of a request, construct a fake
|
||||
request that can be used to build an absolute URI.
|
||||
|
||||
This is useful in cases where we need to pass in a request object
|
||||
but don't have an active request (for example, in test cases).
|
||||
"""
|
||||
request = get_request()
|
||||
|
||||
if request is None:
|
||||
log.warning(
|
||||
"Could not retrieve the current request. "
|
||||
"A stub request will be created instead using settings.SITE_NAME. "
|
||||
"This should be used *only* in test cases, never in production!"
|
||||
)
|
||||
|
||||
# The settings SITE_NAME may contain a port number, so we need to
|
||||
# parse the full URL.
|
||||
full_url = "http://{site_name}".format(site_name=settings.SITE_NAME)
|
||||
parsed_url = urlparse(full_url)
|
||||
|
||||
# Construct the fake request. This can be used to construct absolute
|
||||
# URIs to other paths.
|
||||
return RequestFactory(
|
||||
SERVER_NAME=parsed_url.hostname,
|
||||
SERVER_PORT=parsed_url.port or 80,
|
||||
).get("/")
|
||||
|
||||
else:
|
||||
return request
|
||||
|
||||
20
common/djangoapps/request_cache/tests.py
Normal file
20
common/djangoapps/request_cache/tests.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Tests for the request cache.
|
||||
"""
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
|
||||
from request_cache import get_request_or_stub
|
||||
|
||||
|
||||
class TestRequestCache(TestCase):
|
||||
"""
|
||||
Tests for the request cache.
|
||||
"""
|
||||
|
||||
def test_get_request_or_stub(self):
|
||||
# Outside the context of the request, we should still get a request
|
||||
# that allows us to build an absolute URI.
|
||||
stub = get_request_or_stub()
|
||||
expected_url = "http://{site_name}/foobar".format(site_name=settings.SITE_NAME)
|
||||
self.assertEqual(stub.build_absolute_uri("foobar"), expected_url)
|
||||
@@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase):
|
||||
)
|
||||
create_team_page.submit_form()
|
||||
team_page = TeamPage(self.browser, self.course_id)
|
||||
self.assertTrue(team_page.is_browser_on_page)
|
||||
self.assertTrue(team_page.is_browser_on_page())
|
||||
team_page.click_all_topics()
|
||||
self.assertTrue(self.topics_page.is_browser_on_page())
|
||||
self.topics_page.wait_for_ajax()
|
||||
|
||||
@@ -18,7 +18,12 @@ class CourseModeSerializer(serializers.ModelSerializer):
|
||||
""" CourseMode serializer. """
|
||||
name = serializers.CharField(source='mode_slug')
|
||||
price = serializers.IntegerField(source='min_price')
|
||||
expires = serializers.DateTimeField(source='expiration_datetime', required=False, blank=True)
|
||||
expires = serializers.DateTimeField(
|
||||
source='expiration_datetime',
|
||||
required=False,
|
||||
allow_null=True,
|
||||
format=None
|
||||
)
|
||||
|
||||
def get_identity(self, data):
|
||||
try:
|
||||
@@ -56,8 +61,8 @@ class CourseSerializer(serializers.Serializer):
|
||||
""" Course serializer. """
|
||||
id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name
|
||||
name = serializers.CharField(read_only=True)
|
||||
verification_deadline = serializers.DateTimeField(blank=True)
|
||||
modes = CourseModeSerializer(many=True, allow_add_remove=True)
|
||||
verification_deadline = serializers.DateTimeField(format=None, allow_null=True, required=False)
|
||||
modes = CourseModeSerializer(many=True)
|
||||
|
||||
def validate(self, attrs):
|
||||
""" Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """
|
||||
@@ -68,7 +73,7 @@ class CourseSerializer(serializers.Serializer):
|
||||
|
||||
# Find the earliest upgrade deadline
|
||||
for mode in attrs['modes']:
|
||||
expires = mode.expiration_datetime
|
||||
expires = mode.get("expiration_datetime")
|
||||
if expires:
|
||||
# If we don't already have an upgrade_deadline value, use datetime.max so that we can actually
|
||||
# complete the comparison.
|
||||
@@ -82,9 +87,28 @@ class CourseSerializer(serializers.Serializer):
|
||||
|
||||
return attrs
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
if instance is None:
|
||||
return Course(attrs['id'], attrs['modes'], attrs['verification_deadline'])
|
||||
def create(self, validated_data):
|
||||
"""Create course modes for a course. """
|
||||
course = Course(
|
||||
validated_data["id"],
|
||||
self._new_course_mode_models(validated_data["modes"]),
|
||||
verification_deadline=validated_data["verification_deadline"]
|
||||
)
|
||||
course.save()
|
||||
return course
|
||||
|
||||
instance.update(attrs)
|
||||
def update(self, instance, validated_data):
|
||||
"""Update course modes for an existing course. """
|
||||
validated_data["modes"] = self._new_course_mode_models(validated_data["modes"])
|
||||
|
||||
instance.update(validated_data)
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def _new_course_mode_models(modes_data):
|
||||
"""Convert validated course mode data to CourseMode objects. """
|
||||
return [
|
||||
CourseMode(**modes_dict)
|
||||
for modes_dict in modes_data
|
||||
]
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
import logging
|
||||
|
||||
from django.http import Http404
|
||||
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
|
||||
@@ -10,6 +11,7 @@ from commerce.api.v1.models import Course
|
||||
from commerce.api.v1.permissions import ApiKeyOrModelPermission
|
||||
from commerce.api.v1.serializers import CourseSerializer
|
||||
from course_modes.models import CourseMode
|
||||
from openedx.core.lib.api.mixins import PutAsCreateMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,12 +21,13 @@ class CourseListView(ListAPIView):
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = CourseSerializer
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
return Course.iterator()
|
||||
return list(Course.iterator())
|
||||
|
||||
|
||||
class CourseRetrieveUpdateView(RetrieveUpdateAPIView):
|
||||
class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView):
|
||||
""" Retrieve, update, or create courses/modes. """
|
||||
lookup_field = 'id'
|
||||
lookup_url_kwarg = 'course_id'
|
||||
@@ -33,6 +36,11 @@ class CourseRetrieveUpdateView(RetrieveUpdateAPIView):
|
||||
permission_classes = (ApiKeyOrModelPermission,)
|
||||
serializer_class = CourseSerializer
|
||||
|
||||
# Django Rest Framework v3 requires that we provide a queryset.
|
||||
# Note that we're overriding `get_object()` below to return a `Course`
|
||||
# rather than a CourseMode, so this isn't really used.
|
||||
queryset = CourseMode.objects.all()
|
||||
|
||||
def get_object(self, queryset=None):
|
||||
course_id = self.kwargs.get(self.lookup_url_kwarg)
|
||||
course = Course.get(course_id)
|
||||
|
||||
@@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer):
|
||||
id = serializers.CharField() # pylint: disable=invalid-name
|
||||
name = serializers.CharField(source='display_name')
|
||||
category = serializers.CharField()
|
||||
org = serializers.SerializerMethodField('get_org')
|
||||
run = serializers.SerializerMethodField('get_run')
|
||||
course = serializers.SerializerMethodField('get_course')
|
||||
uri = serializers.SerializerMethodField('get_uri')
|
||||
image_url = serializers.SerializerMethodField('get_image_url')
|
||||
org = serializers.SerializerMethodField()
|
||||
run = serializers.SerializerMethodField()
|
||||
course = serializers.SerializerMethodField()
|
||||
uri = serializers.SerializerMethodField()
|
||||
image_url = serializers.SerializerMethodField()
|
||||
start = serializers.DateTimeField()
|
||||
end = serializers.DateTimeField()
|
||||
|
||||
|
||||
@@ -36,6 +36,23 @@ class CourseViewTestsMixin(object):
|
||||
"""
|
||||
view = None
|
||||
|
||||
raw_grader = [
|
||||
{
|
||||
"min_count": 24,
|
||||
"weight": 0.2,
|
||||
"type": "Homework",
|
||||
"drop_count": 0,
|
||||
"short_label": "HW"
|
||||
},
|
||||
{
|
||||
"min_count": 4,
|
||||
"weight": 0.8,
|
||||
"type": "Exam",
|
||||
"drop_count": 0,
|
||||
"short_label": "Exam"
|
||||
}
|
||||
]
|
||||
|
||||
def setUp(self):
|
||||
super(CourseViewTestsMixin, self).setUp()
|
||||
self.create_user_and_access_token()
|
||||
@@ -51,22 +68,7 @@ class CourseViewTestsMixin(object):
|
||||
@classmethod
|
||||
def create_course_data(cls):
|
||||
cls.invalid_course_id = 'foo/bar/baz'
|
||||
cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=[
|
||||
{
|
||||
"min_count": 24,
|
||||
"weight": 0.2,
|
||||
"type": "Homework",
|
||||
"drop_count": 0,
|
||||
"short_label": "HW"
|
||||
},
|
||||
{
|
||||
"min_count": 4,
|
||||
"weight": 0.8,
|
||||
"type": "Exam",
|
||||
"drop_count": 0,
|
||||
"short_label": "Exam"
|
||||
}
|
||||
])
|
||||
cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=cls.raw_grader)
|
||||
cls.course_id = unicode(cls.course.id)
|
||||
with cls.store.bulk_operations(cls.course.id, emit_signals=False):
|
||||
cls.sequential = ItemFactory.create(
|
||||
@@ -408,6 +410,55 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar
|
||||
self.assertListEqual(response.data, expected)
|
||||
|
||||
|
||||
class CourseGradingPolicyMissingFieldsTests(CourseDetailTestMixin, CourseViewTestsMixin, SharedModuleStoreTestCase):
|
||||
view = 'course_structure_api:v0:grading_policy'
|
||||
|
||||
# Update the raw grader to have missing keys
|
||||
raw_grader = [
|
||||
{
|
||||
"min_count": 24,
|
||||
"weight": 0.2,
|
||||
"type": "Homework",
|
||||
"drop_count": 0,
|
||||
"short_label": "HW"
|
||||
},
|
||||
{
|
||||
# Deleted "min_count" key
|
||||
"weight": 0.8,
|
||||
"type": "Exam",
|
||||
"drop_count": 0,
|
||||
"short_label": "Exam"
|
||||
}
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(CourseGradingPolicyMissingFieldsTests, cls).setUpClass()
|
||||
cls.create_course_data()
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
The view should return grading policy for a course.
|
||||
"""
|
||||
response = super(CourseGradingPolicyMissingFieldsTests, self).test_get()
|
||||
|
||||
expected = [
|
||||
{
|
||||
"count": 24,
|
||||
"weight": 0.2,
|
||||
"assignment_type": "Homework",
|
||||
"dropped": 0
|
||||
},
|
||||
{
|
||||
"count": None,
|
||||
"weight": 0.8,
|
||||
"assignment_type": "Exam",
|
||||
"dropped": 0
|
||||
}
|
||||
]
|
||||
self.assertListEqual(response.data, expected)
|
||||
|
||||
|
||||
#####################################################################################
|
||||
#
|
||||
# The following Mixins/Classes collectively test the CourseBlocksAndNavigation view.
|
||||
|
||||
@@ -6,7 +6,8 @@ import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import Http404
|
||||
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from rest_framework.exceptions import AuthenticationFailed, ParseError
|
||||
from rest_framework.generics import RetrieveAPIView, ListAPIView
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
@@ -21,7 +22,6 @@ from courseware.access import has_access
|
||||
from courseware.model_data import FieldDataCache
|
||||
from courseware.module_render import get_module_for_descriptor
|
||||
from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes
|
||||
from openedx.core.lib.api.serializers import PaginationSerializer
|
||||
from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors
|
||||
from student.roles import CourseInstructorRole, CourseStaffRole
|
||||
from util.module_utils import get_dynamic_descriptor_children
|
||||
@@ -157,9 +157,6 @@ class CourseList(CourseViewMixin, ListAPIView):
|
||||
* end: The course end date. If course end date is not specified, the
|
||||
value is null.
|
||||
"""
|
||||
paginate_by = 10
|
||||
paginate_by_param = 'page_size'
|
||||
pagination_serializer_class = PaginationSerializer
|
||||
serializer_class = serializers.CourseSerializer
|
||||
|
||||
def get_queryset(self):
|
||||
|
||||
@@ -25,7 +25,6 @@ from xmodule.modulestore.django import modulestore
|
||||
from xmodule.modulestore.exceptions import ItemNotFoundError
|
||||
from .models import StudentModule
|
||||
from .module_render import get_module_for_descriptor
|
||||
from submissions import api as sub_api # installed from the edx-submissions repository
|
||||
from opaque_keys import InvalidKeyError
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
from openedx.core.djangoapps.signals.signals import GRADES_UPDATED
|
||||
@@ -349,8 +348,13 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c
|
||||
# Dict of item_ids -> (earned, possible) point tuples. This *only* grabs
|
||||
# scores that were registered with the submissions API, which for the moment
|
||||
# means only openassessment (edx-ora2)
|
||||
# We need to import this here to avoid a circular dependency of the form:
|
||||
# XBlock --> submissions --> Django Rest Framework error strings -->
|
||||
# Django translation --> ... --> courseware --> submissions
|
||||
from submissions import api as sub_api # installed from the edx-submissions repository
|
||||
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
|
||||
max_scores_cache = MaxScoresCache.create_for_course(course)
|
||||
|
||||
# For the moment, we have to get scorable_locations from field_data_cache
|
||||
# and not from scores_client, because scores_client is ignorant of things
|
||||
# in the submissions API. As a further refactoring step, submissions should
|
||||
@@ -565,7 +569,12 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl
|
||||
|
||||
course_module = getattr(course_module, '_x_module', course_module)
|
||||
|
||||
# We need to import this here to avoid a circular dependency of the form:
|
||||
# XBlock --> submissions --> Django Rest Framework error strings -->
|
||||
# Django translation --> ... --> courseware --> submissions
|
||||
from submissions import api as sub_api # installed from the edx-submissions repository
|
||||
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
|
||||
|
||||
max_scores_cache = MaxScoresCache.create_for_course(course)
|
||||
# For the moment, we have to get scorable_locations from field_data_cache
|
||||
# and not from scores_client, because scores_client is ignorant of things
|
||||
|
||||
@@ -560,7 +560,7 @@ def create_thread(request, thread_data):
|
||||
if not (serializer.is_valid() and actions_form.is_valid()):
|
||||
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
|
||||
serializer.save()
|
||||
cc_thread = serializer.object
|
||||
cc_thread = serializer.instance
|
||||
thread_created.send(sender=None, user=user, post=cc_thread)
|
||||
api_thread = serializer.data
|
||||
_do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
|
||||
@@ -606,7 +606,7 @@ def create_comment(request, comment_data):
|
||||
if not (serializer.is_valid() and actions_form.is_valid()):
|
||||
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
|
||||
serializer.save()
|
||||
cc_comment = serializer.object
|
||||
cc_comment = serializer.instance
|
||||
comment_created.send(sender=None, user=request.user, post=cc_comment)
|
||||
api_comment = serializer.data
|
||||
_do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context)
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
"""
|
||||
Discussion API pagination support
|
||||
"""
|
||||
from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField
|
||||
|
||||
|
||||
class _PaginationSerializer(BasePaginationSerializer):
|
||||
"""
|
||||
A pagination serializer without the count field, because the Comments
|
||||
Service does not return result counts
|
||||
"""
|
||||
next = NextPageField(source="*")
|
||||
previous = PreviousPageField(source="*")
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
||||
class _Page(object):
|
||||
@@ -52,7 +43,25 @@ def get_paginated_data(request, results, page_num, per_page):
|
||||
previous: The URL for the previous page
|
||||
results: The results on this page
|
||||
"""
|
||||
return _PaginationSerializer(
|
||||
instance=_Page(results, page_num, per_page),
|
||||
context={"request": request}
|
||||
).data
|
||||
# Note: Previous versions of this function used Django Rest Framework's
|
||||
# paginated serializer. With the upgrade to DRF 3.1, paginated serializers
|
||||
# have been removed. We *could* use DRF's paginator classes, but there are
|
||||
# some slight differences between how DRF does pagination and how we're doing
|
||||
# pagination here. (For example, we respond with a next_url param even if
|
||||
# there is only one result on the current page.) To maintain backwards
|
||||
# compatability, we simulate the behavior that DRF used to provide.
|
||||
page = _Page(results, page_num, per_page)
|
||||
next_url, previous_url = None, None
|
||||
base_url = request.build_absolute_uri()
|
||||
|
||||
if page.has_next():
|
||||
next_url = replace_query_param(base_url, "page", page.next_page_number())
|
||||
|
||||
if page.has_previous():
|
||||
previous_url = replace_query_param(base_url, "page", page.previous_page_number())
|
||||
|
||||
return {
|
||||
"next": next_url,
|
||||
"previous": previous_url,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ from lms.lib.comment_client.thread import Thread
|
||||
from lms.lib.comment_client.user import User as CommentClientUser
|
||||
from lms.lib.comment_client.utils import CommentClientRequestError
|
||||
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
|
||||
from openedx.core.lib.api.fields import NonEmptyCharField
|
||||
|
||||
|
||||
def get_context(course, request, thread=None):
|
||||
@@ -66,36 +65,43 @@ def get_context(course, request, thread=None):
|
||||
}
|
||||
|
||||
|
||||
def validate_not_blank(value):
|
||||
"""
|
||||
Validate that a value is not an empty string or whitespace.
|
||||
|
||||
Raises: ValidationError
|
||||
"""
|
||||
if not value.strip():
|
||||
raise ValidationError("This field may not be blank.")
|
||||
|
||||
|
||||
class _ContentSerializer(serializers.Serializer):
|
||||
"""A base class for thread and comment serializers."""
|
||||
id_ = serializers.CharField(read_only=True)
|
||||
author = serializers.SerializerMethodField("get_author")
|
||||
author_label = serializers.SerializerMethodField("get_author_label")
|
||||
id = serializers.CharField(read_only=True) # pylint: disable=invalid-name
|
||||
author = serializers.SerializerMethodField()
|
||||
author_label = serializers.SerializerMethodField()
|
||||
created_at = serializers.CharField(read_only=True)
|
||||
updated_at = serializers.CharField(read_only=True)
|
||||
raw_body = NonEmptyCharField(source="body")
|
||||
rendered_body = serializers.SerializerMethodField("get_rendered_body")
|
||||
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
|
||||
voted = serializers.SerializerMethodField("get_voted")
|
||||
vote_count = serializers.SerializerMethodField("get_vote_count")
|
||||
editable_fields = serializers.SerializerMethodField("get_editable_fields")
|
||||
raw_body = serializers.CharField(source="body", validators=[validate_not_blank])
|
||||
rendered_body = serializers.SerializerMethodField()
|
||||
abuse_flagged = serializers.SerializerMethodField()
|
||||
voted = serializers.SerializerMethodField()
|
||||
vote_count = serializers.SerializerMethodField()
|
||||
editable_fields = serializers.SerializerMethodField()
|
||||
|
||||
non_updatable_fields = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(_ContentSerializer, self).__init__(*args, **kwargs)
|
||||
# id is an invalid class attribute name, so we must declare a different
|
||||
# name above and modify it here
|
||||
self.fields["id"] = self.fields.pop("id_")
|
||||
|
||||
for field in self.non_updatable_fields:
|
||||
setattr(self, "validate_{}".format(field), self._validate_non_updatable)
|
||||
|
||||
def _validate_non_updatable(self, attrs, _source):
|
||||
def _validate_non_updatable(self, value):
|
||||
"""Ensure that a field is not edited in an update operation."""
|
||||
if self.object:
|
||||
if self.instance:
|
||||
raise ValidationError("This field is not allowed in an update.")
|
||||
return attrs
|
||||
return value
|
||||
|
||||
def _is_user_privileged(self, user_id):
|
||||
"""
|
||||
@@ -131,7 +137,11 @@ class _ContentSerializer(serializers.Serializer):
|
||||
|
||||
def get_author_label(self, obj):
|
||||
"""Returns the role label for the content author."""
|
||||
return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"]))
|
||||
if self._is_anonymous(obj) or obj["user_id"] is None:
|
||||
return None
|
||||
else:
|
||||
user_id = int(obj["user_id"])
|
||||
return self._get_user_label(user_id)
|
||||
|
||||
def get_rendered_body(self, obj):
|
||||
"""Returns the rendered body content."""
|
||||
@@ -142,7 +152,7 @@ class _ContentSerializer(serializers.Serializer):
|
||||
Returns a boolean indicating whether the requester has flagged the
|
||||
content as abusive.
|
||||
"""
|
||||
return self.context["cc_requester"]["id"] in obj["abuse_flaggers"]
|
||||
return self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", [])
|
||||
|
||||
def get_voted(self, obj):
|
||||
"""
|
||||
@@ -153,7 +163,7 @@ class _ContentSerializer(serializers.Serializer):
|
||||
|
||||
def get_vote_count(self, obj):
|
||||
"""Returns the number of votes for the content."""
|
||||
return obj["votes"]["up_count"]
|
||||
return obj.get("votes", {}).get("up_count", 0)
|
||||
|
||||
def get_editable_fields(self, obj):
|
||||
"""Return the list of the fields the requester can edit"""
|
||||
@@ -169,28 +179,28 @@ class ThreadSerializer(_ContentSerializer):
|
||||
at introspection and Thread's __getattr__.
|
||||
"""
|
||||
course_id = serializers.CharField()
|
||||
topic_id = NonEmptyCharField(source="commentable_id")
|
||||
group_id = serializers.IntegerField(required=False)
|
||||
group_name = serializers.SerializerMethodField("get_group_name")
|
||||
type_ = serializers.ChoiceField(
|
||||
topic_id = serializers.CharField(source="commentable_id", validators=[validate_not_blank])
|
||||
group_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
group_name = serializers.SerializerMethodField()
|
||||
type = serializers.ChoiceField(
|
||||
source="thread_type",
|
||||
choices=[(val, val) for val in ["discussion", "question"]]
|
||||
)
|
||||
title = NonEmptyCharField()
|
||||
pinned = serializers.BooleanField(read_only=True)
|
||||
title = serializers.CharField(validators=[validate_not_blank])
|
||||
pinned = serializers.SerializerMethodField(read_only=True)
|
||||
closed = serializers.BooleanField(read_only=True)
|
||||
following = serializers.SerializerMethodField("get_following")
|
||||
following = serializers.SerializerMethodField()
|
||||
comment_count = serializers.IntegerField(source="comments_count", read_only=True)
|
||||
unread_comment_count = serializers.IntegerField(source="unread_comments_count", read_only=True)
|
||||
comment_list_url = serializers.SerializerMethodField("get_comment_list_url")
|
||||
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
|
||||
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
|
||||
comment_list_url = serializers.SerializerMethodField()
|
||||
endorsed_comment_list_url = serializers.SerializerMethodField()
|
||||
non_endorsed_comment_list_url = serializers.SerializerMethodField()
|
||||
read = serializers.BooleanField(read_only=True)
|
||||
has_endorsed = serializers.BooleanField(read_only=True, source="endorsed")
|
||||
response_count = serializers.IntegerField(source="resp_total", read_only=True)
|
||||
|
||||
non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS
|
||||
|
||||
|
||||
# TODO: https://openedx.atlassian.net/browse/MA-1359
|
||||
def __init__(self, *args, **kwargs):
|
||||
remove_fields = kwargs.pop('remove_fields', None)
|
||||
@@ -202,6 +212,13 @@ class ThreadSerializer(_ContentSerializer):
|
||||
# not have the pinned field set
|
||||
if self.object and self.object.get("pinned") is None:
|
||||
self.object["pinned"] = False
|
||||
|
||||
def get_pinned(self, obj):
|
||||
"""
|
||||
Compensate for the fact that some threads in the comments service do
|
||||
not have the pinned field set.
|
||||
"""
|
||||
return bool(obj["pinned"])
|
||||
|
||||
if remove_fields:
|
||||
# for multiple fields in a list
|
||||
@@ -245,13 +262,16 @@ class ThreadSerializer(_ContentSerializer):
|
||||
"""Returns the URL to retrieve the thread's non-endorsed comments."""
|
||||
return self.get_comment_list_url(obj, endorsed=False)
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
if instance:
|
||||
for key, val in attrs.items():
|
||||
instance[key] = val
|
||||
return instance
|
||||
else:
|
||||
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
|
||||
def create(self, validated_data):
|
||||
thread = Thread(user_id=self.context["cc_requester"]["id"], **validated_data)
|
||||
thread.save()
|
||||
return thread
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
for key, val in validated_data.items():
|
||||
instance[key] = val
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
|
||||
class CommentSerializer(_ContentSerializer):
|
||||
@@ -263,12 +283,12 @@ class CommentSerializer(_ContentSerializer):
|
||||
at introspection and Comment's __getattr__.
|
||||
"""
|
||||
thread_id = serializers.CharField()
|
||||
parent_id = serializers.CharField(required=False)
|
||||
parent_id = serializers.CharField(required=False, allow_null=True)
|
||||
endorsed = serializers.BooleanField(required=False)
|
||||
endorsed_by = serializers.SerializerMethodField("get_endorsed_by")
|
||||
endorsed_by_label = serializers.SerializerMethodField("get_endorsed_by_label")
|
||||
endorsed_at = serializers.SerializerMethodField("get_endorsed_at")
|
||||
children = serializers.SerializerMethodField("get_children")
|
||||
endorsed_by = serializers.SerializerMethodField()
|
||||
endorsed_by_label = serializers.SerializerMethodField()
|
||||
endorsed_at = serializers.SerializerMethodField()
|
||||
children = serializers.SerializerMethodField()
|
||||
|
||||
non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS
|
||||
|
||||
@@ -311,6 +331,17 @@ class CommentSerializer(_ContentSerializer):
|
||||
for child in obj.get("children", [])
|
||||
]
|
||||
|
||||
def to_representation(self, data):
|
||||
data = super(CommentSerializer, self).to_representation(data)
|
||||
|
||||
# Django Rest Framework v3 no longer includes None values
|
||||
# in the representation. To maintain the previous behavior,
|
||||
# we do this manually instead.
|
||||
if 'parent_id' not in data:
|
||||
data["parent_id"] = None
|
||||
|
||||
return data
|
||||
|
||||
def validate(self, attrs):
|
||||
"""
|
||||
Ensure that parent_id identifies a comment that is actually in the
|
||||
@@ -332,18 +363,23 @@ class CommentSerializer(_ContentSerializer):
|
||||
raise ValidationError({"parent_id": ["Comment level is too deep."]})
|
||||
return attrs
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
if instance:
|
||||
for key, val in attrs.items():
|
||||
instance[key] = val
|
||||
# TODO: The comments service doesn't populate the endorsement
|
||||
# field on comment creation, so we only provide
|
||||
# endorsement_user_id on update
|
||||
if key == "endorsed":
|
||||
instance["endorsement_user_id"] = self.context["cc_requester"]["id"]
|
||||
return instance
|
||||
return Comment(
|
||||
def create(self, validated_data):
|
||||
comment = Comment(
|
||||
course_id=self.context["thread"]["course_id"],
|
||||
user_id=self.context["cc_requester"]["id"],
|
||||
**attrs
|
||||
**validated_data
|
||||
)
|
||||
comment.save()
|
||||
return comment
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
for key, val in validated_data.items():
|
||||
instance[key] = val
|
||||
# TODO: The comments service doesn't populate the endorsement
|
||||
# field on comment creation, so we only provide
|
||||
# endorsement_user_id on update
|
||||
if key == "endorsed":
|
||||
instance["endorsement_user_id"] = self.context["cc_requester"]["id"]
|
||||
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
@@ -1497,8 +1497,9 @@ class CreateThreadTest(
|
||||
self.assertEqual(actual_post_data["group_id"], [str(cohort.id)])
|
||||
else:
|
||||
self.assertNotIn("group_id", actual_post_data)
|
||||
except ValidationError:
|
||||
self.assertTrue(expected_error)
|
||||
except ValidationError as ex:
|
||||
if not expected_error:
|
||||
self.fail("Unexpected validation error: {}".format(ex))
|
||||
|
||||
def test_following(self):
|
||||
self.register_post_thread_response({"id": "test_id"})
|
||||
@@ -2239,7 +2240,7 @@ class UpdateThreadTest(
|
||||
update_thread(self.request, "test_thread", {"raw_body": ""})
|
||||
self.assertEqual(
|
||||
assertion.exception.message_dict,
|
||||
{"raw_body": ["This field is required."]}
|
||||
{"raw_body": ["This field may not be blank."]}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -523,9 +523,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
data = self.minimal_data.copy()
|
||||
data.update({field: value for field in ["topic_id", "title", "raw_body"]})
|
||||
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
|
||||
{field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
|
||||
)
|
||||
|
||||
def test_create_type(self):
|
||||
@@ -592,9 +593,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
|
||||
{field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
|
||||
)
|
||||
|
||||
def test_update_course_id(self):
|
||||
@@ -604,6 +606,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{"course_id": ["This field is not allowed in an update."]}
|
||||
@@ -769,7 +772,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
|
||||
data["parent_id"] = None
|
||||
serializer = CommentSerializer(data=data, context=context)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(serializer.errors, {"parent_id": ["Comment level is too deep."]})
|
||||
self.assertEqual(serializer.errors, {"non_field_errors": ["Comment level is too deep."]})
|
||||
|
||||
def test_create_missing_field(self):
|
||||
for field in self.minimal_data:
|
||||
@@ -855,9 +858,10 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{"raw_body": ["This field is required."]}
|
||||
{"raw_body": ["This field may not be blank."]}
|
||||
)
|
||||
|
||||
@ddt.data("thread_id", "parent_id")
|
||||
@@ -868,6 +872,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
|
||||
partial=True,
|
||||
context=get_context(self.course, self.request)
|
||||
)
|
||||
self.assertFalse(serializer.is_valid())
|
||||
self.assertEqual(
|
||||
serializer.errors,
|
||||
{field: ["This field is not allowed in an update."]}
|
||||
|
||||
@@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest
|
||||
content_type="application/json"
|
||||
)
|
||||
expected_response_data = {
|
||||
"field_errors": {"title": {"developer_message": "This field is required."}}
|
||||
"field_errors": {"title": {"developer_message": "This field may not be blank."}}
|
||||
}
|
||||
self.assertEqual(response.status_code, 400)
|
||||
response_data = json.loads(response.content)
|
||||
@@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
200,
|
||||
{
|
||||
"results": expected_comments,
|
||||
"next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format(
|
||||
"next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format(
|
||||
self.thread_id
|
||||
),
|
||||
"previous": None,
|
||||
@@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes
|
||||
content_type="application/json"
|
||||
)
|
||||
expected_response_data = {
|
||||
"field_errors": {"raw_body": {"developer_message": "This field is required."}}
|
||||
"field_errors": {"raw_body": {"developer_message": "This field may not be blank."}}
|
||||
}
|
||||
self.assertEqual(response.status_code, 400)
|
||||
response_data = json.loads(response.content)
|
||||
|
||||
@@ -3,7 +3,8 @@ Discussion API views
|
||||
"""
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
@@ -32,8 +32,6 @@ from xmodule.modulestore.tests.factories import check_mongo_calls
|
||||
from xmodule.modulestore.django import modulestore
|
||||
from xmodule.modulestore import ModuleStoreEnum
|
||||
|
||||
from teams.tests.factories import CourseTeamFactory
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -1290,6 +1288,7 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe
|
||||
topic_id='topic_id',
|
||||
discussion_topic_id=self.team_commentable_id
|
||||
)
|
||||
|
||||
self.team.add_user(self.student_in_team)
|
||||
|
||||
# Dummy commentable ID not linked to a team
|
||||
|
||||
@@ -717,6 +717,7 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase):
|
||||
topic_id='topic_id',
|
||||
discussion_topic_id=self.discussion_topic_id
|
||||
)
|
||||
|
||||
self.team.add_user(self.user) # pylint: disable=no-member
|
||||
|
||||
def test_context_can_be_standalone(self, mock_request):
|
||||
@@ -1093,7 +1094,9 @@ class InlineDiscussionTestCase(ModuleStoreTestCase):
|
||||
course_id=self.course.id,
|
||||
discussion_topic_id=self.discussion1.discussion_id
|
||||
)
|
||||
|
||||
team.add_user(self.student) # pylint: disable=no-member
|
||||
|
||||
response = self.send_request(mock_request)
|
||||
self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE)
|
||||
self.verify_response(response)
|
||||
|
||||
@@ -149,6 +149,8 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin):
|
||||
|
||||
def test_team_discussion_id_not_cohorted(self, mock_request):
|
||||
team = CourseTeamFactory(course_id=self.course.id)
|
||||
|
||||
team.add_user(self.student) # pylint: disable=no-member
|
||||
self.call_view(mock_request, team.discussion_topic_id, self.student, None)
|
||||
|
||||
self._assert_comments_service_called_without_group_id(mock_request)
|
||||
|
||||
@@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView):
|
||||
serializer_class = serializers.CoursesWithFriendsSerializer
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.GET, files=request.FILES)
|
||||
serializer = self.get_serializer(data=request.GET)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@@ -61,4 +61,5 @@ class CoursesWithFriends(generics.ListAPIView):
|
||||
and is_mobile_available_for_user(self.request.user, enrollment.course)
|
||||
]
|
||||
|
||||
return Response(CourseEnrollmentSerializer(courses, context={'request': request}).data)
|
||||
serializer = CourseEnrollmentSerializer(courses, context={'request': request}, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView):
|
||||
serializer_class = serializers.FriendsInCourseSerializer
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.GET, files=request.FILES)
|
||||
serializer = self.get_serializer(data=request.GET)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin):
|
||||
serializer_class = serializers.GroupSerializer
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
|
||||
serializer = self.get_serializer(data=request.DATA)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
try:
|
||||
@@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin):
|
||||
serializer_class = serializers.GroupsMembersSerializer
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
|
||||
serializer = self.get_serializer(data=request.DATA)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
graph = facebook_graph_api()
|
||||
url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members"
|
||||
member_ids = serializer.object['member_ids'].split(',')
|
||||
member_ids = serializer.data['member_ids'].split(',')
|
||||
response = {}
|
||||
for member_id in member_ids:
|
||||
try:
|
||||
|
||||
@@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer):
|
||||
"""
|
||||
Serializes user social settings
|
||||
"""
|
||||
share_with_facebook_friends = serializers.BooleanField(required=True, default=False)
|
||||
share_with_facebook_friends = serializers.BooleanField(required=True)
|
||||
|
||||
@@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView):
|
||||
serializer_class = serializers.UserSharingSerializar
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
|
||||
serializer = self.get_serializer(data=request.DATA)
|
||||
if serializer.is_valid():
|
||||
value = serializer.object['share_with_facebook_friends']
|
||||
value = serializer.data['share_with_facebook_friends']
|
||||
set_user_preference(request.user, "share_with_facebook_friends", value)
|
||||
return self.get(request, *args, **kwargs)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer):
|
||||
the error message.
|
||||
"""
|
||||
try:
|
||||
graph = facebook.GraphAPI(serializer.object['oauth_token'])
|
||||
graph = facebook.GraphAPI(serializer.data['oauth_token'])
|
||||
friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends")
|
||||
return get_pagination(friends)
|
||||
except facebook.GraphAPIError, ex:
|
||||
|
||||
@@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE
|
||||
class CourseOverviewField(serializers.RelatedField):
|
||||
"""Custom field to wrap a CourseDescriptor object. Read-only."""
|
||||
|
||||
def to_native(self, course_overview):
|
||||
def to_representation(self, course_overview):
|
||||
course_id = unicode(course_overview.id)
|
||||
request = self.context.get('request', None)
|
||||
if request:
|
||||
@@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
|
||||
"""
|
||||
Serializes CourseEnrollment models
|
||||
"""
|
||||
course = CourseOverviewField(source="course_overview")
|
||||
certificate = serializers.SerializerMethodField('get_certificate')
|
||||
course = CourseOverviewField(source="course_overview", read_only=True)
|
||||
certificate = serializers.SerializerMethodField()
|
||||
|
||||
def get_certificate(self, model):
|
||||
"""Returns the information about the user's certificate in the course."""
|
||||
@@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer):
|
||||
"""
|
||||
Serializes User models
|
||||
"""
|
||||
name = serializers.Field(source='profile.name')
|
||||
name = serializers.ReadOnlyField(source='profile.name')
|
||||
course_enrollments = serializers.HyperlinkedIdentityField(
|
||||
view_name='courseenrollment-detail',
|
||||
lookup_field='username'
|
||||
|
||||
@@ -251,6 +251,14 @@ class UserCourseEnrollmentsList(generics.ListAPIView):
|
||||
serializer_class = CourseEnrollmentSerializer
|
||||
lookup_field = 'username'
|
||||
|
||||
# In Django Rest Framework v3, there is a default pagination
|
||||
# class that transmutes the response data into a dictionary
|
||||
# with pagination information. The original response data (a list)
|
||||
# is stored in a "results" value of the dictionary.
|
||||
# For backwards compatibility with the existing API, we disable
|
||||
# the default behavior by setting the pagination_class to None.
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
enrollments = self.queryset.filter(
|
||||
user__username=self.kwargs['username'],
|
||||
|
||||
@@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer):
|
||||
* course_groups
|
||||
* roles__permissions
|
||||
"""
|
||||
name = serializers.SerializerMethodField("get_name")
|
||||
preferences = serializers.SerializerMethodField("get_preferences")
|
||||
course_info = serializers.SerializerMethodField("get_course_info")
|
||||
name = serializers.SerializerMethodField()
|
||||
preferences = serializers.SerializerMethodField()
|
||||
course_info = serializers.SerializerMethodField()
|
||||
|
||||
def get_name(self, user):
|
||||
return user.profile.name
|
||||
|
||||
@@ -1,11 +1,32 @@
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import pagination
|
||||
|
||||
from notification_prefs import NOTIFICATION_PREF_KEY
|
||||
from notifier_api.serializers import NotifierUserSerializer
|
||||
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
|
||||
|
||||
|
||||
class NotifierPaginator(pagination.PageNumberPagination):
|
||||
"""
|
||||
Paginator for the notifier API.
|
||||
"""
|
||||
page_size = 10
|
||||
page_size_query_param = "page_size"
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""
|
||||
Construct a response with pagination information.
|
||||
"""
|
||||
return Response({
|
||||
'next': self.get_next_link(),
|
||||
'previous': self.get_previous_link(),
|
||||
'count': self.page.paginator.count,
|
||||
'results': data
|
||||
})
|
||||
|
||||
|
||||
class NotifierUsersViewSet(ReadOnlyModelViewSet):
|
||||
"""
|
||||
An endpoint that the notifier can use to retrieve users who have enabled
|
||||
@@ -14,8 +35,7 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet):
|
||||
"""
|
||||
permission_classes = (ApiKeyHeaderPermission,)
|
||||
serializer_class = NotifierUserSerializer
|
||||
paginate_by = 10
|
||||
paginate_by_param = "page_size"
|
||||
pagination_class = NotifierPaginator
|
||||
|
||||
# See NotifierUserSerializer for notes about related tables
|
||||
queryset = User.objects.filter(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
from model_utils import FieldTracker
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
@@ -10,6 +10,7 @@ from django.utils import translation
|
||||
from functools import wraps
|
||||
|
||||
from search.search_engine_base import SearchEngine
|
||||
from request_cache import get_request_or_stub
|
||||
|
||||
from .errors import ElasticSearchConnectionError
|
||||
from .serializers import CourseTeamSerializer, CourseTeam
|
||||
@@ -47,7 +48,15 @@ class CourseTeamIndexer(object):
|
||||
|
||||
Returns serialized object with additional search fields.
|
||||
"""
|
||||
serialized_course_team = CourseTeamSerializer(self.course_team).data
|
||||
# Django Rest Framework v3.1 requires that we pass the request to the serializer
|
||||
# so it can construct hyperlinks. To avoid changing the interface of this object,
|
||||
# we retrieve the request from the request cache.
|
||||
context = {
|
||||
"request": get_request_or_stub()
|
||||
}
|
||||
|
||||
serialized_course_team = CourseTeamSerializer(self.course_team, context=context).data
|
||||
|
||||
# Save the primary key so we can load the full objects easily after we search
|
||||
serialized_course_team['pk'] = self.course_team.pk
|
||||
# Don't save the membership relations in elasticsearch
|
||||
|
||||
@@ -4,15 +4,44 @@ from django.contrib.auth.models import User
|
||||
from django.db.models import Count
|
||||
from django.conf import settings
|
||||
|
||||
from django_countries import countries
|
||||
from rest_framework import serializers
|
||||
|
||||
from openedx.core.lib.api.serializers import CollapsedReferenceSerializer, PaginationSerializer
|
||||
from openedx.core.lib.api.serializers import CollapsedReferenceSerializer
|
||||
from openedx.core.lib.api.fields import ExpandableField
|
||||
from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer
|
||||
|
||||
from .models import CourseTeam, CourseTeamMembership
|
||||
|
||||
|
||||
class CountryField(serializers.Field):
|
||||
"""
|
||||
Field to serialize a country code.
|
||||
"""
|
||||
|
||||
COUNTRY_CODES = dict(countries).keys()
|
||||
|
||||
def to_representation(self, obj):
|
||||
"""
|
||||
Represent the country as a 2-character unicode identifier.
|
||||
"""
|
||||
return unicode(obj)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Check that the code is a valid country code.
|
||||
|
||||
We leave the data in its original format so that the Django model's
|
||||
CountryField can convert it to the internal representation used
|
||||
by the django-countries library.
|
||||
"""
|
||||
if data and data not in self.COUNTRY_CODES:
|
||||
raise serializers.ValidationError(
|
||||
u"{code} is not a valid country code".format(code=data)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class UserMembershipSerializer(serializers.ModelSerializer):
|
||||
"""Serializes CourseTeamMemberships with only user and date_joined
|
||||
|
||||
@@ -43,6 +72,7 @@ class CourseTeamSerializer(serializers.ModelSerializer):
|
||||
"""Serializes a CourseTeam with membership information."""
|
||||
id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name
|
||||
membership = UserMembershipSerializer(many=True, read_only=True)
|
||||
country = CountryField()
|
||||
|
||||
class Meta(object):
|
||||
"""Defines meta information for the ModelSerializer."""
|
||||
@@ -66,6 +96,8 @@ class CourseTeamSerializer(serializers.ModelSerializer):
|
||||
class CourseTeamCreationSerializer(serializers.ModelSerializer):
|
||||
"""Deserializes a CourseTeam for creation."""
|
||||
|
||||
country = CountryField(required=False)
|
||||
|
||||
class Meta(object):
|
||||
"""Defines meta information for the ModelSerializer."""
|
||||
model = CourseTeam
|
||||
@@ -78,16 +110,17 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer):
|
||||
"language",
|
||||
)
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
"""Restores a CourseTeam instance from the given attrs."""
|
||||
return CourseTeam.create(
|
||||
name=attrs.get("name", ''),
|
||||
course_id=attrs.get("course_id"),
|
||||
description=attrs.get("description", ''),
|
||||
topic_id=attrs.get("topic_id", ''),
|
||||
country=attrs.get("country", ''),
|
||||
language=attrs.get("language", ''),
|
||||
def create(self, validated_data):
|
||||
team = CourseTeam.create(
|
||||
name=validated_data.get("name", ''),
|
||||
course_id=validated_data.get("course_id"),
|
||||
description=validated_data.get("description", ''),
|
||||
topic_id=validated_data.get("topic_id", ''),
|
||||
country=validated_data.get("country", ''),
|
||||
language=validated_data.get("language", ''),
|
||||
)
|
||||
team.save()
|
||||
return team
|
||||
|
||||
|
||||
class CourseTeamSerializerWithoutMembership(CourseTeamSerializer):
|
||||
@@ -134,13 +167,6 @@ class MembershipSerializer(serializers.ModelSerializer):
|
||||
read_only_fields = ("date_joined", "last_activity_at")
|
||||
|
||||
|
||||
class PaginatedMembershipSerializer(PaginationSerializer):
|
||||
"""Serializes team memberships with support for pagination."""
|
||||
class Meta(object):
|
||||
"""Defines meta information for the PaginatedMembershipSerializer."""
|
||||
object_serializer_class = MembershipSerializer
|
||||
|
||||
|
||||
class BaseTopicSerializer(serializers.Serializer):
|
||||
"""Serializes a topic without team_count."""
|
||||
description = serializers.CharField()
|
||||
@@ -155,7 +181,7 @@ class TopicSerializer(BaseTopicSerializer):
|
||||
model to get the count. Requires that `context` is provided with a valid course_id
|
||||
in order to filter teams within the course.
|
||||
"""
|
||||
team_count = serializers.SerializerMethodField('get_team_count')
|
||||
team_count = serializers.SerializerMethodField()
|
||||
|
||||
def get_team_count(self, topic):
|
||||
"""Get the number of teams associated with this topic"""
|
||||
@@ -166,31 +192,25 @@ class TopicSerializer(BaseTopicSerializer):
|
||||
return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count()
|
||||
|
||||
|
||||
class PaginatedTopicSerializer(PaginationSerializer):
|
||||
class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method
|
||||
"""
|
||||
Serializes a set of topics, adding the team_count field to each topic individually, if team_count
|
||||
is not already present in the topic data. Requires that `context` is provided with a valid course_id in
|
||||
order to filter teams within the course.
|
||||
List serializer for efficiently serializing a set of topics.
|
||||
"""
|
||||
class Meta(object):
|
||||
"""Defines meta information for the PaginatedTopicSerializer."""
|
||||
object_serializer_class = TopicSerializer
|
||||
|
||||
def to_representation(self, obj):
|
||||
"""Adds team_count to each topic. """
|
||||
data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj)
|
||||
add_team_count(data, self.context["course_id"])
|
||||
return data
|
||||
|
||||
|
||||
class BulkTeamCountPaginatedTopicSerializer(PaginationSerializer):
|
||||
class BulkTeamCountTopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method
|
||||
"""
|
||||
Serializes a set of topics, adding the team_count field to each topic as a bulk operation per page
|
||||
(only on the page being returned). Requires that `context` is provided with a valid course_id in
|
||||
order to filter teams within the course.
|
||||
Serializes a set of topics, adding the team_count field to each topic as a bulk operation.
|
||||
Requires that `context` is provided with a valid course_id in order to filter teams within the course.
|
||||
"""
|
||||
class Meta(object):
|
||||
"""Defines meta information for the BulkTeamCountPaginatedTopicSerializer."""
|
||||
object_serializer_class = BaseTopicSerializer
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Adds team_count to each topic on the current page."""
|
||||
super(BulkTeamCountPaginatedTopicSerializer, self).__init__(*args, **kwargs)
|
||||
add_team_count(self.data['results'], self.context['course_id'])
|
||||
class Meta: # pylint: disable=missing-docstring,old-style-class
|
||||
list_serializer_class = BulkTeamCountTopicListSerializer
|
||||
|
||||
|
||||
def add_team_count(topics, course_id):
|
||||
|
||||
@@ -34,3 +34,10 @@ class CourseTeamMembershipFactory(DjangoModelFactory):
|
||||
class Meta(object): # pylint: disable=missing-docstring
|
||||
model = CourseTeamMembership
|
||||
last_activity_at = LAST_ACTIVITY_AT
|
||||
|
||||
@classmethod
|
||||
def _create(cls, model_class, *args, **kwargs):
|
||||
"""Create the team membership. """
|
||||
obj = model_class(*args, **kwargs)
|
||||
obj.save()
|
||||
return obj
|
||||
|
||||
@@ -11,12 +11,9 @@ from xmodule.modulestore.tests.factories import CourseFactory
|
||||
|
||||
from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory
|
||||
from lms.djangoapps.teams.serializers import (
|
||||
BaseTopicSerializer,
|
||||
PaginatedTopicSerializer,
|
||||
BulkTeamCountPaginatedTopicSerializer,
|
||||
BulkTeamCountTopicSerializer,
|
||||
TopicSerializer,
|
||||
MembershipSerializer,
|
||||
add_team_count
|
||||
)
|
||||
|
||||
|
||||
@@ -73,21 +70,6 @@ class MembershipSerializerTestCase(SerializerTestCase):
|
||||
self.assertNotIn('membership', data['team'])
|
||||
|
||||
|
||||
class BaseTopicSerializerTestCase(SerializerTestCase):
|
||||
"""
|
||||
Tests for the `BaseTopicSerializer`, which should not serialize team count
|
||||
data.
|
||||
"""
|
||||
def test_team_count_not_included(self):
|
||||
"""Verifies that the `BaseTopicSerializer` does not include team count"""
|
||||
with self.assertNumQueries(0):
|
||||
serializer = BaseTopicSerializer(self.course.teams_topics[0])
|
||||
self.assertEqual(
|
||||
serializer.data,
|
||||
{u'name': u'Tøpic', u'description': u'The bést topic!', u'id': u'0'}
|
||||
)
|
||||
|
||||
|
||||
class TopicSerializerTestCase(SerializerTestCase):
|
||||
"""
|
||||
Tests for the `TopicSerializer`, which should serialize team count data for
|
||||
@@ -137,7 +119,7 @@ class TopicSerializerTestCase(SerializerTestCase):
|
||||
)
|
||||
|
||||
|
||||
class BasePaginatedTopicSerializerTestCase(SerializerTestCase):
|
||||
class BaseTopicSerializerTestCase(SerializerTestCase):
|
||||
"""
|
||||
Base class for testing the two paginated topic serializers.
|
||||
"""
|
||||
@@ -191,13 +173,15 @@ class BasePaginatedTopicSerializerTestCase(SerializerTestCase):
|
||||
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
|
||||
|
||||
|
||||
class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase):
|
||||
class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
|
||||
"""
|
||||
Tests for the `BulkTeamCountPaginatedTopicSerializer`, which should serialize team_count
|
||||
Tests for the `BulkTeamCountTopicSerializer`, which should serialize team_count
|
||||
data for many topics with constant time SQL queries.
|
||||
"""
|
||||
__test__ = True
|
||||
serializer = BulkTeamCountPaginatedTopicSerializer
|
||||
serializer = BulkTeamCountTopicSerializer
|
||||
|
||||
NUM_TOPICS = 6
|
||||
|
||||
def test_topics_with_no_team_counts(self):
|
||||
"""
|
||||
@@ -222,13 +206,13 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
|
||||
one SQL query.
|
||||
"""
|
||||
teams_per_topic = 10
|
||||
topics = self.setup_topics(num_topics=self.PAGE_SIZE + 1, teams_per_topic=teams_per_topic)
|
||||
self.assert_serializer_output(topics[:self.PAGE_SIZE], num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
|
||||
def test_scoped_within_course(self):
|
||||
"""Verify that team counts are scoped within a course."""
|
||||
teams_per_topic = 10
|
||||
first_course_topics = self.setup_topics(num_topics=self.PAGE_SIZE, teams_per_topic=teams_per_topic)
|
||||
first_course_topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
|
||||
duplicate_topic = first_course_topics[0]
|
||||
second_course = CourseFactory.create(
|
||||
teams_configuration={
|
||||
@@ -239,27 +223,44 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
|
||||
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
|
||||
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1)
|
||||
|
||||
def _merge_dicts(self, first, second):
|
||||
"""Convenience method to merge two dicts in a single expression"""
|
||||
result = first.copy()
|
||||
result.update(second)
|
||||
return result
|
||||
|
||||
class PaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase):
|
||||
"""
|
||||
Tests for the `PaginatedTopicSerializer`, which will add team_count information per topic if not present.
|
||||
"""
|
||||
__test__ = True
|
||||
serializer = PaginatedTopicSerializer
|
||||
def setup_topics(self, num_topics=5, teams_per_topic=0):
|
||||
"""
|
||||
Helper method to set up topics on the course. Returns a list of
|
||||
created topics.
|
||||
"""
|
||||
self.course.teams_configuration['topics'] = []
|
||||
topics = [
|
||||
{u'name': u'Tøpic {}'.format(i), u'description': u'The bést topic! {}'.format(i), u'id': unicode(i)}
|
||||
for i in xrange(num_topics)
|
||||
]
|
||||
for i in xrange(num_topics):
|
||||
topic_id = unicode(i)
|
||||
self.course.teams_configuration['topics'].append(topics[i])
|
||||
for _ in xrange(teams_per_topic):
|
||||
CourseTeamFactory.create(course_id=self.course.id, topic_id=topic_id)
|
||||
return topics
|
||||
|
||||
def test_topics_with_team_counts(self):
|
||||
def assert_serializer_output(self, topics, num_teams_per_topic, num_queries):
|
||||
"""
|
||||
Verify that we serialize topics with team_count, making one SQL query per topic.
|
||||
Verify that the serializer produced the expected topics.
|
||||
"""
|
||||
teams_per_topic = 2
|
||||
topics = self.setup_topics(teams_per_topic=teams_per_topic)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=5)
|
||||
with self.assertNumQueries(num_queries):
|
||||
serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True)
|
||||
self.assertEqual(
|
||||
serializer.data,
|
||||
[self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics]
|
||||
)
|
||||
|
||||
def test_topics_with_team_counts_prepopulated(self):
|
||||
def test_no_topics(self):
|
||||
"""
|
||||
Verify that if team_count is pre-populated, there are no additional SQL queries.
|
||||
Verify that we return no results and make no SQL queries for a page
|
||||
with no topics.
|
||||
"""
|
||||
teams_per_topic = 8
|
||||
topics = self.setup_topics(teams_per_topic=teams_per_topic)
|
||||
add_team_count(topics, self.course.id)
|
||||
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=0)
|
||||
self.course.teams_configuration['topics'] = []
|
||||
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
|
||||
|
||||
@@ -1,32 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the teams API at the HTTP request level."""
|
||||
import json
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
import ddt
|
||||
from elasticsearch.exceptions import ConnectionError
|
||||
from mock import patch
|
||||
from search.search_engine_base import SearchEngine
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.conf import settings
|
||||
from django.db.models.signals import post_save
|
||||
from django.utils import translation
|
||||
from nose.plugins.attrib import attr
|
||||
from rest_framework.test import APITestCase, APIClient
|
||||
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
|
||||
from xmodule.modulestore.tests.factories import CourseFactory
|
||||
|
||||
from courseware.tests.factories import StaffFactory
|
||||
from common.test.utils import skip_signal
|
||||
from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory
|
||||
from student.models import CourseEnrollment
|
||||
from util.testing import EventTestMixin
|
||||
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
|
||||
from xmodule.modulestore.tests.factories import CourseFactory
|
||||
from .factories import CourseTeamFactory, LAST_ACTIVITY_AT
|
||||
from ..models import CourseTeamMembership
|
||||
from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback
|
||||
|
||||
from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA
|
||||
from django_comment_common.utils import seed_permissions_roles
|
||||
|
||||
@@ -36,11 +35,23 @@ class TestDashboard(SharedModuleStoreTestCase):
|
||||
"""Tests for the Teams dashboard."""
|
||||
test_password = "test"
|
||||
|
||||
NUM_TOPICS = 10
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestDashboard, cls).setUpClass()
|
||||
cls.course = CourseFactory.create(
|
||||
teams_configuration={"max_team_size": 10, "topics": [{"name": "foo", "id": 0, "description": "test topic"}]}
|
||||
teams_configuration={
|
||||
"max_team_size": 10,
|
||||
"topics": [
|
||||
{
|
||||
"name": "Topic {}".format(topic_id),
|
||||
"id": topic_id,
|
||||
"description": "Description for topic {}".format(topic_id)
|
||||
}
|
||||
for topic_id in range(cls.NUM_TOPICS)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@@ -97,6 +108,30 @@ class TestDashboard(SharedModuleStoreTestCase):
|
||||
response = self.client.get(teams_url)
|
||||
self.assertEqual(404, response.status_code)
|
||||
|
||||
def test_query_counts(self):
|
||||
# Enroll in the course and log in
|
||||
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
|
||||
self.client.login(username=self.user.username, password=self.test_password)
|
||||
|
||||
# Check the query count on the dashboard With no teams
|
||||
with self.assertNumQueries(15):
|
||||
self.client.get(self.teams_url)
|
||||
|
||||
# Create some teams
|
||||
for topic_id in range(self.NUM_TOPICS):
|
||||
team = CourseTeamFactory.create(
|
||||
name=u"Team for topic {}".format(topic_id),
|
||||
course_id=self.course.id,
|
||||
topic_id=topic_id,
|
||||
)
|
||||
|
||||
# Add the user to the last team
|
||||
team.add_user(self.user)
|
||||
|
||||
# Check the query count on the dashboard again
|
||||
with self.assertNumQueries(19):
|
||||
self.client.get(self.teams_url)
|
||||
|
||||
def test_bad_course_id(self):
|
||||
"""
|
||||
Verifies expected behavior when course_id does not reference an existing course or is invalid.
|
||||
@@ -252,6 +287,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
|
||||
self.users[user], course.id, check_access=True
|
||||
)
|
||||
|
||||
# Django Rest Framework v3 requires us to pass a request to serializers
|
||||
# that have URL fields. Since we're invoking this code outside the context
|
||||
# of a request, we need to simulate that there's a request.
|
||||
self.solar_team.add_user(self.users['student_enrolled'])
|
||||
self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team'])
|
||||
self.another_team.add_user(self.users['student_enrolled_both_courses_other_team'])
|
||||
@@ -311,7 +349,17 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
|
||||
response = func(url, data=data, content_type=content_type)
|
||||
else:
|
||||
response = func(url, data=data)
|
||||
self.assertEqual(expected_status, response.status_code)
|
||||
|
||||
self.assertEqual(
|
||||
expected_status,
|
||||
response.status_code,
|
||||
msg="Expected status {expected} but got {actual}: {content}".format(
|
||||
expected=expected_status,
|
||||
actual=response.status_code,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
|
||||
if expected_status == 200:
|
||||
return json.loads(response.content)
|
||||
else:
|
||||
|
||||
@@ -5,19 +5,14 @@ import logging
|
||||
from django.shortcuts import get_object_or_404, render_to_response
|
||||
from django.http import Http404
|
||||
from django.conf import settings
|
||||
from django.core.paginator import Paginator
|
||||
from django.views.generic.base import View
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.authentication import (
|
||||
SessionAuthentication,
|
||||
OAuth2Authentication
|
||||
)
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from rest_framework import status
|
||||
from rest_framework import permissions
|
||||
from django.db.models import Count
|
||||
from django.db.models.signals import post_save
|
||||
from django.dispatch import receiver
|
||||
from django.contrib.auth.models import User
|
||||
@@ -32,8 +27,7 @@ from openedx.core.lib.api.view_utils import (
|
||||
build_api_error,
|
||||
ExpandableFieldViewMixin
|
||||
)
|
||||
from openedx.core.lib.api.serializers import PaginationSerializer
|
||||
from openedx.core.lib.api.paginators import paginate_search_results
|
||||
from openedx.core.lib.api.paginators import paginate_search_results, DefaultPagination
|
||||
from xmodule.modulestore.django import modulestore
|
||||
from opaque_keys import InvalidKeyError
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
@@ -49,10 +43,8 @@ from .serializers import (
|
||||
CourseTeamSerializer,
|
||||
CourseTeamCreationSerializer,
|
||||
TopicSerializer,
|
||||
PaginatedTopicSerializer,
|
||||
BulkTeamCountPaginatedTopicSerializer,
|
||||
BulkTeamCountTopicSerializer,
|
||||
MembershipSerializer,
|
||||
PaginatedMembershipSerializer,
|
||||
add_team_count
|
||||
)
|
||||
from .search_indexes import CourseTeamIndexer
|
||||
@@ -85,7 +77,42 @@ def team_post_save_callback(sender, instance, **kwargs): # pylint: disable=unus
|
||||
)
|
||||
|
||||
|
||||
class TeamsDashboardView(View):
|
||||
class TeamAPIPagination(DefaultPagination):
|
||||
"""
|
||||
Pagination format used by the teams API.
|
||||
"""
|
||||
page_size_query_param = "page_size"
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""
|
||||
Annotate the response with pagination information.
|
||||
"""
|
||||
response = super(TeamAPIPagination, self).get_paginated_response(data)
|
||||
|
||||
# Add the current page to the response.
|
||||
# It may make sense to eventually move this field into the default
|
||||
# implementation, but for now, teams is the only API that uses this.
|
||||
response.data["current_page"] = self.page.number
|
||||
|
||||
# This field can be derived from other fields in the response,
|
||||
# so it may make sense to have the JavaScript client calculate it
|
||||
# instead of including it in the response.
|
||||
response.data["start"] = (self.page.number - 1) * self.get_page_size(self.request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class TopicsPagination(TeamAPIPagination):
|
||||
"""Paginate topics. """
|
||||
page_size = TOPICS_PER_PAGE
|
||||
|
||||
|
||||
class MembershipPagination(TeamAPIPagination):
|
||||
"""Paginate memberships. """
|
||||
page_size = TEAM_MEMBERSHIPS_PER_PAGE
|
||||
|
||||
|
||||
class TeamsDashboardView(GenericAPIView):
|
||||
"""
|
||||
View methods related to the teams dashboard.
|
||||
"""
|
||||
@@ -107,29 +134,38 @@ class TeamsDashboardView(View):
|
||||
not has_access(request.user, 'staff', course, course.id):
|
||||
raise Http404
|
||||
|
||||
user = request.user
|
||||
|
||||
# Even though sorting is done outside of the serializer, sort_order needs to be passed
|
||||
# to the serializer so that the paginated results indicate how they were sorted.
|
||||
sort_order = 'name'
|
||||
topics = get_alphabetical_topics(course)
|
||||
topics_page = Paginator(topics, TOPICS_PER_PAGE).page(1)
|
||||
|
||||
# Paginate and serialize topic data
|
||||
# BulkTeamCountPaginatedTopicSerializer will add team counts to the topics in a single
|
||||
# bulk operation per page.
|
||||
topics_serializer = BulkTeamCountPaginatedTopicSerializer(
|
||||
instance=topics_page,
|
||||
context={'course_id': course.id, 'sort_order': sort_order}
|
||||
topics_data = self._serialize_and_paginate(
|
||||
TopicsPagination,
|
||||
topics,
|
||||
request,
|
||||
BulkTeamCountTopicSerializer,
|
||||
{'course_id': course.id},
|
||||
)
|
||||
user = request.user
|
||||
topics_data["sort_order"] = sort_order
|
||||
|
||||
team_memberships = CourseTeamMembership.get_memberships(request.user.username, [course.id])
|
||||
team_memberships_page = Paginator(team_memberships, TEAM_MEMBERSHIPS_PER_PAGE).page(1)
|
||||
team_memberships_serializer = PaginatedMembershipSerializer(
|
||||
instance=team_memberships_page,
|
||||
context={'expand': ('team', 'user'), 'request': request},
|
||||
# Paginate and serialize team membership data.
|
||||
team_memberships = CourseTeamMembership.get_memberships(user.username, [course.id])
|
||||
memberships_data = self._serialize_and_paginate(
|
||||
MembershipPagination,
|
||||
team_memberships,
|
||||
request,
|
||||
MembershipSerializer,
|
||||
{'expand': ('team', 'user',)}
|
||||
)
|
||||
|
||||
context = {
|
||||
"course": course,
|
||||
"topics": topics_serializer.data,
|
||||
"topics": topics_data,
|
||||
# It is necessary to pass both privileged and staff because only privileged users can
|
||||
# administer discussion threads, but both privileged and staff users are allowed to create
|
||||
# multiple teams (since they are not automatically added to teams upon creation).
|
||||
@@ -137,7 +173,7 @@ class TeamsDashboardView(View):
|
||||
"username": user.username,
|
||||
"privileged": has_discussion_privileges(user, course_key),
|
||||
"staff": bool(has_access(user, 'staff', course_key)),
|
||||
"team_memberships_data": team_memberships_serializer.data,
|
||||
"team_memberships_data": memberships_data,
|
||||
},
|
||||
"topic_url": reverse(
|
||||
'topics_detail', kwargs={'topic_id': 'topic_id', 'course_id': str(course_id)}, request=request
|
||||
@@ -154,6 +190,39 @@ class TeamsDashboardView(View):
|
||||
}
|
||||
return render_to_response("teams/teams.html", context)
|
||||
|
||||
def _serialize_and_paginate(self, pagination_cls, queryset, request, serializer_cls, serializer_ctx):
|
||||
"""
|
||||
Serialize and paginate objects in a queryset.
|
||||
|
||||
Arguments:
|
||||
pagination_cls (pagination.Paginator class): Django Rest Framework Paginator subclass.
|
||||
queryset (QuerySet): Django queryset to serialize/paginate.
|
||||
serializer_cls (serializers.Serializer class): Django Rest Framework Serializer subclass.
|
||||
serializer_ctx (dict): Context dictionary to pass to the serializer
|
||||
|
||||
Returns: dict
|
||||
|
||||
"""
|
||||
# Django Rest Framework v3 requires that we pass the request
|
||||
# into the serializer's context if the serialize contains
|
||||
# hyperlink fields.
|
||||
serializer_ctx["request"] = request
|
||||
|
||||
# Instantiate the paginator and use it to paginate the queryset
|
||||
paginator = pagination_cls()
|
||||
page = paginator.paginate_queryset(queryset, request)
|
||||
|
||||
# Serialize the page
|
||||
serializer = serializer_cls(page, context=serializer_ctx, many=True)
|
||||
|
||||
# Use the paginator to construct the response data
|
||||
# This will use the pagination subclass for the view to add additional
|
||||
# fields to the response.
|
||||
# For example, if the input data is a list, the output data would
|
||||
# be a dictionary with keys "count", "next", "previous", and "results"
|
||||
# (where "results" is set to the value of the original list)
|
||||
return paginator.get_paginated_response(serializer.data).data
|
||||
|
||||
|
||||
def has_team_api_access(user, course_key, access_username=None):
|
||||
"""Returns True if the user has access to the Team API for the course
|
||||
@@ -307,11 +376,8 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
# OAuth2Authentication must come first to return a 401 for unauthenticated users
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication)
|
||||
permission_classes = (permissions.IsAuthenticated,)
|
||||
|
||||
paginate_by = 10
|
||||
paginate_by_param = 'page_size'
|
||||
pagination_serializer_class = PaginationSerializer
|
||||
serializer_class = CourseTeamSerializer
|
||||
pagination_class = TeamAPIPagination
|
||||
|
||||
def get(self, request):
|
||||
"""GET /api/team/v0/teams/"""
|
||||
@@ -377,15 +443,18 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
paginated_results = paginate_search_results(
|
||||
CourseTeam,
|
||||
search_results,
|
||||
self.get_paginate_by(),
|
||||
self.paginator.get_page_size(request),
|
||||
self.get_page()
|
||||
)
|
||||
serializer = self.get_pagination_serializer(paginated_results)
|
||||
emit_team_event('edx.team.searched', course_key, {
|
||||
"number_of_results": search_results['total'],
|
||||
"search_text": text_search,
|
||||
"topic_id": topic_id,
|
||||
})
|
||||
|
||||
page = self.paginate_queryset(paginated_results)
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
order_by_input = None
|
||||
else:
|
||||
queryset = CourseTeam.objects.filter(**result_filter)
|
||||
order_by_input = request.QUERY_PARAMS.get('order_by', 'name')
|
||||
@@ -407,10 +476,12 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
serializer = self.get_pagination_serializer(page)
|
||||
serializer.context.update({'sort_order': order_by_input}) # pylint: disable=maybe-no-member
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
|
||||
return Response(serializer.data) # pylint: disable=maybe-no-member
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
if order_by_input is not None:
|
||||
response.data['sort_order'] = order_by_input
|
||||
return response
|
||||
|
||||
def post(self, request):
|
||||
"""POST /api/team/v0/teams/"""
|
||||
@@ -473,14 +544,16 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
'add_method': 'added_on_create'
|
||||
}
|
||||
)
|
||||
return Response(CourseTeamSerializer(team).data)
|
||||
|
||||
data = CourseTeamSerializer(team, context={"request": request}).data
|
||||
return Response(data)
|
||||
|
||||
def get_page(self):
|
||||
""" Returns page number specified in args, params, or defaults to 1. """
|
||||
# This code is taken from within the GenericAPIView#paginate_queryset method.
|
||||
# We need need access to the page outside of that method for our paginate_search_results method
|
||||
page_kwarg = self.kwargs.get(self.page_kwarg)
|
||||
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
|
||||
page_kwarg = self.kwargs.get(self.paginator.page_query_param)
|
||||
page_query_param = self.request.QUERY_PARAMS.get(self.paginator.page_query_param)
|
||||
return page_kwarg or page_query_param or 1
|
||||
|
||||
|
||||
@@ -692,9 +765,7 @@ class TopicListView(GenericAPIView):
|
||||
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication)
|
||||
permission_classes = (permissions.IsAuthenticated,)
|
||||
|
||||
paginate_by = TOPICS_PER_PAGE
|
||||
paginate_by_param = 'page_size'
|
||||
pagination_class = TopicsPagination
|
||||
|
||||
def get(self, request):
|
||||
"""GET /api/team/v0/topics/?course_id={course_id}"""
|
||||
@@ -740,18 +811,20 @@ class TopicListView(GenericAPIView):
|
||||
add_team_count(topics, course_id)
|
||||
topics.sort(key=lambda t: t['team_count'], reverse=True)
|
||||
page = self.paginate_queryset(topics)
|
||||
# Since team_count has already been added to all the topics, use PaginatedTopicSerializer.
|
||||
# Even though sorting is done outside of the serializer, sort_order needs to be passed
|
||||
# to the serializer so that the paginated results indicate how they were sorted.
|
||||
serializer = PaginatedTopicSerializer(page, context={'course_id': course_id, 'sort_order': ordering})
|
||||
serializer = TopicSerializer(
|
||||
page,
|
||||
context={'course_id': course_id},
|
||||
many=True,
|
||||
)
|
||||
else:
|
||||
page = self.paginate_queryset(topics)
|
||||
# Use the serializer that adds team_count in a bulk operation per page.
|
||||
serializer = BulkTeamCountPaginatedTopicSerializer(
|
||||
page, context={'course_id': course_id, 'sort_order': ordering}
|
||||
)
|
||||
serializer = BulkTeamCountTopicSerializer(page, context={'course_id': course_id}, many=True)
|
||||
|
||||
return Response(serializer.data)
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
response.data['sort_order'] = ordering
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_alphabetical_topics(course_module):
|
||||
@@ -960,13 +1033,8 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication)
|
||||
permission_classes = (permissions.IsAuthenticated,)
|
||||
|
||||
serializer_class = MembershipSerializer
|
||||
|
||||
paginate_by = 10
|
||||
paginate_by_param = 'page_size'
|
||||
pagination_serializer_class = PaginationSerializer
|
||||
|
||||
def get(self, request):
|
||||
"""GET /api/team/v0/team_membership"""
|
||||
specified_username_or_team = False
|
||||
@@ -1023,8 +1091,8 @@ class MembershipListView(ExpandableFieldViewMixin, GenericAPIView):
|
||||
|
||||
queryset = CourseTeamMembership.get_memberships(username, course_keys, team_id)
|
||||
page = self.paginate_queryset(queryset)
|
||||
serializer = self.get_pagination_serializer(page)
|
||||
return Response(serializer.data) # pylint: disable=maybe-no-member
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
def post(self, request):
|
||||
"""POST /api/team/v0/team_membership"""
|
||||
|
||||
@@ -1866,6 +1866,12 @@ INSTALLED_APPS = (
|
||||
'provider.oauth2',
|
||||
'oauth2_provider',
|
||||
|
||||
# We don't use this directly (since we use OAuth2), but we need to install it anyway.
|
||||
# When a user is deleted, Django queries all tables with a FK to the auth_user table,
|
||||
# and since django-rest-framework-oauth imports this, it will try to access tables
|
||||
# defined by oauth_provider. If those tables don't exist, an error can occur.
|
||||
'oauth_provider',
|
||||
|
||||
'auth_exchange',
|
||||
|
||||
# For the wiki
|
||||
@@ -1981,6 +1987,14 @@ INSTALLED_APPS = (
|
||||
CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52
|
||||
|
||||
|
||||
######################### Django Rest Framework ########################
|
||||
|
||||
REST_FRAMEWORK = {
|
||||
'DEFAULT_PAGINATION_CLASS': 'openedx.core.lib.api.paginators.DefaultPagination',
|
||||
'PAGE_SIZE': 10,
|
||||
}
|
||||
|
||||
|
||||
######################### MARKETING SITE ###############################
|
||||
EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin'
|
||||
EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info'
|
||||
|
||||
@@ -9,16 +9,12 @@ from django.conf import settings
|
||||
# Force settings to run so that the python path is modified
|
||||
settings.INSTALLED_APPS # pylint: disable=pointless-statement
|
||||
|
||||
from instructor.services import InstructorService
|
||||
|
||||
from openedx.core.lib.django_startup import autostartup
|
||||
import edxmako
|
||||
import logging
|
||||
from monkey_patch import django_utils_translation
|
||||
import analytics
|
||||
|
||||
from edx_proctoring.runtime import set_runtime_service
|
||||
from openedx.core.djangoapps.credit.services import CreditService
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +47,11 @@ def run():
|
||||
# right now edx_proctoring is dependent on the openedx.core.djangoapps.credit
|
||||
# as well as the instructor dashboard (for deleting student attempts)
|
||||
if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'):
|
||||
# Import these here to avoid circular dependencies of the form:
|
||||
# edx-platform app --> DRF --> django translation --> edx-platform app
|
||||
from edx_proctoring.runtime import set_runtime_service
|
||||
from instructor.services import InstructorService
|
||||
from openedx.core.djangoapps.credit.services import CreditService
|
||||
set_runtime_service('credit', CreditService())
|
||||
set_runtime_service('instructor', InstructorService())
|
||||
|
||||
|
||||
@@ -124,4 +124,4 @@ def course_grading_policy(course_key):
|
||||
final grade.
|
||||
"""
|
||||
course = _retrieve_course(course_key)
|
||||
return GradingPolicySerializer(course.raw_grader).data
|
||||
return GradingPolicySerializer(course.raw_grader, many=True).data
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
API Serializers
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
@@ -11,23 +13,58 @@ class GradingPolicySerializer(serializers.Serializer):
|
||||
dropped = serializers.IntegerField(source='drop_count')
|
||||
weight = serializers.FloatField()
|
||||
|
||||
def to_representation(self, obj):
|
||||
"""
|
||||
Return a representation of the grading policy.
|
||||
"""
|
||||
# Backwards compatibility with the behavior of DRF v2.
|
||||
# When the grader dictionary was missing keys, DRF v2 would default to None;
|
||||
# DRF v3 unhelpfully raises an exception.
|
||||
return dict(
|
||||
super(GradingPolicySerializer, self).to_representation(
|
||||
defaultdict(lambda: None, obj)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
class BlockSerializer(serializers.Serializer):
|
||||
""" Serializer for course structure block. """
|
||||
id = serializers.CharField(source='usage_key')
|
||||
type = serializers.CharField(source='block_type')
|
||||
parent = serializers.CharField(source='parent')
|
||||
parent = serializers.CharField(required=False)
|
||||
display_name = serializers.CharField()
|
||||
graded = serializers.BooleanField(default=False)
|
||||
format = serializers.CharField()
|
||||
children = serializers.CharField()
|
||||
|
||||
def to_representation(self, obj):
|
||||
"""
|
||||
Return a representation of the block.
|
||||
|
||||
NOTE: this method maintains backwards compatibility with the behavior
|
||||
of Django Rest Framework v2.
|
||||
"""
|
||||
data = super(BlockSerializer, self).to_representation(obj)
|
||||
|
||||
# Backwards compatibility with the behavior of DRF v2
|
||||
# Include a NULL value for "parent" in the representation
|
||||
# (instead of excluding the key entirely)
|
||||
if obj.get("parent") is None:
|
||||
data["parent"] = None
|
||||
|
||||
# Backwards compatibility with the behavior of DRF v2
|
||||
# Leave the children list as a list instead of serializing
|
||||
# it to a string.
|
||||
data["children"] = obj["children"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class CourseStructureSerializer(serializers.Serializer):
|
||||
""" Serializer for course structure. """
|
||||
root = serializers.CharField(source='root')
|
||||
blocks = serializers.SerializerMethodField('get_blocks')
|
||||
root = serializers.CharField()
|
||||
blocks = serializers.SerializerMethodField()
|
||||
|
||||
def get_blocks(self, structure):
|
||||
""" Serialize the individual blocks. """
|
||||
|
||||
@@ -2,12 +2,33 @@
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
from opaque_keys import InvalidKeyError
|
||||
from openedx.core.djangoapps.credit.models import CreditCourse
|
||||
|
||||
|
||||
class CourseKeyField(serializers.Field):
|
||||
"""
|
||||
Serializer field for a model CourseKey field.
|
||||
"""
|
||||
|
||||
def to_representation(self, data):
|
||||
"""Convert a course key to unicode. """
|
||||
return unicode(data)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""Convert unicode to a course key. """
|
||||
try:
|
||||
return CourseKey.from_string(data)
|
||||
except InvalidKeyError as ex:
|
||||
raise serializers.ValidationError("Invalid course key: {msg}".format(msg=ex.msg))
|
||||
|
||||
|
||||
class CreditCourseSerializer(serializers.ModelSerializer):
|
||||
""" CreditCourse Serializer """
|
||||
|
||||
course_key = CourseKeyField()
|
||||
|
||||
class Meta(object): # pylint: disable=missing-docstring
|
||||
model = CreditCourse
|
||||
exclude = ('id',)
|
||||
|
||||
@@ -393,10 +393,7 @@ class CreditCourseViewSetTests(TestCase):
|
||||
|
||||
# POSTs without a CSRF token should fail.
|
||||
response = client.post(self.path, data=json.dumps(data), content_type=JSON)
|
||||
|
||||
# NOTE (CCB): Ordinarily we would expect a 403; however, since the CSRF validation and session authentication
|
||||
# fail, DRF considers the request to be unauthenticated.
|
||||
self.assertEqual(response.status_code, 401)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertIn('CSRF', response.content)
|
||||
|
||||
# Retrieve a CSRF token
|
||||
|
||||
@@ -18,8 +18,9 @@ from django.views.decorators.http import require_POST, require_GET
|
||||
from opaque_keys import InvalidKeyError
|
||||
from opaque_keys.edx.keys import CourseKey
|
||||
import pytz
|
||||
from rest_framework import viewsets, mixins, permissions, authentication
|
||||
|
||||
from rest_framework import viewsets, mixins, permissions
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from util.json_request import JsonResponse
|
||||
from util.date_utils import from_timestamp
|
||||
from openedx.core.djangoapps.credit import api
|
||||
@@ -377,17 +378,28 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view
|
||||
lookup_value_regex = settings.COURSE_KEY_REGEX
|
||||
queryset = CreditCourse.objects.all()
|
||||
serializer_class = CreditCourseSerializer
|
||||
authentication_classes = (authentication.OAuth2Authentication, authentication.SessionAuthentication,)
|
||||
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
|
||||
permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser)
|
||||
|
||||
# In Django Rest Framework v3, there is a default pagination
|
||||
# class that transmutes the response data into a dictionary
|
||||
# with pagination information. The original response data (a list)
|
||||
# is stored in a "results" value of the dictionary.
|
||||
# For backwards compatibility with the existing API, we disable
|
||||
# the default behavior by setting the pagination_class to None.
|
||||
pagination_class = None
|
||||
|
||||
# This CSRF exemption only applies when authenticating without SessionAuthentication.
|
||||
# SessionAuthentication will enforce CSRF protection.
|
||||
@method_decorator(csrf_exempt)
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
# Convert the course ID/key from a string to an actual CourseKey object.
|
||||
course_id = kwargs.get(self.lookup_field, None)
|
||||
|
||||
if course_id:
|
||||
kwargs[self.lookup_field] = CourseKey.from_string(course_id)
|
||||
|
||||
return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_object(self):
|
||||
# Convert the serialized course key into a CourseKey instance
|
||||
# so we can look up the object.
|
||||
course_key = self.kwargs.get(self.lookup_field)
|
||||
if course_key is not None:
|
||||
self.kwargs[self.lookup_field] = CourseKey.from_string(course_key)
|
||||
|
||||
return super(CreditCourseViewSet, self).get_object()
|
||||
|
||||
@@ -30,6 +30,40 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
|
||||
TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC)
|
||||
|
||||
|
||||
class PatchedClient(APIClient):
|
||||
"""
|
||||
Patch DRF's APIClient to avoid a unicode error on file upload.
|
||||
|
||||
Famous last words: This is a *temporary* fix that we should be
|
||||
able to remove once we upgrade Django past 1.4.
|
||||
"""
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
"""Construct an API request. """
|
||||
# DRF's default test client implementation uses `six.text_type()`
|
||||
# to convert the CONTENT_TYPE to `unicode`. In Django 1.4,
|
||||
# this causes a `UnicodeDecodeError` when Django parses a multipart
|
||||
# upload.
|
||||
#
|
||||
# This is the DRF code we're working around:
|
||||
# https://github.com/tomchristie/django-rest-framework/blob/3.1.3/rest_framework/compat.py#L227
|
||||
#
|
||||
# ... and this is the Django code that raises the exception:
|
||||
#
|
||||
# https://github.com/django/django/blob/1.4.22/django/http/multipartparser.py#L435
|
||||
#
|
||||
# Django unhelpfully swallows the exception, so to the application code
|
||||
# it appears as though the user didn't send any file data.
|
||||
#
|
||||
# This appears to be an issue only with requests constructed in the test
|
||||
# suite, not with the upload code used in production.
|
||||
#
|
||||
if isinstance(kwargs.get("CONTENT_TYPE"), basestring):
|
||||
kwargs["CONTENT_TYPE"] = str(kwargs["CONTENT_TYPE"])
|
||||
|
||||
return super(PatchedClient, self).request(*args, **kwargs)
|
||||
|
||||
|
||||
class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase):
|
||||
"""
|
||||
Base class / shared infrastructure for tests of profile_image "upload" and
|
||||
@@ -111,6 +145,10 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase):
|
||||
"""
|
||||
_view_name = "profile_image_upload"
|
||||
|
||||
# Use the patched version of the API client to workaround a unicode issue
|
||||
# with DRF 3.1 and Django 1.4. Remove this after we upgrade Django past 1.4!
|
||||
client_class = PatchedClient
|
||||
|
||||
def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT):
|
||||
"""
|
||||
Make sure we emit a UserProfile event corresponding to the
|
||||
|
||||
@@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None):
|
||||
serializer.save()
|
||||
|
||||
if "language_proficiencies" in update:
|
||||
new_language_proficiencies = legacy_profile_serializer.data["language_proficiencies"]
|
||||
new_language_proficiencies = update["language_proficiencies"]
|
||||
emit_setting_changed_event(
|
||||
user=existing_user,
|
||||
db_table=existing_user_profile.language_proficiencies.model._meta.db_table,
|
||||
|
||||
@@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer):
|
||||
|
||||
super(UserReadOnlySerializer, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_native(self, user):
|
||||
def to_representation(self, user):
|
||||
"""
|
||||
Overwrite to_native to handle custom logic since we are serializing two models as one here
|
||||
:param user: User object
|
||||
@@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
|
||||
Class that serializes the portion of UserProfile model needed for account information.
|
||||
"""
|
||||
profile_image = serializers.SerializerMethodField("_get_profile_image")
|
||||
requires_parental_consent = serializers.SerializerMethodField("get_requires_parental_consent")
|
||||
language_proficiencies = LanguageProficiencySerializer(many=True, allow_add_remove=True, required=False)
|
||||
requires_parental_consent = serializers.SerializerMethodField()
|
||||
language_proficiencies = LanguageProficiencySerializer(many=True, required=False)
|
||||
|
||||
class Meta(object): # pylint: disable=missing-docstring
|
||||
model = UserProfile
|
||||
@@ -165,25 +165,21 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
|
||||
read_only_fields = ()
|
||||
explicit_read_only_fields = ("profile_image", "requires_parental_consent")
|
||||
|
||||
def validate_name(self, attrs, source):
|
||||
def validate_name(self, new_name):
|
||||
""" Enforce minimum length for name. """
|
||||
if source in attrs:
|
||||
new_name = attrs[source].strip()
|
||||
if len(new_name) < NAME_MIN_LENGTH:
|
||||
raise serializers.ValidationError(
|
||||
"The name field must be at least {} characters long.".format(NAME_MIN_LENGTH)
|
||||
)
|
||||
attrs[source] = new_name
|
||||
if len(new_name) < NAME_MIN_LENGTH:
|
||||
raise serializers.ValidationError(
|
||||
"The name field must be at least {} characters long.".format(NAME_MIN_LENGTH)
|
||||
)
|
||||
return new_name
|
||||
|
||||
return attrs
|
||||
|
||||
def validate_language_proficiencies(self, attrs, source):
|
||||
def validate_language_proficiencies(self, value):
|
||||
""" Enforce all languages are unique. """
|
||||
language_proficiencies = [language for language in attrs.get(source, [])]
|
||||
unique_language_proficiencies = set(language.code for language in language_proficiencies)
|
||||
language_proficiencies = [language for language in value]
|
||||
unique_language_proficiencies = set(language["code"] for language in language_proficiencies)
|
||||
if len(language_proficiencies) != len(unique_language_proficiencies):
|
||||
raise serializers.ValidationError("The language_proficiencies field must consist of unique languages")
|
||||
return attrs
|
||||
return value
|
||||
|
||||
def transform_gender(self, user_profile, value):
|
||||
""" Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """
|
||||
@@ -230,3 +226,29 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
|
||||
call the method with a single argument, the user_profile object.
|
||||
"""
|
||||
return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
"""
|
||||
Update the profile, including nested fields.
|
||||
"""
|
||||
language_proficiencies = validated_data.pop("language_proficiencies", None)
|
||||
|
||||
# Update all fields on the user profile that are writeable,
|
||||
# except for "language_proficiencies", which we'll update separately
|
||||
update_fields = set(self.get_writeable_fields()) - set(["language_proficiencies"])
|
||||
for field_name in update_fields:
|
||||
default = getattr(instance, field_name)
|
||||
field_value = validated_data.get(field_name, default)
|
||||
setattr(instance, field_name, field_value)
|
||||
|
||||
instance.save()
|
||||
|
||||
# Now update the related language proficiency
|
||||
if language_proficiencies is not None:
|
||||
instance.language_proficiencies.all().delete()
|
||||
instance.language_proficiencies.bulk_create([
|
||||
LanguageProficiency(user_profile=instance, code=language["code"])
|
||||
for language in language_proficiencies
|
||||
])
|
||||
|
||||
return instance
|
||||
|
||||
@@ -164,7 +164,10 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
|
||||
field_errors = context_manager.exception.field_errors
|
||||
self.assertEqual(3, len(field_errors))
|
||||
self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"])
|
||||
self.assertIn("Select a valid choice", field_errors["gender"]["developer_message"])
|
||||
self.assertIn(
|
||||
"Value \'undecided\' is not valid for field \'gender\'",
|
||||
field_errors["gender"]["developer_message"]
|
||||
)
|
||||
self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"])
|
||||
|
||||
@patch('django.core.mail.send_mail')
|
||||
|
||||
@@ -359,16 +359,19 @@ class TestAccountAPI(UserAPITestCase):
|
||||
self.assertEqual(404, response.status_code)
|
||||
|
||||
@ddt.data(
|
||||
("gender", "f", "not a gender", u"Select a valid choice. not a gender is not one of the available choices."),
|
||||
("level_of_education", "none", u"ȻħȺɍłɇs", u"Select a valid choice. ȻħȺɍłɇs is not one of the available choices."),
|
||||
("country", "GB", "XY", u"Select a valid choice. XY is not one of the available choices."),
|
||||
("year_of_birth", 2009, "not_an_int", u"Enter a whole number."),
|
||||
("name", "bob", "z" * 256, u"Ensure this value has at most 255 characters (it has 256)."),
|
||||
("gender", "f", "not a gender", u'"not a gender" is not a valid choice.'),
|
||||
("level_of_education", "none", u"ȻħȺɍłɇs", u'"ȻħȺɍłɇs" is not a valid choice.'),
|
||||
("country", "GB", "XY", u'"XY" is not a valid choice.'),
|
||||
("year_of_birth", 2009, "not_an_int", u"A valid integer is required."),
|
||||
("name", "bob", "z" * 256, u"Ensure this field has no more than 255 characters."),
|
||||
("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."),
|
||||
("goals", "Smell the roses"),
|
||||
("mailing_address", "Sesame Street"),
|
||||
# Note that we store the raw data, so it is up to client to escape the HTML.
|
||||
("bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>", "z" * 3001, u"Ensure this value has at most 3000 characters (it has 3001)."),
|
||||
(
|
||||
"bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>",
|
||||
"z" * 3001, u"Ensure this field has no more than 3000 characters."
|
||||
),
|
||||
# Note that email is tested below, as it is not immediately updated.
|
||||
# Note that language_proficiencies is tested below as there are multiple error and success conditions.
|
||||
)
|
||||
@@ -568,10 +571,10 @@ class TestAccountAPI(UserAPITestCase):
|
||||
self.assertItemsEqual(response.data["language_proficiencies"], proficiencies)
|
||||
|
||||
@ddt.data(
|
||||
(u"not_a_list", [{u'non_field_errors': [u'Expected a list of items.']}]),
|
||||
([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data']}]),
|
||||
(u"not_a_list", {u'non_field_errors': [u'Expected a list of items but got type "unicode".']}),
|
||||
([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data. Expected a dictionary, but got unicode.']}]),
|
||||
([{}], [{"code": [u"This field is required."]}]),
|
||||
([{u"code": u"invalid_language_code"}], [{'code': [u'Select a valid choice. invalid_language_code is not one of the available choices.']}]),
|
||||
([{u"code": u"invalid_language_code"}], [{'code': [u'"invalid_language_code" is not a valid choice.']}]),
|
||||
([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']),
|
||||
)
|
||||
@ddt.unpack
|
||||
|
||||
@@ -9,9 +9,10 @@ from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import IntegrityError
|
||||
from django.utils.translation import ugettext as _
|
||||
from student.models import User, UserProfile
|
||||
from django.utils.translation import ugettext_noop
|
||||
|
||||
from student.models import User, UserProfile
|
||||
from request_cache import get_request_or_stub
|
||||
from ..errors import (
|
||||
UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized,
|
||||
PreferenceValidationError, PreferenceUpdateError
|
||||
@@ -68,7 +69,17 @@ def get_user_preferences(requesting_user, username=None):
|
||||
UserAPIInternalError: the operation failed due to an unexpected error.
|
||||
"""
|
||||
existing_user = _get_user(requesting_user, username, allow_staff=True)
|
||||
user_serializer = UserSerializer(existing_user)
|
||||
|
||||
# Django Rest Framework V3 uses the current request to version
|
||||
# hyperlinked URLS, so we need to retrieve the request and pass
|
||||
# it in the serializer's context (otherwise we get an AssertionError).
|
||||
# We're retrieving the request from the cache rather than passing it in
|
||||
# as an argument because this is an implementation detail of how we're
|
||||
# serializing data, which we want to encapsulate in the API call.
|
||||
context = {
|
||||
"request": get_request_or_stub()
|
||||
}
|
||||
user_serializer = UserSerializer(existing_user, context=context)
|
||||
return user_serializer.data["preferences"]
|
||||
|
||||
|
||||
@@ -356,7 +367,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v
|
||||
developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format(
|
||||
preference_key=preference_key, preference_value=preference_value, error=serializer.errors
|
||||
)
|
||||
if serializer.errors["key"]:
|
||||
if "key" in serializer.errors:
|
||||
user_message = _(u"Invalid user preference key '{preference_key}'.").format(
|
||||
preference_key=preference_key
|
||||
)
|
||||
|
||||
@@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value):
|
||||
preference_key=preference_key,
|
||||
preference_value=preference_value,
|
||||
error={
|
||||
"key": [u"Ensure this value has at most 255 characters (it has 256)."]
|
||||
"key": [u"Ensure this field has no more than 255 characters."]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from .models import UserPreference
|
||||
|
||||
|
||||
class UserSerializer(serializers.HyperlinkedModelSerializer):
|
||||
name = serializers.SerializerMethodField("get_name")
|
||||
preferences = serializers.SerializerMethodField("get_preferences")
|
||||
name = serializers.SerializerMethodField()
|
||||
preferences = serializers.SerializerMethodField()
|
||||
|
||||
def get_name(self, user):
|
||||
profile = UserProfile.objects.get(user=user)
|
||||
@@ -32,9 +32,10 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer):
|
||||
|
||||
|
||||
class RawUserPreferenceSerializer(serializers.ModelSerializer):
|
||||
"""Serializer that generates a raw representation of a user preference.
|
||||
"""
|
||||
user = serializers.PrimaryKeyRelatedField()
|
||||
Serializer that generates a raw representation of a user preference.
|
||||
"""
|
||||
user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all())
|
||||
|
||||
class Meta(object): # pylint: disable=missing-docstring
|
||||
model = UserPreference
|
||||
@@ -57,3 +58,11 @@ class ReadOnlyFieldsSerializerMixin(object):
|
||||
cls.Meta.read_only_fields tuple.
|
||||
"""
|
||||
return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '')
|
||||
|
||||
@classmethod
|
||||
def get_writeable_fields(cls):
|
||||
"""
|
||||
Return all fields on this serializer that are writeable.
|
||||
"""
|
||||
all_fields = getattr(cls.Meta, 'fields', tuple())
|
||||
return tuple(set(all_fields) - set(cls.get_read_only_fields()))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
""" Common Authentication Handlers used across projects. """
|
||||
from rest_framework import authentication
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework_oauth.authentication import OAuth2Authentication
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from rest_framework.compat import oauth2_provider, provider_now
|
||||
from rest_framework_oauth.compat import oauth2_provider, provider_now
|
||||
|
||||
|
||||
class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthentication):
|
||||
class SessionAuthenticationAllowInactiveUser(SessionAuthentication):
|
||||
"""Ensure that the user is logged in, but do not require the account to be active.
|
||||
|
||||
We use this in the special case that a user has created an account,
|
||||
@@ -51,7 +52,7 @@ class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthenticatio
|
||||
return (user, None)
|
||||
|
||||
|
||||
class OAuth2AuthenticationAllowInactiveUser(authentication.OAuth2Authentication):
|
||||
class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication):
|
||||
"""
|
||||
This is a temporary workaround while the is_active field on the user is coupled
|
||||
with whether or not the user has verified ownership of their claimed email address.
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Fields useful for edX API implementations."""
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from rest_framework.serializers import CharField, Field
|
||||
from rest_framework.serializers import Field
|
||||
|
||||
|
||||
class ExpandableField(Field):
|
||||
@@ -18,25 +16,19 @@ class ExpandableField(Field):
|
||||
self.expanded = kwargs.pop('expanded_serializer')
|
||||
super(ExpandableField, self).__init__(**kwargs)
|
||||
|
||||
def field_to_native(self, obj, field_name):
|
||||
"""Converts obj to a native representation, using the expanded serializer if the context requires it."""
|
||||
if 'expand' in self.context and field_name in self.context['expand']:
|
||||
self.expanded.initialize(self, field_name)
|
||||
return self.expanded.field_to_native(obj, field_name)
|
||||
else:
|
||||
self.collapsed.initialize(self, field_name)
|
||||
return self.collapsed.field_to_native(obj, field_name)
|
||||
def to_representation(self, obj):
|
||||
"""
|
||||
Return a representation of the field that is either expanded or collapsed.
|
||||
"""
|
||||
should_expand = self.field_name in self.context.get("expand", [])
|
||||
field = self.expanded if should_expand else self.collapsed
|
||||
|
||||
# Avoid double-binding the field, otherwise we'll get
|
||||
# an error about the source kwarg being redundant.
|
||||
if field.source is None:
|
||||
field.bind(self.field_name, self)
|
||||
|
||||
class NonEmptyCharField(CharField):
|
||||
"""
|
||||
A field that enforces non-emptiness even for partial updates.
|
||||
if should_expand:
|
||||
self.expanded.context["expand"] = set(field.context.get("expand", []))
|
||||
|
||||
This is necessary because prior to version 3, DRF skips validation for empty
|
||||
values. Thus, CharField's min_length and RegexField cannot be used to
|
||||
enforce this constraint.
|
||||
"""
|
||||
def validate(self, value):
|
||||
super(NonEmptyCharField, self).validate(value)
|
||||
if not value.strip():
|
||||
raise ValidationError(self.error_messages["required"])
|
||||
return field.to_representation(obj)
|
||||
|
||||
33
openedx/core/lib/api/mixins.py
Normal file
33
openedx/core/lib/api/mixins.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Django Rest Framework view mixins.
|
||||
"""
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.http import Http404
|
||||
from rest_framework import status
|
||||
from rest_framework.mixins import CreateModelMixin
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
class PutAsCreateMixin(CreateModelMixin):
|
||||
"""
|
||||
Backwards compatibility with Django Rest Framework v2, which allowed
|
||||
creation of a new resource using PUT.
|
||||
"""
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
"""
|
||||
Create/update course modes for a course.
|
||||
"""
|
||||
# First, try to update the existing instance
|
||||
try:
|
||||
try:
|
||||
return super(PutAsCreateMixin, self).update(request, *args, **kwargs)
|
||||
except Http404:
|
||||
# If no instance exists yet, create it.
|
||||
# This is backwards-compatible with the behavior of DRF v2.
|
||||
return super(PutAsCreateMixin, self).create(request, *args, **kwargs)
|
||||
|
||||
# Backwards compatibility with DRF v2 behavior, which would catch model-level
|
||||
# validation errors and return a 400
|
||||
except ValidationError as err:
|
||||
return Response(err.messages, status=status.HTTP_400_BAD_REQUEST)
|
||||
@@ -3,6 +3,31 @@
|
||||
from django.http import Http404
|
||||
from django.core.paginator import Paginator, InvalidPage
|
||||
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import pagination
|
||||
|
||||
|
||||
class DefaultPagination(pagination.PageNumberPagination):
|
||||
"""
|
||||
Default paginator for APIs in edx-platform.
|
||||
|
||||
This is configured in settings to be automatically used
|
||||
by any subclass of Django Rest Framework's generic API views.
|
||||
"""
|
||||
page_size_query_param = "page_size"
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""
|
||||
Annotate the response with pagination information.
|
||||
"""
|
||||
return Response({
|
||||
'next': self.get_next_link(),
|
||||
'previous': self.get_previous_link(),
|
||||
'count': self.page.paginator.count,
|
||||
'num_pages': self.page.paginator.num_pages,
|
||||
'results': data
|
||||
})
|
||||
|
||||
|
||||
def paginate_search_results(object_class, search_results, page_size, page):
|
||||
"""
|
||||
|
||||
@@ -1,32 +1,8 @@
|
||||
from rest_framework import pagination, serializers
|
||||
"""
|
||||
Serializers to be used in APIs.
|
||||
"""
|
||||
|
||||
|
||||
class PaginationSerializer(pagination.PaginationSerializer):
|
||||
"""
|
||||
Custom PaginationSerializer for openedx.
|
||||
|
||||
Adds the following fields:
|
||||
- num_pages: total number of pages
|
||||
- current_page: the current page being returned
|
||||
- start: the index of the first page item within the overall collection
|
||||
"""
|
||||
start_page = 1 # django Paginator.page objects have 1-based indexes
|
||||
num_pages = serializers.Field(source='paginator.num_pages')
|
||||
current_page = serializers.SerializerMethodField('get_current_page')
|
||||
start = serializers.SerializerMethodField('get_start')
|
||||
sort_order = serializers.SerializerMethodField('get_sort_order')
|
||||
|
||||
def get_current_page(self, page):
|
||||
"""Get the current page"""
|
||||
return page.number
|
||||
|
||||
def get_start(self, page):
|
||||
"""Get the index of the first page item within the overall collection"""
|
||||
return (self.get_current_page(page) - self.start_page) * page.paginator.per_page
|
||||
|
||||
def get_sort_order(self, page): # pylint: disable=unused-argument
|
||||
"""Get the order by which this collection was sorted"""
|
||||
return self.context.get('sort_order')
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
|
||||
@@ -54,9 +30,10 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
|
||||
|
||||
super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs)
|
||||
|
||||
self.fields[id_source] = serializers.CharField(read_only=True, source=id_source)
|
||||
self.fields[id_source] = serializers.CharField(read_only=True)
|
||||
self.fields['url'].view_name = view_name
|
||||
self.fields['url'].lookup_field = lookup_field
|
||||
self.fields['url'].lookup_url_kwarg = lookup_field
|
||||
|
||||
class Meta(object):
|
||||
"""Defines meta information for the ModelSerializer.
|
||||
|
||||
@@ -1,63 +1,235 @@
|
||||
"""Tests for util.authentication module."""
|
||||
"""
|
||||
Tests for OAuth2. This module is copied from django-rest-framework-oauth (tests/test_authentication.py)
|
||||
and updated to use our subclass of OAuth2Authentication.
|
||||
"""
|
||||
|
||||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
|
||||
from django.conf.urls import patterns, url, include
|
||||
from django.contrib.auth.models import User
|
||||
from django.http import HttpResponse
|
||||
from django.test import TestCase
|
||||
from django.utils import unittest
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework_oauth import permissions
|
||||
from rest_framework_oauth.compat import oauth2_provider, oauth2_provider_scope
|
||||
from rest_framework.test import APIRequestFactory, APIClient
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from mock import patch
|
||||
from django.conf import settings
|
||||
from rest_framework import permissions
|
||||
from rest_framework.compat import patterns, url
|
||||
from rest_framework.tests import test_authentication
|
||||
from provider import scope, constants
|
||||
from unittest import skipUnless
|
||||
|
||||
from ..authentication import OAuth2AuthenticationAllowInactiveUser
|
||||
|
||||
factory = APIRequestFactory() # pylint: disable=invalid-name
|
||||
|
||||
class OAuth2AuthAllowInactiveUserDebug(OAuth2AuthenticationAllowInactiveUser):
|
||||
"""
|
||||
A debug class analogous to the OAuth2AuthenticationDebug class that tests
|
||||
the OAuth2 flow with the access token sent in a query param."""
|
||||
|
||||
class MockView(APIView): # pylint: disable=missing-docstring
|
||||
permission_classes = (IsAuthenticated,)
|
||||
|
||||
def get(self, request): # pylint: disable=missing-docstring,unused-argument
|
||||
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
|
||||
|
||||
def post(self, request): # pylint: disable=missing-docstring,unused-argument
|
||||
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
|
||||
|
||||
def put(self, request): # pylint: disable=missing-docstring,unused-argument
|
||||
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
|
||||
|
||||
|
||||
# This is the a change we've made from the django-rest-framework-oauth version
|
||||
# of these tests. We're subclassing our custom OAuth2AuthenticationAllowInactiveUser
|
||||
# instead of OAuth2Authentication.
|
||||
class OAuth2AuthenticationDebug(OAuth2AuthenticationAllowInactiveUser): # pylint: disable=missing-docstring
|
||||
allow_query_params_token = True
|
||||
|
||||
|
||||
# The following patch overrides the URL patterns for the MockView class used in
|
||||
# rest_framework.tests.test_authentication so that the corresponding AllowInactiveUser
|
||||
# classes are tested instead.
|
||||
@skipUnless(settings.FEATURES.get('ENABLE_OAUTH2_PROVIDER'), 'OAuth2 not enabled')
|
||||
@patch.object(
|
||||
test_authentication,
|
||||
'urlpatterns',
|
||||
patterns(
|
||||
'',
|
||||
url(
|
||||
r'^oauth2-test/$',
|
||||
test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser])
|
||||
),
|
||||
url(
|
||||
r'^oauth2-test-debug/$',
|
||||
test_authentication.MockView.as_view(authentication_classes=[OAuth2AuthAllowInactiveUserDebug])
|
||||
),
|
||||
url(
|
||||
r'^oauth2-with-scope-test/$',
|
||||
test_authentication.MockView.as_view(
|
||||
authentication_classes=[OAuth2AuthenticationAllowInactiveUser],
|
||||
permission_classes=[permissions.TokenHasReadWriteScope]
|
||||
)
|
||||
urlpatterns = patterns(
|
||||
'',
|
||||
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
||||
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser])),
|
||||
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
|
||||
url(
|
||||
r'^oauth2-with-scope-test/$',
|
||||
MockView.as_view(
|
||||
authentication_classes=[OAuth2AuthenticationAllowInactiveUser],
|
||||
permission_classes=[permissions.TokenHasReadWriteScope]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
class OAuth2AuthenticationAllowInactiveUserTestCase(test_authentication.OAuth2Tests):
|
||||
"""
|
||||
Tests the OAuth2AuthenticationAllowInactiveUser class by running all the existing tests in
|
||||
OAuth2Tests but with the is_active flag on the user set to False.
|
||||
"""
|
||||
def setUp(self):
|
||||
super(OAuth2AuthenticationAllowInactiveUserTestCase, self).setUp()
|
||||
|
||||
# set the user's is_active flag to False.
|
||||
|
||||
class OAuth2Tests(TestCase):
|
||||
"""OAuth 2.0 authentication"""
|
||||
urls = 'openedx.core.lib.api.tests.test_authentication'
|
||||
|
||||
def setUp(self):
|
||||
self.csrf_client = APIClient(enforce_csrf_checks=True)
|
||||
self.username = 'john'
|
||||
self.email = 'lennon@thebeatles.com'
|
||||
self.password = 'password'
|
||||
self.user = User.objects.create_user(self.username, self.email, self.password)
|
||||
|
||||
self.CLIENT_ID = 'client_key' # pylint: disable=invalid-name
|
||||
self.CLIENT_SECRET = 'client_secret' # pylint: disable=invalid-name
|
||||
self.ACCESS_TOKEN = "access_token" # pylint: disable=invalid-name
|
||||
self.REFRESH_TOKEN = "refresh_token" # pylint: disable=invalid-name
|
||||
|
||||
self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
|
||||
client_id=self.CLIENT_ID,
|
||||
client_secret=self.CLIENT_SECRET,
|
||||
redirect_uri='',
|
||||
client_type=0,
|
||||
name='example',
|
||||
user=None,
|
||||
)
|
||||
|
||||
self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
|
||||
token=self.ACCESS_TOKEN,
|
||||
client=self.oauth2_client,
|
||||
user=self.user,
|
||||
)
|
||||
self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
|
||||
user=self.user,
|
||||
access_token=self.access_token,
|
||||
client=self.oauth2_client
|
||||
)
|
||||
|
||||
# This is the a change we've made from the django-rest-framework-oauth version
|
||||
# of these tests.
|
||||
self.user.is_active = False
|
||||
self.user.save()
|
||||
|
||||
# This is the a change we've made from the django-rest-framework-oauth version
|
||||
# of these tests.
|
||||
# Override the SCOPE_NAME_DICT setting for tests for oauth2-with-scope-test. This is
|
||||
# needed to support READ and WRITE scopes as they currently aren't supported by the
|
||||
# edx-auth2-provider, and their scope values collide with other scopes defined in the
|
||||
# edx-auth2-provider.
|
||||
scope.SCOPE_NAME_DICT = {'read': constants.READ, 'write': constants.WRITE}
|
||||
|
||||
def _create_authorization_header(self, token=None): # pylint: disable=missing-docstring
|
||||
return "Bearer {0}".format(token or self.access_token.token)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
|
||||
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
|
||||
auth = "Wrong token-type-obviously"
|
||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_with_wrong_authorization_header_token_format_failing(self):
|
||||
"""Ensure that a wrong token format lead to the correct HTTP error status code"""
|
||||
auth = "Bearer wrong token format"
|
||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_with_wrong_authorization_header_token_failing(self):
|
||||
"""Ensure that a wrong token lead to the correct HTTP error status code"""
|
||||
auth = "Bearer wrong-token"
|
||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_with_wrong_authorization_header_token_missing(self):
|
||||
"""Ensure that a missing token lead to the correct HTTP error status code"""
|
||||
auth = "Bearer"
|
||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 401)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_passing_auth(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials succeed"""
|
||||
auth = self._create_authorization_header()
|
||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_passing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
|
||||
response = self.csrf_client.post(
|
||||
'/oauth2-test/',
|
||||
data={'access_token': self.access_token.token}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_passing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
|
||||
query = urlencode({'access_token': self.access_token.token})
|
||||
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_failing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
|
||||
query = urlencode({'access_token': self.access_token.token})
|
||||
response = self.csrf_client.get('/oauth2-test/?%s' % query)
|
||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_passing_auth(self):
|
||||
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
||||
auth = self._create_authorization_header()
|
||||
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_token_removed_failing_auth(self):
|
||||
"""Ensure POSTing when there is no OAuth access token in db fails"""
|
||||
self.access_token.delete()
|
||||
auth = self._create_authorization_header()
|
||||
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_with_refresh_token_failing_auth(self):
|
||||
"""Ensure POSTing with refresh token instead of access token fails"""
|
||||
auth = self._create_authorization_header(token=self.refresh_token.token)
|
||||
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_with_expired_access_token_failing_auth(self):
|
||||
"""Ensure POSTing with expired access token fails with an 'Invalid token' error"""
|
||||
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
|
||||
self.access_token.save()
|
||||
auth = self._create_authorization_header()
|
||||
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||
self.assertIn('Invalid token', response.content)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_with_invalid_scope_failing_auth(self):
|
||||
"""Ensure POSTing with a readonly scope instead of a write scope fails"""
|
||||
read_only_access_token = self.access_token
|
||||
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
|
||||
read_only_access_token.save()
|
||||
auth = self._create_authorization_header(token=read_only_access_token.token)
|
||||
response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_with_valid_scope_passing_auth(self):
|
||||
"""Ensure POSTing with a write scope succeed"""
|
||||
read_write_access_token = self.access_token
|
||||
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
|
||||
read_write_access_token.save()
|
||||
auth = self._create_authorization_header(token=read_write_access_token.token)
|
||||
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@@ -9,6 +9,7 @@ from django.utils.translation import ugettext as _
|
||||
from rest_framework import status, response
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin
|
||||
from rest_framework.generics import GenericAPIView
|
||||
@@ -193,3 +194,23 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView)
|
||||
add_serializer_errors(serializer, patch, field_errors)
|
||||
|
||||
return field_errors
|
||||
|
||||
def get_object_or_none(self):
|
||||
"""
|
||||
Retrieve an object or return None if the object can't be found.
|
||||
|
||||
NOTE: This replaces functionality that was removed in Django Rest Framework v3.1.
|
||||
"""
|
||||
try:
|
||||
return self.get_object()
|
||||
except Http404:
|
||||
if self.request.method == 'PUT':
|
||||
# For PUT-as-create operation, we need to ensure that we have
|
||||
# relevant permissions, as if this was a POST request. This
|
||||
# will either raise a PermissionDenied exception, or simply
|
||||
# return None.
|
||||
self.check_permissions(clone_request(self.request, 'POST'))
|
||||
else:
|
||||
# PATCH requests where the object does not exist should still
|
||||
# return a 404 response.
|
||||
raise
|
||||
|
||||
@@ -28,7 +28,7 @@ django-ses==0.7.0
|
||||
django-simple-history==1.6.3
|
||||
django-storages==1.1.5
|
||||
django-method-override==0.1.0
|
||||
djangorestframework==2.3.14
|
||||
djangorestframework>=3.1,<3.2
|
||||
django==1.4.22
|
||||
elasticsearch==0.4.5
|
||||
facebook-sdk==0.4.0
|
||||
|
||||
@@ -12,6 +12,7 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df
|
||||
-e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline
|
||||
-e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki
|
||||
-e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider
|
||||
-e git+https://github.com/edx/django-rest-framework-oauth.git@f0b503fda8c254a38f97fef802ded4f5fe367f7a#egg=djangorestframework-oauth
|
||||
-e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy
|
||||
git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6
|
||||
-e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev
|
||||
@@ -40,13 +41,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx
|
||||
-e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking
|
||||
-e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash
|
||||
-e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock
|
||||
-e git+https://github.com/edx/edx-ora2.git@release-2015-08-25T16.16#egg=edx-ora2
|
||||
-e git+https://github.com/edx/edx-submissions.git@9538ee8a971d04dc1cb05e88f6aa0c36b224455c#egg=edx-submissions
|
||||
-e git+https://github.com/edx/edx-ora2.git@release-2015-09-16T15.28#egg=edx-ora2
|
||||
-e git+https://github.com/edx/edx-submissions.git@0.1.0#egg=edx-submissions
|
||||
-e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys
|
||||
git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3
|
||||
git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3
|
||||
git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7
|
||||
-e git+https://github.com/edx/edx-val.git@v0.0.5#egg=edx-val
|
||||
-e git+https://github.com/edx/edx-val.git@0.0.6#egg=edx-val
|
||||
-e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock
|
||||
-e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search
|
||||
-e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones
|
||||
|
||||
Reference in New Issue
Block a user