fix: send thread_created signal after transaction commit (#37675)

Prevents notification failures with MySQL backend by ensuring signals
are only sent after database transactions commit. This fixes race
conditions where Celery workers couldn't see newly created threads.

- Added send_signal_after_commit() helper function
- Updated both thread creation paths to use the helper

Co-authored-by: Taimoor  Ahmed <taimoor.ahmed@A006-01711.local>
This commit is contained in:
Taimoor Ahmed
2025-11-26 11:43:45 +05:00
committed by GitHub
parent 8ad4d42e3b
commit 56f7da908a
5 changed files with 142 additions and 49 deletions

View File

@@ -180,7 +180,8 @@ class ThreadActionGroupIdTestCase(
with mock.patch(
"openedx.core.djangoapps.django_comment_common.signals.thread_flagged.send"
) as signal_mock:
response = self.call_view("flag_abuse_for_thread", "update_thread_flag")
with self.captureOnCommitCallbacks(execute=True):
response = self.call_view("flag_abuse_for_thread", "update_thread_flag")
self._assert_json_response_contains_group_info(response)
self.assertEqual(signal_mock.call_count, 1)
response = self.call_view("un_flag_abuse_for_thread", "update_thread_flag")
@@ -471,10 +472,15 @@ class ViewsTestCase(
def assert_discussion_signals(self, signal, user=None):
if user is None:
user = self.student
# Use captureOnCommitCallbacks to execute on_commit callbacks during tests,
# since signals are now deferred until after transaction commit.
# Order matters: assert_signal_sent must be outer context so the signal
# fires (via captureOnCommitCallbacks) before the assertion check.
with self.assert_signal_sent(
views, signal, sender=None, user=user, exclude_args=("post",)
):
yield
with self.captureOnCommitCallbacks(execute=True):
yield
def test_create_thread(self):
with self.assert_discussion_signals("thread_created"):
@@ -1218,7 +1224,8 @@ class CommentActionTestCase(CohortedTestCase, MockForumApiMixin):
with mock.patch(
"openedx.core.djangoapps.django_comment_common.signals.comment_flagged.send"
) as signal_mock:
self.call_view("flag_abuse_for_comment", "update_comment_flag")
with self.captureOnCommitCallbacks(execute=True):
self.call_view("flag_abuse_for_comment", "update_comment_flag")
self.assertEqual(signal_mock.call_count, 1)

View File

@@ -50,6 +50,7 @@ from lms.djangoapps.discussion.django_comment_client.utils import (
prepare_content,
sanitize_body
)
from lms.djangoapps.discussion.rest_api.utils import send_signal_after_commit
from openedx.core.djangoapps.django_comment_common.signals import (
comment_created,
comment_deleted,
@@ -587,7 +588,10 @@ def create_thread(request, course_id, commentable_id):
thread.save()
thread_created.send(sender=None, user=user, post=thread)
# Use send_signal_after_commit() to ensure the signal is sent only after the transaction commits.
send_signal_after_commit(
lambda: thread_created.send(sender=None, user=user, post=thread)
)
# patch for backward compatibility to comments service
if 'pinned' not in thread.attributes:
@@ -598,7 +602,9 @@ def create_thread(request, course_id, commentable_id):
if follow:
cc_user = cc.User.from_django_user(user)
cc_user.follow(thread, course_id)
thread_followed.send(sender=None, user=user, post=thread)
send_signal_after_commit(
lambda: thread_followed.send(sender=None, user=user, post=thread)
)
data = thread.to_dict()
@@ -645,7 +651,9 @@ def update_thread(request, course_id, thread_id):
thread.save()
thread_edited.send(sender=None, user=user, post=thread)
send_signal_after_commit(
lambda: thread_edited.send(sender=None, user=user, post=thread)
)
track_thread_edited_event(request, course, thread, None)
if request.headers.get('x-requested-with') == 'XMLHttpRequest':
@@ -688,7 +696,9 @@ def _create_comment(request, course_key, thread_id=None, parent_id=None):
)
comment.save(params={"course_id": str(course_key)})
comment_created.send(sender=None, user=user, post=comment)
send_signal_after_commit(
lambda: comment_created.send(sender=None, user=user, post=comment)
)
followed = post.get('auto_subscribe', 'false').lower() == 'true'
@@ -729,7 +739,9 @@ def delete_thread(request, course_id, thread_id):
course = get_course_with_access(request.user, 'load', course_key)
thread = cc.Thread.find(thread_id)
thread.delete(course_id=course_id)
thread_deleted.send(sender=None, user=request.user, post=thread)
send_signal_after_commit(
lambda: thread_deleted.send(sender=None, user=request.user, post=thread)
)
track_thread_deleted_event(request, course, thread)
return JsonResponse(prepare_content(thread.to_dict(), course_key))
@@ -751,7 +763,9 @@ def update_comment(request, course_id, comment_id):
comment.body = sanitize_body(request.POST["body"])
comment.save(params={"course_id": course_id})
comment_edited.send(sender=None, user=request.user, post=comment)
send_signal_after_commit(
lambda: comment_edited.send(sender=None, user=request.user, post=comment)
)
track_comment_edited_event(request, course, comment, None)
if request.headers.get('x-requested-with') == 'XMLHttpRequest':
@@ -776,7 +790,9 @@ def endorse_comment(request, course_id, comment_id):
comment.endorsed = endorsed
comment.endorsement_user_id = user.id
comment.save(params={"course_id": course_id})
comment_endorsed.send(sender=None, user=user, post=comment)
send_signal_after_commit(
lambda: comment_endorsed.send(sender=None, user=user, post=comment)
)
track_forum_response_mark_event(request, course, comment, endorsed)
return JsonResponse(prepare_content(comment.to_dict(), course_key))
@@ -828,7 +844,9 @@ def delete_comment(request, course_id, comment_id):
course = get_course_with_access(request.user, 'load', course_key)
comment = cc.Comment.find(comment_id)
comment.delete(course_id=course_id)
comment_deleted.send(sender=None, user=request.user, post=comment)
send_signal_after_commit(
lambda: comment_deleted.send(sender=None, user=request.user, post=comment)
)
track_comment_deleted_event(request, course, comment)
return JsonResponse(prepare_content(comment.to_dict(), course_key))
@@ -847,7 +865,9 @@ def _vote_or_unvote(request, course_id, obj, value='up', undo_vote=False):
# (People could theoretically downvote by handcrafting AJAX requests.)
else:
user.vote(obj, value, course_id)
thread_voted.send(sender=None, user=request.user, post=obj)
send_signal_after_commit(
lambda: thread_voted.send(sender=None, user=request.user, post=obj)
)
track_voted_event(request, course, obj, value, undo_vote)
return JsonResponse(prepare_content(obj.to_dict(), course_key))
@@ -861,7 +881,9 @@ def vote_for_comment(request, course_id, comment_id, value):
"""
comment = cc.Comment.find(comment_id)
result = _vote_or_unvote(request, course_id, comment, value)
comment_voted.send(sender=None, user=request.user, post=comment)
send_signal_after_commit(
lambda: comment_voted.send(sender=None, user=request.user, post=comment)
)
return result
@@ -914,7 +936,9 @@ def flag_abuse_for_thread(request, course_id, thread_id):
thread = cc.Thread.find(thread_id)
thread.flagAbuse(user, thread, course_id)
track_discussion_reported_event(request, course, thread)
thread_flagged.send(sender='flag_abuse_for_thread', user=request.user, post=thread)
send_signal_after_commit(
lambda: thread_flagged.send(sender='flag_abuse_for_thread', user=request.user, post=thread)
)
return JsonResponse(prepare_content(thread.to_dict(), course_key))
@@ -953,7 +977,9 @@ def flag_abuse_for_comment(request, course_id, comment_id):
comment = cc.Comment.find(comment_id)
comment.flagAbuse(user, comment, course_id)
track_discussion_reported_event(request, course, comment)
comment_flagged.send(sender='flag_abuse_for_comment', user=request.user, post=comment)
send_signal_after_commit(
lambda: comment_flagged.send(sender='flag_abuse_for_comment', user=request.user, post=comment)
)
return JsonResponse(prepare_content(comment.to_dict(), course_key))
@@ -1019,7 +1045,9 @@ def follow_thread(request, course_id, thread_id): # lint-amnesty, pylint: disab
course = get_course_by_id(course_key)
thread = cc.Thread.find(thread_id)
user.follow(thread, course_id=course_id)
thread_followed.send(sender=None, user=request.user, post=thread)
send_signal_after_commit(
lambda: thread_followed.send(sender=None, user=request.user, post=thread)
)
track_thread_followed_event(request, course, thread, True)
return JsonResponse({})
@@ -1051,7 +1079,9 @@ def unfollow_thread(request, course_id, thread_id): # lint-amnesty, pylint: dis
user = cc.User.from_django_user(request.user)
thread = cc.Thread.find(thread_id)
user.unfollow(thread, course_id=course_id)
thread_unfollowed.send(sender=None, user=request.user, post=thread)
send_signal_after_commit(
lambda: thread_unfollowed.send(sender=None, user=request.user, post=thread)
)
track_thread_followed_event(request, course, thread, False)
return JsonResponse({})

View File

@@ -128,6 +128,7 @@ from .utils import (
discussion_open_for_user,
get_usernames_for_course,
get_usernames_from_search_string,
send_signal_after_commit,
set_attribute,
is_posting_allowed,
can_user_notify_all_learners, is_captcha_enabled, get_captcha_site_key_by_platform
@@ -1382,7 +1383,9 @@ def _handle_following_field(form_value, user, cc_content, request):
else:
user.unfollow(cc_content)
signal = thread_followed if form_value else thread_unfollowed
signal.send(sender=None, user=user, post=cc_content)
send_signal_after_commit(
lambda: signal.send(sender=None, user=user, post=cc_content)
)
track_thread_followed_event(request, course, cc_content, form_value)
@@ -1395,9 +1398,13 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request):
track_discussion_reported_event(request, course, cc_content)
if ENABLE_DISCUSSIONS_MFE.is_enabled(course_key):
if cc_content.type == 'thread':
thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content)
send_signal_after_commit(
lambda: thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content)
)
else:
comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content)
send_signal_after_commit(
lambda: comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content)
)
else:
remove_all = bool(is_privileged_user(course_key, User.objects.get(id=user.id)))
cc_content.unFlagAbuse(user, cc_content, remove_all)
@@ -1407,7 +1414,9 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request):
def _handle_voted_field(form_value, cc_content, api_content, request, context):
"""vote or undo vote on thread/comment"""
signal = thread_voted if cc_content.type == 'thread' else comment_voted
signal.send(sender=None, user=context["request"].user, post=cc_content)
send_signal_after_commit(
lambda: signal.send(sender=None, user=context["request"].user, post=cc_content)
)
if form_value:
context["cc_requester"].vote(cc_content, "up")
api_content["vote_count"] += 1
@@ -1452,7 +1461,9 @@ def _handle_comment_signals(update_data, comment, user, sender=None):
"""
for key, value in update_data.items():
if key == "endorsed" and value is True:
comment_endorsed.send(sender=sender, user=user, post=comment)
send_signal_after_commit(
lambda: comment_endorsed.send(sender=sender, user=user, post=comment)
)
def create_thread(request, thread_data):
@@ -1502,7 +1513,10 @@ def create_thread(request, thread_data):
raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items())))
serializer.save()
cc_thread = serializer.instance
thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners)
# Use send_signal_after_commit() to ensure the signal is sent only after the transaction commits.
send_signal_after_commit(
lambda: thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners)
)
api_thread = serializer.data
_do_extra_actions(api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request)
@@ -1550,7 +1564,9 @@ def create_comment(request, comment_data):
context["cc_requester"].follow(cc_thread)
serializer.save()
cc_comment = serializer.instance
comment_created.send(sender=None, user=request.user, post=cc_comment)
send_signal_after_commit(
lambda: comment_created.send(sender=None, user=request.user, post=cc_comment)
)
api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, list(comment_data.keys()), actions_form, context, request)
track_comment_created_event(request, course, cc_comment, cc_thread["commentable_id"], followed=False,
@@ -1586,7 +1602,9 @@ def update_thread(request, thread_id, update_data):
if set(update_data) - set(actions_form.fields):
serializer.save()
# signal to update Teams when a user edits a thread
thread_edited.send(sender=None, user=request.user, post=cc_thread)
send_signal_after_commit(
lambda: thread_edited.send(sender=None, user=request.user, post=cc_thread)
)
api_thread = serializer.data
_do_extra_actions(api_thread, cc_thread, list(update_data.keys()), actions_form, context, request)
@@ -1635,7 +1653,9 @@ def update_comment(request, comment_id, update_data):
# Only save comment object if some of the edited fields are in the comment data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save()
comment_edited.send(sender=None, user=request.user, post=cc_comment)
send_signal_after_commit(
lambda: comment_edited.send(sender=None, user=request.user, post=cc_comment)
)
api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, list(update_data.keys()), actions_form, context, request)
_handle_comment_signals(update_data, cc_comment, request.user)
@@ -1823,7 +1843,9 @@ def delete_thread(request, thread_id):
cc_thread, context = _get_thread_and_context(request, thread_id)
if can_delete(cc_thread, context):
cc_thread.delete()
thread_deleted.send(sender=None, user=request.user, post=cc_thread)
send_signal_after_commit(
lambda: thread_deleted.send(sender=None, user=request.user, post=cc_thread)
)
track_thread_deleted_event(request, context["course"], cc_thread)
else:
raise PermissionDenied
@@ -1848,7 +1870,9 @@ def delete_comment(request, comment_id):
cc_comment, context = _get_comment_and_context(request, comment_id)
if can_delete(cc_comment, context):
cc_comment.delete()
comment_deleted.send(sender=None, user=request.user, post=cc_comment)
send_signal_after_commit(
lambda: comment_deleted.send(sender=None, user=request.user, post=cc_comment)
)
track_comment_deleted_event(request, context["course"], cc_comment)
else:
raise PermissionDenied

