Commit 616b3022 authored by Carl Meyer's avatar Carl Meyer
Browse files

Fixed #7539, #13067 -- Added on_delete argument to ForeignKey to control...

Fixed #7539, #13067 -- Added on_delete argument to ForeignKey to control cascade behavior. Also refactored deletion for efficiency and code clarity. Many thanks to Johannes Dollinger and Michael Glassford for extensive work on the patch, and to Alex Gaynor, Russell Keith-Magee, and Jacob Kaplan-Moss for review.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14507 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 3ba3294c
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from django import template
from django.core.exceptions import PermissionDenied
from django.contrib.admin import helpers
from django.contrib.admin.util import get_deleted_objects, model_ngettext
from django.db import router
from django.shortcuts import render_to_response
from django.utils.encoding import force_unicode
from django.utils.translation import ugettext_lazy, ugettext as _
@@ -27,9 +28,12 @@ def delete_selected(modeladmin, request, queryset):
    if not modeladmin.has_delete_permission(request):
        raise PermissionDenied

    using = router.db_for_write(modeladmin.model)

    # Populate deletable_objects, a data structure of all related objects that
    # will also be deleted.
    deletable_objects, perms_needed = get_deleted_objects(queryset, opts, request.user, modeladmin.admin_site)
    deletable_objects, perms_needed = get_deleted_objects(
        queryset, opts, request.user, modeladmin.admin_site, using)

    # The user has already confirmed the deletion.
    # Do the deletion and return a None to display the change list view again.
+5 −2
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from django.contrib.admin.util import unquote, flatten_fieldsets, get_deleted_ob
from django.contrib import messages
from django.views.decorators.csrf import csrf_protect
from django.core.exceptions import PermissionDenied, ValidationError
from django.db import models, transaction
from django.db import models, transaction, router
from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import Http404, HttpResponse, HttpResponseRedirect
from django.shortcuts import get_object_or_404, render_to_response
@@ -1110,9 +1110,12 @@ class ModelAdmin(BaseModelAdmin):
        if obj is None:
            raise Http404(_('%(name)s object with primary key %(key)r does not exist.') % {'name': force_unicode(opts.verbose_name), 'key': escape(object_id)})

        using = router.db_for_write(self.model)

        # Populate deleted_objects, a data structure of all related objects that
        # will also be deleted.
        (deleted_objects, perms_needed) = get_deleted_objects((obj,), opts, request.user, self.admin_site)
        (deleted_objects, perms_needed) = get_deleted_objects(
            [obj], opts, request.user, self.admin_site, using)

        if request.POST: # The user has already confirmed the deletion.
            if perms_needed:
+65 −108
Original line number Diff line number Diff line
from django.db import models
from django.db.models.deletion import Collector
from django.db.models.related import RelatedObject
from django.forms.forms import pretty_name
from django.utils import formats
@@ -10,6 +11,7 @@ from django.utils.translation import ungettext
from django.core.urlresolvers import reverse, NoReverseMatch
from django.utils.datastructures import SortedDict


def quote(s):
    """
    Ensure that primary key values do not confuse the admin URLs by escaping
@@ -26,6 +28,7 @@ def quote(s):
            res[i] = '_%02X' % ord(c)
    return ''.join(res)


def unquote(s):
    """
    Undo the effects of quote(). Based heavily on urllib.unquote().
@@ -46,6 +49,7 @@ def unquote(s):
            myappend('_' + item)
    return "".join(res)


def flatten_fieldsets(fieldsets):
    """Returns a list of field names from an admin fieldsets structure."""
    field_names = []
@@ -58,9 +62,24 @@ def flatten_fieldsets(fieldsets):
                field_names.append(field)
    return field_names

def _format_callback(obj, user, admin_site, perms_needed):

