Commit f51e409a authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
parent 92d7f541
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -75,6 +75,7 @@ class Options(object):
        from django.db.backends.util import truncate_name

        cls._meta = self
        self.model = cls
        self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
        # First, construct the default values for these options.
        self.object_name = cls.__name__
@@ -464,7 +465,7 @@ class Options(object):
        a granparent or even more distant relation.
        """
        if not self.parents:
            return
            return None
        if model in self.parents:
            return [model]
        for parent in self.parents:
@@ -472,8 +473,7 @@ class Options(object):
            if res:
                res.insert(0, parent)
                return res
        raise TypeError('%r is not an ancestor of this model'
                % model._meta.module_name)
        return None

    def get_parent_list(self):
        """
+49 −34
Original line number Diff line number Diff line
@@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet):
    value_annotation = False

def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
                   only_load=None, local_only=False):
                   only_load=None, from_parent=None):
    """
    Helper function that recursively returns an information for a klass, to be
    used in get_cached_row.  It exists just to compute this information only
@@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
     * only_load - if the query has had only() or defer() applied,
       this is the list of field names that will be returned. If None,
       the full field list for `klass` can be assumed.
     * local_only - Only populate local fields. This is used when
       following reverse select-related relations
     * from_parent - the parent model used to get to this model

    Note that when travelling from parent to child, we will only load child
    fields which aren't in the parent.
    """
    if max_depth and requested is None and cur_depth > max_depth:
        # We've recursed deeply enough; stop now.
@@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
        for field, model in klass._meta.get_fields_with_model():
            if field.name not in load_fields:
                skip.add(field.attname)
            elif local_only and model is not None:
            elif from_parent and issubclass(from_parent, model.__class__):
                # Avoid loading fields already loaded for parent model for
                # child models.
                continue
            else:
                init_list.append(field.attname)
@@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
    else:
        # Load all fields on klass

        # We trying to not populate field_names variable for perfomance reason.
        # If field_names variable is set, it is used to instantiate desired fields,
        # by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
        # But kwargs version of Model.__init__ is slower, so we should avoid using
        # it when it is not really neccesary.
        if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
            field_count = len(klass._meta.local_fields)
            field_names = [f.attname for f in klass._meta.local_fields]
        else:
        field_count = len(klass._meta.fields)
        # Check if we need to skip some parent fields.
        if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
            # Only load those fields which haven't been already loaded into
            # 'from_parent'.
            non_seen_models = [p for p in klass._meta.get_parent_list()
                               if not issubclass(from_parent, p)]
            # Load local fields, too...
            non_seen_models.append(klass)
            field_names = [f.attname for f in klass._meta.fields
                           if f.model in non_seen_models]
            field_count = len(field_names)
        # Try to avoid populating field_names variable for perfomance reasons.
        # If field_names variable is set, we use **kwargs based model init
        # which is slower than normal init.
        if field_count == len(klass._meta.fields):
            field_names = ()

    restricted = requested is not None
@@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
            if o.field.unique and select_related_descend(o.field, restricted, requested,
                                                         only_load.get(o.model), reverse=True):
                next = requested[o.field.related_query_name()]
                parent = klass if issubclass(o.model, klass) else None
                klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
                                            requested=next, only_load=only_load, local_only=True)
                                            requested=next, only_load=only_load, from_parent=parent)
                reverse_related_fields.append((o.field, klass_info))
    if field_names:
        pk_idx = field_names.index(klass._meta.pk.attname)
@@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
    return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx


def get_cached_row(row, index_start, using,  klass_info, offset=0):
def get_cached_row(row, index_start, using,  klass_info, offset=0,
                   parent_data=()):
    """
    Helper function that recursively returns an object with the specified
    related attributes already populated.
@@ -1420,11 +1432,14 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
           annotated results on `klass`.
         * using - the database alias on which the query is being executed.
         * klass_info - result of the get_klass_info function
         * parent_data - parent model data in format (field, value). Used
           to populate the non-local fields of child models.
    """
    if klass_info is None:
        return None
    klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info


    fields = row[index_start : index_start + field_count]
    # If the pk column is None (or the Oracle equivalent ''), then the related
    # object must be non-existent - set the relation to None.
@@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
        obj = klass(**dict(zip(field_names, fields)))
    else:
        obj = klass(*fields)

    # If an object was retrieved, set the database state.
    if obj:
        obj._state.db = using
