Commit 5506653b authored by Alex Gaynor's avatar Alex Gaynor
Browse files

Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given...

Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent ceef628c
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local):
        self.settings_dict = settings_dict
        self.alias = alias
        self.vendor = 'unknown'
        self.use_debug_cursor = None

    def __eq__(self, other):
        return self.settings_dict == other.settings_dict
@@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local):
    def cursor(self):
        from django.conf import settings
        cursor = self._cursor()
        if settings.DEBUG:
        if (self.use_debug_cursor or
            (self.use_debug_cursor is None and settings.DEBUG)):
            return self.make_debug_cursor(cursor)
        return cursor

+44 −0
Original line number Diff line number Diff line
import re
import sys
from urlparse import urlsplit, urlunsplit
from xml.dom.minidom import parseString, Node

@@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner):
        for conn in connections:
            transaction.rollback_unless_managed(using=conn)

class _AssertNumQueriesContext(object):
    def __init__(self, test_case, num, connection):
        self.test_case = test_case
        self.num = num
        self.connection = connection

    def __enter__(self):
        self.old_debug_cursor = self.connection.use_debug_cursor
        self.connection.use_debug_cursor = True
        self.starting_queries = len(self.connection.queries)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is not None:
            return

        self.connection.use_debug_cursor = self.old_debug_cursor
        final_queries = len(self.connection.queries)
        executed = final_queries - self.starting_queries

        self.test_case.assertEqual(
            executed, self.num, "%d queries executed, %d expected" % (
                executed, self.num
            )
        )


class TransactionTestCase(unittest.TestCase):
    # The class we'll use for the test client self.client.
    # Can be overridden in derived classes.
@@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase):
    def assertQuerysetEqual(self, qs, values, transform=repr):
        return self.assertEqual(map(transform, qs), values)

    def assertNumQueries(self, num, func=None, *args, **kwargs):
        using = kwargs.pop("using", DEFAULT_DB_ALIAS)
        connection = connections[using]

        context = _AssertNumQueriesContext(self, num, connection)
        if func is None:
            return context

        # Basically emulate the `with` statement here.

        context.__enter__()
        try:
            func(*args, **kwargs)
        finally:
            context.__exit__(*sys.exc_info())

def connections_support_transactions():
    """
    Returns True if all connections support transactions.  This is messy
+26 −0
Original line number Diff line number Diff line
@@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
    implicit ordering, you will need to apply a ``order_by()`` clause to your
    queryset to ensure that the test will pass reliably.

.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):

    .. versionadded:: 1.3

    Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
    ``num`` database queries are executed.

    If a ``"using"`` key is present in ``kwargs`` it is used as the database
    alias for which to check the number of queries.  If you wish to call a
    function with a ``using`` parameter you can do it by wrapping the call with
    a ``lambda`` to add an extra parameter::

        self.assertNumQueries(7, lambda: my_function(using=7))

    If you're using Python 2.5 or greater you can also use this as a context
    manager::

        # This is necessary in Python 2.5 to enable the with statement, in 2.6
        # and up it is no longer necessary.
        from __future__ import with_statement

        with self.assertNumQueries(2):
            Person.objects.create(name="Aaron")
            Person.objects.create(name="Daniel")


.. _topics-testing-email:

E-mail services
+60 −58
Original line number Diff line number Diff line
from django.test import TestCase
from django.conf import settings
from django import db

from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species

@@ -36,36 +34,34 @@ class SelectRelatedTests(TestCase):
        # queries so we'll set it to True here and reset it at the end of the
        # test case.
        self.create_base_data()
        settings.DEBUG = True
        db.reset_queries()

    def tearDown(self):
        settings.DEBUG = False

    def test_access_fks_without_select_related(self):
        """
        Normally, accessing FKs doesn't fill in related objects
        """
        def test():
            fly = Species.objects.get(name="melanogaster")
            domain = fly.genus.family.order.klass.phylum.kingdom.domain
            self.assertEqual(domain.name, 'Eukaryota')
        self.assertEqual(len(db.connection.queries), 8)
        self.assertNumQueries(8, test)

    def test_access_fks_with_select_related(self):
        """
        A select_related() call will fill in those related objects without any
        extra queries
        """
        def test():
            person = Species.objects.select_related(depth=10).get(name="sapiens")
            domain = person.genus.family.order.klass.phylum.kingdom.domain
            self.assertEqual(domain.name, 'Eukaryota')
        self.assertEqual(len(db.connection.queries), 1)
        self.assertNumQueries(1, test)

    def test_list_without_select_related(self):
        """
        select_related() also of course applies to entire lists, not just
        items. This test verifies the expected behavior without select_related.
        """
        def test():
            world = Species.objects.all()
            families = [o.genus.family.name for o in world]
            self.assertEqual(families, [
@@ -74,13 +70,14 @@ class SelectRelatedTests(TestCase):
                'Fabaceae',
                'Amanitacae',
            ])
        self.assertEqual(len(db.connection.queries), 9)
        self.assertNumQueries(9, test)

    def test_list_with_select_related(self):
        """
        select_related() also of course applies to entire lists, not just
        items. This test verifies the expected behavior with select_related.
        """
        def test():
            world = Species.objects.all().select_related()
            families = [o.genus.family.name for o in world]
            self.assertEqual(families, [
@@ -89,20 +86,21 @@ class SelectRelatedTests(TestCase):
                'Fabaceae',
                'Amanitacae',
            ])
        self.assertEqual(len(db.connection.queries), 1)
        self.assertNumQueries(1, test)

    def test_depth(self, depth=1, expected=7):
        """
        The "depth" argument to select_related() will stop the descent at a
        particular level.
        """
        def test():
            pea = Species.objects.select_related(depth=depth).get(name="sativum")
            self.assertEqual(
                pea.genus.family.order.klass.phylum.kingdom.domain.name,
                'Eukaryota'
            )
        # Notice: one fewer queries than above because of depth=1
        self.assertEqual(len(db.connection.queries), expected)
        self.assertNumQueries(expected, test)

    def test_larger_depth(self):
        """
