Commit 6d6fe61b authored by Luke Plant's avatar Luke Plant
Browse files

Cleanups to related manager code, especially in use of closures.

The related manager classes are defined within functions, and the methods
had inconsistent and confusing usage of closures vs. parameters on self to
retrieve needed information. Everything is stored on self now.

Also some methods were not using super() where they should have been.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16913 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent e3a6ac8f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -266,7 +266,7 @@ def create_generic_related_manager(superclass):
                '%s__pk' % self.content_type_field_name : self.content_type.id,
                '%s__exact' % self.object_id_field_name : self.pk_val,
            }
            return superclass.get_query_set(self).using(db).filter(**query)
            return super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)

        def add(self, *objs):
            for obj in objs:
+36 −32
Original line number Diff line number Diff line
@@ -414,88 +414,90 @@ class ForeignRelatedObjectsDescriptor(object):
        Creates the managers used by other methods (__get__() and delete()).
        """
        rel_field = self.related.field
        rel_model = self.related.model

        class RelatedManager(superclass):
            def __init__(self, model=None, core_filters=None, instance=None):
            def __init__(self, model=None, core_filters=None, instance=None,
                         rel_field=None):
                super(RelatedManager, self).__init__()
                self.model = model
                self.core_filters = core_filters
                self.instance = instance
                self.rel_field = rel_field

            def get_query_set(self):
                db = self._db or router.db_for_read(rel_model, instance=self.instance)
                return superclass.get_query_set(self).using(db).filter(**(self.core_filters))
                db = self._db or router.db_for_read(self.model, instance=self.instance)
                return super(RelatedManager, self).get_query_set().using(db).filter(**(self.core_filters))

            def add(self, *objs):
                for obj in objs:
                    if not isinstance(obj, self.model):
                        raise TypeError("'%s' instance expected" % self.model._meta.object_name)
                    setattr(obj, rel_field.name, self.instance)
                    setattr(obj, self.rel_field.name, self.instance)
                    obj.save()
            add.alters_data = True

            def create(self, **kwargs):
                kwargs[rel_field.name] = self.instance
                db = router.db_for_write(rel_model, instance=self.instance)
                kwargs[self.rel_field.name] = self.instance
                db = router.db_for_write(self.model, instance=self.instance)
                return super(RelatedManager, self.db_manager(db)).create(**kwargs)
            create.alters_data = True

            def get_or_create(self, **kwargs):
                # Update kwargs with the related object that this
                # ForeignRelatedObjectsDescriptor knows about.
                kwargs[rel_field.name] = self.instance
                db = router.db_for_write(rel_model, instance=self.instance)
                kwargs[self.rel_field.name] = self.instance
                db = router.db_for_write(self.model, instance=self.instance)
                return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
            get_or_create.alters_data = True

            # remove() and clear() are only provided if the ForeignKey can have a value of null.
            if rel_field.null:
                def remove(self, *objs):
                    val = getattr(self.instance, rel_field.rel.get_related_field().attname)
                    val = getattr(self.instance, self.rel_field.rel.get_related_field().attname)
                    for obj in objs:
                        # Is obj actually part of this descriptor set?
                        if getattr(obj, rel_field.attname) == val:
                            setattr(obj, rel_field.name, None)
                        if getattr(obj, self.rel_field.attname) == val:
                            setattr(obj, self.rel_field.name, None)
                            obj.save()
                        else:
                            raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
                            raise self.rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
                remove.alters_data = True

                def clear(self):
                    self.update(**{rel_field.name: None})
                    self.update(**{self.rel_field.name: None})
                clear.alters_data = True

        attname = rel_field.rel.get_related_field().name
        return RelatedManager(model=self.related.model,
                              core_filters = {'%s__%s' % (rel_field.name, attname):
                                                  getattr(instance, attname)},
                              instance=instance
                              instance=instance,
                              rel_field=rel_field,
                              )


def create_many_related_manager(superclass, rel):
    """Creates a manager that subclasses 'superclass' (which is a Manager)
    and adds behavior for many-to-many related objects."""
    through = rel.through
    class ManyRelatedManager(superclass):
        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
                     source_field_name=None, target_field_name=None, reverse=False):
                     source_field_name=None, target_field_name=None, reverse=False,
                     through=None):
            super(ManyRelatedManager, self).__init__()
            self.core_filters = core_filters
            self.model = model
            self.symmetrical = symmetrical
            self.core_filters = core_filters
            self.instance = instance
            self.symmetrical = symmetrical
            self.source_field_name = source_field_name
            self.target_field_name = target_field_name
            self.reverse = reverse
            self.through = through
            self._pk_val = self.instance.pk
            self.reverse = reverse
            if self._pk_val is None:
                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)

        def get_query_set(self):
            db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance)
            return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters))
            return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**(self.core_filters))

        # If the ManyToMany relation has an intermediary model,
        # the add and remove methods do not exist.
@@ -527,8 +529,8 @@ def create_many_related_manager(superclass, rel):
        def create(self, **kwargs):
            # This check needs to be done here, since we can't later remove this
            # from the method lookup table, as we do with add and remove.
            if not rel.through._meta.auto_created:
                opts = through._meta
            if not self.through._meta.auto_created:
                opts = self.through._meta
                raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
            db = router.db_for_write(self.instance.__class__, instance=self.instance)
            new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
@@ -577,7 +579,7 @@ def create_many_related_manager(superclass, rel):
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='pre_add',
                    signals.m2m_changed.send(sender=self.through, action='pre_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids, using=db)
                # Add the ones that aren't there already
@@ -591,7 +593,7 @@ def create_many_related_manager(superclass, rel):
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='post_add',
                    signals.m2m_changed.send(sender=self.through, action='post_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids, using=db)

@@ -615,7 +617,7 @@ def create_many_related_manager(superclass, rel):
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are deleting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action="pre_remove",
                    signals.m2m_changed.send(sender=self.through, action="pre_remove",
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=old_ids, using=db)
                # Remove the specified objects from the join table
@@ -626,7 +628,7 @@ def create_many_related_manager(superclass, rel):
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are deleting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action="post_remove",
                    signals.m2m_changed.send(sender=self.through, action="post_remove",
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=old_ids, using=db)

@@ -636,7 +638,7 @@ def create_many_related_manager(superclass, rel):
            if self.reverse or source_field_name == self.source_field_name:
                # Don't send the signal when we are clearing the
                # duplicate data rows for symmetrical reverse entries.
                signals.m2m_changed.send(sender=rel.through, action="pre_clear",
                signals.m2m_changed.send(sender=self.through, action="pre_clear",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=None, using=db)
            self.through._default_manager.using(db).filter(**{
@@ -645,7 +647,7 @@ def create_many_related_manager(superclass, rel):
            if self.reverse or source_field_name == self.source_field_name:
                # Don't send the signal when we are clearing the
                # duplicate data rows for symmetrical reverse entries.
                signals.m2m_changed.send(sender=rel.through, action="post_clear",
                signals.m2m_changed.send(sender=self.through, action="post_clear",
                    instance=self.instance, reverse=self.reverse,
                    model=self.model, pk_set=None, using=db)

@@ -678,7 +680,8 @@ class ManyRelatedObjectsDescriptor(object):
            symmetrical=False,
            source_field_name=self.related.field.m2m_reverse_field_name(),
            target_field_name=self.related.field.m2m_field_name(),
            reverse=True
            reverse=True,
            through=self.related.field.rel.through,
        )

        return manager
@@ -730,7 +733,8 @@ class ReverseManyRelatedObjectsDescriptor(object):
            symmetrical=self.field.rel.symmetrical,
            source_field_name=self.field.m2m_field_name(),
            target_field_name=self.field.m2m_reverse_field_name(),
            reverse=False
            reverse=False,
            through=self.field.rel.through,
        )

        return manager