replaced unittest assertions pytest assertions (#26571)

This commit is contained in:
Aarif
2021-02-22 20:03:33 +05:00
committed by GitHub
parent c63b838f71
commit 4ef7d63211
11 changed files with 122 additions and 222 deletions

View File

@@ -1169,23 +1169,9 @@ class RegistrationViewTestV1(ThirdPartyAuthTestMixin, UserAPITestCase):
# Verify that all fields render in the correct order
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"email",
"name",
"username",
"password",
"favorite_movie",
"favorite_editor",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names == ['email', 'name', 'username', 'password', 'favorite_movie', 'favorite_editor',
'city', 'state', 'country', 'gender', 'year_of_birth', 'level_of_education',
'mailing_address', 'goals', 'honor_code']
@override_settings(
REGISTRATION_EXTRA_FIELDS={
@@ -1232,21 +1218,8 @@ class RegistrationViewTestV1(ThirdPartyAuthTestMixin, UserAPITestCase):
# Verify that all fields render in the correct order
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"name",
"username",
"email",
"password",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names == ['name', 'username', 'email', 'password', 'city', 'state', 'country', 'gender',
'year_of_birth', 'level_of_education', 'mailing_address', 'goals', 'honor_code']
@override_settings(
REGISTRATION_EXTRA_FIELDS={
@@ -1286,24 +1259,9 @@ class RegistrationViewTestV1(ThirdPartyAuthTestMixin, UserAPITestCase):
# Verify that all fields render in the correct order
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"email",
"name",
"username",
"password",
"favorite_movie",
"favorite_editor",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names == ['email', 'name', 'username', 'password', 'favorite_movie', 'favorite_editor', 'city',
'state', 'country', 'gender', 'year_of_birth', 'level_of_education',
'mailing_address', 'goals', 'honor_code']
def test_register(self):
# Create a new registration
@@ -1426,8 +1384,8 @@ class RegistrationViewTestV1(ThirdPartyAuthTestMixin, UserAPITestCase):
sent_email = mail.outbox[0]
assert sent_email.to == [self.EMAIL]
assert sent_email.subject ==\
"Action Required: Activate your {platform} account".format(platform=settings.PLATFORM_NAME)
assert "high-quality {platform} courses".format(platform=settings.PLATFORM_NAME) in sent_email.body
u'Action Required: Activate your {platform} account'.format(platform=settings.PLATFORM_NAME)
assert u'high-quality {platform} courses'.format(platform=settings.PLATFORM_NAME) in sent_email.body
@ddt.data(
{"email": ""},
@@ -1668,20 +1626,12 @@ class RegistrationViewTestV1(ThirdPartyAuthTestMixin, UserAPITestCase):
"""
Assert that the actual field and the expected field values match.
"""
self.assertIsNot(
actual_field, None,
msg=u"Could not find field {name}".format(name=expected_field["name"])
)
assert actual_field is not None, "Could not find field {name}".format(name=expected_field["name"])
for key in expected_field:
self.assertEqual(
actual_field[key], expected_field[key],
msg=u"Expected {expected} for {key} but got {actual} instead".format(
key=key,
actual=actual_field[key],
expected=expected_field[key]
)
)
assert actual_field[key] == expected_field[key], \
"Expected {expected} for {key} but got {actual} instead".format(
key=key, actual=actual_field[key], expected=expected_field[key])
def _populate_always_present_fields(self, field):
"""
@@ -1800,24 +1750,9 @@ class RegistrationViewTestV2(RegistrationViewTestV1):
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"email",
"name",
"username",
"password",
"favorite_movie",
"favorite_editor",
"confirm_email",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names == ['email', 'name', 'username', 'password', 'favorite_movie', 'favorite_editor',
'confirm_email', 'city', 'state', 'country', 'gender', 'year_of_birth',
'level_of_education', 'mailing_address', 'goals', 'honor_code']
@override_settings(
REGISTRATION_EXTRA_FIELDS={
@@ -1864,22 +1799,10 @@ class RegistrationViewTestV2(RegistrationViewTestV1):
# Verify that all fields render in the correct order
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"name",
"username",
"email",
"confirm_email",
"password",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names == ['name', 'username', 'email', 'confirm_email',
'password', 'city', 'state', 'country',
'gender', 'year_of_birth', 'level_of_education',
'mailing_address', 'goals', 'honor_code']
@override_settings(
REGISTRATION_EXTRA_FIELDS={
@@ -1903,24 +1826,10 @@ class RegistrationViewTestV2(RegistrationViewTestV1):
# Verify that all fields render in the correct order
form_desc = json.loads(response.content.decode('utf-8'))
field_names = [field["name"] for field in form_desc["fields"]]
assert field_names == [
"email",
"name",
"username",
"password",
"favorite_movie",
"favorite_editor",
"confirm_email",
"city",
"state",
"country",
"gender",
"year_of_birth",
"level_of_education",
"mailing_address",
"goals",
"honor_code",
]
assert field_names ==\
['email', 'name', 'username', 'password', 'favorite_movie', 'favorite_editor', 'confirm_email',
'city', 'state', 'country', 'gender', 'year_of_birth', 'level_of_education', 'mailing_address',
'goals', 'honor_code']
def test_registration_form_confirm_email(self):
self._assert_reg_field(
@@ -2090,7 +1999,7 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
def _verify_user_existence(self, user_exists, social_link_exists, user_is_active=None, username=None):
"""Verifies whether the user object exists."""
users = User.objects.filter(username=(username if username else "test_username"))
self.assertEqual(users.exists(), user_exists)
assert users.exists() == user_exists
if user_exists:
assert users[0].is_active == user_is_active
self.assertEqual(
@@ -2417,6 +2326,6 @@ class RegistrationValidationViewTests(test_utils.ApiTestCase):
"""
for _ in range(int(settings.REGISTRATION_VALIDATION_RATELIMIT.split('/')[0])):
response = self.request_without_auth('post', self.path)
self.assertNotEqual(response.status_code, 403)
assert response.status_code != 403
response = self.request_without_auth('post', self.path)
assert response.status_code == 403

View File

@@ -649,14 +649,11 @@ class PasswordResetViewTest(UserAPITestCase):
form_desc = json.loads(response.content.decode('utf-8'))
assert form_desc['method'] == 'post'
assert form_desc['submit_url'] == reverse('password_change_request')
assert form_desc['fields'] == \
[{'name': 'email',
'defaultValue': '',
'type': 'email',
'required': True,
'label': 'Email',
'placeholder': 'username@domain.com',
'instructions': f'The email address you used to register with {settings.PLATFORM_NAME}',
assert form_desc['fields'] ==\
[{'name': 'email', 'defaultValue': '', 'type': 'email', 'required': True,
'label': 'Email', 'placeholder': 'username@domain.com',
'instructions': u'The email address you used to register with {platform_name}'
.format(platform_name=settings.PLATFORM_NAME),
'restrictions': {'min_length': EMAIL_MIN_LENGTH,
'max_length': EMAIL_MAX_LENGTH},
'errorMessages': {}, 'supplementalText': '',

View File

@@ -11,7 +11,7 @@ class FormTestMixin(object):
thereof) according to expected_valid
"""
form = self.FORM_CLASS(self.form_data, initial=getattr(self, 'initial', None))
self.assertEqual(form.is_valid(), expected_valid)
assert form.is_valid() == expected_valid
return form
def assert_error(self, expected_field, expected_message):
@@ -21,7 +21,7 @@ class FormTestMixin(object):
message
"""
form = self.get_form(expected_valid=False)
self.assertEqual(form.errors, {expected_field: [expected_message]})
assert form.errors == {expected_field: [expected_message]}
def assert_valid(self, expected_cleaned_data):
"""
@@ -36,4 +36,4 @@ class FormTestMixin(object):
that the given field in the cleaned data has the expected value
"""
form = self.get_form(expected_valid=True)
self.assertEqual(form.cleaned_data[field], expected_value)
assert form.cleaned_data[field] == expected_value

View File

@@ -62,7 +62,7 @@ class TestMaintenanceBannerViewDecorator(TestCase):
with override_waffle_switch(DISPLAY_MAINTENANCE_WARNING, active=display_warning):
banner_added, _ = self.add_maintenance_banner()
self.assertEqual(display_warning, banner_added)
assert display_warning == banner_added
@ddt.data(
"If there's somethin' strange in your neighborhood, who ya gonna call?!"
@@ -77,5 +77,5 @@ class TestMaintenanceBannerViewDecorator(TestCase):
with override_settings(MAINTENANCE_BANNER_TEXT=warning_message):
banner_added, banner_message = self.add_maintenance_banner()
self.assertTrue(banner_added)
self.assertEqual(warning_message, banner_message)
assert banner_added
assert warning_message == banner_message

View File

@@ -25,4 +25,4 @@ class TestClearRequestCache(TestCase):
@override_settings(CLEAR_REQUEST_CACHE_ON_TASK_COMPLETION=True)
def test_clear_cache_celery(self):
self._dummy_task.apply(args=(self,)).get()
self.assertFalse(self._get_cache().get_cached_response("cache_key").is_found)
assert not self._get_cache().get_cached_response('cache_key').is_found

View File

@@ -41,8 +41,8 @@ class UserMessagesTestCase(TestCase):
"""
PageLevelMessages.register_user_message(self.request, UserMessageType.INFO, message)
messages = list(PageLevelMessages.user_messages(self.request))
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0].message_html, expected_message_html)
assert len(messages) == 1
assert messages[0].message_html == expected_message_html
@ddt.data(
(UserMessageType.ERROR, 'alert-danger', 'fa fa-warning'),
@@ -57,9 +57,9 @@ class UserMessagesTestCase(TestCase):
"""
PageLevelMessages.register_user_message(self.request, message_type, TEST_MESSAGE)
messages = list(PageLevelMessages.user_messages(self.request))
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0].css_class, expected_css_class)
self.assertEqual(messages[0].icon_class, expected_icon_class)
assert len(messages) == 1
assert messages[0].css_class == expected_css_class
assert messages[0].icon_class == expected_icon_class
@ddt.data(
(normalize_repr(PageLevelMessages.register_error_message), UserMessageType.ERROR),
@@ -74,5 +74,5 @@ class UserMessagesTestCase(TestCase):
"""
register_message_function(self.request, TEST_MESSAGE)
messages = list(PageLevelMessages.user_messages(self.request))
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0].type, expected_message_type)
assert len(messages) == 1
assert messages[0].type == expected_message_type

View File

@@ -16,15 +16,15 @@ class SystemUserTestCase(unittest.TestCase):
self.sysuser = SystemUser()
def test_system_user_is_anonymous(self):
self.assertIsInstance(self.sysuser, AnonymousUser)
self.assertTrue(self.sysuser.is_anonymous)
self.assertIsNone(self.sysuser.id)
assert isinstance(self.sysuser, AnonymousUser)
assert self.sysuser.is_anonymous
assert self.sysuser.id is None
def test_system_user_has_custom_unicode_representation(self):
self.assertNotEqual(six.text_type(self.sysuser), six.text_type(AnonymousUser()))
assert six.text_type(self.sysuser) != six.text_type(AnonymousUser())
def test_system_user_is_not_staff(self):
self.assertFalse(self.sysuser.is_staff)
assert not self.sysuser.is_staff
def test_system_user_is_not_superuser(self):
self.assertFalse(self.sysuser.is_superuser)
assert not self.sysuser.is_superuser

View File

@@ -27,21 +27,15 @@ class TestVerifiedTrackCourseForm(SharedModuleStoreTestCase):
'course_key': six.text_type(self.course.id), 'verified_cohort_name': 'Verified Learners', 'enabled': True
}
form = VerifiedTrackCourseForm(data=form_data)
self.assertTrue(form.is_valid())
assert form.is_valid()
def test_form_validation_failure(self):
form_data = {'course_key': self.FAKE_COURSE, 'verified_cohort_name': 'Verified Learners', 'enabled': True}
form = VerifiedTrackCourseForm(data=form_data)
self.assertFalse(form.is_valid())
self.assertEqual(
form.errors['course_key'],
['COURSE NOT FOUND. Please check that the course ID is valid.']
)
assert not form.is_valid()
assert form.errors['course_key'] == ['COURSE NOT FOUND. Please check that the course ID is valid.']
form_data = {'course_key': self.BAD_COURSE_KEY, 'verified_cohort_name': 'Verified Learners', 'enabled': True}
form = VerifiedTrackCourseForm(data=form_data)
self.assertFalse(form.is_valid())
self.assertEqual(
form.errors['course_key'],
['COURSE NOT FOUND. Please check that the course ID is valid.']
)
assert not form.is_valid()
assert form.errors['course_key'] == ['COURSE NOT FOUND. Please check that the course ID is valid.']

View File

@@ -36,24 +36,24 @@ class TestVerifiedTrackCohortedCourse(TestCase):
def test_course_enabled(self):
course_key = CourseKey.from_string(self.SAMPLE_COURSE)
# Test when no configuration exists
self.assertFalse(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key))
assert not VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key)
# Enable for a course
config = VerifiedTrackCohortedCourse.objects.create(course_key=course_key, enabled=True)
config.save()
self.assertTrue(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key))
assert VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key)
# Disable for the course
config.enabled = False
config.save()
self.assertFalse(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key))
assert not VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(course_key)
def test_unicode(self):
course_key = CourseKey.from_string(self.SAMPLE_COURSE)
# Enable for a course
config = VerifiedTrackCohortedCourse.objects.create(course_key=course_key, enabled=True)
config.save()
self.assertEqual(six.text_type(config), u"Course: {}, enabled: True".format(self.SAMPLE_COURSE))
assert six.text_type(config) == u'Course: {}, enabled: True'.format(self.SAMPLE_COURSE)
def test_verified_cohort_name(self):
cohort_name = 'verified cohort'
@@ -62,12 +62,12 @@ class TestVerifiedTrackCohortedCourse(TestCase):
course_key=course_key, enabled=True, verified_cohort_name=cohort_name
)
config.save()
self.assertEqual(VerifiedTrackCohortedCourse.verified_cohort_name_for_course(course_key), cohort_name)
assert VerifiedTrackCohortedCourse.verified_cohort_name_for_course(course_key) == cohort_name
def test_unset_verified_cohort_name(self):
fake_course_id = 'fake/course/key'
course_key = CourseKey.from_string(fake_course_id)
self.assertEqual(VerifiedTrackCohortedCourse.verified_cohort_name_for_course(course_key), None)
assert VerifiedTrackCohortedCourse.verified_cohort_name_for_course(course_key) is None
@skip_unless_lms
@@ -124,10 +124,10 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
def _verify_no_automatic_cohorting(self):
""" Check that upgrading self.user to verified does not move them into a cohort. """
self._enroll_in_course()
self.assertIsNone(get_cohort(self.user, self.course.id, assign=False))
assert get_cohort(self.user, self.course.id, assign=False) is None
self._upgrade_enrollment()
self.assertIsNone(get_cohort(self.user, self.course.id, assign=False))
self.assertEqual(0, self.mocked_celery_task.call_count)
assert get_cohort(self.user, self.course.id, assign=False) is None
assert 0 == self.mocked_celery_task.call_count
def _unenroll(self):
""" Unenroll self.user from self.course. """
@@ -148,10 +148,10 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._enable_cohorting()
self._create_verified_cohort()
# But do not enable the verified track cohorting feature.
self.assertFalse(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id))
assert not VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id)
self._verify_no_automatic_cohorting()
# No logging occurs if feature is disabled for course.
self.assertFalse(error_logger.called)
assert not error_logger.called
@mock.patch('openedx.core.djangoapps.verified_track_content.models.log.error')
def test_cohorting_enabled_course_not_cohorted(self, error_logger):
@@ -161,10 +161,10 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
"""
# Enable verified track cohorting feature, but course has not been marked as cohorting.
self._enable_verified_track_cohorting()
self.assertTrue(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id))
assert VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id)
self._verify_no_automatic_cohorting()
self.assertTrue(error_logger.called)
self.assertIn("course is not cohorted", error_logger.call_args[0][0])
assert error_logger.called
assert 'course is not cohorted' in error_logger.call_args[0][0]
@mock.patch('openedx.core.djangoapps.verified_track_content.models.log.error')
def test_cohorting_enabled_missing_verified_cohort(self, error_logger):
@@ -177,11 +177,11 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._enable_cohorting()
# Enable verified track cohorting feature
self._enable_verified_track_cohorting()
self.assertTrue(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id))
assert VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id)
self._verify_no_automatic_cohorting()
self.assertTrue(error_logger.called)
assert error_logger.called
error_message = u"cohort named '%s' does not exist"
self.assertIn(error_message, error_logger.call_args[0][0])
assert error_message in error_logger.call_args[0][0]
@ddt.data(CourseMode.VERIFIED, CourseMode.CREDIT_MODE)
def test_automatic_cohorting_enabled(self, upgrade_mode):
@@ -195,15 +195,15 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._create_verified_cohort()
# Enable verified track cohorting feature
self._enable_verified_track_cohorting()
self.assertTrue(VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id))
assert VerifiedTrackCohortedCourse.is_verified_track_cohort_enabled(self.course.id)
self._enroll_in_course()
self.assertEqual(2, self.mocked_celery_task.call_count)
self.assertEqual(DEFAULT_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert 2 == self.mocked_celery_task.call_count
assert DEFAULT_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
self._upgrade_enrollment(upgrade_mode)
self.assertEqual(4, self.mocked_celery_task.call_count)
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert 4 == self.mocked_celery_task.call_count
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
def test_cohorting_enabled_multiple_random_cohorts(self):
"""
@@ -221,13 +221,13 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._enable_verified_track_cohorting()
self._enroll_in_course()
self.assertIn(get_cohort(self.user, self.course.id, assign=False).name, ["Random 1", "Random 2"])
assert get_cohort(self.user, self.course.id, assign=False).name in ['Random 1', 'Random 2']
self._upgrade_enrollment()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
self._unenroll()
self._reenroll()
self.assertIn(get_cohort(self.user, self.course.id, assign=False).name, ["Random 1", "Random 2"])
assert get_cohort(self.user, self.course.id, assign=False).name in ['Random 1', 'Random 2']
def test_unenrolled(self):
"""
@@ -240,15 +240,15 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._enable_verified_track_cohorting()
self._enroll_in_course()
self._upgrade_enrollment()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
# Un-enroll from the course and then re-enroll
self._unenroll()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
self._reenroll()
self.assertEqual(DEFAULT_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
self._upgrade_enrollment()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
def test_custom_verified_cohort_name(self):
"""
@@ -260,7 +260,7 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
self._enable_verified_track_cohorting(cohort_name=custom_cohort_name)
self._enroll_in_course()
self._upgrade_enrollment()
self.assertEqual(custom_cohort_name, get_cohort(self.user, self.course.id, assign=False).name)
assert custom_cohort_name == get_cohort(self.user, self.course.id, assign=False).name
def test_custom_default_cohort_name(self):
"""
@@ -273,13 +273,13 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
default_cohort = self._create_named_random_cohort(random_cohort_name)
self._enable_verified_track_cohorting()
self._enroll_in_course()
self.assertEqual(random_cohort_name, get_cohort(self.user, self.course.id, assign=False).name)
assert random_cohort_name == get_cohort(self.user, self.course.id, assign=False).name
self._upgrade_enrollment()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
# Un-enroll from the course. The learner stays in the verified cohort, but is no longer active.
self._unenroll()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name
# Change the name of the "default" cohort.
modified_cohort_name = "renamed random cohort"
@@ -288,6 +288,6 @@ class TestMoveToVerified(SharedModuleStoreTestCase):
# Re-enroll in the course, which will downgrade the learner to audit.
self._reenroll()
self.assertEqual(modified_cohort_name, get_cohort(self.user, self.course.id, assign=False).name)
assert modified_cohort_name == get_cohort(self.user, self.course.id, assign=False).name
self._upgrade_enrollment()
self.assertEqual(DEFAULT_VERIFIED_COHORT_NAME, get_cohort(self.user, self.course.id, assign=False).name)
assert DEFAULT_VERIFIED_COHORT_NAME == get_cohort(self.user, self.course.id, assign=False).name

View File

@@ -7,7 +7,7 @@ from datetime import datetime, timedelta
import pytz
import six
import pytest
from common.djangoapps.course_modes.models import CourseMode
from common.djangoapps.student.models import CourseEnrollment
from common.djangoapps.student.tests.factories import UserFactory
@@ -32,13 +32,13 @@ class EnrollmentTrackUserPartitionTest(SharedModuleStoreTestCase):
def test_only_default_mode(self):
partition = create_enrollment_track_partition(self.course)
groups = partition.groups
self.assertEqual(1, len(groups))
self.assertEqual("Audit", groups[0].name)
assert 1 == len(groups)
assert 'Audit' == groups[0].name
def test_using_verified_track_cohort(self):
VerifiedTrackCohortedCourse.objects.create(course_key=self.course.id, enabled=True).save()
partition = create_enrollment_track_partition(self.course)
self.assertEqual(0, len(partition.groups))
assert 0 == len(partition.groups)
def test_multiple_groups(self):
create_mode(self.course, CourseMode.AUDIT, "Audit Enrollment Track", min_price=0)
@@ -53,19 +53,19 @@ class EnrollmentTrackUserPartitionTest(SharedModuleStoreTestCase):
partition = create_enrollment_track_partition(self.course)
groups = partition.groups
self.assertEqual(2, len(groups))
self.assertIsNotNone(self.get_group_by_name(partition, "Audit Enrollment Track"))
self.assertIsNotNone(self.get_group_by_name(partition, "Verified Enrollment Track"))
assert 2 == len(groups)
assert self.get_group_by_name(partition, 'Audit Enrollment Track') is not None
assert self.get_group_by_name(partition, 'Verified Enrollment Track') is not None
def test_to_json_supported(self):
user_partition_json = create_enrollment_track_partition(self.course).to_json()
self.assertEqual('Test Enrollment Track Partition', user_partition_json['name'])
self.assertEqual('enrollment_track', user_partition_json['scheme'])
self.assertEqual('Test partition for segmenting users by enrollment track', user_partition_json['description'])
assert 'Test Enrollment Track Partition' == user_partition_json['name']
assert 'enrollment_track' == user_partition_json['scheme']
assert 'Test partition for segmenting users by enrollment track' == user_partition_json['description']
def test_from_json_not_supported(self):
user_partition_json = create_enrollment_track_partition(self.course).to_json()
with self.assertRaises(ReadOnlyUserPartitionError):
with pytest.raises(ReadOnlyUserPartitionError):
UserPartition.from_json(user_partition_json)
def test_group_ids(self):
@@ -74,7 +74,7 @@ class EnrollmentTrackUserPartitionTest(SharedModuleStoreTestCase):
with group IDs associated with cohort and random user partitions).
"""
for mode in ENROLLMENT_GROUP_IDS:
self.assertLess(ENROLLMENT_GROUP_IDS[mode]['id'], MINIMUM_STATIC_PARTITION_ID)
assert ENROLLMENT_GROUP_IDS[mode]['id'] < MINIMUM_STATIC_PARTITION_ID
@staticmethod
def get_group_by_name(partition, name):
@@ -103,34 +103,34 @@ class EnrollmentTrackPartitionSchemeTest(SharedModuleStoreTestCase):
"""
Ensure that the scheme extension is correctly plugged in (via entry point in setup.py)
"""
self.assertEqual(UserPartition.get_scheme('enrollment_track'), EnrollmentTrackPartitionScheme)
assert UserPartition.get_scheme('enrollment_track') == EnrollmentTrackPartitionScheme
def test_create_user_partition(self):
user_partition = UserPartition.get_scheme('enrollment_track').create_user_partition(
301, "partition", "test partition", parameters={"course_id": six.text_type(self.course.id)}
)
self.assertEqual(type(user_partition), EnrollmentTrackUserPartition)
self.assertEqual(user_partition.name, "partition")
assert isinstance(user_partition, EnrollmentTrackUserPartition)
assert user_partition.name == 'partition'
groups = user_partition.groups
self.assertEqual(1, len(groups))
self.assertEqual("Audit", groups[0].name)
assert 1 == len(groups)
assert 'Audit' == groups[0].name
def test_not_enrolled(self):
self.assertIsNone(self._get_user_group())
assert self._get_user_group() is None
def test_default_enrollment(self):
CourseEnrollment.enroll(self.student, self.course.id)
self.assertEqual("Audit", self._get_user_group().name)
assert 'Audit' == self._get_user_group().name
def test_enrolled_in_nonexistent_mode(self):
CourseEnrollment.enroll(self.student, self.course.id, mode=CourseMode.VERIFIED)
self.assertEqual("Audit", self._get_user_group().name)
assert 'Audit' == self._get_user_group().name
def test_enrolled_in_verified(self):
create_mode(self.course, CourseMode.VERIFIED, "Verified Enrollment Track", min_price=1)
CourseEnrollment.enroll(self.student, self.course.id, mode=CourseMode.VERIFIED)
self.assertEqual("Verified Enrollment Track", self._get_user_group().name)
assert 'Verified Enrollment Track' == self._get_user_group().name
def test_enrolled_in_expired(self):
create_mode(
@@ -138,18 +138,18 @@ class EnrollmentTrackPartitionSchemeTest(SharedModuleStoreTestCase):
min_price=1, expiration_datetime=datetime.now(pytz.UTC) + timedelta(days=-1)
)
CourseEnrollment.enroll(self.student, self.course.id, mode=CourseMode.VERIFIED)
self.assertEqual("Verified Enrollment Track", self._get_user_group().name)
assert 'Verified Enrollment Track' == self._get_user_group().name
def test_enrolled_in_non_selectable(self):
create_mode(self.course, CourseMode.CREDIT_MODE, "Credit Enrollment Track", min_price=1)
CourseEnrollment.enroll(self.student, self.course.id, mode=CourseMode.CREDIT_MODE)
# The default mode is returned because Credit mode is filtered out, and no verified mode exists.
self.assertEqual("Audit", self._get_user_group().name)
assert 'Audit' == self._get_user_group().name
# Now create a verified mode and check that it is returned for the learner enrolled in Credit.
create_mode(self.course, CourseMode.VERIFIED, "Verified Enrollment Track", min_price=1)
self.assertEqual("Verified Enrollment Track", self._get_user_group().name)
assert 'Verified Enrollment Track' == self._get_user_group().name
def test_credit_after_upgrade_deadline(self):
create_mode(self.course, CourseMode.CREDIT_MODE, "Credit Enrollment Track", min_price=1)
@@ -162,12 +162,12 @@ class EnrollmentTrackPartitionSchemeTest(SharedModuleStoreTestCase):
self.course, CourseMode.VERIFIED, "Verified Enrollment Track", min_price=1,
expiration_datetime=datetime.now(pytz.UTC) + timedelta(days=-1)
)
self.assertEqual("Verified Enrollment Track", self._get_user_group().name)
assert 'Verified Enrollment Track' == self._get_user_group().name
def test_using_verified_track_cohort(self):
VerifiedTrackCohortedCourse.objects.create(course_key=self.course.id, enabled=True).save()
CourseEnrollment.enroll(self.student, self.course.id)
self.assertIsNone(self._get_user_group())
assert self._get_user_group() is None
def _get_user_group(self):
"""

View File

@@ -5,7 +5,7 @@ Tests for verified track content views.
import json
import six
import pytest
from django.http import Http404
from django.test.client import RequestFactory
@@ -35,7 +35,7 @@ class CohortingSettingsTestCase(SharedModuleStoreTestCase):
"""
request = RequestFactory().get("dummy_url")
request.user = UserFactory()
with self.assertRaises(Http404):
with pytest.raises(Http404):
cohorting_settings(request, six.text_type(self.course.id))
def test_cohorting_settings_enabled(self):
@@ -67,5 +67,5 @@ class CohortingSettingsTestCase(SharedModuleStoreTestCase):
request = RequestFactory().get("dummy_url")
request.user = AdminFactory()
response = cohorting_settings(request, six.text_type(self.course.id))
self.assertEqual(200, response.status_code)
self.assertEqual(expected_response, json.loads(response.content.decode('utf-8')))
assert 200 == response.status_code
assert expected_response == json.loads(response.content.decode('utf-8'))