From ed45aee9dd46f082b064901b52c88127b1117170 Mon Sep 17 00:00:00 2001 From: Waheed Ahmed Date: Fri, 7 Jan 2022 11:06:07 +0500 Subject: [PATCH] refactor: separate user lookup queries for login (#29711) Using a single query to get a user using both username and email fields generates a massive `key_len` and causes DB overload. Separated these lookups into two separate queries. VAN-819 --- .../core/djangoapps/user_authn/views/login.py | 55 +++++++++++-------- .../user_authn/views/tests/test_login.py | 34 +++++------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/openedx/core/djangoapps/user_authn/views/login.py b/openedx/core/djangoapps/user_authn/views/login.py index 62982825a3..b3319e638c 100644 --- a/openedx/core/djangoapps/user_authn/views/login.py +++ b/openedx/core/djangoapps/user_authn/views/login.py @@ -15,7 +15,6 @@ from django.contrib import admin from django.contrib.auth import authenticate, get_user_model from django.contrib.auth import login as django_login from django.contrib.auth.decorators import login_required -from django.db.models import Q from django.http import HttpRequest, HttpResponse, HttpResponseForbidden from django.shortcuts import redirect from django.urls import reverse @@ -102,37 +101,51 @@ def _do_third_party_auth(request): raise AuthFailedError(message, error_code='third-party-auth-with-no-linked-account') # lint-amnesty, pylint: disable=raise-missing-from -def _get_user_by_email(request): +def _get_user_by_email(email): """ - Finds a user object in the database based on the given request, ignores all fields except for email. + Finds a user object in the database based on the given email, ignores all fields except for email. """ - if 'email' not in request.POST or 'password' not in request.POST: - raise AuthFailedError(_('There was an error receiving your login information. Please email us.')) - - email = request.POST['email'] - try: return USER_MODEL.objects.get(email=email) except USER_MODEL.DoesNotExist: - digest = hashlib.shake_128(email.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args - AUDIT_LOG.warning(f"Login failed - Unknown user email {digest}") + return None -def _get_user_by_email_or_username(request): +def _get_user_by_username(username): + """ + Finds a user object in the database based on the given username. + """ + try: + return USER_MODEL.objects.get(username=username) + except USER_MODEL.DoesNotExist: + return None + + +def _get_user_by_email_or_username(request, api_version): """ Finds a user object in the database based on the given request, ignores all fields except for email and username. """ - if 'email_or_username' not in request.POST or 'password' not in request.POST: + is_api_v2 = api_version != API_V1 + login_fields = ['email', 'password'] + if is_api_v2: + login_fields = ['email_or_username', 'password'] + + if any(f not in request.POST.keys() for f in login_fields): raise AuthFailedError(_('There was an error receiving your login information. Please email us.')) - email_or_username = request.POST.get('email_or_username', None) - try: - return USER_MODEL.objects.get( - Q(username=email_or_username) | Q(email=email_or_username) - ) - except USER_MODEL.DoesNotExist: + email_or_username = request.POST.get('email', None) + user = _get_user_by_email(email_or_username) + + if not user and is_api_v2: + # If user not found with email and API_V2, try username lookup + email_or_username = request.POST.get('email_or_username', None) + user = _get_user_by_username(email_or_username) + + if not user: digest = hashlib.shake_128(email_or_username.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args - AUDIT_LOG.warning(f"Login failed - Unknown user username/email {digest}") + AUDIT_LOG.warning(f"Login failed - Unknown user email or username {digest}") + + return user def _check_excessive_login_attempts(user): @@ -545,10 +558,8 @@ def login_user(request, api_version='v1'): # user successfully authenticated with a third party provider, but has no linked Open edX account response_content = e.get_response() return JsonResponse(response_content, status=403) - elif api_version == API_V1: - user = _get_user_by_email(request) else: - user = _get_user_by_email_or_username(request) + user = _get_user_by_email_or_username(request, api_version) _check_excessive_login_attempts(user) diff --git a/openedx/core/djangoapps/user_authn/views/tests/test_login.py b/openedx/core/djangoapps/user_authn/views/tests/test_login.py index c23633ee2d..2f1e00a278 100644 --- a/openedx/core/djangoapps/user_authn/views/tests/test_login.py +++ b/openedx/core/djangoapps/user_authn/views/tests/test_login.py @@ -1010,6 +1010,8 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): def setUp(self): super().setUp() self.url = reverse("user_api_login_session", kwargs={'api_version': 'v1'}) + self.url_v2 = reverse("user_api_login_session", kwargs={'api_version': 'v2'}) + self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) @ddt.data("get", "post") def test_auth_disabled(self, method): @@ -1066,9 +1068,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): @ddt.data(True, False) @patch('openedx.core.djangoapps.user_authn.views.login.segment') def test_login(self, include_analytics, mock_segment): - # Create a test user - user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) - data = { "email": self.EMAIL, "password": self.PASSWORD, @@ -1091,7 +1090,7 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): self.assertHttpOK(response) # Verify events are called - expected_user_id = user.id + expected_user_id = self.user.id mock_segment.identify.assert_called_once_with( expected_user_id, {'username': self.USERNAME, 'email': self.EMAIL}, @@ -1104,19 +1103,14 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): ) def test_login_with_username(self): - UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) data = { "email_or_username": self.USERNAME, "password": self.PASSWORD, } - self.url = reverse("user_api_login_session", kwargs={'api_version': 'v2'}) - response = self.client.post(self.url, data) + response = self.client.post(self.url_v2, data) self.assertHttpOK(response) def test_session_cookie_expiry(self): - # Create a test user - UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) - # Login and remember me data = { "email": self.EMAIL, @@ -1132,9 +1126,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): assert expected_expiry.strftime('%d %b %Y') in cookie.get('expires').replace('-', ' ') def test_invalid_credentials(self): - # Create a test user - UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) - # Invalid password response = self.client.post(self.url, { "email": self.EMAIL, @@ -1149,21 +1140,22 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin): }) self.assertHttpBadRequest(response) - def test_missing_login_params(self): - # Create a test user - UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) - + @ddt.data(True, False) + def test_missing_login_params(self, is_api_v1): + email_field_name = "email" if is_api_v1 else "email_or_username" + url = self.url if is_api_v1 else self.url_v2 # Missing password - response = self.client.post(self.url, { - "email": self.EMAIL, + response = self.client.post(url, { + email_field_name: self.EMAIL, }) self.assertHttpBadRequest(response) # Missing email - response = self.client.post(self.url, { + response = self.client.post(url, { "password": self.PASSWORD, }) self.assertHttpBadRequest(response) # Missing both email and password - response = self.client.post(self.url, {}) + response = self.client.post(url, {}) + self.assertHttpBadRequest(response)