Commit c5a77721 authored by Loic Bistuer's avatar Loic Bistuer
Browse files

Merged ManyRelatedObjectsDescriptor and ReverseManyRelatedObjectsDescriptor

and made all "many" related objects descriptors inherit from
ForeignRelatedObjectsDescriptor.
parent d652906a
Loading
Loading
Loading
Loading
+36 −63
Original line number Diff line number Diff line
@@ -8,9 +8,12 @@ from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
from django.db.models import DO_NOTHING, signals
from django.db.models.base import ModelBase
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.fields.related import (
    ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor,
)
from django.db.models.query_utils import PathInfo
from django.utils.encoding import python_2_unicode_compatible, smart_text
from django.utils.functional import cached_property


@python_2_unicode_compatible
@@ -358,7 +361,7 @@ class GenericRelation(ForeignObject):
        # Save a reference to which model this class is on for future use
        self.model = cls
        # Add the descriptor for the relation
        setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self, self.for_concrete_model))
        setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.rel))

    def set_attributes_from_rel(self):
        pass
@@ -393,7 +396,7 @@ class GenericRelation(ForeignObject):
        })


class ReverseGenericRelatedObjectsDescriptor(object):
class ReverseGenericRelatedObjectsDescriptor(ForeignRelatedObjectsDescriptor):
    """
    This class provides the functionality that makes the related-object
    managers available as attributes on a model class, for fields that have
@@ -402,87 +405,57 @@ class ReverseGenericRelatedObjectsDescriptor(object):
    "article.publications", the publications attribute is a
    ReverseGenericRelatedObjectsDescriptor instance.
    """
    def __init__(self, field, for_concrete_model=True):
        self.field = field
        self.for_concrete_model = for_concrete_model

    def __get__(self, instance, instance_type=None):
        if instance is None:
            return self

        # Dynamically create a class that subclasses the related model's
        # default manager.
        rel_model = self.field.rel.to
        superclass = rel_model._default_manager.__class__
        RelatedManager = create_generic_related_manager(superclass)

        qn = connection.ops.quote_name
        content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
            instance, for_concrete_model=self.for_concrete_model)

        join_cols = self.field.get_joining_columns(reverse_join=True)[0]
        manager = RelatedManager(
            model=rel_model,
            instance=instance,
            source_col_name=qn(join_cols[0]),
            target_col_name=qn(join_cols[1]),
            content_type=content_type,
            content_type_field_name=self.field.content_type_field_name,
            object_id_field_name=self.field.object_id_field_name,
            prefetch_cache_name=self.field.attname,
        )

        return manager

    def __set__(self, instance, value):
        manager = self.__get__(instance)
        manager.set(value)
    @cached_property
    def related_manager_cls(self):
        return create_generic_related_manager(
            self.rel.to._default_manager.__class__,
            self.rel,
        )


def create_generic_related_manager(superclass):
def create_generic_related_manager(superclass, rel):
    """
    Factory function for a manager that subclasses 'superclass' (which is a
    Manager) and adds behavior for generic related objects.
    """

    class GenericRelatedObjectManager(superclass):
        def __init__(self, model=None, instance=None, symmetrical=None,
                     source_col_name=None, target_col_name=None, content_type=None,
                     content_type_field_name=None, object_id_field_name=None,
                     prefetch_cache_name=None):

        def __init__(self, instance=None):
            super(GenericRelatedObjectManager, self).__init__()
            self.model = model
            self.content_type = content_type
            self.symmetrical = symmetrical

            self.instance = instance
            self.source_col_name = source_col_name
            self.target_col_name = target_col_name
            self.content_type_field_name = content_type_field_name
            self.object_id_field_name = object_id_field_name
            self.prefetch_cache_name = prefetch_cache_name
            self.pk_val = self.instance._get_pk_val()

            self.model = rel.to

            content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
                instance, for_concrete_model=rel.field.for_concrete_model)
            self.content_type = content_type

            qn = connection.ops.quote_name
            join_cols = rel.field.get_joining_columns(reverse_join=True)[0]
            self.source_col_name = qn(join_cols[0])
            self.target_col_name = qn(join_cols[1])

            self.content_type_field_name = rel.field.content_type_field_name
            self.object_id_field_name = rel.field.object_id_field_name
            self.prefetch_cache_name = rel.field.attname
            self.pk_val = instance._get_pk_val()

            self.core_filters = {
                '%s__pk' % content_type_field_name: content_type.id,
                '%s' % object_id_field_name: instance._get_pk_val(),
                '%s__pk' % self.content_type_field_name: content_type.id,
                self.object_id_field_name: self.pk_val,
            }

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_generic_related_manager(manager.__class__)
            return manager_class(
                model=self.model,
                instance=self.instance,
                symmetrical=self.symmetrical,
                source_col_name=self.source_col_name,
                target_col_name=self.target_col_name,
                content_type=self.content_type,
                content_type_field_name=self.content_type_field_name,
                object_id_field_name=self.object_id_field_name,
                prefetch_cache_name=self.prefetch_cache_name,
            )
            manager_class = create_generic_related_manager(manager.__class__, rel)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def __str__(self):
