refactor: separate user lookup queries for login (#29711)

Using a single query to get a user using both username and email fields
generates a massive `key_len` and causes DB overload. Separated these
lookups into two separate queries.
VAN-819
This commit is contained in:
Waheed Ahmed
2022-01-07 11:06:07 +05:00
committed by GitHub
parent 8fa1d4d0b6
commit ed45aee9dd
2 changed files with 46 additions and 43 deletions

View File

@@ -15,7 +15,6 @@ from django.contrib import admin
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth import login as django_login
from django.contrib.auth.decorators import login_required
from django.db.models import Q
from django.http import HttpRequest, HttpResponse, HttpResponseForbidden
from django.shortcuts import redirect
from django.urls import reverse
@@ -102,37 +101,51 @@ def _do_third_party_auth(request):
raise AuthFailedError(message, error_code='third-party-auth-with-no-linked-account') # lint-amnesty, pylint: disable=raise-missing-from
def _get_user_by_email(request):
def _get_user_by_email(email):
"""
Finds a user object in the database based on the given request, ignores all fields except for email.
Finds a user object in the database based on the given email, ignores all fields except for email.
"""
if 'email' not in request.POST or 'password' not in request.POST:
raise AuthFailedError(_('There was an error receiving your login information. Please email us.'))
email = request.POST['email']
try:
return USER_MODEL.objects.get(email=email)
except USER_MODEL.DoesNotExist:
digest = hashlib.shake_128(email.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args
AUDIT_LOG.warning(f"Login failed - Unknown user email {digest}")
return None
def _get_user_by_email_or_username(request):
def _get_user_by_username(username):
"""
Finds a user object in the database based on the given username.
"""
try:
return USER_MODEL.objects.get(username=username)
except USER_MODEL.DoesNotExist:
return None
def _get_user_by_email_or_username(request, api_version):
"""
Finds a user object in the database based on the given request, ignores all fields except for email and username.
"""
if 'email_or_username' not in request.POST or 'password' not in request.POST:
is_api_v2 = api_version != API_V1
login_fields = ['email', 'password']
if is_api_v2:
login_fields = ['email_or_username', 'password']
if any(f not in request.POST.keys() for f in login_fields):
raise AuthFailedError(_('There was an error receiving your login information. Please email us.'))
email_or_username = request.POST.get('email_or_username', None)
try:
return USER_MODEL.objects.get(
Q(username=email_or_username) | Q(email=email_or_username)
)
except USER_MODEL.DoesNotExist:
email_or_username = request.POST.get('email', None)
user = _get_user_by_email(email_or_username)
if not user and is_api_v2:
# If user not found with email and API_V2, try username lookup
email_or_username = request.POST.get('email_or_username', None)
user = _get_user_by_username(email_or_username)
if not user:
digest = hashlib.shake_128(email_or_username.encode('utf-8')).hexdigest(16) # pylint: disable=too-many-function-args
AUDIT_LOG.warning(f"Login failed - Unknown user username/email {digest}")
AUDIT_LOG.warning(f"Login failed - Unknown user email or username {digest}")
return user
def _check_excessive_login_attempts(user):
@@ -545,10 +558,8 @@ def login_user(request, api_version='v1'):
# user successfully authenticated with a third party provider, but has no linked Open edX account
response_content = e.get_response()
return JsonResponse(response_content, status=403)
elif api_version == API_V1:
user = _get_user_by_email(request)
else:
user = _get_user_by_email_or_username(request)
user = _get_user_by_email_or_username(request, api_version)
_check_excessive_login_attempts(user)

View File

@@ -1010,6 +1010,8 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
def setUp(self):
super().setUp()
self.url = reverse("user_api_login_session", kwargs={'api_version': 'v1'})
self.url_v2 = reverse("user_api_login_session", kwargs={'api_version': 'v2'})
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
@ddt.data("get", "post")
def test_auth_disabled(self, method):
@@ -1066,9 +1068,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
@ddt.data(True, False)
@patch('openedx.core.djangoapps.user_authn.views.login.segment')
def test_login(self, include_analytics, mock_segment):
# Create a test user
user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
data = {
"email": self.EMAIL,
"password": self.PASSWORD,
@@ -1091,7 +1090,7 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
self.assertHttpOK(response)
# Verify events are called
expected_user_id = user.id
expected_user_id = self.user.id
mock_segment.identify.assert_called_once_with(
expected_user_id,
{'username': self.USERNAME, 'email': self.EMAIL},
@@ -1104,19 +1103,14 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
)
def test_login_with_username(self):
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
data = {
"email_or_username": self.USERNAME,
"password": self.PASSWORD,
}
self.url = reverse("user_api_login_session", kwargs={'api_version': 'v2'})
response = self.client.post(self.url, data)
response = self.client.post(self.url_v2, data)
self.assertHttpOK(response)
def test_session_cookie_expiry(self):
# Create a test user
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
# Login and remember me
data = {
"email": self.EMAIL,
@@ -1132,9 +1126,6 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
assert expected_expiry.strftime('%d %b %Y') in cookie.get('expires').replace('-', ' ')
def test_invalid_credentials(self):
# Create a test user
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
# Invalid password
response = self.client.post(self.url, {
"email": self.EMAIL,
@@ -1149,21 +1140,22 @@ class LoginSessionViewTest(ApiTestCase, OpenEdxEventsTestMixin):
})
self.assertHttpBadRequest(response)
def test_missing_login_params(self):
# Create a test user
UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
@ddt.data(True, False)
def test_missing_login_params(self, is_api_v1):
email_field_name = "email" if is_api_v1 else "email_or_username"
url = self.url if is_api_v1 else self.url_v2
# Missing password
response = self.client.post(self.url, {
"email": self.EMAIL,
response = self.client.post(url, {
email_field_name: self.EMAIL,
})
self.assertHttpBadRequest(response)
# Missing email
response = self.client.post(self.url, {
response = self.client.post(url, {
"password": self.PASSWORD,
})
self.assertHttpBadRequest(response)
# Missing both email and password
response = self.client.post(self.url, {})
response = self.client.post(url, {})
self.assertHttpBadRequest(response)