Commit 5e2d3846 authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather...

Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.

This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 17958fa7
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -715,9 +715,6 @@ class ValuesQuerySet(QuerySet):

    def iterator(self):
        # Purge any extra columns that haven't been explicitly asked for
        if self.extra_names is not None:
            self.query.trim_extra_select(self.extra_names)

        extra_names = self.query.extra_select.keys()
        field_names = self.field_names
        aggregate_names = self.query.aggregate_select.keys()
@@ -741,13 +738,18 @@ class ValuesQuerySet(QuerySet):
        if self._fields:
            self.extra_names = []
            self.aggregate_names = []
            if not self.query.extra_select and not self.query.aggregate_select:
            if not self.query.extra and not self.query.aggregates:
                # Short cut - if there are no extra or aggregates, then
                # the values() clause must be just field names.
                self.field_names = list(self._fields)
            else:
                self.query.default_cols = False
                self.field_names = []
                for f in self._fields:
                    if self.query.extra_select.has_key(f):
                    # we inspect the full extra_select list since we might
                    # be adding back an extra select item that we hadn't
                    # had selected previously.
                    if self.query.extra.has_key(f):
                        self.extra_names.append(f)
                    elif self.query.aggregate_select.has_key(f):
                        self.aggregate_names.append(f)
@@ -760,6 +762,8 @@ class ValuesQuerySet(QuerySet):
            self.aggregate_names = None

        self.query.select = []
        if self.extra_names is not None:
            self.query.set_extra_mask(self.extra_names)
        self.query.add_fields(self.field_names, False)
        if self.aggregate_names is not None:
            self.query.set_aggregate_mask(self.aggregate_names)
@@ -816,9 +820,6 @@ class ValuesQuerySet(QuerySet):

class ValuesListQuerySet(ValuesQuerySet):
    def iterator(self):
        if self.extra_names is not None:
            self.query.trim_extra_select(self.extra_names)

        if self.flat and len(self._fields) == 1:
            for row in self.query.results_iter():
                yield row[0]
+55 −20
Original line number Diff line number Diff line
@@ -88,7 +88,10 @@ class BaseQuery(object):

        # These are for extensions. The contents are more or less appended
        # verbatim to the appropriate clause.
        self.extra_select = SortedDict()  # Maps col_alias -> (col_sql, params).
        self.extra = SortedDict()  # Maps col_alias -> (col_sql, params).
        self.extra_select_mask = None
        self._extra_select_cache = None

        self.extra_tables = ()
        self.extra_where = ()
        self.extra_params = ()
@@ -214,13 +217,21 @@ class BaseQuery(object):
        if self.aggregate_select_mask is None:
            obj.aggregate_select_mask = None
        else:
            obj.aggregate_select_mask = self.aggregate_select_mask[:]
            obj.aggregate_select_mask = self.aggregate_select_mask.copy()
        if self._aggregate_select_cache is None:
            obj._aggregate_select_cache = None
        else:
            obj._aggregate_select_cache = self._aggregate_select_cache.copy()
        obj.max_depth = self.max_depth
        obj.extra_select = self.extra_select.copy()
        obj.extra = self.extra.copy()
        if self.extra_select_mask is None:
            obj.extra_select_mask = None
        else:
            obj.extra_select_mask = self.extra_select_mask.copy()
        if self._extra_select_cache is None:
            obj._extra_select_cache = None
        else:
            obj._extra_select_cache = self._extra_select_cache.copy()
        obj.extra_tables = self.extra_tables
        obj.extra_where = self.extra_where
        obj.extra_params = self.extra_params
@@ -325,7 +336,7 @@ class BaseQuery(object):
            query = self
            self.select = []
            self.default_cols = False
            self.extra_select = {}
            self.extra = {}
            self.remove_inherited_models()

        query.clear_ordering(True)
@@ -540,13 +551,20 @@ class BaseQuery(object):
            # It would be nice to be able to handle this, but the queries don't
            # really make sense (or return consistent value sets). Not worth
            # the extra complexity when you can write a real query instead.
            if self.extra_select and rhs.extra_select:
            if self.extra and rhs.extra:
                raise ValueError("When merging querysets using 'or', you "
                        "cannot have extra(select=...) on both sides.")
            if self.extra_where and rhs.extra_where:
                raise ValueError("When merging querysets using 'or', you "
                        "cannot have extra(where=...) on both sides.")
        self.extra_select.update(rhs.extra_select)
        self.extra.update(rhs.extra)
        extra_select_mask = set()
        if self.extra_select_mask is not None:
            extra_select_mask.update(self.extra_select_mask)
        if rhs.extra_select_mask is not None:
            extra_select_mask.update(rhs.extra_select_mask)
        if extra_select_mask:
            self.set_extra_mask(extra_select_mask)
        self.extra_tables += rhs.extra_tables
        self.extra_where += rhs.extra_where
        self.extra_params += rhs.extra_params
