From e7b9256a233cdab59d42f40657b3531b4959459d Mon Sep 17 00:00:00 2001 From: Matthew Moon Date: Wed, 6 Jun 2018 20:03:29 -0500 Subject: [PATCH 1/5] Added DBRef cls checks for abstract document dereferencing --- mongoengine/dereference.py | 5 ++++ mongoengine/fields.py | 55 +++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 18b365ccd..742be8e6b 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -104,6 +104,11 @@ def _find_references(self, items, depth=0): # LazyReference inherits DBRef but should not be dereferenced here ! continue elif isinstance(v, DBRef): + # Honor the DBRef cls field if present for abstract documents + if hasattr(v, 'cls'): + reference_map.setdefault(get_document(v.cls), set()).add(v.id) + else: + reference_map.setdefault(field.document_type, set()).add(v.id) reference_map.setdefault(field.document_type, set()).add(v.id) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a661874a5..11d19bea8 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -221,9 +221,9 @@ def validate_domain_part(self, domain_part): # Validate IPv4/IPv6, e.g. user@[192.168.0.1] if ( - self.allow_ip_domain and - domain_part[0] == '[' and - domain_part[-1] == ']' + self.allow_ip_domain and + domain_part[0] == '[' and + domain_part[-1] == ']' ): for addr_family in (socket.AF_INET, socket.AF_INET6): try: @@ -616,8 +616,8 @@ class EmbeddedDocumentField(BaseField): def __init__(self, document_type, **kwargs): # XXX ValidationError raised outside of the "validate" method. if not ( - isinstance(document_type, six.string_types) or - issubclass(document_type, EmbeddedDocument) + isinstance(document_type, six.string_types) or + issubclass(document_type, EmbeddedDocument) ): self.error('Invalid embedded document class provided to an ' 'EmbeddedDocumentField') @@ -819,10 +819,10 @@ def prepare_query_value(self, op, value): # If the value is iterable and it's not a string nor a # BaseDocument, call prepare_query_value for each of its items. if ( - op in ('set', 'unset', None) and - hasattr(value, '__iter__') and - not isinstance(value, six.string_types) and - not isinstance(value, BaseDocument) + op in ('set', 'unset', None) and + hasattr(value, '__iter__') and + not isinstance(value, six.string_types) and + not isinstance(value, BaseDocument) ): return [self.field.prepare_query_value(op, v) for v in value] @@ -1035,8 +1035,8 @@ def __init__(self, document_type, dbref=False, """ # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to ReferenceField constructor must be a ' 'document class or a string') @@ -1137,8 +1137,8 @@ def validate(self, value): 'saved to the database') if ( - self.document_type._meta.get('abstract') and - not isinstance(value, self.document_type) + self.document_type._meta.get('abstract') and + not isinstance(value, self.document_type) ): self.error( '%s is not an instance of abstract reference type %s' % ( @@ -1168,8 +1168,8 @@ def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to CachedReferenceField constructor must be a' ' document class or a string') @@ -1640,8 +1640,8 @@ def __get__(self, instance, owner): def __set__(self, instance, value): key = self.name if ( - (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or - isinstance(value, (six.binary_type, six.string_types)) + (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or + isinstance(value, (six.binary_type, six.string_types)) ): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) @@ -1720,13 +1720,13 @@ def put(self, file_obj, **kwargs): if (kwargs.get('progressive') and isinstance(kwargs.get('progressive'), bool) and - img_format == 'JPEG'): + img_format == 'JPEG'): progressive = True else: progressive = False if (field.size and (img.size[0] > field.size['width'] or - img.size[1] > field.size['height'])): + img.size[1] > field.size['height'])): size = field.size if size['force']: @@ -2069,7 +2069,7 @@ def validate(self, value): self.error('Value (%s) must be a two-dimensional point' % repr(value)) elif (not isinstance(value[0], (float, int)) or - not isinstance(value[1], (float, int))): + not isinstance(value[1], (float, int))): self.error( 'Both values (%s) in point must be float or int' % repr(value)) @@ -2228,8 +2228,8 @@ def __init__(self, document_type, passthrough=False, dbref=False, """ # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to LazyReferenceField constructor must be a ' 'document class or a string') @@ -2257,7 +2257,11 @@ def build_lazyref(self, value): if isinstance(value, self.document_type): value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough) elif isinstance(value, DBRef): - value = LazyReference(self.document_type, value.id, passthrough=self.passthrough) + # Honor cls field on DBRef for Abstract Documents + if hasattr(value, 'cls'): + value = LazyReference(get_document(value.cls), value.id, passthrough=self.passthrough) + else: + LazyReference(self.document_type, value.id, passthrough=self.passthrough) else: # value is the primary key of the referenced document value = LazyReference(self.document_type, value, passthrough=self.passthrough) @@ -2302,7 +2306,10 @@ def validate(self, value): pk = value.pk elif isinstance(value, DBRef): # TODO: check collection ? - collection = self.document_type._get_collection_name() + if hasattr(value, 'cls'): + collection = get_document(value.cls)._get_collection_name() + else: + collection = self.document_type._get_collection_name() if value.collection != collection: self.error("DBRef on bad collection (must be on `%s`)" % collection) pk = value.id From 24fabc2abaa2ba777f5058df8a620fa13357a12c Mon Sep 17 00:00:00 2001 From: Matthew Moon Date: Wed, 6 Jun 2018 20:12:57 -0500 Subject: [PATCH 2/5] fixed code format issues --- mongoengine/dereference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 742be8e6b..6a500254b 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -104,12 +104,10 @@ def _find_references(self, items, depth=0): # LazyReference inherits DBRef but should not be dereferenced here ! continue elif isinstance(v, DBRef): - # Honor the DBRef cls field if present for abstract documents if hasattr(v, 'cls'): reference_map.setdefault(get_document(v.cls), set()).add(v.id) else: reference_map.setdefault(field.document_type, set()).add(v.id) - reference_map.setdefault(field.document_type, set()).add(v.id) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: From 4adc8f243140f3c5de7b5bb84690031289b76469 Mon Sep 17 00:00:00 2001 From: Matthew Moon Date: Wed, 6 Jun 2018 20:18:00 -0500 Subject: [PATCH 3/5] fixed code format issues --- mongoengine/fields.py | 44 +++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 11d19bea8..6e652ec16 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -221,9 +221,9 @@ def validate_domain_part(self, domain_part): # Validate IPv4/IPv6, e.g. user@[192.168.0.1] if ( - self.allow_ip_domain and - domain_part[0] == '[' and - domain_part[-1] == ']' + self.allow_ip_domain and + domain_part[0] == '[' and + domain_part[-1] == ']' ): for addr_family in (socket.AF_INET, socket.AF_INET6): try: @@ -616,8 +616,8 @@ class EmbeddedDocumentField(BaseField): def __init__(self, document_type, **kwargs): # XXX ValidationError raised outside of the "validate" method. if not ( - isinstance(document_type, six.string_types) or - issubclass(document_type, EmbeddedDocument) + isinstance(document_type, six.string_types) or + issubclass(document_type, EmbeddedDocument) ): self.error('Invalid embedded document class provided to an ' 'EmbeddedDocumentField') @@ -819,10 +819,10 @@ def prepare_query_value(self, op, value): # If the value is iterable and it's not a string nor a # BaseDocument, call prepare_query_value for each of its items. if ( - op in ('set', 'unset', None) and - hasattr(value, '__iter__') and - not isinstance(value, six.string_types) and - not isinstance(value, BaseDocument) + op in ('set', 'unset', None) and + hasattr(value, '__iter__') and + not isinstance(value, six.string_types) and + not isinstance(value, BaseDocument) ): return [self.field.prepare_query_value(op, v) for v in value] @@ -1035,8 +1035,8 @@ def __init__(self, document_type, dbref=False, """ # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to ReferenceField constructor must be a ' 'document class or a string') @@ -1137,8 +1137,8 @@ def validate(self, value): 'saved to the database') if ( - self.document_type._meta.get('abstract') and - not isinstance(value, self.document_type) + self.document_type._meta.get('abstract') and + not isinstance(value, self.document_type) ): self.error( '%s is not an instance of abstract reference type %s' % ( @@ -1168,8 +1168,8 @@ def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to CachedReferenceField constructor must be a' ' document class or a string') @@ -1640,8 +1640,8 @@ def __get__(self, instance, owner): def __set__(self, instance, value): key = self.name if ( - (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or - isinstance(value, (six.binary_type, six.string_types)) + (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or + isinstance(value, (six.binary_type, six.string_types)) ): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) @@ -1720,13 +1720,13 @@ def put(self, file_obj, **kwargs): if (kwargs.get('progressive') and isinstance(kwargs.get('progressive'), bool) and - img_format == 'JPEG'): + img_format == 'JPEG'): progressive = True else: progressive = False if (field.size and (img.size[0] > field.size['width'] or - img.size[1] > field.size['height'])): + img.size[1] > field.size['height'])): size = field.size if size['force']: @@ -2069,7 +2069,7 @@ def validate(self, value): self.error('Value (%s) must be a two-dimensional point' % repr(value)) elif (not isinstance(value[0], (float, int)) or - not isinstance(value[1], (float, int))): + not isinstance(value[1], (float, int))): self.error( 'Both values (%s) in point must be float or int' % repr(value)) @@ -2228,8 +2228,8 @@ def __init__(self, document_type, passthrough=False, dbref=False, """ # XXX ValidationError raised outside of the "validate" method. if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) ): self.error('Argument to LazyReferenceField constructor must be a ' 'document class or a string') From d5bd14705fc2dd1497fdee7db915484caa1e6d5c Mon Sep 17 00:00:00 2001 From: Matthew Moon Date: Fri, 31 Aug 2018 21:28:09 -0500 Subject: [PATCH 4/5] updates to dereference to fix abstract document dereference --- mongoengine/dereference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 6a500254b..f665a4103 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -121,7 +121,10 @@ def _find_references(self, items, depth=0): # LazyReference inherits DBRef but should not be dereferenced here ! continue elif isinstance(item, DBRef): - reference_map.setdefault(item.collection, set()).add(item.id) + if hasattr(item, 'cls'): + reference_map.setdefault(get_document(item.cls), set()).add(item.id) + else: + reference_map.setdefault(item.collection, set()).add(item.id) elif isinstance(item, (dict, SON)) and '_ref' in item: reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: From de02cfe606448af5056e23ff1adec8144650355a Mon Sep 17 00:00:00 2001 From: Matthew Moon Date: Fri, 31 Aug 2018 22:23:36 -0500 Subject: [PATCH 5/5] fixes for abstract document dereferencing --- mongoengine/dereference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index f665a4103..c4b792559 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -115,7 +115,8 @@ def _find_references(self, items, depth=0): references = self._find_references(v, depth) for key, refs in references.iteritems(): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): - key = field_cls + if not field_cls._meta.get('abstract', False): + key = field_cls reference_map.setdefault(key, set()).update(refs) elif isinstance(item, LazyReference): # LazyReference inherits DBRef but should not be dereferenced here !