diff --git a/openedx/core/djangoapps/schedules/resolvers.py b/openedx/core/djangoapps/schedules/resolvers.py index 4ad324c77a..be8a938df1 100644 --- a/openedx/core/djangoapps/schedules/resolvers.py +++ b/openedx/core/djangoapps/schedules/resolvers.py @@ -78,7 +78,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): def __attrs_post_init__(self): # TODO: in the next refactor of this task, pass in current_datetime instead of reproducing it here self.current_datetime = self.target_datetime - datetime.timedelta(days=self.day_offset) - self.exclude_orgs, self.org_list = self.get_course_org_filter() def send(self, msg_type): for (user, language, context) in self.schedules_for_bin(): @@ -133,11 +132,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): **schedule_day_equals_target_day_filter ).order_by(order_by) - if self.org_list is not None: - if self.exclude_orgs: - schedules = schedules.exclude(enrollment__course__org__in=self.org_list) - else: - schedules = schedules.filter(enrollment__course__org__in=self.org_list) + schedules = self.filter_by_org(schedules) if "read_replica" in settings.DATABASES: schedules = schedules.using("read_replica") @@ -153,7 +148,7 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): return schedules - def get_course_org_filter(self): + def filter_by_org(self, schedules): """ Given the configuration of sites, get the list of orgs that should be included or excluded from this send. @@ -165,7 +160,6 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): try: site_config = self.site.configuration org_list = site_config.get_value('course_org_filter') - exclude_orgs = False if not org_list: not_orgs = set() for other_site_config in SiteConfiguration.objects.all(): @@ -175,15 +169,13 @@ class BinnedSchedulesBaseResolver(PrefixedDebugLoggerMixin, RecipientResolver): not_orgs.add(other) else: not_orgs.update(other) - org_list = list(not_orgs) - exclude_orgs = True + return schedules.exclude(enrollment__course__org__in=not_orgs) elif not isinstance(org_list, list): - org_list = [org_list] + return schedules.filter(enrollment__course__org=org_list) except SiteConfiguration.DoesNotExist: - org_list = None - exclude_orgs = False + return schedules - return exclude_orgs, org_list + return schedules.filter(enrollment__course__org__in=org_list) def schedules_for_bin(self): schedules = self.get_schedules_with_target_date_by_bin_and_orgs() diff --git a/openedx/core/djangoapps/schedules/tests/test_resolvers.py b/openedx/core/djangoapps/schedules/tests/test_resolvers.py index 3144a45b02..e64844c375 100644 --- a/openedx/core/djangoapps/schedules/tests/test_resolvers.py +++ b/openedx/core/djangoapps/schedules/tests/test_resolvers.py @@ -30,28 +30,40 @@ class TestBinnedSchedulesBaseResolver(CacheIsolationTestCase): bin_num=2, ) - @ddt.unpack @ddt.data( - ('course1', ['course1']), - (['course1', 'course2'], ['course1', 'course2']) + 'course1' ) - def test_get_course_org_filter_include(self, course_org_filter, expected_org_list): + def test_get_course_org_filter_equal(self, course_org_filter): self.site_config.values['course_org_filter'] = course_org_filter self.site_config.save() - exclude_orgs, org_list = self.resolver.get_course_org_filter() - assert not exclude_orgs - assert org_list == expected_org_list + mock_query = Mock() + result = self.resolver.filter_by_org(mock_query) + self.assertEqual(result, mock_query.filter.return_value) + mock_query.filter.assert_called_once_with(enrollment__course__org=course_org_filter) @ddt.unpack @ddt.data( - (None, []), - ('course1', [u'course1']), - (['course1', 'course2'], [u'course1', u'course2']) + (['course1', 'course2'], ['course1', 'course2']) ) - def test_get_course_org_filter_exclude(self, course_org_filter, expected_org_list): + def test_get_course_org_filter_include__in(self, course_org_filter, expected_org_list): + self.site_config.values['course_org_filter'] = course_org_filter + self.site_config.save() + mock_query = Mock() + result = self.resolver.filter_by_org(mock_query) + self.assertEqual(result, mock_query.filter.return_value) + mock_query.filter.assert_called_once_with(enrollment__course__org__in=expected_org_list) + + @ddt.unpack + @ddt.data( + (None, set([])), + ('course1', set([u'course1'])), + (['course1', 'course2'], set([u'course1', u'course2'])) + ) + def test_get_course_org_filter_exclude__in(self, course_org_filter, expected_org_list): SiteConfigurationFactory.create( values={'course_org_filter': course_org_filter}, ) - exclude_orgs, org_list = self.resolver.get_course_org_filter() - assert exclude_orgs - self.assertItemsEqual(org_list, expected_org_list) + mock_query = Mock() + result = self.resolver.filter_by_org(mock_query) + mock_query.exclude.assert_called_once_with(enrollment__course__org__in=expected_org_list) + self.assertEqual(result, mock_query.exclude.return_value)