test: Stop sharing API client between users in unit tests
Using the same Client or APIClient instance for multiple users, where one user has an active session and the other is making an Authorization header call, results in a Safe Sessions violation. By using separate clients for different test users, we avoid this violation, allowing `ENFORCE_SAFE_SESSIONS` to be enabled by default.
This commit is contained in:
@@ -13,21 +13,21 @@ class BulkUserRetirementViewTests(APITestCase):
|
||||
"""
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.client = APIClient()
|
||||
login_client = APIClient()
|
||||
self.user1 = UserFactory.create(
|
||||
username='testuser1',
|
||||
email='test1@example.com',
|
||||
password='test1_password',
|
||||
profile__name="Test User1"
|
||||
)
|
||||
self.client.login(username=self.user1.username, password='test1_password')
|
||||
login_client.login(username=self.user1.username, password='test1_password')
|
||||
self.user2 = UserFactory.create(
|
||||
username='testuser2',
|
||||
email='test2@example.com',
|
||||
password='test2_password',
|
||||
profile__name="Test User2"
|
||||
)
|
||||
self.client.login(username=self.user2.username, password='test2_password')
|
||||
login_client.login(username=self.user2.username, password='test2_password')
|
||||
self.user3 = UserFactory.create(
|
||||
username='testuser3',
|
||||
email='test3@example.com',
|
||||
@@ -47,6 +47,8 @@ class BulkUserRetirementViewTests(APITestCase):
|
||||
required=True
|
||||
)
|
||||
self.pending_state = RetirementState.objects.get(state_name='PENDING')
|
||||
# Use a separate client for retirement worker (don't mix cookie state)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user1)
|
||||
|
||||
def test_gdpr_user_retirement_api(self):
|
||||
|
||||
@@ -528,6 +528,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
self.retirement.save()
|
||||
|
||||
self.superuser = SuperuserFactory()
|
||||
self.superuser_client = APIClient()
|
||||
self.retired_username = get_retired_username_by_username(self.user.username)
|
||||
self.url = reverse("retire_discussion_user")
|
||||
|
||||
@@ -555,7 +556,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
self.register_get_user_retire_response(self.user)
|
||||
headers = self.build_jwt_headers(self.superuser)
|
||||
data = {'username': self.user.username}
|
||||
response = self.client.post(self.url, data, **headers)
|
||||
response = self.superuser_client.post(self.url, data, **headers)
|
||||
self.assert_response_correct(response, 204, b"")
|
||||
|
||||
def test_downstream_forums_error(self):
|
||||
@@ -565,7 +566,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
self.register_get_user_retire_response(self.user, status=500, body="Server error")
|
||||
headers = self.build_jwt_headers(self.superuser)
|
||||
data = {'username': self.user.username}
|
||||
response = self.client.post(self.url, data, **headers)
|
||||
response = self.superuser_client.post(self.url, data, **headers)
|
||||
self.assert_response_correct(response, 500, '"Server error"')
|
||||
|
||||
def test_nonexistent_user(self):
|
||||
@@ -576,7 +577,7 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
self.retired_username = get_retired_username_by_username(nonexistent_username)
|
||||
data = {'username': nonexistent_username}
|
||||
headers = self.build_jwt_headers(self.superuser)
|
||||
response = self.client.post(self.url, data, **headers)
|
||||
response = self.superuser_client.post(self.url, data, **headers)
|
||||
self.assert_response_correct(response, 404, None)
|
||||
|
||||
def test_not_authenticated(self):
|
||||
@@ -594,8 +595,9 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
"""Tests for ReplaceUsernamesView"""
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.client_user = UserFactory()
|
||||
self.client_user.username = "test_replace_username_service_worker"
|
||||
self.worker = UserFactory()
|
||||
self.worker.username = "test_replace_username_service_worker"
|
||||
self.worker_client = APIClient()
|
||||
self.new_username = "test_username_replacement"
|
||||
self.url = reverse("replace_discussion_username")
|
||||
|
||||
@@ -616,11 +618,11 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
headers = {'HTTP_AUTHORIZATION': 'JWT ' + token}
|
||||
return headers
|
||||
|
||||
def call_api(self, user, data):
|
||||
def call_api(self, user, client, data):
|
||||
""" Helper function to call API with data """
|
||||
data = json.dumps(data)
|
||||
headers = self.build_jwt_headers(user)
|
||||
return self.client.post(self.url, data, content_type='application/json', **headers)
|
||||
return client.post(self.url, data, content_type='application/json', **headers)
|
||||
|
||||
@ddt.data(
|
||||
[{}, {}],
|
||||
@@ -632,7 +634,7 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
data = {
|
||||
"username_mappings": mapping_data
|
||||
}
|
||||
response = self.call_api(self.client_user, data)
|
||||
response = self.call_api(self.worker, self.worker_client, data)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_auth(self):
|
||||
@@ -650,11 +652,11 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
|
||||
# Test non-service worker
|
||||
random_user = UserFactory()
|
||||
response = self.call_api(random_user, data)
|
||||
response = self.call_api(random_user, APIClient(), data)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Test service worker
|
||||
response = self.call_api(self.client_user, data)
|
||||
response = self.call_api(self.worker, self.worker_client, data)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_basic(self):
|
||||
@@ -669,7 +671,7 @@ class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
|
||||
'successful_replacements': data["username_mappings"]
|
||||
}
|
||||
self.register_get_username_replacement_response(self.user)
|
||||
response = self.call_api(self.client_user, data)
|
||||
response = self.call_api(self.worker, self.worker_client, data)
|
||||
assert response.status_code == 200
|
||||
assert response.data == expected_response
|
||||
|
||||
|
||||
@@ -171,6 +171,7 @@ class CreditCourseViewSetTests(AuthMixin, UserMixin, TestCase):
|
||||
|
||||
def test_oauth(self):
|
||||
""" Verify the endpoint supports OAuth, and only allows authorization for staff users. """
|
||||
client = Client() # avoid mixing cookies
|
||||
user = UserFactory(is_staff=False)
|
||||
oauth_client = ApplicationFactory.create()
|
||||
access_token = AccessTokenFactory.create(user=user, application=oauth_client).token
|
||||
@@ -179,13 +180,13 @@ class CreditCourseViewSetTests(AuthMixin, UserMixin, TestCase):
|
||||
}
|
||||
|
||||
# Non-staff users should not have access to the API
|
||||
response = self.client.get(self.path, **headers)
|
||||
response = client.get(self.path, **headers)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Staff users should have access to the API
|
||||
user.is_staff = True
|
||||
user.save()
|
||||
response = self.client.get(self.path, **headers)
|
||||
response = client.get(self.path, **headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
def assert_course_created(self, course_id, response):
|
||||
|
||||
@@ -1349,6 +1349,7 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase):
|
||||
""" Create a course and user, then log in. """
|
||||
super().setUp()
|
||||
self.superuser = SuperuserFactory()
|
||||
self.superuser_client = Client()
|
||||
# Pass emit_signals when creating the course so it would be cached
|
||||
# as a CourseOverview. Enrollments require a cached CourseOverview.
|
||||
self.first_org_course = CourseFactory.create(emit_signals=True, org="org", course="course", run="run")
|
||||
@@ -1404,7 +1405,7 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase):
|
||||
def test_deactivate_enrollments(self):
|
||||
self._assert_active()
|
||||
self._create_test_retirement(self.user)
|
||||
response = self._submit_unenroll(self.superuser, self.user.username)
|
||||
response = self._submit_unenroll(self.user.username)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
# order doesn't matter so compare sets
|
||||
@@ -1413,18 +1414,18 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase):
|
||||
|
||||
def test_deactivate_enrollments_no_retirement_status(self):
|
||||
self._assert_active()
|
||||
response = self._submit_unenroll(self.superuser, self.user.username)
|
||||
response = self._submit_unenroll(self.user.username)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_deactivate_enrollments_unauthorized(self):
|
||||
self._assert_active()
|
||||
response = self._submit_unenroll(self.user, self.user.username)
|
||||
response = self._submit_unenroll(self.user.username, submitting_user=self.user, client=self.client)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
self._assert_active()
|
||||
|
||||
def test_deactivate_enrollments_no_username(self):
|
||||
self._assert_active()
|
||||
response = self._submit_unenroll(self.superuser, None)
|
||||
response = self._submit_unenroll(None)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
assert data == 'Username not specified.'
|
||||
@@ -1433,23 +1434,23 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase):
|
||||
def test_deactivate_enrollments_empty_username(self):
|
||||
self._assert_active()
|
||||
self._create_test_retirement(self.user)
|
||||
response = self._submit_unenroll(self.superuser, "")
|
||||
response = self._submit_unenroll("")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
self._assert_active()
|
||||
|
||||
def test_deactivate_enrollments_invalid_username(self):
|
||||
self._assert_active()
|
||||
self._create_test_retirement(self.user)
|
||||
response = self._submit_unenroll(self.superuser, "a made up username")
|
||||
response = self._submit_unenroll("a made up username")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
self._assert_active()
|
||||
|
||||
def test_deactivate_enrollments_called_twice(self):
|
||||
self._assert_active()
|
||||
self._create_test_retirement(self.user)
|
||||
response = self._submit_unenroll(self.superuser, self.user.username)
|
||||
response = self._submit_unenroll(self.user.username)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response = self._submit_unenroll(self.superuser, self.user.username)
|
||||
response = self._submit_unenroll(self.user.username)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
assert response.content.decode('utf-8') == ''
|
||||
self._assert_inactive()
|
||||
@@ -1465,14 +1466,22 @@ class UnenrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase):
|
||||
_, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, course.id)
|
||||
assert not is_active
|
||||
|
||||
def _submit_unenroll(self, submitting_user, unenrolling_username):
|
||||
def _submit_unenroll(self, unenrolling_username, submitting_user=None, client=None):
|
||||
""" Submit enrollment, by default as superuser. """
|
||||
# Provide both or neither of the overrides
|
||||
assert (submitting_user is None) == (client is None)
|
||||
|
||||
# Avoid mixing cookies between two users
|
||||
client = client or self.superuser_client
|
||||
submitting_user = submitting_user or self.superuser
|
||||
|
||||
data = {}
|
||||
if unenrolling_username is not None:
|
||||
data['username'] = unenrolling_username
|
||||
|
||||
url = reverse('unenrollment')
|
||||
headers = self.build_jwt_headers(submitting_user)
|
||||
return self.client.post(url, json.dumps(data), content_type='application/json', **headers)
|
||||
return client.post(url, json.dumps(data), content_type='application/json', **headers)
|
||||
|
||||
|
||||
@ddt.ddt
|
||||
|
||||
Reference in New Issue
Block a user