+80 −153
Original line number Diff line number Diff line
@@ -678,25 +678,28 @@ class ReverseSingleRelatedObjectDescriptor(object):
            setattr(value, self.field.rel.get_cache_name(), instance)


def create_foreign_related_manager(superclass, rel_field, rel_model):
def create_foreign_related_manager(superclass, rel):
    class RelatedManager(superclass):
        def __init__(self, instance):
            super(RelatedManager, self).__init__()

            self.instance = instance
            self.core_filters = {rel_field.name: instance}
            self.model = rel_model
            self.model = rel.related_model
            self.field = rel.field

            self.core_filters = {self.field.name: instance}

        def __call__(self, **kwargs):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_foreign_related_manager(manager.__class__, rel_field, rel_model)
            manager_class = create_foreign_related_manager(manager.__class__, rel)
            return manager_class(self.instance)
        do_not_call_in_templates = True

        def get_queryset(self):
            try:
                return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
                return self.instance._prefetched_objects_cache[self.field.related_query_name()]
            except (AttributeError, KeyError):
                db = self._db or router.db_for_read(self.model, instance=self.instance)
                empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
@@ -705,11 +708,11 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
                if self._db:
                    qs = qs.using(self._db)
                qs = qs.filter(**self.core_filters)
                for field in rel_field.foreign_related_fields:
                for field in self.field.foreign_related_fields:
                    val = getattr(self.instance, field.attname)
                    if val is None or (val == '' and empty_strings_as_null):
                        return qs.none()
                qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
                qs._known_related_objects = {self.field: {self.instance.pk: self.instance}}
                return qs

        def get_prefetch_queryset(self, instances, queryset=None):
@@ -719,18 +722,18 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
            queryset._add_hints(instance=instances[0])
            queryset = queryset.using(queryset._db or self._db)

            rel_obj_attr = rel_field.get_local_related_value
            instance_attr = rel_field.get_foreign_related_value
            rel_obj_attr = self.field.get_local_related_value
            instance_attr = self.field.get_foreign_related_value
            instances_dict = {instance_attr(inst): inst for inst in instances}
            query = {'%s__in' % rel_field.name: instances}
            query = {'%s__in' % self.field.name: instances}
            queryset = queryset.filter(**query)

            # Since we just bypassed this class' get_queryset(), we must manage
            # the reverse relation manually.
            for rel_obj in queryset:
                instance = instances_dict[rel_obj_attr(rel_obj)]
                setattr(rel_obj, rel_field.name, instance)
            cache_name = rel_field.related_query_name()
                setattr(rel_obj, self.field.name, instance)
            cache_name = self.field.related_query_name()
            return queryset, rel_obj_attr, instance_attr, False, cache_name

        def add(self, *objs):
