Commit 07fbc6ae authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.
parent 695b2089
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -496,7 +496,7 @@ class ForeignRelatedObjectsDescriptor(object):
                except (AttributeError, KeyError):
                    db = self._db or router.db_for_read(self.model, instance=self.instance)
                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
                    qs._known_related_object = (rel_field.name, self.instance)
                    qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
                    return qs

            def get_prefetch_query_set(self, instances):
+22 −7
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ class QuerySet(object):
        self._for_write = False
        self._prefetch_related_lookups = []
        self._prefetch_done = False
        self._known_related_object = None       # (attname, rel_obj)
        self._known_related_objects = {}        # {rel_field, {pk: rel_obj}}

    ########################
    # PYTHON MAGIC METHODS #
@@ -221,6 +221,7 @@ class QuerySet(object):
        if isinstance(other, EmptyQuerySet):
            return other._clone()
        combined = self._clone()
        combined._merge_known_related_objects(other)
        combined.query.combine(other.query, sql.AND)
        return combined

@@ -229,6 +230,7 @@ class QuerySet(object):
        combined = self._clone()
        if isinstance(other, EmptyQuerySet):
            return combined
        combined._merge_known_related_objects(other)
        combined.query.combine(other.query, sql.OR)
        return combined

@@ -289,10 +291,9 @@ class QuerySet(object):
                    init_list.append(field.attname)
            model_cls = deferred_class_factory(self.model, skip)

        # Cache db, model and known_related_object outside the loop
        # Cache db and model outside the loop
        db = self.db
        model = self.model
        kro_attname, kro_instance = self._known_related_object or (None, None)
        compiler = self.query.get_compiler(using=db)
        if fill_cache:
            klass_info = get_klass_info(model, max_depth=max_depth,
@@ -323,9 +324,16 @@ class QuerySet(object):
                for i, aggregate in enumerate(aggregate_select):
                    setattr(obj, aggregate, row[i + aggregate_start])

            # Add the known related object to the model, if there is one
            if kro_instance:
                setattr(obj, kro_attname, kro_instance)
            # Add the known related objects to the model, if there are any
            if self._known_related_objects:
                for field, rel_objs in self._known_related_objects.items():
                    pk = getattr(obj, field.get_attname())
                    try:
                        rel_obj = rel_objs[pk]
                    except KeyError:
                        pass               # may happen in qs1 | qs2 scenarios
                    else:
                        setattr(obj, field.name, rel_obj)

            yield obj

@@ -902,7 +910,7 @@ class QuerySet(object):
        c = klass(model=self.model, query=query, using=self._db)
        c._for_write = self._for_write
        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
        c._known_related_object = self._known_related_object
        c._known_related_objects = self._known_related_objects
        c.__dict__.update(kwargs)
        if setup and hasattr(c, '_setup_query'):
            c._setup_query()
@@ -942,6 +950,13 @@ class QuerySet(object):
        """
        pass

    def _merge_known_related_objects(self, other):
        """
        Keep track of all known related objects from either QuerySet instance.
        """
        for field, objects in other._known_related_objects.items():
            self._known_related_objects.setdefault(field, {}).update(objects)

    def _setup_aggregate_query(self, aggregates):
        """
        Prepare the query for computing a result that contains aggregate annotations.
+11 −0
Original line number Diff line number Diff line
@@ -13,11 +13,19 @@
            "name": "Tourney 2"
            }
        },
    {
        "pk": 1,
        "model": "known_related_objects.organiser",
        "fields": {
            "name": "Organiser 1"
            }
        },
    {
        "pk": 1,
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 1,
            "organiser": 1,
            "name": "T1 Pool 1"
            }
        },
@@ -26,6 +34,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 1,
            "organiser": 1,
            "name": "T1 Pool 2"
            }
        },
@@ -34,6 +43,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 2,
            "organiser": 1,
            "name": "T2 Pool 1"
            }
        },
@@ -42,6 +52,7 @@
        "model": "known_related_objects.pool",
        "fields": {
            "tournament": 2,
            "organiser": 1,
            "name": "T2 Pool 2"
            }
        },
+4 −0
Original line number Diff line number Diff line
@@ -9,9 +9,13 @@ from django.db import models
class Tournament(models.Model):
    name = models.CharField(max_length=30)

class Organiser(models.Model):
    name = models.CharField(max_length=30)

class Pool(models.Model):
    name = models.CharField(max_length=30)
    tournament = models.ForeignKey(Tournament)
    organiser = models.ForeignKey(Organiser)

class PoolStyle(models.Model):
    name = models.CharField(max_length=30)
+41 −1
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from __future__ import absolute_import

from django.test import TestCase

from .models import Tournament, Pool, PoolStyle
from .models import Tournament, Organiser, Pool, PoolStyle

class ExistingRelatedInstancesTests(TestCase):
    fixtures = ['tournament.json']
@@ -27,6 +27,46 @@ class ExistingRelatedInstancesTests(TestCase):
            pool2 = tournaments[1].pool_set.all()[0]
            self.assertIs(tournaments[1], pool2.tournament)

    def test_queryset_or(self):
        tournament_1 = Tournament.objects.get(pk=1)
        tournament_2 = Tournament.objects.get(pk=2)
        with self.assertNumQueries(1):
            pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))

    def test_queryset_or_different_cached_items(self):
        tournament = Tournament.objects.get(pk=1)
        organiser = Organiser.objects.get(pk=1)
        with self.assertNumQueries(1):
            pools = tournament.pool_set.all() | organiser.pool_set.all()
            first = pools.filter(pk=1)[0]
            self.assertIs(first.tournament, tournament)
            self.assertIs(first.organiser, organiser)

    def test_queryset_or_only_one_with_precache(self):
        tournament_1 = Tournament.objects.get(pk=1)
        tournament_2 = Tournament.objects.get(pk=2)
        # 2 queries here as pool id 3 has tournament 2, which is not cached
        with self.assertNumQueries(2):
            pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))
        # and the other direction
        with self.assertNumQueries(2):
            pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
            related_objects = set(pool.tournament for pool in pools)
            self.assertEqual(related_objects, set((tournament_1, tournament_2)))

    def test_queryset_and(self):
        tournament = Tournament.objects.get(pk=1)
        organiser = Organiser.objects.get(pk=1)
        with self.assertNumQueries(1):
            pools = tournament.pool_set.all() & organiser.pool_set.all()
            first = pools.filter(pk=1)[0]
            self.assertIs(first.tournament, tournament)
            self.assertIs(first.organiser, organiser)

    def test_one_to_one(self):
        with self.assertNumQueries(2):
            style = PoolStyle.objects.get(pk=1)