@@ -1464,28 +1478,29 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
    # Only handle the restricted case - i.e., don't do a depth
    # descent into reverse relations unless explicitly requested
    for f, klass_info in reverse_related_fields:
        # Transfer data from this object to childs.
        parent_data = []
        for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
            if rel_model is not None and isinstance(obj, rel_model):
                parent_data.append((rel_field, getattr(obj, rel_field.attname)))
        # Recursively retrieve the data for the related object
        cached_row = get_cached_row(row, index_end, using, klass_info)
        cached_row = get_cached_row(row, index_end, using, klass_info,
                                   parent_data=parent_data)
        # If the recursive descent found an object, populate the
        # descriptor caches relevant to the object
        if cached_row:
            rel_obj, index_end = cached_row
            if obj is not None:
                # If the field is unique, populate the
                # reverse descriptor cache
                # populate the reverse descriptor cache
                setattr(obj, f.related.get_cache_name(), rel_obj)
            if rel_obj is not None:
                # If the related object exists, populate
                # the descriptor cache.
                setattr(rel_obj, f.get_cache_name(), obj)
                # Now populate all the non-local field values
                # on the related object
                for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
                    if rel_model is not None:
                        setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
                        # populate the field cache for any related object
                        # that has already been retrieved
                # Populate related object caches using parent data.
                for rel_field, _ in parent_data:
                    if rel_field.rel:
                        setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
                        try:
                            cached_obj = getattr(obj, rel_field.get_cache_name())
                            setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
