Commit 2f35c6f1 authored by Florian Hahn's avatar Florian Hahn Committed by Anssi Kääriäinen
Browse files

Fixed #14930 -- values_list() failure on qs ordered by extra column

Thanks lsaffre for the report and simon29, vicould, and Florian Hahn
for the patch.

Some changes done by committer.
parent 9da9b3eb
Loading
Loading
Loading
Loading
+30 −15
Original line number Diff line number Diff line
@@ -22,6 +22,12 @@ class SQLCompiler(object):
        self.connection = connection
        self.using = using
        self.quote_cache = {}
        # When ordering a queryset with distinct on a column not part of the
        # select set, the ordering column needs to be added to the select
        # clause. This information is needed both in SQL construction and
        # masking away the ordering selects from the returned row.
        self.ordering_aliases = []
        self.ordering_params = []

    def pre_sql_setup(self):
        """
@@ -74,7 +80,7 @@ class SQLCompiler(object):
        # another run of it.
        self.refcounts_before = self.query.alias_refcount.copy()
        out_cols, s_params = self.get_columns(with_col_aliases)
        ordering, ordering_group_by = self.get_ordering()
        ordering, o_params, ordering_group_by = self.get_ordering()

        distinct_fields = self.get_distinct()

@@ -95,9 +101,10 @@ class SQLCompiler(object):

        if self.query.distinct:
            result.append(self.connection.ops.distinct_sql(distinct_fields))

        result.append(', '.join(out_cols + self.query.ordering_aliases))
        params.extend(o_params)
        result.append(', '.join(out_cols + self.ordering_aliases))
        params.extend(s_params)
        params.extend(self.ordering_params)

        result.append('FROM')
        result.extend(from_)
@@ -319,7 +326,6 @@ class SQLCompiler(object):
                result.append("%s.%s" % (qn(alias), qn2(col)))
        return result


    def get_ordering(self):
        """
        Returns a tuple containing a list representing the SQL elements in the
@@ -357,7 +363,9 @@ class SQLCompiler(object):
        # the table/column pairs we use and discard any after the first use.
        processed_pairs = set()

        for field in ordering:
        params = []
        ordering_params = []
        for pos, field in enumerate(ordering):
            if field == '?':
                result.append(self.connection.ops.random_function_sql())
                continue
@@ -384,7 +392,7 @@ class SQLCompiler(object):
                    if not distinct or elt in select_aliases:
                        result.append('%s %s' % (elt, order))
                        group_by.append((elt, []))
            elif get_order_dir(field)[0] not in self.query.extra_select:
            elif get_order_dir(field)[0] not in self.query.extra:
                # 'col' is of the form 'field' or 'field1__field2' or
                # '-field1__field2__field', etc.
                for table, cols, order in self.find_ordering_name(field,
@@ -399,12 +407,19 @@ class SQLCompiler(object):
                            group_by.append((elt, []))
            else:
                elt = qn2(col)
                if col not in self.query.extra_select:
                    sql = "(%s) AS %s" % (self.query.extra[col][0], elt)
                    ordering_aliases.append(sql)
                    ordering_params.extend(self.query.extra[col][1])
                else:
                    if distinct and col not in select_aliases:
                        ordering_aliases.append(elt)
                        ordering_params.extend(params)
                result.append('%s %s' % (elt, order))
                group_by.append(self.query.extra_select[col])
        self.query.ordering_aliases = ordering_aliases
        return result, group_by
                group_by.append(self.query.extra[col])
        self.ordering_aliases = ordering_aliases
        self.ordering_params = ordering_params
        return result, params, group_by

    def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
            already_seen=None):
@@ -764,13 +779,13 @@ class SQLCompiler(object):
        if not result_type:
            return cursor
        if result_type == SINGLE:
            if self.query.ordering_aliases:
                return cursor.fetchone()[:-len(self.query.ordering_aliases)]
            if self.ordering_aliases:
                return cursor.fetchone()[:-len(self.ordering_aliases)]
            return cursor.fetchone()

        # The MULTI case.
        if self.query.ordering_aliases:
            result = order_modified_iter(cursor, len(self.query.ordering_aliases),
        if self.ordering_aliases:
            result = order_modified_iter(cursor, len(self.ordering_aliases),
                    self.connection.features.empty_fetchmany_value)
        else:
            result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+0 −2
Original line number Diff line number Diff line
@@ -115,7 +115,6 @@ class Query(object):
        self.default_cols = True
        self.default_ordering = True
        self.standard_ordering = True
        self.ordering_aliases = []
        self.used_aliases = set()
        self.filter_is_sticky = False
        self.included_inherited_models = {}
@@ -227,7 +226,6 @@ class Query(object):
        obj.default_ordering = self.default_ordering
        obj.standard_ordering = self.standard_ordering
        obj.included_inherited_models = self.included_inherited_models.copy()
        obj.ordering_aliases = []
        obj.select = self.select[:]
        obj.related_select_cols = []
        obj.tables = self.tables[:]
+53 −4
Original line number Diff line number Diff line
@@ -1976,13 +1976,62 @@ class EmptyQuerySetTests(TestCase):


class ValuesQuerysetTests(BaseQuerysetTest):
    def test_flat_values_lits(self):
    def setUp(self):
        Number.objects.create(num=72)
        self.identity = lambda x: x

    def test_flat_values_list(self):
        qs = Number.objects.values_list("num")
        qs = qs.values_list("num", flat=True)
        self.assertValueQuerysetEqual(
            qs, [72]
        )
        self.assertValueQuerysetEqual(qs, [72])

    def test_extra_values(self):
        # testing for ticket 14930 issues
        qs = Number.objects.extra(select=SortedDict([('value_plus_x', 'num+%s'),
                                                     ('value_minus_x', 'num-%s')]),
                                  select_params=(1, 2))
        qs = qs.order_by('value_minus_x')
        qs = qs.values('num')
        self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)

    def test_extra_values_order_twice(self):
        # testing for ticket 14930 issues
        qs = Number.objects.extra(select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'})
        qs = qs.order_by('value_minus_one').order_by('value_plus_one')
        qs = qs.values('num')
        self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)

    def test_extra_values_order_multiple(self):
        # Postgres doesn't allow constants in order by, so check for that.
        qs = Number.objects.extra(select={
            'value_plus_one': 'num+1',
            'value_minus_one': 'num-1',
            'constant_value': '1'
        })
        qs = qs.order_by('value_plus_one', 'value_minus_one', 'constant_value')
        qs = qs.values('num')
        self.assertQuerysetEqual(qs, [{'num': 72}], self.identity)

    def test_extra_values_order_in_extra(self):
        # testing for ticket 14930 issues
        qs = Number.objects.extra(
            select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'},
            order_by=['value_minus_one'])
        qs = qs.values('num')

    def test_extra_values_list(self):
        # testing for ticket 14930 issues
        qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
        qs = qs.order_by('value_plus_one')
        qs = qs.values_list('num')
        self.assertQuerysetEqual(qs, [(72,)], self.identity)

    def test_flat_extra_values_list(self):
        # testing for ticket 14930 issues
        qs = Number.objects.extra(select={'value_plus_one': 'num+1'})
        qs = qs.order_by('value_plus_one')
        qs = qs.values_list('num', flat=True)
        self.assertQuerysetEqual(qs, [72], self.identity)


class WeirdQuerysetSlicingTests(BaseQuerysetTest):