diff --git a/tests/document/instance.py b/tests/document/instance.py index 555cf6ace..1278cf28f 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -27,7 +27,7 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), '../fields/mongoengine.png') -__all__ = ("InstanceTest",) +__all__ = ("InstanceTest", "DBFieldMappingTest") class InstanceTest(unittest.TestCase): @@ -3214,5 +3214,187 @@ class BlogPost(Document): self.assertEqual(blog.tags, [["value1", 123]]) +class DBFieldMappingTest(unittest.TestCase): + + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Fields(object): + w1 = BooleanField(db_field='w2') + + x1 = BooleanField(db_field='x2') + x2 = BooleanField(db_field='x3') + + y1 = BooleanField(db_field='y0') + y2 = BooleanField(db_field='y1') + + z1 = BooleanField(db_field='z2') + z2 = BooleanField(db_field='z1') + + class Doc(Fields, Document): + pass + + class DynDoc(Fields, DynamicDocument): + pass + + self.Doc = Doc + self.DynDoc = DynDoc + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_setting_fields_in_constructor_of_strict_doc_uses_model_names(self): + doc = self.Doc(z1=True, z2=False) + self.assertEqual((doc.z1, doc.z2), (True, False)) + + def test_setting_fields_in_constructor_of_dyn_doc_uses_model_names(self): + doc = self.DynDoc(z1=True, z2=False) + self.assertEqual((doc.z1, doc.z2), (True, False)) + + def test_setting_unknown_field_in_constructor_of_strict_doc_raises_exception(self): + with self.assertRaises(FieldDoesNotExist): + doc = self.Doc(w2=True) + + def test_setting_unknown_field_in_constructor_of_dyn_doc_does_not_overwrite_model_fields(self): + doc = self.DynDoc(w2=True) + self.assertEqual((doc.w1, doc.w2), (None, True)) + + def test_unknown_fields_of_strict_doc_do_not_overwrite_dbfields_1(self): + doc = self.Doc() + doc.w2 = True + doc.x3 = True + doc.y0 = True + doc.save() + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (None,) * 5) + + def test_unknown_fields_of_strict_doc_do_not_overwrite_dbfields_2(self): + doc = self.Doc() + doc.w1 = True + doc.w2 = False + doc.x1 = True + doc.x2 = True + doc.x3 = False + doc.y0 = False + doc.y1 = True + doc.y2 = True + doc.save() + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (True,) * 5) + + def test_unknown_fields_of_dyn_doc_do_not_overwrite_dbfields_1(self): + doc = self.DynDoc() + doc.w2 = True + doc.x3 = True + doc.y0 = True + doc.save() + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (None,) * 5) + + def test_unknown_fields_of_dyn_doc_do_not_overwrite_dbfields_2(self): + doc = self.DynDoc() + doc.w1 = True + doc.w2 = False + doc.x1 = True + doc.x2 = True + doc.x3 = False + doc.y0 = False + doc.y1 = True + doc.y2 = True + doc.save() + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (True,) * 5) + + def test_dbfields_do_not_overwrite_modelfields_of_strict_doc(self): + doc = self.Doc() + doc.save() + doc._get_collection().update({'_id': doc.id}, {'$set': dict(w1=True, x1=True, y2=True)}) + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (None,) * 5) + + def test_dbfields_do_not_overwrite_modelfields_of_dyn_doc(self): + doc = self.DynDoc() + doc.save() + doc._get_collection().update({'_id': doc.id}, {'$set': dict(w1=True, x1=True, y2=True)}) + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual((reloaded.w1, reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2), (None,) * 5) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_strict_doc_1(self): + doc = self.Doc() + doc.x1 = True + doc.y1 = True + doc.z1 = True + doc.save() + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_strict_doc_2(self): + doc = self.Doc() + doc.x2 = True + doc.y2 = True + doc.z2 = True + doc.save() + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_strict_doc_3(self): + doc = self.Doc() + doc.x1 = True + doc.x2 = False + doc.y1 = True + doc.x2 = False + doc.z1 = True + doc.z2 = False + doc.save() + reloaded = self.Doc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_dyn_doc_1(self): + doc = self.DynDoc() + doc.x1 = True + doc.y1 = True + doc.z1 = True + doc.save() + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_dyn_doc_2(self): + doc = self.DynDoc() + doc.x2 = True + doc.y2 = True + doc.z2 = True + doc.save() + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + + def test_dbfields_are_loaded_to_the_right_modelfield_for_dyn_doc_3(self): + doc = self.DynDoc() + doc.x1 = True + doc.x2 = False + doc.y1 = True + doc.x2 = False + doc.z1 = True + doc.z2 = False + doc.save() + reloaded = self.DynDoc.objects.get(id=doc.id) + self.assertEqual( + (reloaded.x1, reloaded.x2, reloaded.y1, reloaded.y2, reloaded.z1, reloaded.z2), + (doc.x1, doc.x2, doc.y1, doc.y2, doc.z1, doc.z2)) + if __name__ == '__main__': unittest.main()