Commit 3daa9d60 authored by Niclas Olofsson's avatar Niclas Olofsson Committed by Tim Graham
Browse files

Fixed #10414 -- Made select_related() fail on invalid field names.

parent b27db97b
Loading
Loading
Loading
Loading
+46 −9
Original line number Diff line number Diff line
from itertools import chain
import warnings

from django.core.exceptions import FieldError
@@ -599,6 +600,14 @@ class SQLCompiler(object):
        (for example, cur_depth=1 means we are looking at models with direct
        connections to the root model).
        """
        def _get_field_choices():
            direct_choices = (f.name for (f, _) in opts.get_fields_with_model() if f.rel)
            reverse_choices = (
                f.field.related_query_name()
                for f in opts.get_all_related_objects() if f.field.unique
            )
            return chain(direct_choices, reverse_choices)

        if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
            # We've recursed far enough; bail out.
            return
@@ -611,6 +620,7 @@ class SQLCompiler(object):

        # Setup for the case when only particular related fields should be
        # included in the related selection.
        fields_found = set()
        if requested is None:
            if isinstance(self.query.select_related, dict):
                requested = self.query.select_related
@@ -619,6 +629,24 @@ class SQLCompiler(object):
                restricted = False

        for f, model in opts.get_fields_with_model():
            fields_found.add(f.name)

            if restricted:
                next = requested.get(f.name, {})
                if not f.rel:
                    # If a non-related field is used like a relation,
                    # or if a single non-relational field is given.
                    if next or (cur_depth == 1 and f.name in requested):
                        raise FieldError(
                            "Non-relational field given in select_related: '%s'. "
                            "Choices are: %s" % (
                                f.name,
                                ", ".join(_get_field_choices()) or '(none)',
                            )
                        )
            else:
                next = False

            # The get_fields_with_model() returns None for fields that live
            # in the field's local model. So, for those fields we want to use
            # the f.model - that is the field's local model.
@@ -632,13 +660,9 @@ class SQLCompiler(object):
            columns, _ = self.get_default_columns(start_alias=alias,
                    opts=f.rel.to._meta, as_pairs=True)
            self.query.related_select_cols.extend(
                SelectInfo((col[0], col[1].column), col[1]) for col in columns)
            if restricted:
                next = requested.get(f.name, {})
            else:
                next = False
            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
                    next, restricted)
                SelectInfo((col[0], col[1].column), col[1]) for col in columns
            )
            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted)

        if restricted:
            related_fields = [
@@ -651,8 +675,10 @@ class SQLCompiler(object):
                                              only_load.get(model), reverse=True):
                    continue

                _, _, _, joins, _ = self.query.setup_joins(
                    [f.related_query_name()], opts, root_alias)
                related_field_name = f.related_query_name()
                fields_found.add(related_field_name)

                _, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias)
                alias = joins[-1]
                from_parent = (opts.model if issubclass(model, opts.model)
                               else None)
@@ -664,6 +690,17 @@ class SQLCompiler(object):
                self.fill_related_selections(model._meta, alias, cur_depth + 1,
                                             next, restricted)

            fields_not_found = set(requested.keys()).difference(fields_found)
            if fields_not_found:
                invalid_fields = ("'%s'" % s for s in fields_not_found)
                raise FieldError(
                    'Invalid field name(s) given in select_related: %s. '
                    'Choices are: %s' % (
                        ', '.join(invalid_fields),
                        ', '.join(_get_field_choices()) or '(none)',
                    )
                )

    def deferred_to_columns(self):
        """
        Converts the self.deferred_loading data structure to mapping of table
+18 −0
Original line number Diff line number Diff line
@@ -681,6 +681,24 @@ lookups::
    ...
    ValueError: Cannot query "<Book: Django>": Must be "Author" instance.

``select_related()`` now checks given fields
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

``select_related()`` now validates that the given fields actually exist.
Previously, nonexistent fields were silently ignored. Now, an error is raised::

    >>> book = Book.objects.select_related('nonexistent_field')
    Traceback (most recent call last):
    ...
    FieldError: Invalid field name(s) given in select_related: 'nonexistent_field'

The validation also makes sure that the given field is relational::

    >>> book = Book.objects.select_related('name')
    Traceback (most recent call last):
    ...
    FieldError: Non-relational field given in select_related: 'name'