def get_deleted_objects(objs, opts, user, admin_site, using):
    """
    Find all objects related to ``objs`` that should also be deleted. ``objs``
    must be a homogenous iterable of objects (e.g. a QuerySet).

    Returns a nested list of strings suitable for display in the
    template with the ``unordered_list`` filter.

    """
    collector = NestedObjects(using=using)
    collector.collect(objs)
    perms_needed = set()

    def format_callback(obj):
        has_admin = obj.__class__ in admin_site._registry
        opts = obj._meta

        if has_admin:
            admin_url = reverse('%s:%s_%s_change'
                                % (admin_site.name,
@@ -82,120 +101,55 @@ def _format_callback(obj, user, admin_site, perms_needed):
            return u'%s: %s' % (capfirst(opts.verbose_name),
                                force_unicode(obj))

def get_deleted_objects(objs, opts, user, admin_site):
    """
    Find all objects related to ``objs`` that should also be
    deleted. ``objs`` should be an iterable of objects.

    Returns a nested list of strings suitable for display in the
    template with the ``unordered_list`` filter.

    """
    collector = NestedObjects()
    for obj in objs:
        # TODO using a private model API!
        obj._collect_sub_objects(collector)

    perms_needed = set()

    to_delete = collector.nested(_format_callback,
                                 user=user,
                                 admin_site=admin_site,
                                 perms_needed=perms_needed)
    to_delete = collector.nested(format_callback)

    return to_delete, perms_needed


class NestedObjects(object):
    """
    A directed acyclic graph collection that exposes the add() API
    expected by Model._collect_sub_objects and can present its data as
    a nested list of objects.

    """
    def __init__(self):
        # Use object keys of the form (model, pk) because actual model
        # objects may not be unique

        # maps object key to list of child keys
        self.children = SortedDict()

        # maps object key to parent key
        self.parents = SortedDict()

        # maps object key to actual object
        self.seen = SortedDict()

    def add(self, model, pk, obj,
            parent_model=None, parent_obj=None, nullable=False):
        """
        Add item ``obj`` to the graph. Returns True (and does nothing)
        if the item has been seen already.

        The ``parent_obj`` argument must already exist in the graph; if
        not, it's ignored (but ``obj`` is still added with no
        parent). In any case, Model._collect_sub_objects (for whom
        this API exists) will never pass a parent that hasn't already
        been added itself.

        These restrictions in combination ensure the graph will remain
        acyclic (but can have multiple roots).

        ``model``, ``pk``, and ``parent_model`` arguments are ignored
        in favor of the appropriate lookups on ``obj`` and
        ``parent_obj``; unlike CollectedObjects, we can't maintain
        independence from the knowledge that we're operating on model
        instances, and we don't want to allow for inconsistency.

        ``nullable`` arg is ignored: it doesn't affect how the tree of
        collected objects should be nested for display.
        """
        model, pk = type(obj), obj._get_pk_val()

        # auto-created M2M models don't interest us
        if model._meta.auto_created:
            return True
class NestedObjects(Collector):
    def __init__(self, *args, **kwargs):
        super(NestedObjects, self).__init__(*args, **kwargs)
        self.edges = {} # {from_instance: [to_instances]}

        key = model, pk
    def add_edge(self, source, target):
        self.edges.setdefault(source, []).append(target)

        if key in self.seen:
            return True
        self.seen.setdefault(key, obj)
    def collect(self, objs, source_attr=None, **kwargs):
        for obj in objs:
            if source_attr:
                self.add_edge(getattr(obj, source_attr), obj)
            else:
                self.add_edge(None, obj)
        return super(NestedObjects, self).collect(objs, source_attr=source_attr, **kwargs)

        if parent_obj is not None:
            parent_model, parent_pk = (type(parent_obj),
                                       parent_obj._get_pk_val())
            parent_key = (parent_model, parent_pk)
            if parent_key in self.seen:
                self.children.setdefault(parent_key, list()).append(key)
                self.parents.setdefault(key, parent_key)
    def related_objects(self, related, objs):
        qs = super(NestedObjects, self).related_objects(related, objs)
        return qs.select_related(related.field.name)

    def _nested(self, key, format_callback=None, **kwargs):
        obj = self.seen[key]
    def _nested(self, obj, seen, format_callback):
        if obj in seen:
            return []
        seen.add(obj)
        children = []
        for child in self.edges.get(obj, ()):
            children.extend(self._nested(child, seen, format_callback))
        if format_callback:
            ret = [format_callback(obj, **kwargs)]
            ret = [format_callback(obj)]
        else:
            ret = [obj]

        children = []
        for child in self.children.get(key, ()):
            children.extend(self._nested(child, format_callback, **kwargs))
        if children:
            ret.append(children)

        return ret

    def nested(self, format_callback=None, **kwargs):
    def nested(self, format_callback=None):
        """
        Return the graph as a nested list.

        Passes **kwargs back to the format_callback as kwargs.

        """
        seen = set()
        roots = []
        for key in self.seen.keys():
            if key not in self.parents:
                roots.extend(self._nested(key, format_callback, **kwargs))
        for root in self.edges.get(None, ()):
            roots.extend(self._nested(root, seen, format_callback))
        return roots


@@ -218,6 +172,7 @@ def model_format_dict(obj):
        'verbose_name_plural': force_unicode(opts.verbose_name_plural)
    }


def model_ngettext(obj, n=None):
    """
    Return the appropriate `verbose_name` or `verbose_name_plural` value for
@@ -236,6 +191,7 @@ def model_ngettext(obj, n=None):
    singular, plural = d["verbose_name"], d["verbose_name_plural"]
    return ungettext(singular, plural, n or 0)


def lookup_field(name, obj, model_admin=None):
    opts = obj._meta
    try:
@@ -262,6 +218,7 @@ def lookup_field(name, obj, model_admin=None):
        value = getattr(obj, name)
    return f, attr, value


def label_for_field(name, model, model_admin=None, return_attr=False):
    attr = None
    try:
+17 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ Classes allowing "generic" relations through ContentType and object-id fields.
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
from django.db.models import signals
from django.db import models, router
from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
from django.db.models.loading import get_model
from django.forms import ModelForm
@@ -13,6 +13,9 @@ from django.forms.models import BaseModelFormSet, modelformset_factory, save_ins
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
from django.utils.encoding import smart_unicode

from django.contrib.contenttypes.models import ContentType


class GenericForeignKey(object):
    """
    Provides a generic relation to any object through content-type/object-id
@@ -167,6 +170,19 @@ class GenericRelation(RelatedField, Field):
        return [("%s__%s" % (prefix, self.content_type_field_name),
            content_type)]

    def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
        """
        Return all objects related to ``objs`` via this ``GenericRelation``.

        """
        return self.rel.to._base_manager.db_manager(using).filter(**{
                "%s__pk" % self.content_type_field_name:
                    ContentType.objects.db_manager(using).get_for_model(self.model).pk,
                "%s__in" % self.object_id_field_name:
                    [obj.pk for obj in objs]
                })


class ReverseGenericRelatedObjectsDescriptor(object):
    """
    This class provides the functionality that makes the related-object
+8 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ def get_validation_errors(outfile, app=None):
    from django.db import models, connection
    from django.db.models.loading import get_app_errors
    from django.db.models.fields.related import RelatedObject
    from django.db.models.deletion import SET_NULL, SET_DEFAULT

    e = ModelErrorCollection(outfile)

@@ -85,6 +86,13 @@ def get_validation_errors(outfile, app=None):
            # Perform any backend-specific field validation.
            connection.validation.validate_field(e, opts, f)

            # Check if the on_delete behavior is sane
            if f.rel and hasattr(f.rel, 'on_delete'):
                if f.rel.on_delete == SET_NULL and not f.null:
                    e.add(opts, "'%s' specifies on_delete=SET_NULL, but cannot be null." % f.name)
                elif f.rel.on_delete == SET_DEFAULT and not f.has_default():
                    e.add(opts, "'%s' specifies on_delete=SET_DEFAULT, but has no default value." % f.name)

            # Check to see if the related field will clash with any existing
            # fields, m2m fields, m2m related objects or related objects
            if f.rel:
Loading