@@ -741,42 +744,42 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
                    if not isinstance(obj, self.model):
                        raise TypeError("'%s' instance expected, got %r" %
                                        (self.model._meta.object_name, obj))
                    setattr(obj, rel_field.name, self.instance)
                    setattr(obj, self.field.name, self.instance)
                    obj.save()
        add.alters_data = True

        def create(self, **kwargs):
            kwargs[rel_field.name] = self.instance
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).create(**kwargs)
        create.alters_data = True

        def get_or_create(self, **kwargs):
            kwargs[rel_field.name] = self.instance
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
        get_or_create.alters_data = True

        def update_or_create(self, **kwargs):
            kwargs[rel_field.name] = self.instance
            kwargs[self.field.name] = self.instance
            db = router.db_for_write(self.model, instance=self.instance)
            return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
        update_or_create.alters_data = True

        # remove() and clear() are only provided if the ForeignKey can have a value of null.
        if rel_field.null:
        if rel.field.null:
            def remove(self, *objs, **kwargs):
                if not objs:
                    return
                bulk = kwargs.pop('bulk', True)
                val = rel_field.get_foreign_related_value(self.instance)
                val = self.field.get_foreign_related_value(self.instance)
                old_ids = set()
                for obj in objs:
                    # Is obj actually part of this descriptor set?
                    if rel_field.get_local_related_value(obj) == val:
                    if self.field.get_local_related_value(obj) == val:
                        old_ids.add(obj.pk)
                    else:
                        raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
                        raise self.field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
                self._clear(self.filter(pk__in=old_ids), bulk)
            remove.alters_data = True

@@ -790,12 +793,12 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
                queryset = queryset.using(db)
                if bulk:
                    # `QuerySet.update()` is intrinsically atomic.
                    queryset.update(**{rel_field.name: None})
                    queryset.update(**{self.field.name: None})
                else:
                    with transaction.atomic(using=db, savepoint=False):
                        for obj in queryset:
                            setattr(obj, rel_field.name, None)
                            obj.save(update_fields=[rel_field.name])
                            setattr(obj, self.field.name, None)
                            obj.save(update_fields=[self.field.name])
            _clear.alters_data = True

        def set(self, objs, **kwargs):
@@ -805,7 +808,7 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):

            clear = kwargs.pop('clear', False)

            if rel_field.null:
            if self.field.null:
                db = router.db_for_write(self.model, instance=self.instance)
                with transaction.atomic(using=db, savepoint=False):
                    if clear:
@@ -835,8 +838,16 @@ class ForeignRelatedObjectsDescriptor(object):
    # multiple "remote" values and have a ForeignKey pointed at them by
    # some other model. In the example "poll.choice_set", the choice_set
    # attribute is a ForeignRelatedObjectsDescriptor instance.
    def __init__(self, related):
        self.related = related   # RelatedObject instance
    def __init__(self, rel):
        self.rel = rel
        self.field = rel.field

    @cached_property
    def related_manager_cls(self):
        return create_foreign_related_manager(
            self.rel.related_model._default_manager.__class__,
            self.rel,
        )

    def __get__(self, instance, instance_type=None):
        if instance is None:
@@ -848,49 +859,47 @@ class ForeignRelatedObjectsDescriptor(object):
        manager = self.__get__(instance)
        manager.set(value)

    @cached_property
    def related_manager_cls(self):
        # Dynamically create a class that subclasses the related model's default
        # manager.
        return create_foreign_related_manager(
            self.related.related_model._default_manager.__class__,
            self.related.field,
            self.related.related_model,
        )

def create_many_related_manager(superclass, rel, reverse):