@@ -2011,7 +2029,7 @@ class BaseQuery(object):
        except MultiJoin:
            raise FieldError("Invalid field name: '%s'" % name)
        except FieldError:
            names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
            names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
            names.sort()
            raise FieldError("Cannot resolve keyword %r into field. "
                    "Choices are: %s" % (name, ", ".join(names)))
@@ -2139,7 +2157,7 @@ class BaseQuery(object):
                    pos = entry.find("%s", pos + 2)
                select_pairs[name] = (entry, entry_params)
            # This is order preserving, since self.extra_select is a SortedDict.
            self.extra_select.update(select_pairs)
            self.extra.update(select_pairs)
        if where:
            self.extra_where += tuple(where)
        if params:
@@ -2213,22 +2231,26 @@ class BaseQuery(object):
        """
        target[model] = set([f.name for f in fields])

    def trim_extra_select(self, names):
        """
        Removes any aliases in the extra_select dictionary that aren't in
        'names'.

        This is needed if we are selecting certain values that don't incldue
        all of the extra_select names.
        """
        for key in set(self.extra_select).difference(set(names)):
            del self.extra_select[key]

    def set_aggregate_mask(self, names):
        "Set the mask of aggregates that will actually be returned by the SELECT"
        self.aggregate_select_mask = names
        if names is None:
            self.aggregate_select_mask = None
        else:
            self.aggregate_select_mask = set(names)
        self._aggregate_select_cache = None

    def set_extra_mask(self, names):
        """
        Set the mask of extra select items that will be returned by SELECT,
        we don't actually remove them from the Query since they might be used
        later
        """
        if names is None:
            self.extra_select_mask = None
        else:
            self.extra_select_mask = set(names)
        self._extra_select_cache = None

    def _aggregate_select(self):
        """The SortedDict of aggregate columns that are not masked, and should
        be used in the SELECT clause.
@@ -2247,6 +2269,19 @@ class BaseQuery(object):
            return self.aggregates
    aggregate_select = property(_aggregate_select)

    def _extra_select(self):
        if self._extra_select_cache is not None:
            return self._extra_select_cache
        elif self.extra_select_mask is not None:
            self._extra_select_cache = SortedDict([
                (k,v) for k,v in self.extra.items()
                if k in self.extra_select_mask
            ])
            return self._extra_select_cache
        else:
            return self.extra
    extra_select = property(_extra_select)

    def set_start(self, start):
        """
        Sets the table from which to start joining. The start position is
+2 −2
Original line number Diff line number Diff line
@@ -178,7 +178,7 @@ class UpdateQuery(Query):
        # from other tables.
        query = self.clone(klass=Query)
        query.bump_prefix()
        query.extra_select = {}
        query.extra = {}
        query.select = []
        query.add_fields([query.model._meta.pk.name])
        must_pre_select = count > 1 and not self.connection.features.update_can_self_select
@@ -409,7 +409,7 @@ class DateQuery(Query):
        self.select = [select]
        self.select_fields = [None]
        self.select_related = False # See #7097.
        self.extra_select = {}
        self.extra = {}
        self.distinct = True
        self.order_by = order == 'ASC' and [1] or [-1]

+17 −1
Original line number Diff line number Diff line
@@ -35,6 +35,9 @@ class TestObject(models.Model):
    second = models.CharField(max_length=20)
    third = models.CharField(max_length=20)

    def __unicode__(self):
        return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)

__test__ = {"API_TESTS": """
# Regression tests for #7314 and #7372

@@ -189,6 +192,19 @@ True
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
[(u'third', u'first', u'second', 1)]

"""}
# Regression for #10847: the list of extra columns can always be accurately evaluated.
# Using an inner query ensures that as_sql() is producing correct output
# without requiring full evaluation and execution of the inner query.
>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
[{'pk': 1}]

>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))
[<TestObject: TestObject: first,second,third>]

>>> TestObject.objects.values('pk').extra(select={'extra': 1})
[{'pk': 1}]

>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>]

"""}