Add get_queryset method to CourseModesView; cleanup tests again

This commit is contained in:
Alex Dusenbery
2019-04-05 12:05:13 -04:00
committed by Alex Dusenbery
parent 4c6ff94b54
commit d134a435d5
2 changed files with 51 additions and 13 deletions

View File

@@ -48,17 +48,23 @@ class CourseModesViewTestBase(AuthAndScopesTestMixin):
mode_display_name='Verified',
min_price=25,
)
# use these to make sure we don't fetch data for other courses
cls.other_course_key = CourseKey.from_string('course-v1:edX+DemoX+Other_Course')
cls.other_course = CourseOverviewFactory.create(id=cls.other_course_key)
cls.other_mode = CourseModeFactory.create(
course_id=cls.other_course_key,
mode_slug='other-audit',
mode_display_name='Other Audit',
min_price=0,
)
@classmethod
def tearDownClass(cls):
cls.course.delete()
cls.audit_mode.delete()
cls.verified_mode.delete()
@classmethod
def tearDownClass(cls):
cls.audit_mode.delete()
cls.verified_mode.delete()
cls.other_course.delete()
cls.other_mode.delete()
def setUp(self):
super(CourseModesViewTestBase, self).setUp()
@@ -73,7 +79,6 @@ class CourseModesViewTestBase(AuthAndScopesTestMixin):
"""
Required method to implement AuthAndScopesTestMixin.
"""
# TODO
pass
@ddt.data(*product(JWT_AUTH_TYPES, (True, False)))
@@ -125,10 +130,7 @@ class TestCourseModesListViews(CourseModesViewTestBase, APITestCase):
response = self.client.get(url)
assert status.HTTP_200_OK == response.status_code
actual_results = sorted(
[dict(item) for item in response.data],
key=lambda item: item['mode_slug'],
)
actual_results = self._sorted_results(response)
expected_results = [
{
'course_id': text_type(self.course_key),
@@ -157,6 +159,36 @@ class TestCourseModesListViews(CourseModesViewTestBase, APITestCase):
]
assert expected_results == actual_results
# Now test the "other" course
url = self.get_url(course_id=self.other_course_key)
other_response = self.client.get(url)
assert status.HTTP_200_OK == other_response.status_code
other_actual_results = self._sorted_results(other_response)
other_expected_results = [
{
'course_id': text_type(self.other_course_key),
'mode_slug': 'other-audit',
'mode_display_name': 'Other Audit',
'min_price': 0,
'currency': 'usd',
'expiration_datetime': None,
'expiration_datetime_is_explicit': False,
'description': None,
'sku': None,
'bulk_sku': None,
},
]
assert other_expected_results == other_actual_results
@staticmethod
def _sorted_results(response):
return sorted(
[dict(item) for item in response.data],
key=lambda item: item['mode_slug'],
)
def test_post_course_mode_forbidden(self):
self.client.login(username=self.other_student.username, password=self.user_password)
url = self.get_url(course_id=self.course_key)

View File

@@ -40,7 +40,6 @@ class CourseModesMixin(object):
serializer_class = CourseModeSerializer
pagination_class = None
lookup_field = 'course_id'
queryset = CourseMode.objects.all()
class CourseModesView(CourseModesMixin, ListCreateAPIView):
@@ -87,7 +86,12 @@ class CourseModesView(CourseModesMixin, ListCreateAPIView):
POST: If the request is successful, an HTTP 201 "Created" response is returned.
"""
pass
def get_queryset(self):
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
if 'course_id' in filter_kwargs:
filter_kwargs['course_id'] = CourseKey.from_string(filter_kwargs['course_id'])
return CourseMode.objects.filter(**filter_kwargs)
class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView):
@@ -145,6 +149,7 @@ class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView):
http_method_names = ['get', 'patch', 'delete', 'head', 'options']
parser_classes = (MergePatchParser,)
multiple_lookup_fields = ('course_id', 'mode_slug')
queryset = CourseMode.objects.all()
def get_object(self):
queryset = self.get_queryset()
@@ -152,7 +157,8 @@ class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView):
for field in self.multiple_lookup_fields:
query_filter[field] = self.kwargs[field]
query_filter['course_id'] = CourseKey.from_string(query_filter['course_id'])
if 'course_id' in query_filter:
query_filter['course_id'] = CourseKey.from_string(query_filter['course_id'])
obj = get_object_or_404(queryset, **query_filter)
self.check_object_permissions(self.request, obj)