+8 −5
Original line number Diff line number Diff line
@@ -240,7 +240,7 @@ class SQLCompiler(object):
        return result

    def get_default_columns(self, with_aliases=False, col_aliases=None,
            start_alias=None, opts=None, as_pairs=False, local_only=False):
            start_alias=None, opts=None, as_pairs=False, from_parent=None):
        """
        Computes the default columns for selecting every field in the base
        model. Will sometimes be called to pull in related models (e.g. via
@@ -265,7 +265,8 @@ class SQLCompiler(object):
        if start_alias:
            seen = {None: start_alias}
        for field, model in opts.get_fields_with_model():
            if local_only and model is not None:
            if from_parent and model is not None and issubclass(from_parent, model):
                # Avoid loading data for already loaded parents.
                continue
            if start_alias:
                try:
@@ -686,11 +687,13 @@ class SQLCompiler(object):
                    (alias, table, f.rel.get_related_field().column, f.column),
                    promote=True
                )
                from_parent = (opts.model if issubclass(model, opts.model)
                               else None)
                columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=model._meta, as_pairs=True, local_only=True)
                    opts=model._meta, as_pairs=True, from_parent=from_parent)
                self.query.related_select_cols.extend(
                    SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))

                    SelectInfo(col, field) for col, field
                    in zip(columns, model._meta.fields))
                next = requested.get(f.related_query_name(), {})
                # Use True here because we are looking at the _reverse_ side of
                # the relation, which is always nullable.
+42 −0
Original line number Diff line number Diff line
@@ -51,6 +51,7 @@ class StatDetails(models.Model):
class AdvancedUserStat(UserStat):
    karma = models.IntegerField()


class Image(models.Model):
    name = models.CharField(max_length=100)

@@ -58,3 +59,44 @@ class Image(models.Model):
class Product(models.Model):
    name = models.CharField(max_length=100)
    image = models.OneToOneField(Image, null=True)


@python_2_unicode_compatible
class Parent1(models.Model):
    name1 = models.CharField(max_length=50)

    def __str__(self):
        return self.name1


@python_2_unicode_compatible
class Parent2(models.Model):
    # Avoid having two "id" fields in the Child1 subclass
    id2 = models.AutoField(primary_key=True)
    name2 = models.CharField(max_length=50)

    def __str__(self):
        return self.name2


@python_2_unicode_compatible
class Child1(Parent1, Parent2):
    value = models.IntegerField()

    def __str__(self):
        return self.name1


@python_2_unicode_compatible
class Child2(Parent1):
    parent2 = models.OneToOneField(Parent2)
    value = models.IntegerField()

    def __str__(self):
        return self.name1

class Child3(Child2):
    value3 = models.IntegerField()

class Child4(Child1):
    value4 = models.IntegerField()
+101 −1
Original line number Diff line number Diff line
from __future__ import absolute_import

from django.test import TestCase
from django.utils import unittest

from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
    AdvancedUserStat, Image, Product)
    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3,
    Child4)


class ReverseSelectRelatedTestCase(TestCase):
@@ -21,6 +23,14 @@ class ReverseSelectRelatedTestCase(TestCase):
        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
                                                  results=results2)
        StatDetails.objects.create(base_stats=advstat, comments=250)
        p1 = Parent1(name1="Only Parent1")
        p1.save()
        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1)
        c1.save()
        p2 = Parent2(name2="Child2 Parent2")
        p2.save()
        c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2)
        c2.save()

    def test_basic(self):
        with self.assertNumQueries(1):
@@ -108,3 +118,93 @@ class ReverseSelectRelatedTestCase(TestCase):
            image = Image.objects.select_related('product').get()
            with self.assertRaises(Product.DoesNotExist):
                image.product

    def test_parent_only(self):
        with self.assertNumQueries(1):
            p = Parent1.objects.select_related('child1').get(name1="Only Parent1")
        with self.assertNumQueries(0):
            with self.assertRaises(Child1.DoesNotExist):
                p.child1

    def test_multiple_subclass(self):
        with self.assertNumQueries(1):
            p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
            self.assertEqual(p.child1.name2, 'Child1 Parent2')

    def test_onetoone_with_subclass(self):
        with self.assertNumQueries(1):
            p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
            self.assertEqual(p.child2.name1, 'Child2 Parent1')

    def test_onetoone_with_two_subclasses(self):
        with self.assertNumQueries(1):
            p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2")
            self.assertEqual(p.child2.name1, 'Child2 Parent1')
            with self.assertRaises(Child3.DoesNotExist):
                p.child2.child3
        p3 = Parent2(name2="Child3 Parent2")
        p3.save()
        c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3)
        c2.save()
        with self.assertNumQueries(1):
            p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2")
            self.assertEqual(p.child2.name1, 'Child3 Parent1')
            self.assertEqual(p.child2.child3.value3, 3)
            self.assertEqual(p.child2.child3.value, p.child2.value)
            self.assertEqual(p.child2.name1, p.child2.child3.name1)

    def test_multiinheritance_two_subclasses(self):
        with self.assertNumQueries(1):
            p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1")
            self.assertEqual(p.child1.name2, 'Child1 Parent2')
            self.assertEqual(p.child1.name1, p.name1)
            with self.assertRaises(Child4.DoesNotExist):
                p.child1.child4
        Child4(name1='n1', name2='n2', value=1, value4=4).save()
        with self.assertNumQueries(1):
            p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2")
            self.assertEqual(p.name2, 'n2')
            self.assertEqual(p.child1.name1, 'n1')
            self.assertEqual(p.child1.name2, p.name2)
            self.assertEqual(p.child1.value, 1)
            self.assertEqual(p.child1.child4.name1, p.child1.name1)
            self.assertEqual(p.child1.child4.name2, p.child1.name2)
            self.assertEqual(p.child1.child4.value, p.child1.value)
            self.assertEqual(p.child1.child4.value4, 4)

    @unittest.expectedFailure
    def test_inheritance_deferred(self):
        c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
        with self.assertNumQueries(1):
            p = Parent2.objects.select_related('child1').only(
                'id2',  'child1__value').get(name2="n2")
            self.assertEqual(p.id2, c.id2)
            self.assertEqual(p.child1.value, 1)
        p = Parent2.objects.select_related('child1').only(
            'id2',  'child1__value').get(name2="n2")
        with self.assertNumQueries(1):
            self.assertEquals(p.name2, 'n2')
        p = Parent2.objects.select_related('child1').only(
            'id2',  'child1__value').get(name2="n2")
        with self.assertNumQueries(1):
            self.assertEquals(p.child1.name2, 'n2')

    @unittest.expectedFailure
    def test_inheritance_deferred2(self):
        c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
        qs = Parent2.objects.select_related('child1', 'child4').only(
            'id2',  'child1__value', 'child1__child4__value4')
        with self.assertNumQueries(1):
            p = qs.get(name2="n2")
            self.assertEqual(p.id2, c.id2)
            self.assertEqual(p.child1.value, 1)
            self.assertEqual(p.child1.child4.value4, 4)
            self.assertEqual(p.child1.child4.id2, c.id2)
        p = qs.get(name2="n2")
        with self.assertNumQueries(1):
            self.assertEquals(p.child1.name2, 'n2')
        p = qs.get(name2="n2")
        with self.assertNumQueries(1):
            self.assertEquals(p.child1.name1, 'n1')
        with self.assertNumQueries(1):
            self.assertEquals(p.child1.child4.name1, 'n1')