diff --git a/common/djangoapps/track/backends/tests/test_mongodb.py b/common/djangoapps/track/backends/tests/test_mongodb.py index 2ba18a22a7..c3e032c89c 100644 --- a/common/djangoapps/track/backends/tests/test_mongodb.py +++ b/common/djangoapps/track/backends/tests/test_mongodb.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from uuid import uuid4 -import pymongo +from mock import patch from django.test import TestCase @@ -11,33 +11,30 @@ from track.backends.mongodb import MongoBackend class TestMongoBackend(TestCase): def setUp(self): - # Use a random database name to prevent problems with tests running - # simultenousely against the same mongo instance - database = '_track_backends_mongodb_{0}'.format(uuid4().hex) - collection = '_test' + self.mongo_patcher = patch('track.backends.mongodb.MongoClient') + self.addCleanup(self.mongo_patcher.stop) + self.mongo_patcher.start() - self.connection = pymongo.MongoClient() - self.database = self.connection[database] - self.collection = self.database[collection] - - # During tests, wait until mongo acknowledged the write - write_concern = 1 - - self.backend = MongoBackend( - database=database, - collection=collection, - w=write_concern - ) + self.backend = MongoBackend() def test_mongo_backend(self): - self.backend.send({'test': 1}) - self.backend.send({'test': 2}) + events = [{'test': 1}, {'test': 2}] - # Get all the objects in the db ignoring _id - results = list(self.collection.find({}, {'_id': False})) + self.backend.send(events[0]) + self.backend.send(events[1]) - self.assertEqual(len(results), 2) - self.assertEqual(results, [{'test': 1}, {'test': 2}]) + # Check if we inserted events into the database - def tearDown(self): - self.connection.drop_database(self.database) + calls = self.backend.collection.insert.mock_calls + + self.assertEqual(len(calls), 2) + + # Unpack the arguments and check if the events were used + # as the first argument to collection.insert + + def first_argument(call): + _, args, _ = call + return args[0] + + self.assertEqual(events[0], first_argument(calls[0])) + self.assertEqual(events[1], first_argument(calls[1]))