Commit 056ace0f authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

[1.5.x] Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.

Backport of 07fbc6ae.
parent da2cdd3a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -497,7 +497,7 @@ class ForeignRelatedObjectsDescriptor(object):
                except (AttributeError, KeyError):
                    db = self._db or router.db_for_read(self.model, instance=self.instance)
                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
                    qs._known_related_object = (rel_field.name, self.instance)
                    qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
                    return qs

            def get_prefetch_query_set(self, instances):
+22 −7
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ class QuerySet(object):
        self._for_write = False
        self._prefetch_related_lookups = []
        self._prefetch_done = False
        self._known_related_object = None       # (attname, rel_obj)
        self._known_related_objects = {}        # {rel_field, {pk: rel_obj}}

    ########################
    # PYTHON MAGIC METHODS #
@@ -221,6 +221,7 @@ class QuerySet(object):
        if isinstance(other, EmptyQuerySet):
            return other._clone()
        combined = self._clone()
        combined._merge_known_related_objects(other)
        combined.query.combine(other.query, sql.AND)
        return combined

@@ -229,6 +230,7 @@ class QuerySet(object):
        combined = self._clone()
        if isinstance(other, EmptyQuerySet):
            return combined
        combined._merge_known_related_objects(other)
        combined.query.combine(other.query, sql.OR)
        return combined

@@ -289,10 +291,9 @@ class QuerySet(object):
                    init_list.append(field.attname)
            model_cls = deferred_class_factory(self.model, skip)

        # Cache db, model and known_related_object outside the loop
        # Cache db and model outside the loop
        db = self.db
        model = self.model
        kro_attname, kro_instance = self._known_related_object or (None, None)
        compiler = self.query.get_compiler(using=db)
        if fill_cache:
            klass_info = get_klass_info(model, max_depth=max_depth,
@@ -323,9 +324,16 @@ class QuerySet(object):
                for i, aggregate in enumerate(aggregate_select):
                    setattr(obj, aggregate, row[i + aggregate_start])

            # Add the known related object to the model, if there is one
            if kro_instance:
                setattr(obj, kro_attname, kro_instance)
            # Add the known related objects to the model, if there are any
            if self._known_related_objects:
                for field, rel_objs in self._known_related_objects.items():
                    pk = getattr(obj, field.get_attname())
                    try:
                        rel_obj = rel_objs[pk]
                    except KeyError:
                        pass               # may happen in qs1 | qs2 scenarios
                    else:
                        setattr(obj, field.name, rel_obj)

            yield obj

@@ -902,7 +910,7 @@ class QuerySet(object):
        c = klass(model=self.model, query=query, using=self._db)
        c._for_write = self._for_write
        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
        c._known_related_object = self._known_related_object
        c._known_related_objects = self._known_related_objects
        c.__dict__.update(kwargs)
        if setup and hasattr(c, '_setup_query'):
            c._setup_query()
@@ -942,6 +950,13 @@ class QuerySet(object):
        """
        pass

    def _merge_known_related_objects(self, other):
        """
        Keep track of all known related objects from either QuerySet instance.
        """
        for field, objects in other._known_related_objects.items():
            self._known_related_objects.setdefault(field, {}).update(objects)

    def _setup_aggregate_query(self, aggregates):
        """
        Prepare the query for computing a result that contains aggregate annotations.
+11 −0
Original line number Diff line number Diff line
@@ -13,11 +13,19 @@
            "name": "Tourney 2"
            }
        },
    {
        "pk": 1,
        "model": "known_related_objects.organiser",
        "fields": {
            "name": "Organiser 1"
            }
        },
    {
        "pk": 1,
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 1,
            "organiser": 1,
            "name": "T1 Pool 1"
            }
        },
@@ -26,6 +34,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 1,
            "organiser": 1,
            "name": "T1 Pool 2"
            }
        },
@@ -34,6 +43,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 2,
            "organiser": 1,
            "name": "T2 Pool 1"
            }
        },
@@ -42,6 +52,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 2,
            "organiser": 1,
            "name": "T2 Pool 2"
            }
        },
+4 −0
Original line number Diff line number Diff line
@@ -9,9 +9,13 @@ from django.db import models
class Tournament(models.Model):
    name = models.CharField(max_length=30)

class Organiser(models.Model):
    name = models.CharField(max_length=30)

class Pool(models.Model):
    name = models.CharField(max_length=30)
    tournament = models.ForeignKey(Tournament)
    organiser = models.ForeignKey(Organiser)

class PoolStyle(models.Model):
    name = models.CharField(max_length=30)
+41 −1
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from __future__ import absolute_import

from django.test import TestCase

from .models import Tournament, Pool, PoolStyle
from .models import Tournament, Organiser, Pool, PoolStyle

class ExistingRelatedInstancesTests(TestCase):
    fixtures = ['tournament.json']
@@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase):
            pool2 = tournaments[1].pool_set.all()[0]
            self.assertIs(tournaments[1], pool2.tournament)

    def test_queryset_or(self):
        tournament_1 = Tournament.objects.get(pk=1)
        tournament_2 = Tournament.objects.get(pk=2)
        with self.assertNumQueries(1):
            pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))

    def test_queryset_or_different_cached_items(self):
        tournament = Tournament.objects.get(pk=1)
        organiser = Organiser.objects.get(pk=1)
        with self.assertNumQueries(1):
            pools = tournament.pool_set.all() | organiser.pool_set.all()
            first = pools.filter(pk=1)[0]
            self.assertIs(first.tournament, tournament)
            self.assertIs(first.organiser, organiser)

    def test_queryset_or_only_one_with_precache(self):
        tournament_1 = Tournament.objects.get(pk=1)
        tournament_2 = Tournament.objects.get(pk=2)
        # 2 queries here as pool id 3 has tournament 2, which is not cached
        with self.assertNumQueries(2):
            pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))
        # and the other direction
        with self.assertNumQueries(2):
            pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))

    def test_queryset_and(self):
        tournament = Tournament.objects.get(pk=1)
        organiser = Organiser.objects.get(pk=1)
        with self.assertNumQueries(1):
            pools = tournament.pool_set.all() & organiser.pool_set.all()
            first = pools.filter(pk=1)[0]
            self.assertIs(first.tournament, tournament)
            self.assertIs(first.organiser, organiser)

    def test_one_to_one(self):
        with self.assertNumQueries(2):
            style = PoolStyle.objects.get(pk=1)