refactor: separate user lookup queries for login (#29711)
Using a single query to get a user using both username and email fields generates a massive `key_len` and causes DB overload. Separated these lookups into two separate queries. VAN-819
This commit is contained in:
@@ -15,7 +15,6 @@ from django.contrib import admin
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.contrib.auth import login as django_login
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.db.models import Q
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseForbidden
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import reverse
|
||||
@@ -102,37 +101,51 @@ def _do_third_party_auth(request):
|
||||
raise AuthFailedError(message, error_code='third-party-auth-with-no-linked-account') # lint-amnesty, pylint: disable=raise-missing-from
|
||||
|
||||
|
||||
def _get_user_by_email(request):
|
||||
def _get_user_by_email(email):
|
||||
"""
|
||||
Finds a user object in the database based on the given request, ignores all fields except for email.
|
||||
Finds a user object in the database based on the given email, ignores all fields except for email.
|
||||
"""
|
||||
if 'email' not in request.POST or 'password' not in request.POST:
|
||||
raise AuthFailedError(_('There was an error receiving your login information. Please email us.'))
|
||||
|
||||
email = request.POST['email']
|
||||
|
||||
try:
|
||||
return USER_MODEL.objects.get(email=email)
|
||||
except USER_MODEL.DoesNotExist:
|
||||
digest = hashlib.shake_128(email.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args
|
||||
AUDIT_LOG.warning(f"Login failed - Unknown user email {digest}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_user_by_email_or_username(request):
|
||||
def _get_user_by_username(username):
|
||||
"""
|
||||
Finds a user object in the database based on the given username.
|
||||
"""
|
||||
try:
|
||||
return USER_MODEL.objects.get(username=username)
|
||||
except USER_MODEL.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def _get_user_by_email_or_username(request, api_version):
|
||||
"""
|
||||
Finds a user object in the database based on the given request, ignores all fields except for email and username.
|
||||
"""
|
||||
if 'email_or_username' not in request.POST or 'password' not in request.POST:
|
||||
is_api_v2 = api_version != API_V1
|
||||
login_fields = ['email', 'password']
|
||||
if is_api_v2:
|
||||
login_fields = ['email_or_username', 'password']
|
||||
|
||||
if any(f not in request.POST.keys() for f in login_fields):
|
||||
raise AuthFailedError(_('There was an error receiving your login information. Please email us.'))
|
||||
|
||||
email_or_username = request.POST.get('email_or_username', None)
|
||||
try:
|
||||
return USER_MODEL.objects.get(
|
||||
Q(username=email_or_username) | Q(email=email_or_username)
|
||||
)
|
||||
except USER_MODEL.DoesNotExist:
|
||||
email_or_username = request.POST.get('email', None)
|
||||
user = _get_user_by_email(email_or_username)
|
||||
|
||||
if not user and is_api_v2:
|
||||
# If user not found with email and API_V2, try username lookup
|
||||
email_or_username = request.POST.get('email_or_username', None)
|
||||
user = _get_user_by_username(email_or_username)
|
||||
|
||||
if not user:
|
||||
digest = hashlib.shake_128(email_or_username.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args
|
||||
AUDIT_LOG.warning(f"Login failed - Unknown user username/email {digest}")
|
||||
AUDIT_LOG.warning(f"Login failed - Unknown user email or username {digest}")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _check_excessive_login_attempts(user):
|
||||
@@ -545,10 +558,8 @@ def login_user(request, api_version='v1'):
|
||||
# user successfully authenticated with a third party provider, but has no linked Open edX account
|
||||
response_content = e.get_response()
|
||||
return JsonResponse(response_content, status=403)
|
||||
elif api_version == API_V1:
|
||||
user = _get_user_by_email(request)
|
||||
else:
|
||||
user = _get_user_by_email_or_username(request)
|
||||
user = _get_user_by_email_or_username(request, api_version)
|
||||
|
||||
_check_excessive_login_attempts(user)
|
||||
|
||||
|
||||
@@ -1010,6 +1010,8 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.url = reverse("user_api_login_session", kwargs={'api_version': 'v1'})
|
||||
self.url_v2 = reverse("user_api_login_session", kwargs={'api_version': 'v2'})
|
||||
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
@ddt.data("get", "post")
|
||||
def test_auth_disabled(self, method):
|
||||
@@ -1066,9 +1068,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
@ddt.data(True, False)
|
||||
@patch('openedx.core.djangoapps.user_authn.views.login.segment')
|
||||
def test_login(self, include_analytics, mock_segment):
|
||||
# Create a test user
|
||||
user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
data = {
|
||||
"email": self.EMAIL,
|
||||
"password": self.PASSWORD,
|
||||
@@ -1091,7 +1090,7 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
self.assertHttpOK(response)
|
||||
|
||||
# Verify events are called
|
||||
expected_user_id = user.id
|
||||
expected_user_id = self.user.id
|
||||
mock_segment.identify.assert_called_once_with(
|
||||
expected_user_id,
|
||||
{'username': self.USERNAME, 'email': self.EMAIL},
|
||||
@@ -1104,19 +1103,14 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
)
|
||||
|
||||
def test_login_with_username(self):
|
||||
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
data = {
|
||||
"email_or_username": self.USERNAME,
|
||||
"password": self.PASSWORD,
|
||||
}
|
||||
self.url = reverse("user_api_login_session", kwargs={'api_version': 'v2'})
|
||||
response = self.client.post(self.url, data)
|
||||
response = self.client.post(self.url_v2, data)
|
||||
self.assertHttpOK(response)
|
||||
|
||||
def test_session_cookie_expiry(self):
|
||||
# Create a test user
|
||||
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
# Login and remember me
|
||||
data = {
|
||||
"email": self.EMAIL,
|
||||
@@ -1132,9 +1126,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
assert expected_expiry.strftime('%d %b %Y') in cookie.get('expires').replace('-', ' ')
|
||||
|
||||
def test_invalid_credentials(self):
|
||||
# Create a test user
|
||||
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
# Invalid password
|
||||
response = self.client.post(self.url, {
|
||||
"email": self.EMAIL,
|
||||
@@ -1149,21 +1140,22 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
|
||||
})
|
||||
self.assertHttpBadRequest(response)
|
||||
|
||||
def test_missing_login_params(self):
|
||||
# Create a test user
|
||||
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
|
||||
|
||||
@ddt.data(True, False)
|
||||
def test_missing_login_params(self, is_api_v1):
|
||||
email_field_name = "email" if is_api_v1 else "email_or_username"
|
||||
url = self.url if is_api_v1 else self.url_v2
|
||||
# Missing password
|
||||
response = self.client.post(self.url, {
|
||||
"email": self.EMAIL,
|
||||
response = self.client.post(url, {
|
||||
email_field_name: self.EMAIL,
|
||||
})
|
||||
self.assertHttpBadRequest(response)
|
||||
|
||||
# Missing email
|
||||
response = self.client.post(self.url, {
|
||||
response = self.client.post(url, {
|
||||
"password": self.PASSWORD,
|
||||
})
|
||||
self.assertHttpBadRequest(response)
|
||||
|
||||
# Missing both email and password
|
||||
response = self.client.post(self.url, {})
|
||||
response = self.client.post(url, {})
|
||||
self.assertHttpBadRequest(response)
|
||||
|
||||
Reference in New Issue
Block a user