diff --git a/common/djangoapps/course_modes/api/v1/tests/test_views.py b/common/djangoapps/course_modes/api/v1/tests/test_views.py index e499459ce0..e6d2ec1a2a 100644 --- a/common/djangoapps/course_modes/api/v1/tests/test_views.py +++ b/common/djangoapps/course_modes/api/v1/tests/test_views.py @@ -48,17 +48,23 @@ class CourseModesViewTestBase(AuthAndScopesTestMixin): mode_display_name='Verified', min_price=25, ) + # use these to make sure we don't fetch data for other courses + cls.other_course_key = CourseKey.from_string('course-v1:edX+DemoX+Other_Course') + cls.other_course = CourseOverviewFactory.create(id=cls.other_course_key) + cls.other_mode = CourseModeFactory.create( + course_id=cls.other_course_key, + mode_slug='other-audit', + mode_display_name='Other Audit', + min_price=0, + ) @classmethod def tearDownClass(cls): cls.course.delete() cls.audit_mode.delete() cls.verified_mode.delete() - - @classmethod - def tearDownClass(cls): - cls.audit_mode.delete() - cls.verified_mode.delete() + cls.other_course.delete() + cls.other_mode.delete() def setUp(self): super(CourseModesViewTestBase, self).setUp() @@ -73,7 +79,6 @@ class CourseModesViewTestBase(AuthAndScopesTestMixin): """ Required method to implement AuthAndScopesTestMixin. """ - # TODO pass @ddt.data(*product(JWT_AUTH_TYPES, (True, False))) @@ -125,10 +130,7 @@ class TestCourseModesListViews(CourseModesViewTestBase, APITestCase): response = self.client.get(url) assert status.HTTP_200_OK == response.status_code - actual_results = sorted( - [dict(item) for item in response.data], - key=lambda item: item['mode_slug'], - ) + actual_results = self._sorted_results(response) expected_results = [ { 'course_id': text_type(self.course_key), @@ -157,6 +159,36 @@ class TestCourseModesListViews(CourseModesViewTestBase, APITestCase): ] assert expected_results == actual_results + # Now test the "other" course + url = self.get_url(course_id=self.other_course_key) + + other_response = self.client.get(url) + + assert status.HTTP_200_OK == other_response.status_code + other_actual_results = self._sorted_results(other_response) + other_expected_results = [ + { + 'course_id': text_type(self.other_course_key), + 'mode_slug': 'other-audit', + 'mode_display_name': 'Other Audit', + 'min_price': 0, + 'currency': 'usd', + 'expiration_datetime': None, + 'expiration_datetime_is_explicit': False, + 'description': None, + 'sku': None, + 'bulk_sku': None, + }, + ] + assert other_expected_results == other_actual_results + + @staticmethod + def _sorted_results(response): + return sorted( + [dict(item) for item in response.data], + key=lambda item: item['mode_slug'], + ) + def test_post_course_mode_forbidden(self): self.client.login(username=self.other_student.username, password=self.user_password) url = self.get_url(course_id=self.course_key) diff --git a/common/djangoapps/course_modes/api/v1/views.py b/common/djangoapps/course_modes/api/v1/views.py index 4d674ed932..5806b156db 100644 --- a/common/djangoapps/course_modes/api/v1/views.py +++ b/common/djangoapps/course_modes/api/v1/views.py @@ -40,7 +40,6 @@ class CourseModesMixin(object): serializer_class = CourseModeSerializer pagination_class = None lookup_field = 'course_id' - queryset = CourseMode.objects.all() class CourseModesView(CourseModesMixin, ListCreateAPIView): @@ -87,7 +86,12 @@ class CourseModesView(CourseModesMixin, ListCreateAPIView): POST: If the request is successful, an HTTP 201 "Created" response is returned. """ - pass + def get_queryset(self): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} + if 'course_id' in filter_kwargs: + filter_kwargs['course_id'] = CourseKey.from_string(filter_kwargs['course_id']) + return CourseMode.objects.filter(**filter_kwargs) class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView): @@ -145,6 +149,7 @@ class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView): http_method_names = ['get', 'patch', 'delete', 'head', 'options'] parser_classes = (MergePatchParser,) multiple_lookup_fields = ('course_id', 'mode_slug') + queryset = CourseMode.objects.all() def get_object(self): queryset = self.get_queryset() @@ -152,7 +157,8 @@ class CourseModesDetailView(CourseModesMixin, RetrieveUpdateDestroyAPIView): for field in self.multiple_lookup_fields: query_filter[field] = self.kwargs[field] - query_filter['course_id'] = CourseKey.from_string(query_filter['course_id']) + if 'course_id' in query_filter: + query_filter['course_id'] = CourseKey.from_string(query_filter['course_id']) obj = get_object_or_404(queryset, **query_filter) self.check_object_permissions(self.request, obj)