diff --git a/lms/djangoapps/discussion/rest_api/render.py b/lms/djangoapps/discussion/rest_api/render.py
index d2d52b2074..303b26299c 100644
--- a/lms/djangoapps/discussion/rest_api/render.py
+++ b/lms/djangoapps/discussion/rest_api/render.py
@@ -4,88 +4,18 @@ Content rendering functionality
Note that this module is designed to imitate the front end behavior as
implemented in Markdown.Sanitizer.js.
"""
-
-
-import re
-
+import bleach
import markdown
-# These patterns could be more flexible about things like attributes and
-# whitespace, but this is imitating Markdown.Sanitizer.js, so it uses the
-# patterns defined therein.
-TAG_PATTERN = re.compile(r"<[^>]*>?")
-SANITIZED_TAG_PATTERN = re.compile(r"<(/?)(\w+)[^>]*>")
-ALLOWED_BASIC_TAG_PATTERN = re.compile(
- r"^(?(b|blockquote|code|del|dd|dl|dt|em|h1|h2|h3|i|kbd|li|ol|p|pre|s|sup|sub|strong|strike|ul)>|<(br|hr)\s?/?>)$"
-)
-ALLOWED_A_PATTERN = re.compile(
- r'^(]+")?\s?>|)$'
-)
-ALLOWED_IMG_PATTERN = re.compile(
- r'^(]*")?(\stitle="[^"<>]*")?\s?/?>)$'
-)
-
-
-def _sanitize_tag(match):
- """Return the tag if it is allowed or the empty string otherwise"""
- tag = match.group(0)
- if (
- ALLOWED_BASIC_TAG_PATTERN.match(tag) or
- ALLOWED_A_PATTERN.match(tag) or
- ALLOWED_IMG_PATTERN.match(tag)
- ):
- return tag
- else:
- return ""
-
-
-def _sanitize_html(source):
- """
- Return source with all non-allowed tags removed, preserving the text content
- """
- return TAG_PATTERN.sub(_sanitize_tag, source)
-
-
-def _remove_unpaired_tags(source):
- """
- Return source with all unpaired tags removed, preserving the text content
-
- source should have already been sanitized
- """
- tag_matches = list(SANITIZED_TAG_PATTERN.finditer(source))
- if not tag_matches:
- return source
- tag_stack = []
- tag_name_stack = []
- text_stack = [source[:tag_matches[0].start()]]
- for i, match in enumerate(tag_matches):
- tag_name = match.group(2)
- following_text = (
- source[match.end():tag_matches[i + 1].start()] if i + 1 < len(tag_matches) else
- source[match.end():]
- )
- if tag_name in ["p", "img", "br", "li", "hr"]: # tags that don't require closing
- text_stack[-1] += match.group(0) + following_text
- elif match.group(1): # end tag
- if tag_name in tag_name_stack: # paired with a start tag somewhere
- # pop tags until we find the matching one, keeping the non-tag text
- while True:
- popped_tag_name = tag_name_stack.pop()
- popped_tag = tag_stack.pop()
- popped_text = text_stack.pop()
- if popped_tag_name == tag_name:
- text_stack[-1] += popped_tag + popped_text + match.group(0)
- break
- else:
- text_stack[-1] += popped_text
- # else unpaired; drop the tag
- text_stack[-1] += following_text
- else: # start tag
- tag_stack.append(match.group(0))
- tag_name_stack.append(tag_name)
- text_stack.append(following_text)
- return "".join(text_stack)
+ALLOWED_TAGS = bleach.ALLOWED_TAGS + [
+ 'br', 'dd', 'del', 'dl', 'dt', 'h1', 'h2', 'h3', 'h4', 'hr', 'img', 'kbd', 'p', 'pre', 's',
+ 'strike', 'sub', 'sup'
+]
+ALLOWED_PROTOCOLS = ["http", "https", "ftp", "mailto"]
+ALLOWED_ATTRIBUTES = {
+ "a": ["href", "title"],
+ "img": ["src", "alt", "title", "width", "height"],
+}
def render_body(raw_body):
@@ -95,13 +25,17 @@ def render_body(raw_body):
This includes the following steps:
* Convert Markdown to HTML
- * Strip non-whitelisted HTML
- * Remove unbalanced HTML tags
+ * Sanitise HTML using bleach
Note that this does not prevent Markdown syntax inside a MathJax block from
being processed, which the forums JavaScript code does.
"""
- rendered = markdown.markdown(raw_body)
- rendered = _sanitize_html(rendered)
- rendered = _remove_unpaired_tags(rendered)
- return rendered
+ rendered_html = markdown.markdown(raw_body)
+ sanitised_html = bleach.clean(
+ rendered_html,
+ tags=ALLOWED_TAGS,
+ protocols=ALLOWED_PROTOCOLS,
+ strip=True,
+ attributes=ALLOWED_ATTRIBUTES
+ )
+ return sanitised_html
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_render.py b/lms/djangoapps/discussion/rest_api/tests/test_render.py
index e5fe26a3dd..ef4fad3242 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_render.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_render.py
@@ -3,9 +3,8 @@ Tests for content rendering
"""
-from unittest import TestCase
-
import ddt
+from django.test import TestCase
from lms.djangoapps.discussion.rest_api.render import render_body
@@ -29,15 +28,14 @@ class RenderBodyTest(TestCase):
)
@ddt.unpack
def test_markdown_inline(self, delimiter, tag):
- assert render_body('{delimiter}some text{delimiter}'.format(delimiter=delimiter)) == \
- '
<{tag}>some text{tag}>
'.format(tag=tag) + assert render_body(f'{delimiter}some text{delimiter}') == f'<{tag}>some text{tag}>
' @ddt.data( "b", "blockquote", "code", "del", "dd", "dl", "dt", "em", "h1", "h2", "h3", "i", "kbd", "li", "ol", "p", "pre", "s", "sup", "sub", "strong", "strike", "ul" ) def test_openclose_tag(self, tag): - raw_body = "<{tag}>some text{tag}>".format(tag=tag) + raw_body = f"<{tag}>some text{tag}>" is_inline_tag = tag in ["b", "code", "del", "em", "i", "kbd", "s", "sup", "sub", "strong", "strike"] rendered_body = _add_p_tags(raw_body) if is_inline_tag else raw_body assert render_body(raw_body) == rendered_body @@ -49,40 +47,59 @@ class RenderBodyTest(TestCase): rendered_body = _add_p_tags(raw_body) if is_inline_tag else raw_body assert render_body(raw_body) == rendered_body - @ddt.data("http", "https", "ftp") - def test_allowed_a_tag(self, protocol): + @ddt.data( + ("http", True), + ("https", True), + ("ftp", True), + ("gopher", False), + ("file", False), + ("data", False), + ) + @ddt.unpack + def test_protocols_a_tag(self, protocol, is_allowed): raw_body = f'baz' - assert render_body(raw_body) == _add_p_tags(raw_body) + cleaned_body = 'baz' + rendered = render_body(raw_body) + if is_allowed: + assert rendered == _add_p_tags(raw_body) + else: + assert rendered == _add_p_tags(cleaned_body) - def test_disallowed_a_tag(self): - raw_body = 'link content' - assert render_body(raw_body) == 'link content
' - - @ddt.data("http", "https") - def test_allowed_img_tag(self, protocol): - raw_body = 'foo
bar
foo
foo
foobar
foobar
'), + ) + @ddt.unpack + def test_unpaired_tags(self, tag, rendered_output): raw_body = f"foo<{tag}>bar" - assert render_body(raw_body) == _add_p_tags(raw_body) - - def test_unpaired_start_tag(self): - assert render_body('foobar') == 'foobar
' - - def test_unpaired_end_tag(self): - assert render_body('foobar') == 'foobar
' + assert render_body(raw_body) == rendered_output def test_interleaved_tags(self): - assert render_body('foobarbazquuxgreg') == 'foobarbazquuxgreg
' + self.assertHTMLEqual( + render_body('foobarbazquuxgreg'), + 'foobarbazquuxgreg
', + )