Default ``EmailField.max_length`` increased to 254
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

+1 −1
Original line number Diff line number Diff line
@@ -181,7 +181,7 @@ class NonAggregateAnnotationTestCase(TestCase):
            other_chain=F('chain'),
            is_open=Value(True, BooleanField()),
            book_isbn=F('books__isbn')
        ).select_related('store').order_by('book_isbn').filter(chain='Westfield')
        ).order_by('book_isbn').filter(chain='Westfield')

        self.assertQuerysetEqual(
            qs, [
+41 −0
Original line number Diff line number Diff line
@@ -10,6 +10,9 @@ the select-related behavior will traverse.
from django.db import models
from django.utils.encoding import python_2_unicode_compatible

from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType

# Who remembers high school biology?


@@ -94,3 +97,41 @@ class HybridSpecies(models.Model):

    def __str__(self):
        return self.name


@python_2_unicode_compatible
class Topping(models.Model):
    name = models.CharField(max_length=30)

    def __str__(self):
        return self.name


@python_2_unicode_compatible
class Pizza(models.Model):
    name = models.CharField(max_length=100)
    toppings = models.ManyToManyField(Topping)

    def __str__(self):
        return self.name


@python_2_unicode_compatible
class TaggedItem(models.Model):
    tag = models.CharField(max_length=30)

    content_type = models.ForeignKey(ContentType, related_name='select_related_tagged_items')
    object_id = models.PositiveIntegerField()
    content_object = GenericForeignKey('content_type', 'object_id')

    def __str__(self):
        return self.tag


@python_2_unicode_compatible
class Bookmark(models.Model):
    url = models.URLField()
    tags = GenericRelation(TaggedItem)

    def __str__(self):
        return self.url
+55 −1
Original line number Diff line number Diff line
from __future__ import unicode_literals

from django.test import TestCase
from django.core.exceptions import FieldError

from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies
from .models import (
    Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies,
    Pizza, TaggedItem, Bookmark,
)


class SelectRelatedTests(TestCase):
@@ -126,6 +130,12 @@ class SelectRelatedTests(TestCase):
            orders = [o.genus.family.order.name for o in world]
            self.assertEqual(orders, ['Agaricales'])

    def test_single_related_field(self):
        with self.assertNumQueries(1):
            species = Species.objects.select_related('genus__name')
            names = [s.genus.name for s in species]
            self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum'])

    def test_field_traversal(self):
        with self.assertNumQueries(1):
            s = (Species.objects.all()
@@ -152,3 +162,47 @@ class SelectRelatedTests(TestCase):
            obj = queryset[0]
            self.assertEqual(obj.parent_1, parent_1)
            self.assertEqual(obj.parent_2, parent_2)


class SelectRelatedValidationTests(TestCase):
    """
    select_related() should thrown an error on fields that do not exist and
    non-relational fields.
    """
    non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s"
    invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s"

    def test_non_relational_field(self):
        with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
            list(Species.objects.select_related('name__some_field'))

        with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
            list(Species.objects.select_related('name'))

        with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', '(none)')):
            list(Domain.objects.select_related('name'))

    def test_many_to_many_field(self):
        with self.assertRaisesMessage(FieldError, self.invalid_error % ('toppings', '(none)')):
            list(Pizza.objects.select_related('toppings'))

    def test_reverse_relational_field(self):
        with self.assertRaisesMessage(FieldError, self.invalid_error % ('child_1', 'genus')):
            list(Species.objects.select_related('child_1'))

    def test_invalid_field(self):
        with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', 'genus')):
            list(Species.objects.select_related('invalid_field'))

        with self.assertRaisesMessage(FieldError, self.invalid_error % ('related_invalid_field', 'family')):
            list(Species.objects.select_related('genus__related_invalid_field'))

        with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', '(none)')):
            list(Domain.objects.select_related('invalid_field'))

    def test_generic_relations(self):
        with self.assertRaisesMessage(FieldError, self.invalid_error % ('tags', '')):
            list(Bookmark.objects.select_related('tags'))

        with self.assertRaisesMessage(FieldError, self.invalid_error % ('content_object', 'content_type')):
            list(TaggedItem.objects.select_related('content_object'))
Loading