refactor: pyupgrade in profile_images, programs, safe_sessions (#26953)

This commit is contained in:
M. Zulqarnain
2021-03-22 17:51:13 +05:00
committed by GitHub
parent 1c19fbf3b3
commit 91d33611b1
35 changed files with 156 additions and 213 deletions

View File

@@ -3,9 +3,6 @@ Exceptions related to the handling of profile images.
"""
from six import text_type
class ImageValidationError(Exception):
"""
Exception to use when the system rejects a user-supplied source image.
@@ -15,4 +12,4 @@ class ImageValidationError(Exception):
"""
Translate the developer-facing exception message for API clients.
"""
return text_type(self)
return str(self)

View File

@@ -9,7 +9,6 @@ from contextlib import closing
from io import BytesIO
import piexif
import six
from django.conf import settings
from django.core.files.base import ContentFile
from django.utils.translation import ugettext as _
@@ -96,25 +95,25 @@ def validate_uploaded_image(uploaded_file):
if uploaded_file.size > settings.PROFILE_IMAGE_MAX_BYTES: # lint-amnesty, pylint: disable=no-else-raise
file_upload_too_large = _(
u'The file must be smaller than {image_max_size} in size.'
'The file must be smaller than {image_max_size} in size.'
).format(
image_max_size=_user_friendly_size(settings.PROFILE_IMAGE_MAX_BYTES)
)
raise ImageValidationError(file_upload_too_large)
elif uploaded_file.size < settings.PROFILE_IMAGE_MIN_BYTES:
file_upload_too_small = _(
u'The file must be at least {image_min_size} in size.'
'The file must be at least {image_min_size} in size.'
).format(
image_min_size=_user_friendly_size(settings.PROFILE_IMAGE_MIN_BYTES)
)
raise ImageValidationError(file_upload_too_small)
# check the file extension looks acceptable
filename = six.text_type(uploaded_file.name).lower()
filename = str(uploaded_file.name).lower()
filetype = [ft for ft in IMAGE_TYPES if any(filename.endswith(ext) for ext in IMAGE_TYPES[ft].extensions)]
if not filetype:
file_upload_bad_type = _(
u'The file must be one of the following types: {valid_file_types}.'
'The file must be one of the following types: {valid_file_types}.'
).format(valid_file_types=_get_valid_file_types())
raise ImageValidationError(file_upload_bad_type)
filetype = filetype[0]
@@ -122,8 +121,8 @@ def validate_uploaded_image(uploaded_file):
# check mimetype matches expected file type
if uploaded_file.content_type not in IMAGE_TYPES[filetype].mimetypes:
file_upload_bad_mimetype = _(
u'The Content-Type header for this file does not match '
u'the file data. The file may be corrupted.'
'The Content-Type header for this file does not match '
'the file data. The file may be corrupted.'
)
raise ImageValidationError(file_upload_bad_mimetype)
@@ -131,8 +130,8 @@ def validate_uploaded_image(uploaded_file):
headers = IMAGE_TYPES[filetype].magic
if binascii.hexlify(uploaded_file.read(len(headers[0]) // 2)).decode('utf-8') not in headers:
file_upload_bad_ext = _(
u'The file name extension for this file does not match '
u'the file data. The file may be corrupted.'
'The file name extension for this file does not match '
'the file data. The file may be corrupted.'
)
raise ImageValidationError(file_upload_bad_ext)
# avoid unexpected errors from subsequent modules expecting the fp to be at 0
@@ -245,4 +244,4 @@ def _user_friendly_size(size):
while size >= 1024 and i < len(units):
size //= 1024
i += 1
return u'{} {}'.format(size, units[i])
return '{} {}'.format(size, units[i])

View File

@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
from django.core.files.uploadedfile import UploadedFile
import piexif
from PIL import Image
from six.moves import range
@contextmanager

View File

@@ -1,9 +1,10 @@
"""
Test cases for image processing functions in the profile image package.
"""
from contextlib import closing
from unittest import mock
import pytest
from contextlib import closing
from itertools import product
import os
from tempfile import NamedTemporaryFile
@@ -12,10 +13,8 @@ from django.core.files.uploadedfile import UploadedFile
from django.test import TestCase
from django.test.utils import override_settings
import ddt
import mock
import piexif
from PIL import Image
from six import text_type
from openedx.core.djangolib.testing.utils import skip_unless_lms
from ..exceptions import ImageValidationError
@@ -37,7 +36,7 @@ class TestValidateUploadedImage(TestCase):
Test validate_uploaded_image
"""
FILE_UPLOAD_BAD_TYPE = (
u'The file must be one of the following types: {valid_file_types}.'.format(
'The file must be one of the following types: {valid_file_types}.'.format(
valid_file_types=_get_valid_file_types()
)
)
@@ -49,16 +48,16 @@ class TestValidateUploadedImage(TestCase):
if expected_failure_message is not None:
with pytest.raises(ImageValidationError) as ctx:
validate_uploaded_image(uploaded_file)
assert text_type(ctx.value) == expected_failure_message
assert str(ctx.value) == expected_failure_message
else:
validate_uploaded_image(uploaded_file)
assert uploaded_file.tell() == 0
@ddt.data(
(99, u"The file must be at least 100 bytes in size."),
(99, "The file must be at least 100 bytes in size."),
(100, ),
(1024, ),
(1025, u"The file must be smaller than 1 KB in size."),
(1025, "The file must be smaller than 1 KB in size."),
)
@ddt.unpack
@override_settings(PROFILE_IMAGE_MIN_BYTES=100, PROFILE_IMAGE_MAX_BYTES=1024)
@@ -93,8 +92,8 @@ class TestValidateUploadedImage(TestCase):
file data.
"""
file_upload_bad_ext = (
u'The file name extension for this file does not match '
u'the file data. The file may be corrupted.'
'The file name extension for this file does not match '
'the file data. The file may be corrupted.'
)
# make a bmp, try to fool the function into thinking it's a jpeg
with make_image_file(extension=".bmp") as bmp_file:
@@ -108,7 +107,7 @@ class TestValidateUploadedImage(TestCase):
)
with pytest.raises(ImageValidationError) as ctx:
validate_uploaded_image(uploaded_file)
assert text_type(ctx.value) == file_upload_bad_ext
assert str(ctx.value) == file_upload_bad_ext
def test_content_type(self):
"""
@@ -116,13 +115,13 @@ class TestValidateUploadedImage(TestCase):
extension do not match
"""
file_upload_bad_mimetype = (
u'The Content-Type header for this file does not match '
u'the file data. The file may be corrupted.'
'The Content-Type header for this file does not match '
'the file data. The file may be corrupted.'
)
with make_uploaded_file(extension=".jpeg", content_type="image/gif") as uploaded_file:
with pytest.raises(ImageValidationError) as ctx:
validate_uploaded_image(uploaded_file)
assert text_type(ctx.value) == file_upload_bad_mimetype
assert str(ctx.value) == file_upload_bad_mimetype
@ddt.ddt

View File

@@ -1,9 +1,11 @@
"""
Test cases for the HTTP endpoints of the profile image api.
"""
from contextlib import closing
from unittest import mock
from unittest.mock import patch
import pytest
from contextlib import closing
import datetime
from pytz import UTC
@@ -11,8 +13,6 @@ from django.urls import reverse
from django.http import HttpResponse
import ddt
import mock
from mock import patch
from PIL import Image
from rest_framework.test import APITestCase, APIClient
@@ -44,7 +44,7 @@ class ProfileImageEndpointMixin(UserSettingsEventTestMixin):
_view_name = None
def setUp(self):
super(ProfileImageEndpointMixin, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory.create(password=TEST_PASSWORD)
# Ensure that parental controls don't apply to this user
self.user.profile.year_of_birth = 1980
@@ -62,7 +62,7 @@ class ProfileImageEndpointMixin(UserSettingsEventTestMixin):
self.reset_tracker()
def tearDown(self):
super(ProfileImageEndpointMixin, self).tearDown() # lint-amnesty, pylint: disable=super-with-arguments
super().tearDown()
for name in get_profile_image_names(self.user.username).values():
self.storage.delete(name)
@@ -212,7 +212,7 @@ class ProfileImageViewPostTestCase(ProfileImageEndpointMixin, APITestCase):
self.url,
data,
content_type=content_type,
HTTP_CONTENT_DISPOSITION='attachment;filename=filename{}'.format(extension),
HTTP_CONTENT_DISPOSITION=f'attachment;filename=filename{extension}',
)
self.check_response(response, 204)
self.check_images()
@@ -297,8 +297,8 @@ class ProfileImageViewPostTestCase(ProfileImageEndpointMixin, APITestCase):
response = self.client.post(self.url, {}, format='multipart')
self.check_response(
response, 400,
expected_developer_message=u"No file provided for profile image",
expected_user_message=u"No file provided for profile image",
expected_developer_message="No file provided for profile image",
expected_user_message="No file provided for profile image",
)
self.check_images(False)
self.check_has_profile_image(False)
@@ -313,8 +313,8 @@ class ProfileImageViewPostTestCase(ProfileImageEndpointMixin, APITestCase):
response = self.client.post(self.url, {'file': 'not a file'}, format='multipart')
self.check_response(
response, 400,
expected_developer_message=u"No file provided for profile image",
expected_user_message=u"No file provided for profile image",
expected_developer_message="No file provided for profile image",
expected_user_message="No file provided for profile image",
)
self.check_images(False)
self.check_has_profile_image(False)
@@ -329,13 +329,13 @@ class ProfileImageViewPostTestCase(ProfileImageEndpointMixin, APITestCase):
with make_image_file() as image_file:
with mock.patch(
'openedx.core.djangoapps.profile_images.views.validate_uploaded_image',
side_effect=ImageValidationError(u"test error message")
side_effect=ImageValidationError("test error message")
):
response = self.client.post(self.url, {'file': image_file}, format='multipart')
self.check_response(
response, 400,
expected_developer_message=u"test error message",
expected_user_message=u"test error message",
expected_developer_message="test error message",
expected_user_message="test error message",
)
self.check_images(False)
self.check_has_profile_image(False)
@@ -348,7 +348,7 @@ class ProfileImageViewPostTestCase(ProfileImageEndpointMixin, APITestCase):
Test that when upload validation fails, the proper HTTP response and
messages are returned.
"""
image_open.side_effect = [Exception(u"whoops"), None]
image_open.side_effect = [Exception("whoops"), None]
with make_image_file() as image_file:
with pytest.raises(Exception):
self.client.post(self.url, {'file': image_file}, format='multipart')
@@ -367,7 +367,7 @@ class ProfileImageViewDeleteTestCase(ProfileImageEndpointMixin, APITestCase):
_view_name = "accounts_profile_image_api"
def setUp(self):
super(ProfileImageViewDeleteTestCase, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
with make_image_file() as image_file:
create_profile_images(image_file, get_profile_image_names(self.user.username))
self.check_images()
@@ -447,7 +447,7 @@ class ProfileImageViewDeleteTestCase(ProfileImageEndpointMixin, APITestCase):
Test that when remove validation fails, the proper HTTP response and
messages are returned.
"""
user_profile_save.side_effect = [Exception(u"whoops"), None]
user_profile_save.side_effect = [Exception("whoops"), None]
with pytest.raises(Exception):
self.client.delete(self.url)
self.check_images(True) # thumbnails should remain intact.

View File

@@ -16,7 +16,6 @@ from rest_framework import permissions, status
from rest_framework.parsers import FormParser, MultiPartParser
from rest_framework.response import Response
from rest_framework.views import APIView
from six import text_type
from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_names, set_has_profile_image
from openedx.core.djangoapps.user_api.errors import UserNotFound
@@ -30,8 +29,8 @@ from .images import IMAGE_TYPES, create_profile_images, remove_profile_images, v
log = logging.getLogger(__name__)
LOG_MESSAGE_CREATE = u'Generated and uploaded images %(image_names)s for user %(user_id)s'
LOG_MESSAGE_DELETE = u'Deleted images %(image_names)s for user %(user_id)s'
LOG_MESSAGE_CREATE = 'Generated and uploaded images %(image_names)s for user %(user_id)s'
LOG_MESSAGE_DELETE = 'Deleted images %(image_names)s for user %(user_id)s'
def _make_upload_dt():
@@ -133,8 +132,8 @@ class ProfileImageView(DeveloperErrorViewMixin, APIView):
if 'file' not in request.FILES:
return Response(
{
"developer_message": u"No file provided for profile image",
"user_message": _(u"No file provided for profile image"),
"developer_message": "No file provided for profile image",
"user_message": _("No file provided for profile image"),
},
status=status.HTTP_400_BAD_REQUEST
@@ -151,7 +150,7 @@ class ProfileImageView(DeveloperErrorViewMixin, APIView):
validate_uploaded_image(uploaded_file)
except ImageValidationError as error:
return Response(
{"developer_message": text_type(error), "user_message": error.user_message},
{"developer_message": str(error), "user_message": error.user_message},
status=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -10,7 +10,7 @@ class ProgramsConfig(AppConfig):
"""
Default configuration for the "openedx.core.djangoapps.programs" Django application.
"""
name = u'openedx.core.djangoapps.programs'
name = 'openedx.core.djangoapps.programs'
def ready(self):
# noinspection PyUnresolvedReferences

View File

@@ -84,10 +84,10 @@ class Command(BaseCommand):
self._load_usernames(users=usernames)
if options.get('commit'):
logger.info(u'Enqueuing program certification tasks for %d candidates.', len(self.usernames))
logger.info('Enqueuing program certification tasks for %d candidates.', len(self.usernames))
else:
logger.info(
u'Found %d candidates. To enqueue program certification tasks, pass the -c or --commit flags.',
'Found %d candidates. To enqueue program certification tasks, pass the -c or --commit flags.',
len(self.usernames)
)
return
@@ -98,14 +98,14 @@ class Command(BaseCommand):
award_program_certificates.delay(username)
except: # pylint: disable=bare-except
failed += 1
logger.exception(u'Failed to enqueue task for user [%s]', username)
logger.exception('Failed to enqueue task for user [%s]', username)
else:
succeeded += 1
logger.debug(u'Successfully enqueued task for user [%s]', username)
logger.debug('Successfully enqueued task for user [%s]', username)
logger.info(
u'Done. Successfully enqueued tasks for %d candidates. '
u'Failed to enqueue tasks for %d candidates.',
'Done. Successfully enqueued tasks for %d candidates. '
'Failed to enqueue tasks for %d candidates.',
succeeded,
failed
)
@@ -117,7 +117,7 @@ class Command(BaseCommand):
programs.extend(get_programs(uuids=program_uuids))
else:
for site in Site.objects.all():
logger.info(u'Loading programs from the catalog for site %s.', site.domain)
logger.info('Loading programs from the catalog for site %s.', site.domain)
programs.extend(get_programs(site))
self.course_runs = self._flatten(programs)

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
from django.db import migrations, models

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.26 on 2019-12-13 07:44

View File

@@ -1,7 +1,4 @@
"""Models providing Programs support for the LMS and Studio."""
import six
from config_models.models import ConfigurationModel
from django.db import models
from django.utils.translation import ugettext_lazy as _
@@ -15,7 +12,7 @@ class ProgramsApiConfig(ConfigurationModel):
.. no_pii:
"""
class Meta(object):
class Meta:
app_label = "programs"
marketing_path = models.CharField(
@@ -31,7 +28,7 @@ class CustomProgramsConfig(ConfigurationModel): # pylint: disable=model-missing
"""
Manages configuration for a run of the backpopulate_program_credentials management command.
"""
class Meta(object):
class Meta:
app_label = 'programs'
verbose_name = 'backpopulate_program_credentials argument'
@@ -42,4 +39,4 @@ class CustomProgramsConfig(ConfigurationModel): # pylint: disable=model-missing
)
def __str__(self):
return six.text_type(self.arguments)
return str(self.arguments)

View File

@@ -51,7 +51,7 @@ def handle_course_cert_awarded(sender, user, course_key, mode, status, **kwargs)
# schedule background task to process
LOGGER.debug(
u'handling COURSE_CERT_AWARDED: username=%s, course_key=%s, mode=%s, status=%s',
'handling COURSE_CERT_AWARDED: username=%s, course_key=%s, mode=%s, status=%s',
user,
course_key,
mode,
@@ -89,13 +89,13 @@ def handle_course_cert_changed(sender, user, course_key, mode, status, **kwargs)
verbose = kwargs.get('verbose', False)
if verbose:
msg = u"Starting handle_course_cert_changed with params: "\
u"sender [{sender}], "\
u"user [{username}], "\
u"course_key [{course_key}], "\
u"mode [{mode}], "\
u"status [{status}], "\
u"kwargs [{kw}]"\
msg = "Starting handle_course_cert_changed with params: "\
"sender [{sender}], "\
"user [{username}], "\
"course_key [{course_key}], "\
"mode [{mode}], "\
"status [{status}], "\
"kwargs [{kw}]"\
.format(
sender=sender,
username=getattr(user, 'username', None),
@@ -118,7 +118,7 @@ def handle_course_cert_changed(sender, user, course_key, mode, status, **kwargs)
if not is_learner_records_enabled_for_org(course_key.org):
if verbose:
LOGGER.info(
u"Skipping send cert: ENABLE_LEARNER_RECORDS False for org [{org}]".format(
"Skipping send cert: ENABLE_LEARNER_RECORDS False for org [{org}]".format(
org=course_key.org
)
)
@@ -126,7 +126,7 @@ def handle_course_cert_changed(sender, user, course_key, mode, status, **kwargs)
# schedule background task to process
LOGGER.debug(
u'handling COURSE_CERT_CHANGED: username=%s, course_key=%s, mode=%s, status=%s',
'handling COURSE_CERT_CHANGED: username=%s, course_key=%s, mode=%s, status=%s',
user,
course_key,
mode,
@@ -169,7 +169,7 @@ def handle_course_cert_revoked(sender, user, course_key, mode, status, **kwargs)
# schedule background task to process
LOGGER.info(
u'handling COURSE_CERT_REVOKED: username=%s, course_key=%s, mode=%s, status=%s',
'handling COURSE_CERT_REVOKED: username=%s, course_key=%s, mode=%s, status=%s',
user,
course_key,
mode,

View File

@@ -366,7 +366,7 @@ def award_course_certificate(self, username, course_run_key, certificate_availab
if certificate.mode in CourseMode.CERTIFICATE_RELEVANT_MODES:
try:
course_overview = CourseOverview.get_from_id(course_key)
except (CourseOverview.DoesNotExist, IOError):
except (CourseOverview.DoesNotExist, OSError):
LOGGER.exception(
f"Task award_course_certificate was called without course overview data for course {course_key}"
)

View File

@@ -6,7 +6,7 @@ import factory
class ProgressFactory(factory.Factory):
class Meta(object):
class Meta:
model = dict
uuid = factory.Faker('uuid4')

View File

@@ -4,7 +4,7 @@
from openedx.core.djangoapps.programs.models import ProgramsApiConfig
class ProgramsApiConfigMixin(object):
class ProgramsApiConfigMixin:
"""Utilities for working with Programs configuration during testing."""
DEFAULTS = {

View File

@@ -1,12 +1,11 @@
"""Tests for the backpopulate_program_credentials management command."""
from unittest import mock
import ddt
import mock
from django.core.management import call_command
from django.test import TestCase
from opaque_keys.edx.keys import CourseKey
from six.moves import range
from common.djangoapps.course_modes.models import CourseMode
from lms.djangoapps.certificates.api import MODES
@@ -42,7 +41,7 @@ class BackpopulateProgramCredentialsTests(CatalogIntegrationMixin, CredentialsAp
SAME_COURSE = 'same_course'
def setUp(self):
super(BackpopulateProgramCredentialsTests, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.alice = UserFactory()
self.bob = UserFactory()

View File

@@ -3,7 +3,8 @@ This module contains tests for programs-related signals and signal handlers.
"""
import datetime
import mock
from unittest import mock
from django.test import TestCase
from opaque_keys.edx.keys import CourseKey
@@ -102,7 +103,7 @@ class CertChangedReceiverTest(TestCase):
"""
def setUp(self):
super(CertChangedReceiverTest, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory.create(username=TEST_USERNAME)
@property

View File

@@ -6,10 +6,11 @@ Tests for programs celery tasks.
import json
import logging
from datetime import datetime, timedelta
from unittest import mock
import pytest
import ddt
import httpretty
import mock
import pytz
from celery.exceptions import MaxRetriesExceededError
from django.conf import settings
@@ -125,7 +126,7 @@ class AwardProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiCo
"""
def setUp(self):
super(AwardProgramCertificatesTestCase, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.create_credentials_config()
self.student = UserFactory.create(username='test-student')
self.site = SiteFactory()
@@ -228,7 +229,7 @@ class AwardProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiCo
Checks that the task is aborted if any relevant api configs are
disabled.
"""
getattr(self, 'create_{}_config'.format(disabled_config_type))(**{disabled_config_attribute: False})
getattr(self, f'create_{disabled_config_type}_config')(**{disabled_config_attribute: False})
with mock.patch(TASKS_MODULE + '.LOGGER.warning') as mock_warning:
with pytest.raises(MaxRetriesExceededError):
tasks.award_program_certificates.delay(self.student.username).get()
@@ -348,7 +349,7 @@ class AwardProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiCo
assert mock_award_program_certificate.call_count == 3
mock_warning.assert_called_once_with(
u'Failed to award certificate for program {uuid} to user {username}.'.format(
'Failed to award certificate for program {uuid} to user {username}.'.format(
uuid=1,
username=self.student.username)
)
@@ -512,7 +513,7 @@ class AwardCourseCertificatesTestCase(CredentialsApiConfigMixin, TestCase):
"""
def setUp(self):
super(AwardCourseCertificatesTestCase, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.available_date = datetime.now(pytz.UTC) + timedelta(days=1)
self.course = CourseOverviewFactory.create(
@@ -663,7 +664,7 @@ class RevokeProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiC
"""
def setUp(self):
super(RevokeProgramCertificatesTestCase, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.student = UserFactory.create(username='test-student')
self.course_key = 'course-v1:testX+test101+2T2020'
@@ -724,7 +725,7 @@ class RevokeProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiC
Checks that the task is aborted if any relevant api configs are
disabled.
"""
getattr(self, 'create_{}_config'.format(disabled_config_type))(**{disabled_config_attribute: False})
getattr(self, f'create_{disabled_config_type}_config')(**{disabled_config_attribute: False})
with mock.patch(TASKS_MODULE + '.LOGGER.warning') as mock_warning:
with pytest.raises(MaxRetriesExceededError):
tasks.revoke_program_certificates.delay(self.student.username, self.course_key).get()
@@ -802,7 +803,7 @@ class RevokeProgramCertificatesTestCase(CatalogIntegrationMixin, CredentialsApiC
assert mock_revoke_program_certificate.call_count == 3
mock_warning.assert_called_once_with(
u'Failed to revoke certificate for program {uuid} of user {username}.'.format(
'Failed to revoke certificate for program {uuid} of user {username}.'.format(
uuid=1,
username=self.student.username)
)

View File

@@ -6,13 +6,11 @@ import json
import uuid
from collections import namedtuple
from copy import deepcopy
from unittest import mock
import ddt
from edx_toggles.toggles.testutils import override_waffle_switch
import httpretty
import mock
import six
from six.moves import range
from django.conf import settings
from django.test import TestCase
from django.test.utils import override_settings
@@ -69,7 +67,7 @@ class TestProgramProgressMeter(TestCase):
"""Tests of the program progress utility class."""
def setUp(self):
super(TestProgramProgressMeter, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory()
self.site = SiteFactory()
@@ -651,8 +649,7 @@ class TestProgramProgressMeter(TestCase):
self._create_certificates(unknown['key'], status='unknown')
meter = ProgramProgressMeter(self.site, self.user)
six.assertCountEqual(
self,
self.assertCountEqual(
meter.completed_course_runs,
[
{'course_run_id': downloadable['key'], 'type': CourseMode.VERIFIED},
@@ -725,7 +722,7 @@ class TestProgramProgressMeter(TestCase):
program_data = meter.engaged_programs[0]
detail_fragment_url = reverse('program_details_fragment_view', kwargs={'program_uuid': program_data['uuid']})
path_id = detail_fragment_url.replace('/dashboard/', '')
expected_url = 'edxapp://enrolled_program_info?path_id={}'.format(path_id)
expected_url = f'edxapp://enrolled_program_info?path_id={path_id}'
assert program_data['detail_url'] == expected_url
@@ -746,7 +743,7 @@ def _create_course(self, course_price, course_run_count=1, make_entitlement=Fals
course.instructor_info = self.instructors
course = self.update_course(course, self.user.id)
run = CourseRunFactory(key=six.text_type(course.id), seats=[SeatFactory(price=course_price)])
run = CourseRunFactory(key=str(course.id), seats=[SeatFactory(price=course_price)])
course_runs.append(run)
entitlements = [EntitlementFactory()] if make_entitlement else []
@@ -775,14 +772,14 @@ class TestProgramDataExtender(ModuleStoreTestCase):
}
def setUp(self):
super(TestProgramDataExtender, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.course = ModuleStoreCourseFactory()
self.course.start = datetime.datetime.now(utc) - datetime.timedelta(days=1)
self.course.end = datetime.datetime.now(utc) + datetime.timedelta(days=1)
self.course = self.update_course(self.course, self.user.id)
self.course_run = CourseRunFactory(key=six.text_type(self.course.id))
self.course_run = CourseRunFactory(key=str(self.course.id))
self.catalog_course = CourseFactory(course_runs=[self.course_run])
self.program = ProgramFactory(courses=[self.catalog_course])
self.course_price = 100
@@ -1081,7 +1078,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
"""
course1 = _create_course(self, self.course_price, course_run_count=2, make_entitlement=True)
course2 = _create_course(self, self.course_price, course_run_count=2, make_entitlement=True)
expected_skus = set([course1['entitlements'][0]['sku'], course2['entitlements'][0]['sku']])
expected_skus = {course1['entitlements'][0]['sku'], course2['entitlements'][0]['sku']}
program = ProgramFactory(
courses=[course1, course2],
is_program_eligible_for_one_click_purchase=True,
@@ -1114,7 +1111,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
course1 = _create_course(self, self.course_price, course_run_count=2, make_entitlement=True)
course2 = _create_course(self, self.course_price, course_run_count=2, make_entitlement=True)
CourseEntitlementFactory(user=self.user, course_uuid=course1['uuid'], mode=CourseMode.VERIFIED)
expected_skus = set([course2['entitlements'][0]['sku']])
expected_skus = {course2['entitlements'][0]['sku']}
program = ProgramFactory(
courses=[course1, course2],
is_program_eligible_for_one_click_purchase=True,
@@ -1155,7 +1152,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
ProgramDataExtender(program_data, self.user).extend()
logger.check(
(LOGGER_NAME,
'WARNING', u'Failed to get course overview for course run key: {}'.format(course_run.get('key')))
'WARNING', 'Failed to get course overview for course run key: {}'.format(course_run.get('key')))
)
def test_entitlement_product_wrong_mode(self):
@@ -1185,7 +1182,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
# The above statement makes a verified entitlement for the course, which is an applicable seat type
# and the statement below makes a professional entitlement for the same course, which is not applicable
course2['entitlements'].append(EntitlementFactory(mode=CourseMode.PROFESSIONAL))
expected_skus = set([course1['course_runs'][0]['seats'][0]['sku'], course2['entitlements'][0]['sku']])
expected_skus = {course1['course_runs'][0]['seats'][0]['sku'], course2['entitlements'][0]['sku']}
program = ProgramFactory(
courses=[course1, course2],
is_program_eligible_for_one_click_purchase=True,
@@ -1202,7 +1199,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
"""
course1 = _create_course(self, self.course_price, make_entitlement=True)
course2 = _create_course(self, self.course_price)
expected_skus = set([course2['course_runs'][0]['seats'][0]['sku']])
expected_skus = {course2['course_runs'][0]['seats'][0]['sku']}
CourseEnrollmentFactory(user=self.user, course_id=course1['course_runs'][0]['key'], mode=CourseMode.VERIFIED)
program = ProgramFactory(
courses=[course1, course2],
@@ -1221,7 +1218,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
course1 = _create_course(self, self.course_price, course_run_count=2)
course2 = _create_course(self, self.course_price, course_run_count=2, make_entitlement=True)
CourseEnrollmentFactory(user=self.user, course_id=course1['course_runs'][0]['key'], mode=CourseMode.VERIFIED)
expected_skus = set([course2['entitlements'][0]['sku']])
expected_skus = {course2['entitlements'][0]['sku']}
program = ProgramFactory(
courses=[course1, course2],
is_program_eligible_for_one_click_purchase=True,
@@ -1236,7 +1233,7 @@ class TestProgramDataExtender(ModuleStoreTestCase):
Verify that correct course url is returned for mobile.
"""
data = ProgramDataExtender(self.program, self.user, mobile_only=True).extend()
expected_course_url = 'edxapp://enrolled_course_info?course_id={}'.format(self.course.id)
expected_course_url = f'edxapp://enrolled_course_info?course_id={self.course.id}'
self._assert_supplemented(data, course_url=expected_course_url)
@@ -1247,7 +1244,7 @@ class TestGetCertificates(TestCase):
Tests of the function used to get certificates associated with a program.
"""
def setUp(self):
super(TestGetCertificates, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory()
self.program = ProgramFactory()
@@ -1351,7 +1348,7 @@ class TestGetCertificates(TestCase):
@skip_unless_lms
class TestProgramMarketingDataExtender(ModuleStoreTestCase):
"""Tests of the program data extender utility class."""
ECOMMERCE_CALCULATE_DISCOUNT_ENDPOINT = '{root}/api/v2/baskets/calculate/'.format(root=ECOMMERCE_URL_ROOT)
ECOMMERCE_CALCULATE_DISCOUNT_ENDPOINT = f'{ECOMMERCE_URL_ROOT}/api/v2/baskets/calculate/'
instructors = {
'instructors': [
{
@@ -1366,7 +1363,7 @@ class TestProgramMarketingDataExtender(ModuleStoreTestCase):
}
def setUp(self):
super(TestProgramMarketingDataExtender, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
# Ensure the E-Commerce service user exists
UserFactory(username=settings.ECOMMERCE_SERVICE_WORKER_USERNAME, is_staff=True)
@@ -1424,7 +1421,7 @@ class TestProgramMarketingDataExtender(ModuleStoreTestCase):
for __ in range(3):
course = ModuleStoreCourseFactory()
course = self.update_course(course, self.user.id)
course_runs.append(CourseRunFactory(key=six.text_type(course.id), seats=[]))
course_runs.append(CourseRunFactory(key=str(course.id), seats=[]))
program = ProgramFactory(courses=[CourseFactory(course_runs=course_runs)])
data = ProgramMarketingDataExtender(program, self.user).extend()
@@ -1492,7 +1489,7 @@ class TestProgramMarketingDataExtender(ModuleStoreTestCase):
content_type='application/json'
)
ProgramMarketingDataExtender(self.program, self.user).extend()
assert httpretty.last_request().querystring.get('is_anonymous')[0] == u'True' # lint-amnesty, pylint: disable=no-member, line-too-long
assert httpretty.last_request().querystring.get('is_anonymous')[0] == 'True' # lint-amnesty, pylint: disable=no-member, line-too-long
@httpretty.activate
def test_fetching_program_discounted_price_as_anonymous_user(self):

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""Helper functions for working with Programs."""
@@ -7,8 +6,8 @@ import logging
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from urllib.parse import urljoin, urlparse, urlunparse
import six
from dateutil.parser import parse
from django.conf import settings
from django.contrib.auth import get_user_model
@@ -20,7 +19,6 @@ from edx_rest_api_client.exceptions import SlumberBaseException
from opaque_keys.edx.keys import CourseKey
from pytz import utc
from requests.exceptions import ConnectionError, Timeout # lint-amnesty, pylint: disable=redefined-builtin
from six.moves.urllib.parse import urljoin, urlparse, urlunparse # pylint: disable=import-error
from common.djangoapps.course_modes.api import get_paid_modes_for_course
from common.djangoapps.course_modes.models import CourseMode
@@ -77,7 +75,7 @@ def attach_program_detail_url(programs, mobile_only=False):
if mobile_only:
detail_fragment_url = reverse('program_details_fragment_view', kwargs={'program_uuid': program['uuid']})
path_id = detail_fragment_url.replace('/dashboard/', '')
detail_url = 'edxapp://enrolled_program_info?path_id={path_id}'.format(path_id=path_id)
detail_url = f'edxapp://enrolled_program_info?path_id={path_id}'
else:
detail_url = reverse('program_details_view', kwargs={'program_uuid': program['uuid']})
@@ -86,7 +84,7 @@ def attach_program_detail_url(programs, mobile_only=False):
return programs
class ProgramProgressMeter(object):
class ProgramProgressMeter:
"""Utility for gauging a user's progress towards program completion.
Arguments:
@@ -110,7 +108,7 @@ class ProgramProgressMeter(object):
self.course_run_ids = []
for enrollment in self.enrollments:
# enrollment.course_id is really a CourseKey (╯ಠ_ಠ╯︵ ┻━┻
enrollment_id = six.text_type(enrollment.course_id)
enrollment_id = str(enrollment.course_id)
mode = enrollment.mode
if mode == CourseMode.NO_ID_PROFESSIONAL_MODE:
mode = CourseMode.PROFESSIONAL
@@ -153,7 +151,7 @@ class ProgramProgressMeter(object):
program_list.append(program)
# Sort programs by title for consistent presentation.
for program_list in six.itervalues(inverted_programs):
for program_list in inverted_programs.values():
program_list.sort(key=lambda p: p['title'])
return inverted_programs
@@ -436,7 +434,7 @@ class ProgramProgressMeter(object):
completed_runs, failed_runs = [], []
for certificate in course_run_certificates:
course_data = {
'course_run_id': six.text_type(certificate['course_key']),
'course_run_id': str(certificate['course_key']),
'type': self._certificate_mode_translation(certificate['type']),
}
@@ -463,7 +461,7 @@ class ProgramProgressMeter(object):
# pylint: disable=missing-docstring
class ProgramDataExtender(object):
class ProgramDataExtender:
"""
Utility for extending program data meant for the program detail page with
user-specific (e.g., CourseEnrollment) data.
@@ -509,7 +507,7 @@ class ProgramDataExtender(object):
try:
self.course_overview = CourseOverview.get_from_id(self.course_run_key)
except CourseOverview.DoesNotExist:
log.warning(u'Failed to get course overview for course run key: %s', course_run.get('key'))
log.warning('Failed to get course overview for course run key: %s', course_run.get('key'))
else:
self.enrollment_start = self.course_overview.enrollment_start or DEFAULT_ENROLLMENT_START_DATE
@@ -592,7 +590,7 @@ class ProgramDataExtender(object):
Returns:
A subset of the given list of course dicts
"""
course_uuids = set(course['uuid'] for course in courses)
course_uuids = {course['uuid'] for course in courses}
# Filter the entitlements' modes with a case-insensitive match against applicable seat_types
entitlements = self.user.courseentitlement_set.filter(
mode__in=self.data['applicable_seat_types'],
@@ -601,7 +599,7 @@ class ProgramDataExtender(object):
# Here we check the entitlements' expired_at_datetime property rather than filter by the expired_at attribute
# to ensure that the expiration status is as up to date as possible
entitlements = [e for e in entitlements if not e.expired_at_datetime]
courses_with_entitlements = set(six.text_type(entitlement.course_uuid) for entitlement in entitlements)
courses_with_entitlements = {str(entitlement.course_uuid) for entitlement in entitlements}
return [course for course in courses if course['uuid'] not in courses_with_entitlements]
def _filter_out_courses_with_enrollments(self, courses):
@@ -619,10 +617,10 @@ class ProgramDataExtender(object):
is_active=True,
mode__in=self.data['applicable_seat_types']
)
course_runs_with_enrollments = set(six.text_type(enrollment.course_id) for enrollment in enrollments)
course_runs_with_enrollments = {str(enrollment.course_id) for enrollment in enrollments}
courses_without_enrollments = []
for course in courses:
if all(six.text_type(run['key']) not in course_runs_with_enrollments for run in course['course_runs']):
if all(str(run['key']) not in course_runs_with_enrollments for run in course['course_runs']):
courses_without_enrollments.append(course)
return courses_without_enrollments
@@ -634,7 +632,7 @@ class ProgramDataExtender(object):
"""
if 'professional' in self.data['applicable_seat_types']:
self.data['applicable_seat_types'].append('no-id-professional')
applicable_seat_types = set(seat for seat in self.data['applicable_seat_types'] if seat != 'credit')
applicable_seat_types = {seat for seat in self.data['applicable_seat_types'] if seat != 'credit'}
is_learner_eligible_for_one_click_purchase = self.data['is_program_eligible_for_one_click_purchase']
bundle_uuid = self.data.get('uuid')
@@ -712,7 +710,7 @@ class ProgramDataExtender(object):
'variant': bundle_variant
})
except (ConnectionError, SlumberBaseException, Timeout):
log.exception(u'Failed to get discount price for following product SKUs: %s ', ', '.join(skus))
log.exception('Failed to get discount price for following product SKUs: %s ', ', '.join(skus))
self.data.update({
'discount_data': {'is_discounted': False}
})
@@ -791,7 +789,7 @@ class ProgramMarketingDataExtender(ProgramDataExtender):
user (User): The user whose enrollments to inspect.
"""
def __init__(self, program_data, user):
super(ProgramMarketingDataExtender, self).__init__(program_data, user) # lint-amnesty, pylint: disable=super-with-arguments
super().__init__(program_data, user)
# Aggregate list of instructors for the program keyed by name
self.instructors = []
@@ -843,7 +841,7 @@ class ProgramMarketingDataExtender(ProgramDataExtender):
def extend(self):
"""Execute extension handlers, returning the extended data."""
self.data.update(super(ProgramMarketingDataExtender, self).extend()) # lint-amnesty, pylint: disable=super-with-arguments
self.data.update(super().extend())
return self.data
@classmethod
@@ -869,7 +867,7 @@ class ProgramMarketingDataExtender(ProgramDataExtender):
def _attach_course_run_upgrade_url(self, run_mode):
if not self.user.is_anonymous:
super(ProgramMarketingDataExtender, self)._attach_course_run_upgrade_url(run_mode) # lint-amnesty, pylint: disable=super-with-arguments
super()._attach_course_run_upgrade_url(run_mode)
else:
run_mode['upgrade_url'] = None

View File

@@ -66,7 +66,6 @@ from contextlib import contextmanager
from hashlib import sha256
from logging import ERROR, getLogger
import six
from django.conf import settings
from django.contrib.auth import SESSION_KEY
from django.contrib.auth.views import redirect_to_login
@@ -78,8 +77,6 @@ from django.utils.deprecation import MiddlewareMixin
from django.utils.encoding import python_2_unicode_compatible
from edx_django_utils.monitoring import set_custom_attribute
from six import text_type # pylint: disable=ungrouped-imports
from openedx.core.lib.mobile_utils import is_request_from_mobile_app
log = getLogger(__name__)
@@ -90,19 +87,19 @@ class SafeCookieError(Exception):
An exception class for safe cookie related errors.
"""
def __init__(self, error_message):
super(SafeCookieError, self).__init__(error_message) # lint-amnesty, pylint: disable=super-with-arguments
super().__init__(error_message)
log.error(error_message)
@python_2_unicode_compatible
class SafeCookieData(object):
class SafeCookieData:
"""
Cookie data that cryptographically binds and timestamps the user
to the session id. It verifies the freshness of the cookie by
checking its creation date using settings.SESSION_COOKIE_AGE.
"""
CURRENT_VERSION = '1'
SEPARATOR = u"|"
SEPARATOR = "|"
def __init__(self, version, session_id, key_salt, signature):
"""
@@ -154,16 +151,16 @@ class SafeCookieData(object):
safe_cookie_string.
"""
try:
raw_cookie_components = six.text_type(safe_cookie_string).split(cls.SEPARATOR)
raw_cookie_components = str(safe_cookie_string).split(cls.SEPARATOR)
safe_cookie_data = SafeCookieData(*raw_cookie_components)
except TypeError:
raise SafeCookieError( # lint-amnesty, pylint: disable=raise-missing-from
u"SafeCookieData BWC parse error: {0!r}.".format(safe_cookie_string)
f"SafeCookieData BWC parse error: {safe_cookie_string!r}."
)
else:
if safe_cookie_data.version != cls.CURRENT_VERSION:
raise SafeCookieError(
u"SafeCookieData version {0!r} is not supported. Current version is {1}.".format(
"SafeCookieData version {!r} is not supported. Current version is {}.".format(
safe_cookie_data.version,
cls.CURRENT_VERSION,
))
@@ -194,12 +191,12 @@ class SafeCookieData(object):
unsigned_data = signing.loads(self.signature, salt=self.key_salt, max_age=settings.SESSION_COOKIE_AGE)
if unsigned_data == self._compute_digest(user_id):
return True
log.error(u"SafeCookieData '%r' is not bound to user '%s'.", six.text_type(self), user_id)
log.error("SafeCookieData '%r' is not bound to user '%s'.", str(self), user_id)
except signing.BadSignature as sig_error:
log.error(
u"SafeCookieData signature error for cookie data {0!r}: {1}".format( # pylint: disable=logging-format-interpolation
six.text_type(self),
text_type(sig_error),
"SafeCookieData signature error for cookie data {!r}: {}".format( # pylint: disable=logging-format-interpolation
str(self),
str(sig_error),
)
)
return False
@@ -210,8 +207,8 @@ class SafeCookieData(object):
"""
hash_func = sha256()
for data_item in [self.version, self.session_id, user_id]:
hash_func.update(six.b(six.text_type(data_item)))
hash_func.update(six.b('|'))
hash_func.update(str(data_item).encode())
hash_func.update(b'|')
return hash_func.hexdigest()
@staticmethod
@@ -224,10 +221,10 @@ class SafeCookieData(object):
# Compare against unicode(None) as well since the 'value'
# property of a cookie automatically serializes None to a
# string.
if not session_id or session_id == six.text_type(None):
if not session_id or session_id == str(None):
# The session ID should always be valid in the cookie.
raise SafeCookieError(
u"SafeCookieData not created due to invalid value for session_id '{}' for user_id '{}'.".format(
"SafeCookieData not created due to invalid value for session_id '{}' for user_id '{}'.".format(
session_id,
user_id,
))
@@ -238,7 +235,7 @@ class SafeCookieData(object):
# as some of the session requests are made as
# Anonymous users.
log.debug(
u"SafeCookieData received empty user_id '%s' for session_id '%s'.",
"SafeCookieData received empty user_id '%s' for session_id '%s'.",
user_id,
session_id,
)
@@ -287,7 +284,7 @@ class SafeSessionMiddleware(SessionMiddleware, MiddlewareMixin):
else:
request.COOKIES[settings.SESSION_COOKIE_NAME] = safe_cookie_data.session_id # Step 2
process_request_response = super(SafeSessionMiddleware, self).process_request(request) # Step 3 # lint-amnesty, pylint: disable=assignment-from-no-return, super-with-arguments
process_request_response = super().process_request(request) # Step 3 # lint-amnesty, pylint: disable=assignment-from-no-return, super-with-arguments
if process_request_response:
# The process_request pipeline has been short circuited so
# return the response.
@@ -326,7 +323,7 @@ class SafeSessionMiddleware(SessionMiddleware, MiddlewareMixin):
Step 4. Delete the cookie, if it's marked for deletion.
"""
response = super(SafeSessionMiddleware, self).process_response(request, response) # Step 1 # lint-amnesty, pylint: disable=super-with-arguments
response = super().process_response(request, response) # Step 1
if not _is_cookie_marked_for_deletion(request) and _is_cookie_present(response):
try:
@@ -441,7 +438,7 @@ class SafeSessionMiddleware(SessionMiddleware, MiddlewareMixin):
)
# Update the cookie's value with the safe_cookie_data.
cookies[settings.SESSION_COOKIE_NAME] = six.text_type(safe_cookie_data)
cookies[settings.SESSION_COOKIE_NAME] = str(safe_cookie_data)
def _mark_cookie_for_deletion(request):
@@ -488,14 +485,14 @@ def _delete_cookie(request, response):
# malicious gets directly dumped into the log.
cookie_header = request.META.get('HTTP_COOKIE', '')[:4096]
log.warning(
u"Malformed Cookie Header? First 4K, in Base64: %s",
b64encode(six.b(cookie_header))
"Malformed Cookie Header? First 4K, in Base64: %s",
b64encode(str(cookie_header).encode())
)
# Note, there is no request.user attribute at this point.
if hasattr(request, 'session') and hasattr(request.session, 'session_key'):
log.warning(
u"SafeCookieData deleted session cookie for session %s",
"SafeCookieData deleted session cookie for session %s",
request.session.session_key
)

View File

@@ -2,9 +2,9 @@
Unit tests for SafeSessionMiddleware
"""
from unittest.mock import patch
import ddt
import six
from crum import set_current_request
from django.conf import settings
from django.contrib.auth import SESSION_KEY
@@ -12,7 +12,6 @@ from django.contrib.auth.models import AnonymousUser
from django.http import HttpResponse, HttpResponseRedirect, SimpleCookie
from django.test import TestCase
from django.test.utils import override_settings
from mock import patch
from openedx.core.djangolib.testing.utils import get_mock_request
from common.djangoapps.student.tests.factories import UserFactory
@@ -27,7 +26,7 @@ class TestSafeSessionProcessRequest(TestSafeSessionsLogMixin, TestCase):
"""
def setUp(self):
super(TestSafeSessionProcessRequest, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory.create()
self.addCleanup(set_current_request, None)
self.request = get_mock_request()
@@ -44,7 +43,7 @@ class TestSafeSessionProcessRequest(TestSafeSessionsLogMixin, TestCase):
Else, verifies a failed response with an HTTP redirect.
"""
if safe_cookie_data:
self.request.COOKIES[settings.SESSION_COOKIE_NAME] = six.text_type(safe_cookie_data)
self.request.COOKIES[settings.SESSION_COOKIE_NAME] = str(safe_cookie_data)
response = SafeSessionMiddleware().process_request(self.request)
if success:
assert response is None
@@ -127,7 +126,7 @@ class TestSafeSessionProcessResponse(TestSafeSessionsLogMixin, TestCase):
"""
def setUp(self):
super(TestSafeSessionProcessResponse, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory.create()
self.addCleanup(set_current_request, None)
self.request = get_mock_request()
@@ -232,7 +231,7 @@ class TestSafeSessionMiddleware(TestSafeSessionsLogMixin, TestCase):
"""
def setUp(self):
super(TestSafeSessionMiddleware, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.user = UserFactory.create()
self.addCleanup(set_current_request, None)
self.request = get_mock_request()
@@ -258,7 +257,7 @@ class TestSafeSessionMiddleware(TestSafeSessionsLogMixin, TestCase):
session_id = self.client.session.session_key
safe_cookie_data = SafeCookieData.create(session_id, self.user.id)
self.request.COOKIES[settings.SESSION_COOKIE_NAME] = six.text_type(safe_cookie_data)
self.request.COOKIES[settings.SESSION_COOKIE_NAME] = str(safe_cookie_data)
with self.assert_not_logged():
response = SafeSessionMiddleware().process_request(self.request)

View File

@@ -6,12 +6,11 @@ Unit tests for SafeCookieData
import itertools
from time import time
from unittest.mock import patch
import pytest
import ddt
import six
from django.test import TestCase
from mock import patch
from six.moves import range # pylint: disable=ungrouped-imports
from ..middleware import SafeCookieData, SafeCookieError
from .test_utils import TestSafeSessionsLogMixin
@@ -24,7 +23,7 @@ class TestSafeCookieData(TestSafeSessionsLogMixin, TestCase):
"""
def setUp(self):
super(TestSafeCookieData, self).setUp() # lint-amnesty, pylint: disable=super-with-arguments
super().setUp()
self.session_id = 'test_session_id'
self.user_id = 'test_user_id'
self.safe_cookie_data = SafeCookieData.create(self.session_id, self.user_id)
@@ -51,7 +50,7 @@ class TestSafeCookieData(TestSafeSessionsLogMixin, TestCase):
assert safe_cookie_data_1.verify(user_id)
# serialize
serialized_value = six.text_type(safe_cookie_data_1)
serialized_value = str(safe_cookie_data_1)
# parse and verify
safe_cookie_data_2 = SafeCookieData.parse(serialized_value)
@@ -64,9 +63,9 @@ class TestSafeCookieData(TestSafeSessionsLogMixin, TestCase):
assert self.safe_cookie_data.version == SafeCookieData.CURRENT_VERSION
def test_serialize(self):
serialized_value = six.text_type(self.safe_cookie_data)
for field_value in six.itervalues(self.safe_cookie_data.__dict__):
assert six.text_type(field_value) in serialized_value
serialized_value = str(self.safe_cookie_data)
for field_value in self.safe_cookie_data.__dict__.values():
assert str(field_value) in serialized_value
#---- Test Parse ----#
@@ -78,7 +77,7 @@ class TestSafeCookieData(TestSafeSessionsLogMixin, TestCase):
)
def test_parse_success_serialized(self):
serialized_value = six.text_type(self.safe_cookie_data)
serialized_value = str(self.safe_cookie_data)
self.assert_cookie_data_equal(
SafeCookieData.parse(serialized_value),
self.safe_cookie_data,
@@ -92,7 +91,7 @@ class TestSafeCookieData(TestSafeSessionsLogMixin, TestCase):
@ddt.data(0, 2, -1, 'invalid_version')
def test_parse_invalid_version(self, version):
serialized_value = '{}|session_id|key_salt|signature'.format(version)
serialized_value = f'{version}|session_id|key_salt|signature'
with self.assert_logged(r"SafeCookieData version .* is not supported."):
with pytest.raises(SafeCookieError):
SafeCookieData.parse(serialized_value)

View File

@@ -4,11 +4,10 @@ Shared test utilities for Safe Sessions tests
from contextlib import contextmanager
from mock import patch
from unittest.mock import patch
class TestSafeSessionsLogMixin(object):
class TestSafeSessionsLogMixin:
"""
Test Mixin class with helpers for testing log method
calls in the safe sessions middleware.