Loading django/test/testcases.py +11 −24 Original line number Diff line number Diff line Loading @@ -24,31 +24,30 @@ from django.core.exceptions import ValidationError, ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.core.management import call_command from django.core.management.color import no_style from django.core.signals import request_started from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer, WSGIServerException) from django.core.urlresolvers import clear_url_caches from django.core.validators import EMPTY_VALUES from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS, reset_queries) from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction from django.forms.fields import CharField from django.http import QueryDict from django.test import _doctest as doctest from django.test.client import Client from django.test.html import HTMLParseError, parse_html from django.test.signals import template_rendered from django.test.utils import (override_settings, compare_xml, strip_quotes) from django.test.utils import ContextList from django.utils import unittest as ut2 from django.test.utils import (CaptureQueriesContext, ContextList, override_settings, compare_xml, strip_quotes) from django.utils import six, unittest as ut2 from django.utils.encoding import force_text from django.utils import six from django.utils.unittest import skipIf # Imported here for backward compatibility from django.utils.unittest.util import safe_repr from django.utils.unittest import skipIf from django.views.static import serve __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s) Loading Loading @@ -168,28 +167,17 @@ class DocTestRunner(doctest.DocTestRunner): transaction.rollback_unless_managed(using=conn) class _AssertNumQueriesContext(object): class _AssertNumQueriesContext(CaptureQueriesContext): def __init__(self, test_case, num, connection): self.test_case = test_case self.num = num self.connection = connection def __enter__(self): self.old_debug_cursor = self.connection.use_debug_cursor self.connection.use_debug_cursor = True self.starting_queries = len(self.connection.queries) request_started.disconnect(reset_queries) return self super(_AssertNumQueriesContext, self).__init__(connection) def __exit__(self, exc_type, exc_value, traceback): self.connection.use_debug_cursor = self.old_debug_cursor request_started.connect(reset_queries) if exc_type is not None: return final_queries = len(self.connection.queries) executed = final_queries - self.starting_queries super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback) executed = len(self) self.test_case.assertEqual( executed, self.num, "%d queries executed, %d expected" % ( executed, self.num Loading Loading @@ -1051,7 +1039,6 @@ class LiveServerThread(threading.Thread): http requests. """ if self.connections_override: from django.db import connections # Override this thread's database connections with the ones # provided by the main thread. for alias, conn in self.connections_override.items(): Loading django/test/utils.py +39 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,8 @@ from xml.dom.minidom import parseString, Node from django.conf import settings, UserSettingsHolder from django.core import mail from django.core.signals import request_started from django.db import reset_queries from django.template import Template, loader, TemplateDoesNotExist from django.template.loaders import cached from django.test.signals import template_rendered, setting_changed Loading Loading @@ -339,5 +341,42 @@ def strip_quotes(want, got): got = got.strip()[2:-1] return want, got def str_prefix(s): return s % {'_': '' if six.PY3 else 'u'} class CaptureQueriesContext(object): """ Context manager that captures queries executed by the specified connection. """ def __init__(self, connection): self.connection = connection def __iter__(self): return iter(self.captured_queries) def __getitem__(self, index): return self.captured_queries[index] def __len__(self): return len(self.captured_queries) @property def captured_queries(self): return self.connection.queries[self.initial_queries:self.final_queries] def __enter__(self): self.use_debug_cursor = self.connection.use_debug_cursor self.connection.use_debug_cursor = True self.initial_queries = len(self.connection.queries) self.final_queries = None request_started.disconnect(reset_queries) return self def __exit__(self, exc_type, exc_value, traceback): self.connection.use_debug_cursor = self.use_debug_cursor request_started.connect(reset_queries) if exc_type is not None: return self.final_queries = len(self.connection.queries) tests/test_utils/tests.py +58 −9 Original line number Diff line number Diff line # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals import warnings from django.db import connection from django.forms import EmailField, IntegerField from django.http import HttpResponse from django.template.loader import render_to_string from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.html import HTMLParseError, parse_html from django.test.utils import CaptureQueriesContext from django.utils import six from django.utils.unittest import skip Loading Loading @@ -94,6 +98,60 @@ class AssertQuerysetEqualTests(TestCase): ) class CaptureQueriesContextManagerTests(TestCase): urls = 'test_utils.urls' def setUp(self): self.person_pk = six.text_type(Person.objects.create(name='test').pk) def test_simple(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.get(pk=self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: pass self.assertEqual(0, len(captured_queries)) def test_within(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.get(pk=self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) def test_nested(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.count() with CaptureQueriesContext(connection) as nested_captured_queries: Person.objects.count() self.assertEqual(1, len(nested_captured_queries)) self.assertEqual(2, len(captured_queries)) def test_failure(self): with self.assertRaises(TypeError): with CaptureQueriesContext(connection): raise TypeError def test_with_client(self): with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 2) self.assertIn(self.person_pk, captured_queries[0]['sql']) self.assertIn(self.person_pk, captured_queries[1]['sql']) class AssertNumQueriesContextManagerTests(TestCase): urls = 'test_utils.urls' Loading Loading @@ -219,7 +277,6 @@ class SaveRestoreWarningState(TestCase): # In reality this test could be satisfied by many broken implementations # of save_warnings_state/restore_warnings_state (e.g. just # warnings.resetwarnings()) , but it is difficult to test more. import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) Loading @@ -245,7 +302,6 @@ class SaveRestoreWarningState(TestCase): class HTMLEqualTests(TestCase): def test_html_parser(self): from django.test.html import parse_html element = parse_html('<div><p>Hello</p></div>') self.assertEqual(len(element.children), 1) self.assertEqual(element.children[0].name, 'p') Loading @@ -259,7 +315,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom[0], 'foo') def test_parse_html_in_script(self): from django.test.html import parse_html parse_html('<script>var a = "<p" + ">";</script>'); parse_html(''' <script> Loading @@ -275,8 +330,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>") def test_self_closing_tags(self): from django.test.html import parse_html self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer', 'link', 'frame', 'base', 'col') for tag in self_closing_tags: Loading Loading @@ -400,7 +453,6 @@ class HTMLEqualTests(TestCase): </html>""") def test_html_contain(self): from django.test.html import parse_html # equal html contains each other dom1 = parse_html('<p>foo') dom2 = parse_html('<p>foo</p>') Loading @@ -424,7 +476,6 @@ class HTMLEqualTests(TestCase): self.assertTrue(dom1 in dom2) def test_count(self): from django.test.html import parse_html # equal html contains each other one time dom1 = parse_html('<p>foo') dom2 = parse_html('<p>foo</p>') Loading Loading @@ -459,7 +510,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom2.count(dom1), 0) def test_parsing_errors(self): from django.test.html import HTMLParseError, parse_html with self.assertRaises(AssertionError): self.assertHTMLEqual('<p>', '') with self.assertRaises(AssertionError): Loading Loading @@ -488,7 +538,6 @@ class HTMLEqualTests(TestCase): self.assertContains(response, '<p "whats" that>') def test_unicode_handling(self): from django.http import HttpResponse response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>') self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True) Loading Loading
django/test/testcases.py +11 −24 Original line number Diff line number Diff line Loading @@ -24,31 +24,30 @@ from django.core.exceptions import ValidationError, ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.core.management import call_command from django.core.management.color import no_style from django.core.signals import request_started from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer, WSGIServerException) from django.core.urlresolvers import clear_url_caches from django.core.validators import EMPTY_VALUES from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS, reset_queries) from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction from django.forms.fields import CharField from django.http import QueryDict from django.test import _doctest as doctest from django.test.client import Client from django.test.html import HTMLParseError, parse_html from django.test.signals import template_rendered from django.test.utils import (override_settings, compare_xml, strip_quotes) from django.test.utils import ContextList from django.utils import unittest as ut2 from django.test.utils import (CaptureQueriesContext, ContextList, override_settings, compare_xml, strip_quotes) from django.utils import six, unittest as ut2 from django.utils.encoding import force_text from django.utils import six from django.utils.unittest import skipIf # Imported here for backward compatibility from django.utils.unittest.util import safe_repr from django.utils.unittest import skipIf from django.views.static import serve __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s) Loading Loading @@ -168,28 +167,17 @@ class DocTestRunner(doctest.DocTestRunner): transaction.rollback_unless_managed(using=conn) class _AssertNumQueriesContext(object): class _AssertNumQueriesContext(CaptureQueriesContext): def __init__(self, test_case, num, connection): self.test_case = test_case self.num = num self.connection = connection def __enter__(self): self.old_debug_cursor = self.connection.use_debug_cursor self.connection.use_debug_cursor = True self.starting_queries = len(self.connection.queries) request_started.disconnect(reset_queries) return self super(_AssertNumQueriesContext, self).__init__(connection) def __exit__(self, exc_type, exc_value, traceback): self.connection.use_debug_cursor = self.old_debug_cursor request_started.connect(reset_queries) if exc_type is not None: return final_queries = len(self.connection.queries) executed = final_queries - self.starting_queries super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback) executed = len(self) self.test_case.assertEqual( executed, self.num, "%d queries executed, %d expected" % ( executed, self.num Loading Loading @@ -1051,7 +1039,6 @@ class LiveServerThread(threading.Thread): http requests. """ if self.connections_override: from django.db import connections # Override this thread's database connections with the ones # provided by the main thread. for alias, conn in self.connections_override.items(): Loading
django/test/utils.py +39 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,8 @@ from xml.dom.minidom import parseString, Node from django.conf import settings, UserSettingsHolder from django.core import mail from django.core.signals import request_started from django.db import reset_queries from django.template import Template, loader, TemplateDoesNotExist from django.template.loaders import cached from django.test.signals import template_rendered, setting_changed Loading Loading @@ -339,5 +341,42 @@ def strip_quotes(want, got): got = got.strip()[2:-1] return want, got def str_prefix(s): return s % {'_': '' if six.PY3 else 'u'} class CaptureQueriesContext(object): """ Context manager that captures queries executed by the specified connection. """ def __init__(self, connection): self.connection = connection def __iter__(self): return iter(self.captured_queries) def __getitem__(self, index): return self.captured_queries[index] def __len__(self): return len(self.captured_queries) @property def captured_queries(self): return self.connection.queries[self.initial_queries:self.final_queries] def __enter__(self): self.use_debug_cursor = self.connection.use_debug_cursor self.connection.use_debug_cursor = True self.initial_queries = len(self.connection.queries) self.final_queries = None request_started.disconnect(reset_queries) return self def __exit__(self, exc_type, exc_value, traceback): self.connection.use_debug_cursor = self.use_debug_cursor request_started.connect(reset_queries) if exc_type is not None: return self.final_queries = len(self.connection.queries)
tests/test_utils/tests.py +58 −9 Original line number Diff line number Diff line # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals import warnings from django.db import connection from django.forms import EmailField, IntegerField from django.http import HttpResponse from django.template.loader import render_to_string from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.html import HTMLParseError, parse_html from django.test.utils import CaptureQueriesContext from django.utils import six from django.utils.unittest import skip Loading Loading @@ -94,6 +98,60 @@ class AssertQuerysetEqualTests(TestCase): ) class CaptureQueriesContextManagerTests(TestCase): urls = 'test_utils.urls' def setUp(self): self.person_pk = six.text_type(Person.objects.create(name='test').pk) def test_simple(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.get(pk=self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: pass self.assertEqual(0, len(captured_queries)) def test_within(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.get(pk=self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) def test_nested(self): with CaptureQueriesContext(connection) as captured_queries: Person.objects.count() with CaptureQueriesContext(connection) as nested_captured_queries: Person.objects.count() self.assertEqual(1, len(nested_captured_queries)) self.assertEqual(2, len(captured_queries)) def test_failure(self): with self.assertRaises(TypeError): with CaptureQueriesContext(connection): raise TypeError def test_with_client(self): with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 1) self.assertIn(self.person_pk, captured_queries[0]['sql']) with CaptureQueriesContext(connection) as captured_queries: self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.client.get("/test_utils/get_person/%s/" % self.person_pk) self.assertEqual(len(captured_queries), 2) self.assertIn(self.person_pk, captured_queries[0]['sql']) self.assertIn(self.person_pk, captured_queries[1]['sql']) class AssertNumQueriesContextManagerTests(TestCase): urls = 'test_utils.urls' Loading Loading @@ -219,7 +277,6 @@ class SaveRestoreWarningState(TestCase): # In reality this test could be satisfied by many broken implementations # of save_warnings_state/restore_warnings_state (e.g. just # warnings.resetwarnings()) , but it is difficult to test more. import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) Loading @@ -245,7 +302,6 @@ class SaveRestoreWarningState(TestCase): class HTMLEqualTests(TestCase): def test_html_parser(self): from django.test.html import parse_html element = parse_html('<div><p>Hello</p></div>') self.assertEqual(len(element.children), 1) self.assertEqual(element.children[0].name, 'p') Loading @@ -259,7 +315,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom[0], 'foo') def test_parse_html_in_script(self): from django.test.html import parse_html parse_html('<script>var a = "<p" + ">";</script>'); parse_html(''' <script> Loading @@ -275,8 +330,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>") def test_self_closing_tags(self): from django.test.html import parse_html self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer', 'link', 'frame', 'base', 'col') for tag in self_closing_tags: Loading Loading @@ -400,7 +453,6 @@ class HTMLEqualTests(TestCase): </html>""") def test_html_contain(self): from django.test.html import parse_html # equal html contains each other dom1 = parse_html('<p>foo') dom2 = parse_html('<p>foo</p>') Loading @@ -424,7 +476,6 @@ class HTMLEqualTests(TestCase): self.assertTrue(dom1 in dom2) def test_count(self): from django.test.html import parse_html # equal html contains each other one time dom1 = parse_html('<p>foo') dom2 = parse_html('<p>foo</p>') Loading Loading @@ -459,7 +510,6 @@ class HTMLEqualTests(TestCase): self.assertEqual(dom2.count(dom1), 0) def test_parsing_errors(self): from django.test.html import HTMLParseError, parse_html with self.assertRaises(AssertionError): self.assertHTMLEqual('<p>', '') with self.assertRaises(AssertionError): Loading Loading @@ -488,7 +538,6 @@ class HTMLEqualTests(TestCase): self.assertContains(response, '<p "whats" that>') def test_unicode_handling(self): from django.http import HttpResponse response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>') self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True) Loading