def create_many_related_manager(superclass, rel):
    """Creates a manager that subclasses 'superclass' (which is a Manager)
    and adds behavior for many-to-many related objects."""
    class ManyRelatedManager(superclass):
        def __init__(self, model=None, query_field_name=None, instance=None, symmetrical=None,
                     source_field_name=None, target_field_name=None, reverse=False,
                     through=None, prefetch_cache_name=None):
        def __init__(self, instance=None):
            super(ManyRelatedManager, self).__init__()
            self.model = model
            self.query_field_name = query_field_name

            source_field = through._meta.get_field(source_field_name)
            source_related_fields = source_field.related_fields
            self.instance = instance

            self.core_filters = {}
            for lh_field, rh_field in source_related_fields:
                self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
            if not reverse:
                self.model = rel.to
                self.query_field_name = rel.field.related_query_name()
                self.prefetch_cache_name = rel.field.name
                self.source_field_name = rel.field.m2m_field_name()
                self.target_field_name = rel.field.m2m_reverse_field_name()
                self.symmetrical = rel.symmetrical
            else:
                self.model = rel.related_model
                self.query_field_name = rel.field.name
                self.prefetch_cache_name = rel.field.related_query_name()
                self.source_field_name = rel.field.m2m_reverse_field_name()
                self.target_field_name = rel.field.m2m_field_name()
                self.symmetrical = False

            self.instance = instance
            self.symmetrical = symmetrical
            self.source_field = source_field
            self.target_field = through._meta.get_field(target_field_name)
            self.source_field_name = source_field_name
            self.target_field_name = target_field_name
            self.through = rel.through
            self.reverse = reverse
            self.through = through
            self.prefetch_cache_name = prefetch_cache_name
            self.related_val = source_field.get_foreign_related_value(instance)

            self.source_field = self.through._meta.get_field(self.source_field_name)
            self.target_field = self.through._meta.get_field(self.target_field_name)

            self.core_filters = {}
            for lh_field, rh_field in self.source_field.related_fields:
                self.core_filters['%s__%s' % (self.query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)

            self.related_val = self.source_field.get_foreign_related_value(instance)
            if None in self.related_val:
                raise ValueError('"%r" needs to have a value for field "%s" before '
                                 'this many-to-many relationship can be used.' %
                                 (instance, source_field_name))
                                 (instance, self.source_field_name))
            # Even if this relation is not to pk, we require still pk value.
            # The wish is that the instance has been already saved to DB,
            # although having a pk value isn't a guarantee of that.
@@ -903,18 +912,8 @@ def create_many_related_manager(superclass, rel):
            # We use **kwargs rather than a kwarg argument to enforce the
            # `manager='manager_name'` syntax.
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_many_related_manager(manager.__class__, rel)
            return manager_class(
                model=self.model,
                query_field_name=self.query_field_name,
                instance=self.instance,
                symmetrical=self.symmetrical,
                source_field_name=self.source_field_name,
                target_field_name=self.target_field_name,
                reverse=self.reverse,
                through=self.through,
                prefetch_cache_name=self.prefetch_cache_name,
            )
            manager_class = create_many_related_manager(manager.__class__, rel, reverse)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def _build_remove_filters(self, removed_vals):
@@ -1198,107 +1197,38 @@ def create_many_related_manager(superclass, rel):
    return ManyRelatedManager


class ManyRelatedObjectsDescriptor(object):
    # This class provides the functionality that makes the related-object
    # managers available as attributes on a model class, for fields that have
    # multiple "remote" values and have a ManyToManyField pointed at them by
    # some other model (rather than having a ManyToManyField themselves).
    # In the example "publication.article_set", the article_set attribute is a
    # ManyRelatedObjectsDescriptor instance.
    def __init__(self, related):
        self.related = related   # RelatedObject instance
class ManyRelatedObjectsDescriptor(ForeignRelatedObjectsDescriptor):

    @cached_property
    def related_manager_cls(self):
        # Dynamically create a class that subclasses the related
        # model's default manager.
        return create_many_related_manager(
            self.related.related_model._default_manager.__class__,
            self.related.field.rel
        )

    def __get__(self, instance, instance_type=None):
        if instance is None:
            return self

        rel_model = self.related.related_model

        manager = self.related_manager_cls(
            model=rel_model,
            query_field_name=self.related.field.name,
            prefetch_cache_name=self.related.field.related_query_name(),
            instance=instance,
            symmetrical=False,
            source_field_name=self.related.field.m2m_reverse_field_name(),
            target_field_name=self.related.field.m2m_field_name(),
            reverse=True,
            through=self.related.field.rel.through,
        )

        return manager

    def __set__(self, instance, value):
        manager = self.__get__(instance)
        manager.set(value)

    def __init__(self, rel, reverse=False):
        super(ManyRelatedObjectsDescriptor, self).__init__(rel)