View File

@@ -273,7 +273,8 @@ class CreateThreadTest(
with self.assert_signal_sent(
api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
):
actual = create_thread(self.request, self.minimal_data)
with self.captureOnCommitCallbacks(execute=True):
actual = create_thread(self.request, self.minimal_data)
expected = self.expected_thread_data(
{
"id": "test_id",
@@ -352,7 +353,8 @@ class CreateThreadTest(
with self.assert_signal_sent(
api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
):
actual = create_thread(self.request, self.minimal_data)
with self.captureOnCommitCallbacks(execute=True):
actual = create_thread(self.request, self.minimal_data)
expected = self.expected_thread_data(
{
"author_label": "Moderator",
@@ -428,7 +430,8 @@ class CreateThreadTest(
with self.assert_signal_sent(
api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
):
create_thread(self.request, data)
with self.captureOnCommitCallbacks(execute=True):
create_thread(self.request, data)
event_name, event_data = mock_emit.call_args[0]
assert event_name == "edx.forum.thread.created"
assert event_data == {
@@ -678,7 +681,8 @@ class CreateCommentTest(
with self.assert_signal_sent(
api, "comment_created", sender=None, user=self.user, exclude_args=("post",)
):
actual = create_comment(self.request, data)
with self.captureOnCommitCallbacks(execute=True):
actual = create_comment(self.request, data)
expected = {
"id": "test_comment",
"thread_id": "test_thread",
@@ -785,7 +789,8 @@ class CreateCommentTest(
with self.assert_signal_sent(
api, "comment_created", sender=None, user=self.user, exclude_args=("post",)
):
actual = create_comment(self.request, data)
with self.captureOnCommitCallbacks(execute=True):
actual = create_comment(self.request, data)
expected = {
"id": "test_comment",
"thread_id": "test_thread",
@@ -1118,9 +1123,10 @@ class UpdateThreadTest(
with self.assert_signal_sent(
api, "thread_edited", sender=None, user=self.user, exclude_args=("post",)
):
actual = update_thread(
self.request, "test_thread", {"raw_body": "Edited body"}
)
with self.captureOnCommitCallbacks(execute=True):
actual = update_thread(
self.request, "test_thread", {"raw_body": "Edited body"}
)
assert actual == self.expected_thread_data(
{
@@ -1436,13 +1442,13 @@ class UpdateThreadTest(
self.register_thread()
data = {"following": new_following}
signal_name = "thread_followed" if new_following else "thread_unfollowed"
mock_path = (
f"openedx.core.djangoapps.django_comment_common.signals.{signal_name}.send"
)
# Patch at the api module level where the signal is imported and used
mock_path = f"lms.djangoapps.discussion.rest_api.api.{signal_name}"
with mock.patch(mock_path) as signal_patch:
result = update_thread(self.request, "test_thread", data)
with self.captureOnCommitCallbacks(execute=True):
result = update_thread(self.request, "test_thread", data)
if old_following != new_following:
self.assertEqual(signal_patch.call_count, 1)
self.assertEqual(signal_patch.send.call_count, 1)
assert result["following"] == new_following
if old_following == new_following:
@@ -1782,9 +1788,10 @@ class UpdateCommentTest(
with self.assert_signal_sent(
api, "comment_edited", sender=None, user=self.user, exclude_args=("post",)
):
actual = update_comment(
self.request, "test_comment", {"raw_body": "Edited body"}
)
with self.captureOnCommitCallbacks(execute=True):
actual = update_comment(
self.request, "test_comment", {"raw_body": "Edited body"}
)
expected = {
"anonymous": False,
"anonymous_to_peers": False,
@@ -2207,7 +2214,7 @@ class UpdateCommentTest(
)
@ddt.unpack
@mock.patch(
"openedx.core.djangoapps.django_comment_common.signals.comment_endorsed.send"
"lms.djangoapps.discussion.rest_api.api.comment_endorsed.send"
)
def test_endorsed_access(
self, role_name, is_thread_author, thread_type, is_comment_author, endorsed_mock
@@ -2226,7 +2233,8 @@ class UpdateCommentTest(
thread_type == "discussion" or not is_thread_author
)
try:
update_comment(self.request, "test_comment", {"endorsed": True})
with self.captureOnCommitCallbacks(execute=True):
update_comment(self.request, "test_comment", {"endorsed": True})
self.assertEqual(endorsed_mock.call_count, 1)
assert not expected_error
except ValidationError as err:
@@ -2354,7 +2362,8 @@ class DeleteThreadTest(
with self.assert_signal_sent(
api, "thread_deleted", sender=None, user=self.user, exclude_args=("post",)
):
assert delete_thread(self.request, self.thread_id) is None
with self.captureOnCommitCallbacks(execute=True):
assert delete_thread(self.request, self.thread_id) is None
self.check_mock_called("delete_thread")
params = {
"thread_id": self.thread_id,
@@ -2540,7 +2549,8 @@ class DeleteCommentTest(
with self.assert_signal_sent(
api, "comment_deleted", sender=None, user=self.user, exclude_args=("post",)
):
assert delete_comment(self.request, self.comment_id) is None
with self.captureOnCommitCallbacks(execute=True):
assert delete_comment(self.request, self.comment_id) is None
self.check_mock_called("delete_comment")
params = {
"comment_id": self.comment_id,

View File

@@ -3,13 +3,14 @@ Utils for discussion API.
"""
import logging
from datetime import datetime
from typing import Dict, List
from typing import Callable, Dict, List
import requests
from crum import get_current_request
from django.conf import settings
from django.contrib.auth.models import User # lint-amnesty, pylint: disable=imported-auth-user
from django.core.paginator import Paginator
from django.db import transaction
from django.db.models.functions import Length
from pytz import UTC
@@ -496,3 +497,24 @@ def get_captcha_site_key_by_platform(platform: str) -> str | None:
Get reCAPTCHA site key based on the platform.
"""
return settings.RECAPTCHA_SITE_KEYS.get(platform, None)
def send_signal_after_commit(signal_func: Callable):
"""
Schedule a signal to be sent after the current database transaction commits.
This helper ensures that signals are only sent after the transaction commits,
preventing race conditions where async tasks (like Celery workers) may try to
access database records before they are visible (especially important for MySQL
backend with transaction isolation).
Args:
signal_func: A callable that sends the signal. This will be executed
after the transaction commits.
Example:
send_signal_after_commit(
lambda: thread_created.send(sender=None, user=user, post=thread, notify_all_learners=False)
)
"""
transaction.on_commit(signal_func)