diff --git a/openedx/core/djangoapps/coursegraph/management/commands/dump_to_neo4j.py b/openedx/core/djangoapps/coursegraph/management/commands/dump_to_neo4j.py index 65cabcb459..32f96ccad4 100644 --- a/openedx/core/djangoapps/coursegraph/management/commands/dump_to_neo4j.py +++ b/openedx/core/djangoapps/coursegraph/management/commands/dump_to_neo4j.py @@ -113,7 +113,6 @@ class ModuleStoreSerializer(object): items = modulestore().get_items(course_id) # create nodes - nodes = [] for item in items: fields, block_type = self.serialize_item(item) @@ -121,19 +120,20 @@ class ModuleStoreSerializer(object): fields[field_name] = self.coerce_types(value) node = Node(block_type, 'item', **fields) - nodes.append(node) location_to_node[item.location] = node # create relationships relationships = [] for item in items: - for child_loc in item.get_children(): + for index, child_loc in enumerate(item.get_children()): parent_node = location_to_node.get(item.location) child_node = location_to_node.get(child_loc.location) + child_node["index"] = index if parent_node is not None and child_node is not None: relationship = Relationship(parent_node, "PARENT_OF", child_node) relationships.append(relationship) + nodes = location_to_node.values() return nodes, relationships @staticmethod diff --git a/openedx/core/djangoapps/coursegraph/management/commands/tests/test_dump_to_neo4j.py b/openedx/core/djangoapps/coursegraph/management/commands/tests/test_dump_to_neo4j.py index 0eb9f5d620..3b69cbac09 100644 --- a/openedx/core/djangoapps/coursegraph/management/commands/tests/test_dump_to_neo4j.py +++ b/openedx/core/djangoapps/coursegraph/management/commands/tests/test_dump_to_neo4j.py @@ -230,6 +230,23 @@ class TestModuleStoreSerializer(TestDumpToNeo4jCommandBase): self.assertEqual(len(nodes), 9) self.assertEqual(len(relationships), 7) + def test_nodes_have_indices(self): + """ + Test that we add index values on nodes + """ + nodes, relationships = self.mss.serialize_course(self.course.id) + + # the html node should have 0 index, and the problem should have 1 + html_nodes = [node for node in nodes if node['block_type'] == 'html'] + self.assertEqual(len(html_nodes), 1) + problem_nodes = [node for node in nodes if node['block_type'] == 'problem'] + self.assertEqual(len(problem_nodes), 1) + html_node = html_nodes[0] + problem_node = problem_nodes[0] + + self.assertEqual(html_node['index'], 0) + self.assertEqual(problem_node['index'], 1) + @ddt.data( (1, 1), (object, ""),