class ReverseManyRelatedObjectsDescriptor(object):
    # This class provides the functionality that makes the related-object
    # managers available as attributes on a model class, for fields that have
    # multiple "remote" values and have a ManyToManyField defined in their
    # model (rather than having another model pointed *at* them).
    # In the example "article.publications", the publications attribute is a
    # ReverseManyRelatedObjectsDescriptor instance.
    def __init__(self, m2m_field):
        self.field = m2m_field
        self.reverse = reverse

    @property
    def through(self):
        # through is provided so that you have easy access to the through
        # model (Book.authors.through) for inlines, etc. This is done as
        # a property to ensure that the fully resolved value is returned.
        return self.field.rel.through
        return self.rel.through

    @cached_property
    def related_manager_cls(self):
        model = self.rel.related_model if self.reverse else self.rel.to
        # Dynamically create a class that subclasses the related model's
        # default manager.
        return create_many_related_manager(
            self.field.rel.to._default_manager.__class__,
            self.field.rel
        )

    def __get__(self, instance, instance_type=None):
        if instance is None:
            return self

        manager = self.related_manager_cls(
            model=self.field.rel.to,
            query_field_name=self.field.related_query_name(),
            prefetch_cache_name=self.field.name,
            instance=instance,
            symmetrical=self.field.rel.symmetrical,
            source_field_name=self.field.m2m_field_name(),
            target_field_name=self.field.m2m_reverse_field_name(),
            reverse=False,
            through=self.field.rel.through,
            model._default_manager.__class__,
            self.rel,
            reverse=self.reverse,
        )

        return manager

    def __set__(self, instance, value):
        if not self.field.rel.through._meta.auto_created:
            opts = self.field.rel.through._meta
            raise AttributeError(
                "Cannot set values on a ManyToManyField which specifies an "
                "intermediary model.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
            )

        manager = self.__get__(instance)
        manager.set(value)
class ForeignObjectRel(object):


class ForeignObjectRel(object):
    # Field flags
    auto_created = True
    concrete = False
@@ -1505,10 +1435,6 @@ class ManyToManyRel(ForeignObjectRel):
        self.symmetrical = symmetrical
        self.db_constraint = db_constraint

    def is_hidden(self):
        "Should the related object be hidden?"
        return self.related_name is not None and self.related_name[-1] == '+'

    def get_related_field(self):
        """
        Returns the field in the 'to' object to which this relationship is tied.
@@ -2599,7 +2525,8 @@ class ManyToManyField(RelatedField):
            self.rel.through = create_many_to_many_intermediary_model(self, cls)

        # Add the descriptor for the m2m relation
        setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self))
        # Add the descriptor for the m2m relation.
        setattr(cls, self.name, ManyRelatedObjectsDescriptor(self.rel, reverse=False))

        # Set up the accessor for the m2m table name for the relation
        self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
@@ -2615,7 +2542,7 @@ class ManyToManyField(RelatedField):
        # Internal M2Ms (i.e., those with a related name ending with '+')
        # and swapped models don't get a related descriptor.
        if not self.rel.is_hidden() and not related.related_model._meta.swapped:
            setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related))
            setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(self.rel, reverse=True))

        # Set up the accessors for the column names on the m2m table
        self.m2m_column_name = curry(self._get_m2m_attr, related, 'column')
+1 −1
Original line number Diff line number Diff line
@@ -424,7 +424,7 @@ class ManyToManyTests(TestCase):
    def test_reverse_assign_with_queryset(self):
        # Ensure that querysets used in M2M assignments are pre-evaluated
        # so their value isn't affected by the clearing operation in
        # ReverseManyRelatedObjectsDescriptor.__set__. Refs #19816.
        # ManyRelatedObjectsDescriptor.__set__. Refs #19816.
        self.p1.article_set = [self.a1, self.a2]

        qs = self.p1.article_set.filter(headline='Django lets you build Web apps easily')
+3 −3

File changed.

Preview size limit exceeded, changes collapsed.