From e6536d0d0ef96babfbf0b64024e84754653e4f86 Mon Sep 17 00:00:00 2001 From: Tim McCormack Date: Tue, 18 Jan 2022 22:33:30 +0000 Subject: [PATCH] test: Stop sharing API client between users in unit tests Using the same Client or APIClient instance for multiple users, where one user has an active session and the other is making an Authorization header call, results in a Safe Sessions violation. By using separate clients for different test users, we avoid this violation, allowing `ENFORCE_SAFE_SESSIONS` to be enabled by default. --- .../bulk_user_retirement/tests/test_views.py | 8 +++-- .../discussion/rest_api/tests/test_views.py | 24 ++++++++------- .../djangoapps/credit/tests/test_views.py | 5 ++-- .../enrollments/tests/test_views.py | 29 ++++++++++++------- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/lms/djangoapps/bulk_user_retirement/tests/test_views.py b/lms/djangoapps/bulk_user_retirement/tests/test_views.py index f1cae5adaa..ac28c00f47 100644 --- a/lms/djangoapps/bulk_user_retirement/tests/test_views.py +++ b/lms/djangoapps/bulk_user_retirement/tests/test_views.py @@ -13,21 +13,21 @@ class BulkUserRetirementViewTests(APITestCase): """ def setUp(self): super().setUp() - self.client = APIClient() + login_client = APIClient() self.user1 = UserFactory.create( username='testuser1', email='test1@example.com', password='test1_password', profile__name="Test User1" ) - self.client.login(username=self.user1.username, password='test1_password') + login_client.login(username=self.user1.username, password='test1_password') self.user2 = UserFactory.create( username='testuser2', email='test2@example.com', password='test2_password', profile__name="Test User2" ) - self.client.login(username=self.user2.username, password='test2_password') + login_client.login(username=self.user2.username, password='test2_password') self.user3 = UserFactory.create( username='testuser3', email='test3@example.com', @@ -47,6 +47,8 @@ class BulkUserRetirementViewTests(APITestCase): required=True ) self.pending_state = RetirementState.objects.get(state_name='PENDING') + # Use a separate client for retirement worker (don't mix cookie state) + self.client = APIClient() self.client.force_authenticate(user=self.user1) def test_gdpr_user_retirement_api(self): diff --git a/lms/djangoapps/discussion/rest_api/tests/test_views.py b/lms/djangoapps/discussion/rest_api/tests/test_views.py index 391c797a28..9a24c81c59 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_views.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_views.py @@ -528,6 +528,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): self.retirement.save() self.superuser = SuperuserFactory() + self.superuser_client = APIClient() self.retired_username = get_retired_username_by_username(self.user.username) self.url = reverse("retire_discussion_user") @@ -555,7 +556,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): self.register_get_user_retire_response(self.user) headers = self.build_jwt_headers(self.superuser) data = {'username': self.user.username} - response = self.client.post(self.url, data, **headers) + response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 204, b"") def test_downstream_forums_error(self): @@ -565,7 +566,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): self.register_get_user_retire_response(self.user, status=500, body="Server error") headers = self.build_jwt_headers(self.superuser) data = {'username': self.user.username} - response = self.client.post(self.url, data, **headers) + response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 500, '"Server error"') def test_nonexistent_user(self): @@ -576,7 +577,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): self.retired_username = get_retired_username_by_username(nonexistent_username) data = {'username': nonexistent_username} headers = self.build_jwt_headers(self.superuser) - response = self.client.post(self.url, data, **headers) + response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 404, None) def test_not_authenticated(self): @@ -594,8 +595,9 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): """Tests for ReplaceUsernamesView""" def setUp(self): super().setUp() - self.client_user = UserFactory() - self.client_user.username = "test_replace_username_service_worker" + self.worker = UserFactory() + self.worker.username = "test_replace_username_service_worker" + self.worker_client = APIClient() self.new_username = "test_username_replacement" self.url = reverse("replace_discussion_username") @@ -616,11 +618,11 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): headers = {'HTTP_AUTHORIZATION': 'JWT ' + token} return headers - def call_api(self, user, data): + def call_api(self, user, client, data): """ Helper function to call API with data """ data = json.dumps(data) headers = self.build_jwt_headers(user) - return self.client.post(self.url, data, content_type='application/json', **headers) + return client.post(self.url, data, content_type='application/json', **headers) @ddt.data( [{}, {}], @@ -632,7 +634,7 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): data = { "username_mappings": mapping_data } - response = self.call_api(self.client_user, data) + response = self.call_api(self.worker, self.worker_client, data) assert response.status_code == 400 def test_auth(self): @@ -650,11 +652,11 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): # Test non-service worker random_user = UserFactory() - response = self.call_api(random_user, data) + response = self.call_api(random_user, APIClient(), data) assert response.status_code == 403 # Test service worker - response = self.call_api(self.client_user, data) + response = self.call_api(self.worker, self.worker_client, data) assert response.status_code == 200 def test_basic(self): @@ -669,7 +671,7 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): 'successful_replacements': data["username_mappings"] } self.register_get_username_replacement_response(self.user) - response = self.call_api(self.client_user, data) + response = self.call_api(self.worker, self.worker_client, data) assert response.status_code == 200 assert response.data == expected_response diff --git a/openedx/core/djangoapps/credit/tests/test_views.py b/openedx/core/djangoapps/credit/tests/test_views.py index 90f2a30a27..deb3c8726a 100644 --- a/openedx/core/djangoapps/credit/tests/test_views.py +++ b/openedx/core/djangoapps/credit/tests/test_views.py @@ -171,6 +171,7 @@ class CreditCourseViewSetTests(AuthMixin, UserMixin, TestCase): def test_oauth(self): """ Verify the endpoint supports OAuth, and only allows authorization for staff users. """ + client = Client() # avoid mixing cookies user = UserFactory(is_staff=False) oauth_client = ApplicationFactory.create() access_token = AccessTokenFactory.create(user=user, application=oauth_client).token @@ -179,13 +180,13 @@ class CreditCourseViewSetTests(AuthMixin, UserMixin, TestCase): } # Non-staff users should not have access to the API - response = self.client.get(self.path, **headers) + response = client.get(self.path, **headers) assert response.status_code == 403 # Staff users should have access to the API user.is_staff = True user.save() - response = self.client.get(self.path, **headers) + response = client.get(self.path, **headers) assert response.status_code == 200 def assert_course_created(self, course_id, response): diff --git a/openedx/core/djangoapps/enrollments/tests/test_views.py b/openedx/core/djangoapps/enrollments/tests/test_views.py index 3232ced200..104e6e6bcf 100644 --- a/openedx/core/djangoapps/enrollments/tests/test_views.py +++ b/openedx/core/djangoapps/enrollments/tests/test_views.py @@ -1349,6 +1349,7 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase): """ Create a course and user, then log in. """ super().setUp() self.superuser = SuperuserFactory() + self.superuser_client = Client() # Pass emit_signals when creating the course so it would be cached # as a CourseOverview. Enrollments require a cached CourseOverview. self.first_org_course = CourseFactory.create(emit_signals=True, org="org", course="course", run="run") @@ -1404,7 +1405,7 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase): def test_deactivate_enrollments(self): self._assert_active() self._create_test_retirement(self.user) - response = self._submit_unenroll(self.superuser, self.user.username) + response = self._submit_unenroll(self.user.username) assert response.status_code == status.HTTP_200_OK data = json.loads(response.content.decode('utf-8')) # order doesn't matter so compare sets @@ -1413,18 +1414,18 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase): def test_deactivate_enrollments_no_retirement_status(self): self._assert_active() - response = self._submit_unenroll(self.superuser, self.user.username) + response = self._submit_unenroll(self.user.username) assert response.status_code == status.HTTP_404_NOT_FOUND def test_deactivate_enrollments_unauthorized(self): self._assert_active() - response = self._submit_unenroll(self.user, self.user.username) + response = self._submit_unenroll(self.user.username, submitting_user=self.user, client=self.client) assert response.status_code == status.HTTP_403_FORBIDDEN self._assert_active() def test_deactivate_enrollments_no_username(self): self._assert_active() - response = self._submit_unenroll(self.superuser, None) + response = self._submit_unenroll(None) assert response.status_code == status.HTTP_404_NOT_FOUND data = json.loads(response.content.decode('utf-8')) assert data == 'Username not specified.' @@ -1433,23 +1434,23 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase): def test_deactivate_enrollments_empty_username(self): self._assert_active() self._create_test_retirement(self.user) - response = self._submit_unenroll(self.superuser, "") + response = self._submit_unenroll("") assert response.status_code == status.HTTP_404_NOT_FOUND self._assert_active() def test_deactivate_enrollments_invalid_username(self): self._assert_active() self._create_test_retirement(self.user) - response = self._submit_unenroll(self.superuser, "a made up username") + response = self._submit_unenroll("a made up username") assert response.status_code == status.HTTP_404_NOT_FOUND self._assert_active() def test_deactivate_enrollments_called_twice(self): self._assert_active() self._create_test_retirement(self.user) - response = self._submit_unenroll(self.superuser, self.user.username) + response = self._submit_unenroll(self.user.username) assert response.status_code == status.HTTP_200_OK - response = self._submit_unenroll(self.superuser, self.user.username) + response = self._submit_unenroll(self.user.username) assert response.status_code == status.HTTP_204_NO_CONTENT assert response.content.decode('utf-8') == '' self._assert_inactive() @@ -1465,14 +1466,22 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase): _, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, course.id) assert not is_active - def _submit_unenroll(self, submitting_user, unenrolling_username): + def _submit_unenroll(self, unenrolling_username, submitting_user=None, client=None): + """ Submit enrollment, by default as superuser. """ + # Provide both or neither of the overrides + assert (submitting_user is None) == (client is None) + + # Avoid mixing cookies between two users + client = client or self.superuser_client + submitting_user = submitting_user or self.superuser + data = {} if unenrolling_username is not None: data['username'] = unenrolling_username url = reverse('unenrollment') headers = self.build_jwt_headers(submitting_user) - return self.client.post(url, json.dumps(data), content_type='application/json', **headers) + return client.post(url, json.dumps(data), content_type='application/json', **headers) @ddt.ddt