diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f8ab73d0d..7b2545829 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -291,7 +291,7 @@ def get_text_score(self): return self._data['_text_score'] - def to_mongo(self, use_db_field=True, fields=None): + def to_mongo(self, use_db_field=True, fields=None, populate=[]): """ Return as SON data ready for use with MongoDB. """ @@ -302,10 +302,19 @@ def to_mongo(self, use_db_field=True, fields=None): data['_id'] = None data['_cls'] = self._class_name + Document = _import_class("Document") + ReferenceField = _import_class("ReferenceField") + ListField = _import_class("ListField") + # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] root_fields = set([f.split('.')[0] for f in fields]) + if populate: + populate_list = [_field.split('.') for _field in populate] + populate_domain = dict((_field[0], _field[1:]) for _field in populate_list) + for field_name in self: + if root_fields and field_name not in root_fields: continue @@ -331,6 +340,27 @@ def to_mongo(self, use_db_field=True, fields=None): value = field.to_mongo(value, **ex_vars) + if populate and field_name in populate_domain.keys(): + if isinstance(field, ListField): + _obj = [] + for ref in value: + if isinstance(ref, ObjectId): + _ref_model = field.field.document_type + _obj.append(_ref_model.objects.get(id=ref).to_mongo(populate=['.'.join(populate_domain[field_name])])) + + elif isinstance(ref, Document): + _obj.append(ref.to_mongo(populate=['.'.join(populate_domain[field_name])])) + + else: + _obj.append(ref) + + value = _obj + + if isinstance(field, ReferenceField): + _ref_model = field.document_type + _obj = _ref_model.objects.get(id=value).to_mongo(populate=['.'.join(populate_domain[field_name])]) + value = _obj + # Handle self generating fields if value is None and field._auto_gen: value = field.generate() @@ -399,9 +429,11 @@ def to_json(self, *args, **kwargs): :param use_db_field: Serialize field names as they appear in MongoDB (as opposed to attribute names on this document). Defaults to True. + :param populate: """ + populate = kwargs.pop('populate', []) use_db_field = kwargs.pop('use_db_field', True) - return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) + return json_util.dumps(self.to_mongo(use_db_field, populate=populate), *args, **kwargs) @classmethod def from_json(cls, json_data, created=False): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 110f1e14d..bdfa80cf2 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -60,6 +60,29 @@ def __eq__(self, other): self.assertEqual(doc, Doc.from_json(doc.to_json())) + def test_populate(self): + class VerySimple(Document): + num = StringField(default='2') + + class Simple(Document): + val = StringField(default='1') + inner = ReferenceField(VerySimple) + + class Doc(Document): + simple_ref = ReferenceField(Simple) + list_ref = ListField(ReferenceField(Simple)) + + docum = Doc(simple_ref=Simple(inner=VerySimple().save()).save()) + docum.list_ref=[Simple(inner=VerySimple().save()).save(), Simple().save()] + docum.save() + _json = docum.to_json( populate=['simple_ref','list_ref.inner']) + + import json + _json = json.loads(_json) + print(_json) + self.assertEqual(_json['simple_ref']['val'],'1') + self.assertEqual(_json['list_ref'][0]['inner']['num'],'2') + def test_json_complex(self): if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: