Commit 952ba523 authored by Simon Charette's avatar Simon Charette
Browse files

Added a context manager to capture queries while testing.

Also made some import cleanups while I was there.

Refs #10399.
parent 203c17c2
Loading
Loading
Loading
Loading
+11 −24
Original line number Diff line number Diff line
@@ -24,31 +24,30 @@ from django.core.exceptions import ValidationError, ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.core.management import call_command
from django.core.management.color import no_style
from django.core.signals import request_started
from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer,
    WSGIServerException)
from django.core.urlresolvers import clear_url_caches
from django.core.validators import EMPTY_VALUES
from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
    reset_queries)
from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction
from django.forms.fields import CharField
from django.http import QueryDict
from django.test import _doctest as doctest
from django.test.client import Client
from django.test.html import HTMLParseError, parse_html
from django.test.signals import template_rendered
from django.test.utils import (override_settings, compare_xml, strip_quotes)
from django.test.utils import ContextList
from django.utils import unittest as ut2
from django.test.utils import (CaptureQueriesContext, ContextList,
    override_settings, compare_xml, strip_quotes)
from django.utils import six, unittest as ut2
from django.utils.encoding import force_text
from django.utils import six
from django.utils.unittest import skipIf # Imported here for backward compatibility
from django.utils.unittest.util import safe_repr
from django.utils.unittest import skipIf
from django.views.static import serve


__all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
           'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')


normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
                                lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
@@ -168,28 +167,17 @@ class DocTestRunner(doctest.DocTestRunner):
            transaction.rollback_unless_managed(using=conn)


class _AssertNumQueriesContext(object):
class _AssertNumQueriesContext(CaptureQueriesContext):
    def __init__(self, test_case, num, connection):
        self.test_case = test_case
        self.num = num
        self.connection = connection

    def __enter__(self):
        self.old_debug_cursor = self.connection.use_debug_cursor
        self.connection.use_debug_cursor = True
        self.starting_queries = len(self.connection.queries)
        request_started.disconnect(reset_queries)
        return self
        super(_AssertNumQueriesContext, self).__init__(connection)

    def __exit__(self, exc_type, exc_value, traceback):
        self.connection.use_debug_cursor = self.old_debug_cursor
        request_started.connect(reset_queries)
        if exc_type is not None:
            return

        final_queries = len(self.connection.queries)
        executed = final_queries - self.starting_queries

        super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback)
        executed = len(self)
        self.test_case.assertEqual(
            executed, self.num, "%d queries executed, %d expected" % (
                executed, self.num
@@ -1051,7 +1039,6 @@ class LiveServerThread(threading.Thread):
        http requests.
        """
        if self.connections_override:
            from django.db import connections
            # Override this thread's database connections with the ones
            # provided by the main thread.
            for alias, conn in self.connections_override.items():
+39 −0
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ from xml.dom.minidom import parseString, Node

from django.conf import settings, UserSettingsHolder
from django.core import mail
from django.core.signals import request_started
from django.db import reset_queries
from django.template import Template, loader, TemplateDoesNotExist
from django.template.loaders import cached
from django.test.signals import template_rendered, setting_changed
@@ -339,5 +341,42 @@ def strip_quotes(want, got):
        got = got.strip()[2:-1]
    return want, got


def str_prefix(s):
    return s % {'_': '' if six.PY3 else 'u'}


class CaptureQueriesContext(object):
    """
    Context manager that captures queries executed by the specified connection.
    """
    def __init__(self, connection):
        self.connection = connection

    def __iter__(self):
        return iter(self.captured_queries)

    def __getitem__(self, index):
        return self.captured_queries[index]

    def __len__(self):
        return len(self.captured_queries)

    @property
    def captured_queries(self):
        return self.connection.queries[self.initial_queries:self.final_queries]

    def __enter__(self):
        self.use_debug_cursor = self.connection.use_debug_cursor
        self.connection.use_debug_cursor = True
        self.initial_queries = len(self.connection.queries)
        self.final_queries = None
        request_started.disconnect(reset_queries)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.connection.use_debug_cursor = self.use_debug_cursor
        request_started.connect(reset_queries)
        if exc_type is not None:
            return
        self.final_queries = len(self.connection.queries)
+58 −9
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
import warnings

from django.db import connection
from django.forms import EmailField, IntegerField
from django.http import HttpResponse
from django.template.loader import render_to_string
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.html import HTMLParseError, parse_html
from django.test.utils import CaptureQueriesContext
from django.utils import six
from django.utils.unittest import skip

@@ -94,6 +98,60 @@ class AssertQuerysetEqualTests(TestCase):
        )


class CaptureQueriesContextManagerTests(TestCase):
    urls = 'test_utils.urls'

    def setUp(self):
        self.person_pk = six.text_type(Person.objects.create(name='test').pk)

    def test_simple(self):
        with CaptureQueriesContext(connection) as captured_queries:
            Person.objects.get(pk=self.person_pk)
        self.assertEqual(len(captured_queries), 1)
        self.assertIn(self.person_pk, captured_queries[0]['sql'])

        with CaptureQueriesContext(connection) as captured_queries:
            pass
        self.assertEqual(0, len(captured_queries))

    def test_within(self):
        with CaptureQueriesContext(connection) as captured_queries:
            Person.objects.get(pk=self.person_pk)
            self.assertEqual(len(captured_queries), 1)
            self.assertIn(self.person_pk, captured_queries[0]['sql'])

    def test_nested(self):
        with CaptureQueriesContext(connection) as captured_queries:
            Person.objects.count()
            with CaptureQueriesContext(connection) as nested_captured_queries:
                Person.objects.count()
        self.assertEqual(1, len(nested_captured_queries))
        self.assertEqual(2, len(captured_queries))

    def test_failure(self):
        with self.assertRaises(TypeError):
            with CaptureQueriesContext(connection):
                raise TypeError

    def test_with_client(self):
        with CaptureQueriesContext(connection) as captured_queries:
            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
        self.assertEqual(len(captured_queries), 1)
        self.assertIn(self.person_pk, captured_queries[0]['sql'])

        with CaptureQueriesContext(connection) as captured_queries:
            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
        self.assertEqual(len(captured_queries), 1)
        self.assertIn(self.person_pk, captured_queries[0]['sql'])

        with CaptureQueriesContext(connection) as captured_queries:
            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
            self.client.get("/test_utils/get_person/%s/" % self.person_pk)
        self.assertEqual(len(captured_queries), 2)
        self.assertIn(self.person_pk, captured_queries[0]['sql'])
        self.assertIn(self.person_pk, captured_queries[1]['sql'])


class AssertNumQueriesContextManagerTests(TestCase):
    urls = 'test_utils.urls'

@@ -219,7 +277,6 @@ class SaveRestoreWarningState(TestCase):
        # In reality this test could be satisfied by many broken implementations
        # of save_warnings_state/restore_warnings_state (e.g. just
        # warnings.resetwarnings()) , but it is difficult to test more.
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", DeprecationWarning)

@@ -245,7 +302,6 @@ class SaveRestoreWarningState(TestCase):

class HTMLEqualTests(TestCase):
    def test_html_parser(self):
        from django.test.html import parse_html
        element = parse_html('<div><p>Hello</p></div>')
        self.assertEqual(len(element.children), 1)
        self.assertEqual(element.children[0].name, 'p')
@@ -259,7 +315,6 @@ class HTMLEqualTests(TestCase):
        self.assertEqual(dom[0], 'foo')

    def test_parse_html_in_script(self):
        from django.test.html import parse_html
        parse_html('<script>var a = "<p" + ">";</script>');
        parse_html('''
            <script>
@@ -275,8 +330,6 @@ class HTMLEqualTests(TestCase):
        self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")

    def test_self_closing_tags(self):
        from django.test.html import parse_html

        self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer',
            'link', 'frame', 'base', 'col')
        for tag in self_closing_tags:
@@ -400,7 +453,6 @@ class HTMLEqualTests(TestCase):
        </html>""")

    def test_html_contain(self):
        from django.test.html import parse_html
        # equal html contains each other
        dom1 = parse_html('<p>foo')
        dom2 = parse_html('<p>foo</p>')
@@ -424,7 +476,6 @@ class HTMLEqualTests(TestCase):
        self.assertTrue(dom1 in dom2)

    def test_count(self):
        from django.test.html import parse_html
        # equal html contains each other one time
        dom1 = parse_html('<p>foo')
        dom2 = parse_html('<p>foo</p>')
@@ -459,7 +510,6 @@ class HTMLEqualTests(TestCase):
        self.assertEqual(dom2.count(dom1), 0)

    def test_parsing_errors(self):
        from django.test.html import HTMLParseError, parse_html
        with self.assertRaises(AssertionError):
            self.assertHTMLEqual('<p>', '')
        with self.assertRaises(AssertionError):
@@ -488,7 +538,6 @@ class HTMLEqualTests(TestCase):
            self.assertContains(response, '<p "whats" that>')

    def test_unicode_handling(self):
        from django.http import HttpResponse
        response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>')
        self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True)