Commit bfa080f4 authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #12937 -- Corrected the operation of select_related() when following an...

Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 4528f398
Loading
Loading
Loading
Loading
+27 −6
Original line number Diff line number Diff line
@@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet):


def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
                   requested=None, offset=0, only_load=None):
                   requested=None, offset=0, only_load=None, local_only=False):
    """
    Helper function that recursively returns an object with the specified
    related attributes already populated.
@@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
     * only_load - if the query has had only() or defer() applied,
       this is the list of field names that will be returned. If None,
       the full field list for `klass` can be assumed.
     * local_only - Only populate local fields. This is used when building
       following reverse select-related relations
    """
    if max_depth and requested is None and cur_depth > max_depth:
        # We've recursed deeply enough; stop now.
@@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
        skip = set()
        init_list = []
        # Build the list of fields that *haven't* been requested
        for field in klass._meta.fields:
        for field, model in klass._meta.get_fields_with_model():
            if field.name not in load_fields:
                skip.add(field.name)
            elif local_only and model is not None:
                continue
            else:
                init_list.append(field.attname)
        # Retrieve all the requested fields
@@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,

    else:
        # Load all fields on klass
        field_count = len(klass._meta.fields)
        if local_only:
            field_names = [f.attname for f in klass._meta.local_fields]
        else:
            field_names = [f.attname for f in klass._meta.fields]
        field_count = len(field_names)
        fields = row[index_start : index_start + field_count]
        # If all the select_related columns are None, then the related
        # object must be non-existent - set the relation to None.
@@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
        if fields == (None,) * field_count:
            obj = None
        else:
            obj = klass(*fields)
            obj = klass(**dict(zip(field_names, fields)))

    # If an object was retrieved, set the database state.
    if obj:
@@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
            next = requested[f.related_query_name()]
            # Recursively retrieve the data for the related object
            cached_row = get_cached_row(model, row, index_end, using,
                max_depth, cur_depth+1, next)
                max_depth, cur_depth+1, next, local_only=True)
            # If the recursive descent found an object, populate the
            # descriptor caches relevant to the object
            if cached_row:
@@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
                    # If the related object exists, populate
                    # the descriptor cache.
                    setattr(rel_obj, f.get_cache_name(), obj)

                    # Now populate all the non-local field values
                    # on the related object
                    for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
                        if rel_model is not None:
                            setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
                            # populate the field cache for any related object
                            # that has already been retrieved
                            if rel_field.rel:
                                try:
                                    cached_obj = getattr(obj, rel_field.get_cache_name())
                                    setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
                                except AttributeError:
                                    # Related object hasn't been cached yet
                                    pass
    return obj, index_end

def delete_objects(seen_objs, using):
+4 −2
Original line number Diff line number Diff line
@@ -215,7 +215,7 @@ class SQLCompiler(object):
        return result

    def get_default_columns(self, with_aliases=False, col_aliases=None,
            start_alias=None, opts=None, as_pairs=False):
            start_alias=None, opts=None, as_pairs=False, local_only=False):
        """
        Computes the default columns for selecting every field in the base
        model. Will sometimes be called to pull in related models (e.g. via
@@ -240,6 +240,8 @@ class SQLCompiler(object):
        if start_alias:
            seen = {None: start_alias}
        for field, model in opts.get_fields_with_model():
            if local_only and model is not None:
                continue
            if start_alias:
                try:
                    alias = seen[model]
@@ -643,7 +645,7 @@ class SQLCompiler(object):
                )
                used.add(alias)
                columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=model._meta, as_pairs=True)
                    opts=model._meta, as_pairs=True, local_only=True)
                self.query.related_select_cols.extend(columns)
                self.query.related_select_fields.extend(model._meta.fields)

+1 −2
Original line number Diff line number Diff line
@@ -43,8 +43,7 @@ class StatDetails(models.Model):


class AdvancedUserStat(UserStat):
    pass

    karma = models.IntegerField()

class Image(models.Model):
    name = models.CharField(max_length=100)
+9 −6
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase):

        user2 = User.objects.create(username="bob")
        results2 = UserStatResult.objects.create(results='moar results')
        advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
                                                  results=results2)
        StatDetails.objects.create(base_stats=advstat, comments=250)

@@ -74,13 +74,16 @@ class ReverseSelectRelatedTestCase(TestCase):
        self.assertQueries(2)

    def test_follow_from_child_class(self):
        stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
        stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
        self.assertEqual(stat.statdetails.comments, 250)
        self.assertEqual(stat.user.username, 'bob')
        self.assertQueries(1)

    def test_follow_inheritance(self):
        stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
        stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
        self.assertEqual(stat.advanceduserstat.posts, 200)
        self.assertEqual(stat.user.username, 'bob')
        self.assertEqual(stat.advanceduserstat.user.username, 'bob')
        self.assertQueries(1)

    def test_nullable_relation(self):