Commit 97774429 authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19385 again, now with real code changes

The commit of 266de5f9 included only
tests, this time also code changes included...
parent 266de5f9
Loading
Loading
Loading
Loading
+45 −57
Original line number Diff line number Diff line
@@ -8,10 +8,11 @@ from functools import partial

from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
from django.db.models import signals
from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
from django.db.models import signals
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo
from django.db.models.sql.where import Constraint
from django.forms import ModelForm
from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@@ -149,17 +150,14 @@ class GenericForeignKey(six.with_metaclass(RenameGenericForeignKeyMethods)):
        setattr(instance, self.fk_field, fk)
        setattr(instance, self.cache_attr, value)

class GenericRelation(RelatedField, Field):
class GenericRelation(ForeignObject):
    """Provides an accessor to generic related objects (e.g. comments)"""

    def __init__(self, to, **kwargs):
        kwargs['verbose_name'] = kwargs.get('verbose_name', None)
        kwargs['rel'] = GenericRel(to,
                            related_name=kwargs.pop('related_name', None),
                            limit_choices_to=kwargs.pop('limit_choices_to', None),
                            symmetrical=kwargs.pop('symmetrical', True))


        kwargs['rel'] = GenericRel(
            self, to, related_name=kwargs.pop('related_name', None),
            limit_choices_to=kwargs.pop('limit_choices_to', None),)
        # Override content-type/object-id field names on the related class
        self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
        self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
@@ -167,47 +165,44 @@ class GenericRelation(RelatedField, Field):
        kwargs['blank'] = True
        kwargs['editable'] = False
        kwargs['serialize'] = False
        Field.__init__(self, **kwargs)

    def get_path_info(self):
        from_field = self.model._meta.pk
        # This construct is somewhat of an abuse of ForeignObject. This field
        # represents a relation from pk to object_id field. But, this relation
        # isn't direct, the join is generated reverse along foreign key. So,
        # the from_field is object_id field, to_field is pk because of the
        # reverse join.
        super(GenericRelation, self).__init__(
            to, to_fields=[],
            from_fields=[self.object_id_field_name], **kwargs)

    def resolve_related_fields(self):
        self.to_fields = [self.model._meta.pk.name]
        return [(self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0],
                 self.model._meta.pk)]

    def get_reverse_path_info(self):
        opts = self.rel.to._meta
        target = opts.get_field_by_name(self.object_id_field_name)[0]
        # Note that we are using different field for the join_field
        # than from_field or to_field. This is a hack, but we need the
        # GenericRelation to generate the extra SQL.
        return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
                opts, target, self)
        return [PathInfo(self.model._meta, opts, (target,), self.rel, True, False)]

    def get_choices_default(self):
        return Field.get_choices(self, include_blank=False)
        return super(GenericRelation, self).get_choices(include_blank=False)

    def value_to_string(self, obj):
        qs = getattr(obj, self.name).all()
        return smart_text([instance._get_pk_val() for instance in qs])

    def m2m_db_table(self):
        return self.rel.to._meta.db_table

    def m2m_column_name(self):
        return self.object_id_field_name

    def m2m_reverse_name(self):
        return self.rel.to._meta.pk.column

    def m2m_target_field_name(self):
        return self.model._meta.pk.name

    def m2m_reverse_target_field_name(self):
        return self.rel.to._meta.pk.name
    def get_joining_columns(self, reverse_join=False):
        if not reverse_join:
            # This error message is meant for the user, and from user
            # perspective this is a reverse join along the GenericRelation.
            raise ValueError('Joining in reverse direction not allowed.')
        return super(GenericRelation, self).get_joining_columns(reverse_join)

    def contribute_to_class(self, cls, name):
        super(GenericRelation, self).contribute_to_class(cls, name)

        super(GenericRelation, self).contribute_to_class(cls, name, virtual_only=True)
        # Save a reference to which model this class is on for future use
        self.model = cls

        # Add the descriptor for the m2m relation
        # Add the descriptor for the relation
        setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))

    def contribute_to_related_class(self, cls, related):
@@ -219,21 +214,18 @@ class GenericRelation(RelatedField, Field):
    def get_internal_type(self):
        return "ManyToManyField"

    def db_type(self, connection):
        # Since we're simulating a ManyToManyField, in effect, best return the
        # same db_type as well.
        return None

    def get_content_type(self):
        """
        Returns the content type associated with this field's model.
        """
        return ContentType.objects.get_for_model(self.model)

    def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
        extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
        contenttype = self.get_content_type().pk
        return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
    def get_extra_restriction(self, where_class, alias, remote_alias):
        field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0]
        contenttype_pk = self.get_content_type().pk
        cond = where_class()
        cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND')
        return cond

    def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
        """
@@ -273,12 +265,12 @@ class ReverseGenericRelatedObjectsDescriptor(object):
        qn = connection.ops.quote_name
        content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance)

        join_cols = self.field.get_joining_columns(reverse_join=True)[0]
        manager = RelatedManager(
            model = rel_model,
            instance = instance,
            symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model),
            source_col_name = qn(self.field.m2m_column_name()),
            target_col_name = qn(self.field.m2m_reverse_name()),
            source_col_name = qn(join_cols[0]),
            target_col_name = qn(join_cols[1]),
            content_type = content_type,
            content_type_field_name = self.field.content_type_field_name,
            object_id_field_name = self.field.object_id_field_name,
@@ -378,14 +370,10 @@ def create_generic_related_manager(superclass):

    return GenericRelatedObjectManager

class GenericRel(ManyToManyRel):
    def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
        self.to = to
        self.related_name = related_name
        self.limit_choices_to = limit_choices_to or {}
        self.symmetrical = symmetrical
        self.multiple = True
        self.through = None
class GenericRel(ForeignObjectRel):

    def __init__(self, field, to, related_name=None, limit_choices_to=None):
        super(GenericRel, self).__init__(field, to, related_name, limit_choices_to)

class BaseGenericInlineFormSet(BaseModelFormSet):
    """