@@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase):
        The "depth" argument to select_related() will stop the descent at a
        particular level. This can be used on lists as well.
        """
        def test():
            world = Species.objects.all().select_related(depth=2)
            orders = [o.genus.family.order.name for o in world]
            self.assertEqual(orders,
                ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
        self.assertEqual(len(db.connection.queries), 5)
        self.assertNumQueries(5, test)

    def test_select_related_with_extra(self):
        s = Species.objects.all().select_related(depth=1)\
@@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase):
        In this case, we explicitly say to select the 'genus' and
        'genus.family' models, leading to the same number of queries as before.
        """
        def test():
            world = Species.objects.select_related('genus__family')
            families = [o.genus.family.name for o in world]
            self.assertEqual(families,
                ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
        self.assertEqual(len(db.connection.queries), 1)
        self.assertNumQueries(1, test)

    def test_more_certain_fields(self):
        """
        In this case, we explicitly say to select the 'genus' and
        'genus.family' models, leading to the same number of queries as before.
        """
        def test():
            world = Species.objects.filter(genus__name='Amanita')\
                .select_related('genus__family')
            orders = [o.genus.family.order.name for o in world]
            self.assertEqual(orders, [u'Agaricales'])
        self.assertEqual(len(db.connection.queries), 2)
        self.assertNumQueries(2, test)

    def test_field_traversal(self):
        def test():
            s = Species.objects.all().select_related('genus__family__order'
                ).order_by('id')[0:1].get().genus.family.order.name
            self.assertEqual(s, u'Diptera')
        self.assertEqual(len(db.connection.queries), 1)
        self.assertNumQueries(1, test)

    def test_depth_fields_fails(self):
        self.assertRaises(TypeError,
+18 −27
Original line number Diff line number Diff line
@@ -2,9 +2,11 @@ import datetime

from django.conf import settings
from django.db import connection
from django.test import TestCase
from django.utils import unittest

from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
    UniqueForDateModel, ModelToValidate)


class GetUniqueCheckTests(unittest.TestCase):
@@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase):
            ), m._get_unique_checks(exclude='start_date')
        )

class PerformUniqueChecksTest(unittest.TestCase):
    def setUp(self):
        # Set debug to True to gain access to connection.queries.
        self._old_debug, settings.DEBUG = settings.DEBUG, True
        super(PerformUniqueChecksTest, self).setUp()

    def tearDown(self):
        # Restore old debug value.
        settings.DEBUG = self._old_debug
        super(PerformUniqueChecksTest, self).tearDown()

class PerformUniqueChecksTest(TestCase):
    def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
        # Regression test for #12560
        query_count = len(connection.queries)
        def test():
            mtv = ModelToValidate(number=10, name='Some Name')
            setattr(mtv, '_adding', True)
            mtv.full_clean()
        self.assertEqual(query_count, len(connection.queries))
        self.assertNumQueries(0, test)

    def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
        # Regression test for #12560
        query_count = len(connection.queries)
        def test():
            mtv = ModelToValidate(number=10, name='Some Name', id=123)
            setattr(mtv, '_adding', True)
            mtv.full_clean()
        self.assertEqual(query_count + 1, len(connection.queries))
        self.assertNumQueries(1, test)

    def test_primary_key_unique_check_not_performed_when_not_adding(self):
        # Regression test for #12132
        query_count= len(connection.queries)
        def test():
            mtv = ModelToValidate(number=10, name='Some Name')
            mtv.full_clean()
        self.assertEqual(query_count, len(connection.queries))
        self.assertNumQueries(0, test)
Loading