+10 −2
Original line number Diff line number Diff line
@@ -153,8 +153,16 @@ def get_validation_errors(outfile, app=None):
                    continue

                # Make sure the related field specified by a ForeignKey is unique
                if not f.rel.to._meta.get_field(f.rel.field_name).unique:
                    e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__))
                if f.requires_unique_target:
                    if len(f.foreign_related_fields) > 1:
                        has_unique_field = False
                        for rel_field in f.foreign_related_fields:
                            has_unique_field = has_unique_field or rel_field.unique
                        if not has_unique_field:
                            e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__))
                    else:
                        if not f.foreign_related_fields[0].unique:
                            e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__))

                rel_opts = f.rel.to._meta
                rel_name = f.related.get_accessor_name()
+6 −0
Original line number Diff line number Diff line
@@ -17,6 +17,12 @@ class SQLCompiler(compiler.SQLCompiler):
            values.append(value)
        return row[:index_extra_select] + tuple(values)

    def as_subquery_condition(self, alias, columns):
        qn = self.quote_name_unless_alias
        qn2 = self.connection.ops.quote_name
        sql, params = self.as_sql()
        return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params

class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
    pass

+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from django.db.models.aggregates import *
from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
from django.db.models import signals
from django.utils.decorators import wraps
+11 −8
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from django.conf import settings
from django.core.exceptions import (ObjectDoesNotExist,
    MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (ManyToOneRel,
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
    OneToOneField, add_lazy_relation)
from django.db import (router, transaction, DatabaseError,
    DEFAULT_DB_ALIAS)
@@ -333,12 +333,12 @@ class Model(six.with_metaclass(ModelBase)):
        # The reason for the kwargs check is that standard iterator passes in by
        # args, and instantiation for iteration is 33% faster.
        args_len = len(args)
        if args_len > len(self._meta.fields):
        if args_len > len(self._meta.concrete_fields):
            # Daft, but matches old exception sans the err msg.
            raise IndexError("Number of args exceeds number of fields")

        fields_iter = iter(self._meta.fields)
        if not kwargs:
            fields_iter = iter(self._meta.concrete_fields)
            # The ordering of the zip calls matter - zip throws StopIteration
            # when an iter throws it. So if the first iter throws it, the second
            # is *not* consumed. We rely on this, so don't change the order
@@ -347,6 +347,7 @@ class Model(six.with_metaclass(ModelBase)):
                setattr(self, field.attname, val)
        else:
            # Slower, kwargs-ready version.
            fields_iter = iter(self._meta.fields)
            for val, field in zip(args, fields_iter):
                setattr(self, field.attname, val)
                kwargs.pop(field.name, None)
@@ -363,11 +364,12 @@ class Model(six.with_metaclass(ModelBase)):
            # data-descriptor object (DeferredAttribute) without triggering its
            # __get__ method.
            if (field.attname not in kwargs and
                    isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)):
                    (isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)
                     or field.column is None)):
                # This field will be populated on request.
                continue
            if kwargs:
                if isinstance(field.rel, ManyToOneRel):
                if isinstance(field.rel, ForeignObjectRel):
                    try:
                        # Assume object instance was passed in.
                        rel_obj = kwargs.pop(field.name)
@@ -394,6 +396,7 @@ class Model(six.with_metaclass(ModelBase)):
                        val = field.get_default()
            else:
                val = field.get_default()

            if is_related_object:
                # If we are passed a related instance, set it using the
                # field.name instead of field.attname (e.g. "user" instead of
@@ -528,7 +531,7 @@ class Model(six.with_metaclass(ModelBase)):
        # automatically do a "update_fields" save on the loaded fields.
        elif not force_insert and self._deferred and using == self._state.db:
            field_names = set()
            for field in self._meta.fields:
            for field in self._meta.concrete_fields:
                if not field.primary_key and not hasattr(field, 'through'):
                    field_names.add(field.attname)
            deferred_fields = [
@@ -614,7 +617,7 @@ class Model(six.with_metaclass(ModelBase)):
        for a single table.
        """
        meta = cls._meta
        non_pks = [f for f in meta.local_fields if not f.primary_key]
        non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]

        if update_fields:
            non_pks = [f for f in non_pks
@@ -652,7 +655,7 @@ class Model(six.with_metaclass(ModelBase)):
                    **{field.name: getattr(self, field.attname)}).count()
                self._order = order_value

            fields = meta.local_fields
            fields = meta.local_concrete_fields
            if not pk_set:
                fields = [f for f in fields if not isinstance(f, AutoField)]

Loading