Commit 0a503110 authored by Ramiro Morales's avatar Ramiro Morales
Browse files

Fixed #20004 -- Moved non DB-related assertions to SimpleTestCase.

Thanks zalew for the suggestion and work on a patch.

Also updated, tweaked and fixed testing documentation.
parent 69523c1b
Loading
Loading
Loading
Loading
+259 −256
Original line number Diff line number Diff line
@@ -231,6 +231,10 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):

class SimpleTestCase(ut2.TestCase):

    # The class we'll use for the test client self.client.
    # Can be overridden in derived classes.
    client_class = Client

    _warn_txt = ("save_warnings_state/restore_warnings_state "
        "django.test.*TestCase methods are deprecated. Use Python's "
        "warnings.catch_warnings context manager instead.")
@@ -263,249 +267,17 @@ class SimpleTestCase(ut2.TestCase):
                result.addError(self, sys.exc_info())
                return

    def _pre_setup(self):
        pass

    def _post_teardown(self):
        pass

    def save_warnings_state(self):
        """
        Saves the state of the warnings module
        """
        warnings.warn(self._warn_txt, DeprecationWarning, stacklevel=2)
        self._warnings_state = warnings.filters[:]

    def restore_warnings_state(self):
        """
        Restores the state of the warnings module to the state
        saved by save_warnings_state()
        """
        warnings.warn(self._warn_txt, DeprecationWarning, stacklevel=2)
        warnings.filters = self._warnings_state[:]

    def settings(self, **kwargs):
        """
        A context manager that temporarily sets a setting and reverts
        back to the original value when exiting the context.
        """
        return override_settings(**kwargs)

    def assertRaisesMessage(self, expected_exception, expected_message,
                           callable_obj=None, *args, **kwargs):
        """
        Asserts that the message in a raised exception matches the passed
        value.

        Args:
            expected_exception: Exception class expected to be raised.
            expected_message: expected error message string value.
            callable_obj: Function to be called.
            args: Extra args.
            kwargs: Extra kwargs.
        """
        return six.assertRaisesRegex(self, expected_exception,
                re.escape(expected_message), callable_obj, *args, **kwargs)

    def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
            field_kwargs=None, empty_value=''):
        """
        Asserts that a form field behaves correctly with various inputs.

        Args:
            fieldclass: the class of the field to be tested.
            valid: a dictionary mapping valid inputs to their expected
                    cleaned values.
            invalid: a dictionary mapping invalid inputs to one or more
                    raised error messages.
            field_args: the args passed to instantiate the field
            field_kwargs: the kwargs passed to instantiate the field
            empty_value: the expected clean output for inputs in empty_values

        """
        if field_args is None:
            field_args = []
        if field_kwargs is None:
            field_kwargs = {}
        required = fieldclass(*field_args, **field_kwargs)
        optional = fieldclass(*field_args,
                              **dict(field_kwargs, required=False))
        # test valid inputs
        for input, output in valid.items():
            self.assertEqual(required.clean(input), output)
            self.assertEqual(optional.clean(input), output)
        # test invalid inputs
        for input, errors in invalid.items():
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)

            with self.assertRaises(ValidationError) as context_manager:
                optional.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)
        # test required inputs
        error_required = [force_text(required.error_messages['required'])]
        for e in required.empty_values:
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(e)
            self.assertEqual(context_manager.exception.messages,
                             error_required)
            self.assertEqual(optional.clean(e), empty_value)
        # test that max_length and min_length are always accepted
        if issubclass(fieldclass, CharField):
            field_kwargs.update({'min_length':2, 'max_length':20})
            self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
                                       fieldclass))

    def assertHTMLEqual(self, html1, html2, msg=None):
        """
        Asserts that two HTML snippets are semantically the same.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid HTML.
        """
        dom1 = assert_and_parse_html(self, html1, msg,
            'First argument is not valid HTML:')
        dom2 = assert_and_parse_html(self, html2, msg,
            'Second argument is not valid HTML:')

        if dom1 != dom2:
            standardMsg = '%s != %s' % (
                safe_repr(dom1, True), safe_repr(dom2, True))
            diff = ('\n' + '\n'.join(difflib.ndiff(
                           six.text_type(dom1).splitlines(),
                           six.text_type(dom2).splitlines())))
            standardMsg = self._truncateMessage(standardMsg, diff)
            self.fail(self._formatMessage(msg, standardMsg))

    def assertHTMLNotEqual(self, html1, html2, msg=None):
        """Asserts that two HTML snippets are not semantically equivalent."""
        dom1 = assert_and_parse_html(self, html1, msg,
            'First argument is not valid HTML:')
        dom2 = assert_and_parse_html(self, html2, msg,
            'Second argument is not valid HTML:')

        if dom1 == dom2:
            standardMsg = '%s == %s' % (
                safe_repr(dom1, True), safe_repr(dom2, True))
            self.fail(self._formatMessage(msg, standardMsg))

    def assertInHTML(self, needle, haystack, count = None, msg_prefix=''):
        needle = assert_and_parse_html(self, needle, None,
            'First argument is not valid HTML:')
        haystack = assert_and_parse_html(self, haystack, None,
            'Second argument is not valid HTML:')
        real_count = haystack.count(needle)
        if count is not None:
            self.assertEqual(real_count, count,
                msg_prefix + "Found %d instances of '%s' in response"
                " (expected %d)" % (real_count, needle, count))
        else:
            self.assertTrue(real_count != 0,
                msg_prefix + "Couldn't find '%s' in response" % needle)

    def assertJSONEqual(self, raw, expected_data, msg=None):
        try:
            data = json.loads(raw)
        except ValueError:
            self.fail("First argument is not valid JSON: %r" % raw)
        if isinstance(expected_data, six.string_types):
            try:
                expected_data = json.loads(expected_data)
            except ValueError:
                self.fail("Second argument is not valid JSON: %r" % expected_data)
        self.assertEqual(data, expected_data, msg=msg)

    def assertXMLEqual(self, xml1, xml2, msg=None):
        """
        Asserts that two XML snippets are semantically the same.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = 'First or second argument is not valid XML\n%s' % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if not result:
                standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
                self.fail(self._formatMessage(msg, standardMsg))

    def assertXMLNotEqual(self, xml1, xml2, msg=None):
        """
        Asserts that two XML snippets are not semantically equivalent.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = 'First or second argument is not valid XML\n%s' % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if result:
                standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
                self.fail(self._formatMessage(msg, standardMsg))


class TransactionTestCase(SimpleTestCase):

    # The class we'll use for the test client self.client.
    # Can be overridden in derived classes.
    client_class = Client

    # Subclasses can ask for resetting of auto increment sequence before each
    # test case
    reset_sequences = False

    def _pre_setup(self):
        """Performs any pre-test setup. This includes:

            * Flushing the database.
            * If the Test Case class has a 'fixtures' member, installing the
              named fixtures.
           * If the Test Case class has a 'urls' member, replace the
             ROOT_URLCONF with it.
           * Clearing the mail test outbox.
        """
        self.client = self.client_class()
        self._fixture_setup()
        self._urlconf_setup()
        mail.outbox = []

    def _databases_names(self, include_mirrors=True):
        # If the test case has a multi_db=True flag, act on all databases,
        # including mirrors or not. Otherwise, just on the default DB.
        if getattr(self, 'multi_db', False):
            return [alias for alias in connections
                    if include_mirrors or not connections[alias].settings_dict['TEST_MIRROR']]
        else:
            return [DEFAULT_DB_ALIAS]

    def _reset_sequences(self, db_name):
        conn = connections[db_name]
        if conn.features.supports_sequence_reset:
            sql_list = \
                conn.ops.sequence_reset_by_name_sql(no_style(),
                                                    conn.introspection.sequence_list())
            if sql_list:
                with transaction.commit_on_success_unless_managed(using=db_name):
                    cursor = conn.cursor()
                    for sql in sql_list:
                        cursor.execute(sql)

    def _fixture_setup(self):
        for db_name in self._databases_names(include_mirrors=False):
            # Reset sequences
            if self.reset_sequences:
                self._reset_sequences(db_name)

            if hasattr(self, 'fixtures'):
                # We have to use this slightly awkward syntax due to the fact
                # that we're using *args and **kwargs together.
                call_command('loaddata', *self.fixtures,
                             **{'verbosity': 0, 'database': db_name, 'skip_validation': True})

    def _urlconf_setup(self):
        set_urlconf(None)
        if hasattr(self, 'urls'):
@@ -514,28 +286,7 @@ class TransactionTestCase(SimpleTestCase):
            clear_url_caches()

    def _post_teardown(self):
        """ Performs any post-test things. This includes:

            * Putting back the original ROOT_URLCONF if it was changed.
            * Force closing the connection, so that the next test gets
              a clean cursor.
        """
        self._fixture_teardown()
        self._urlconf_teardown()
        # Some DB cursors include SQL statements as part of cursor
        # creation. If you have a test that does rollback, the effect
        # of these statements is lost, which can effect the operation
        # of tests (e.g., losing a timezone setting causing objects to
        # be created with the wrong time).
        # To make sure this doesn't happen, get a clean connection at the
        # start of every test.
        for conn in connections.all():
            conn.close()

    def _fixture_teardown(self):
        for db in self._databases_names(include_mirrors=False):
            call_command('flush', verbosity=0, interactive=False, database=db,
                         skip_validation=True, reset_sequences=False)

    def _urlconf_teardown(self):
        set_urlconf(None)
@@ -543,6 +294,28 @@ class TransactionTestCase(SimpleTestCase):
            settings.ROOT_URLCONF = self._old_root_urlconf
            clear_url_caches()

    def save_warnings_state(self):
        """
        Saves the state of the warnings module
        """
        warnings.warn(self._warn_txt, DeprecationWarning, stacklevel=2)
        self._warnings_state = warnings.filters[:]

    def restore_warnings_state(self):
        """
        Restores the state of the warnings module to the state
        saved by save_warnings_state()
        """
        warnings.warn(self._warn_txt, DeprecationWarning, stacklevel=2)
        warnings.filters = self._warnings_state[:]

    def settings(self, **kwargs):
        """
        A context manager that temporarily sets a setting and reverts
        back to the original value when exiting the context.
        """
        return override_settings(**kwargs)

    def assertRedirects(self, response, expected_url, status_code=302,
                        target_status_code=200, host=None, msg_prefix=''):
        """Asserts that a response redirected to a specific URL, and that the
@@ -787,6 +560,236 @@ class TransactionTestCase(SimpleTestCase):
            msg_prefix + "Template '%s' was used unexpectedly in rendering"
            " the response" % template_name)

    def assertRaisesMessage(self, expected_exception, expected_message,
                           callable_obj=None, *args, **kwargs):
        """
        Asserts that the message in a raised exception matches the passed
        value.

        Args:
            expected_exception: Exception class expected to be raised.
            expected_message: expected error message string value.
            callable_obj: Function to be called.
            args: Extra args.
            kwargs: Extra kwargs.
        """
        return six.assertRaisesRegex(self, expected_exception,
                re.escape(expected_message), callable_obj, *args, **kwargs)

    def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
            field_kwargs=None, empty_value=''):
        """
        Asserts that a form field behaves correctly with various inputs.

        Args:
            fieldclass: the class of the field to be tested.
            valid: a dictionary mapping valid inputs to their expected
                    cleaned values.
            invalid: a dictionary mapping invalid inputs to one or more
                    raised error messages.
            field_args: the args passed to instantiate the field
            field_kwargs: the kwargs passed to instantiate the field
            empty_value: the expected clean output for inputs in empty_values

        """
        if field_args is None:
            field_args = []
        if field_kwargs is None:
            field_kwargs = {}
        required = fieldclass(*field_args, **field_kwargs)
        optional = fieldclass(*field_args,
                              **dict(field_kwargs, required=False))
        # test valid inputs
        for input, output in valid.items():
            self.assertEqual(required.clean(input), output)
            self.assertEqual(optional.clean(input), output)
        # test invalid inputs
        for input, errors in invalid.items():
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)

            with self.assertRaises(ValidationError) as context_manager:
                optional.clean(input)
            self.assertEqual(context_manager.exception.messages, errors)
        # test required inputs
        error_required = [force_text(required.error_messages['required'])]
        for e in required.empty_values:
            with self.assertRaises(ValidationError) as context_manager:
                required.clean(e)
            self.assertEqual(context_manager.exception.messages,
                             error_required)
            self.assertEqual(optional.clean(e), empty_value)
        # test that max_length and min_length are always accepted
        if issubclass(fieldclass, CharField):
            field_kwargs.update({'min_length':2, 'max_length':20})
            self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
                                       fieldclass))

    def assertHTMLEqual(self, html1, html2, msg=None):
        """
        Asserts that two HTML snippets are semantically the same.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid HTML.
        """
        dom1 = assert_and_parse_html(self, html1, msg,
            'First argument is not valid HTML:')
        dom2 = assert_and_parse_html(self, html2, msg,
            'Second argument is not valid HTML:')

        if dom1 != dom2:
            standardMsg = '%s != %s' % (
                safe_repr(dom1, True), safe_repr(dom2, True))
            diff = ('\n' + '\n'.join(difflib.ndiff(
                           six.text_type(dom1).splitlines(),
                           six.text_type(dom2).splitlines())))
            standardMsg = self._truncateMessage(standardMsg, diff)
            self.fail(self._formatMessage(msg, standardMsg))

    def assertHTMLNotEqual(self, html1, html2, msg=None):
        """Asserts that two HTML snippets are not semantically equivalent."""
        dom1 = assert_and_parse_html(self, html1, msg,
            'First argument is not valid HTML:')
        dom2 = assert_and_parse_html(self, html2, msg,
            'Second argument is not valid HTML:')

        if dom1 == dom2:
            standardMsg = '%s == %s' % (
                safe_repr(dom1, True), safe_repr(dom2, True))
            self.fail(self._formatMessage(msg, standardMsg))

    def assertInHTML(self, needle, haystack, count=None, msg_prefix=''):
        needle = assert_and_parse_html(self, needle, None,
            'First argument is not valid HTML:')
        haystack = assert_and_parse_html(self, haystack, None,
            'Second argument is not valid HTML:')
        real_count = haystack.count(needle)
        if count is not None:
            self.assertEqual(real_count, count,
                msg_prefix + "Found %d instances of '%s' in response"
                " (expected %d)" % (real_count, needle, count))
        else:
            self.assertTrue(real_count != 0,
                msg_prefix + "Couldn't find '%s' in response" % needle)

    def assertJSONEqual(self, raw, expected_data, msg=None):
        try:
            data = json.loads(raw)
        except ValueError:
            self.fail("First argument is not valid JSON: %r" % raw)
        if isinstance(expected_data, six.string_types):
            try:
                expected_data = json.loads(expected_data)
            except ValueError:
                self.fail("Second argument is not valid JSON: %r" % expected_data)
        self.assertEqual(data, expected_data, msg=msg)

    def assertXMLEqual(self, xml1, xml2, msg=None):
        """
        Asserts that two XML snippets are semantically the same.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = 'First or second argument is not valid XML\n%s' % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if not result:
                standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
                self.fail(self._formatMessage(msg, standardMsg))

    def assertXMLNotEqual(self, xml1, xml2, msg=None):
        """
        Asserts that two XML snippets are not semantically equivalent.
        Whitespace in most cases is ignored, and attribute ordering is not
        significant. The passed-in arguments must be valid XML.
        """
        try:
            result = compare_xml(xml1, xml2)
        except Exception as e:
            standardMsg = 'First or second argument is not valid XML\n%s' % e
            self.fail(self._formatMessage(msg, standardMsg))
        else:
            if result:
                standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
                self.fail(self._formatMessage(msg, standardMsg))


class TransactionTestCase(SimpleTestCase):

    # Subclasses can ask for resetting of auto increment sequence before each
    # test case
    reset_sequences = False

    def _pre_setup(self):
        """Performs any pre-test setup. This includes:

           * Flushing the database.
           * If the Test Case class has a 'fixtures' member, installing the
             named fixtures.
        """
        super(TransactionTestCase, self)._pre_setup()
        self._fixture_setup()

    def _databases_names(self, include_mirrors=True):
        # If the test case has a multi_db=True flag, act on all databases,
        # including mirrors or not. Otherwise, just on the default DB.
        if getattr(self, 'multi_db', False):
            return [alias for alias in connections
                    if include_mirrors or not connections[alias].settings_dict['TEST_MIRROR']]
        else:
            return [DEFAULT_DB_ALIAS]

    def _reset_sequences(self, db_name):
        conn = connections[db_name]
        if conn.features.supports_sequence_reset:
            sql_list = \
                conn.ops.sequence_reset_by_name_sql(no_style(),
                                                    conn.introspection.sequence_list())
            if sql_list:
                with transaction.commit_on_success_unless_managed(using=db_name):
                    cursor = conn.cursor()
                    for sql in sql_list:
                        cursor.execute(sql)

    def _fixture_setup(self):
        for db_name in self._databases_names(include_mirrors=False):
            # Reset sequences
            if self.reset_sequences:
                self._reset_sequences(db_name)

            if hasattr(self, 'fixtures'):
                # We have to use this slightly awkward syntax due to the fact
                # that we're using *args and **kwargs together.
                call_command('loaddata', *self.fixtures,
                             **{'verbosity': 0, 'database': db_name, 'skip_validation': True})

    def _post_teardown(self):
        """Performs any post-test things. This includes:

           * Putting back the original ROOT_URLCONF if it was changed.
           * Force closing the connection, so that the next test gets
             a clean cursor.
        """
        self._fixture_teardown()
        super(TransactionTestCase, self)._post_teardown()
        # Some DB cursors include SQL statements as part of cursor
        # creation. If you have a test that does rollback, the effect
        # of these statements is lost, which can effect the operation
        # of tests (e.g., losing a timezone setting causing objects to
        # be created with the wrong time).
        # To make sure this doesn't happen, get a clean connection at the
        # start of every test.
        for conn in connections.all():
            conn.close()

    def _fixture_teardown(self):
        for db_name in self._databases_names(include_mirrors=False):
            call_command('flush', verbosity=0, interactive=False, database=db_name,
                         skip_validation=True, reset_sequences=False)

    def assertQuerysetEqual(self, qs, values, transform=repr, ordered=True):
        items = six.moves.map(transform, qs)
        if not ordered:
@@ -841,14 +844,14 @@ class TestCase(TransactionTestCase):
        # Remove this when the legacy transaction management goes away.
        disable_transaction_methods()

        for db in self._databases_names(include_mirrors=False):
        for db_name in self._databases_names(include_mirrors=False):
            if hasattr(self, 'fixtures'):
                try:
                    call_command('loaddata', *self.fixtures,
                                 **{
                                    'verbosity': 0,
                                    'commit': False,
                                    'database': db,
                                    'database': db_name,
                                    'skip_validation': True,
                                 })
                except Exception:
+2 −2
Original line number Diff line number Diff line
@@ -503,8 +503,8 @@ of the process of creating polls.
message: "No polls are available." and verifies the ``latest_poll_list`` is
empty. Note that the :class:`django.test.TestCase` class provides some
additional assertion methods. In these examples, we use
:meth:`~django.test.TestCase.assertContains()` and
:meth:`~django.test.TestCase.assertQuerysetEqual()`.
:meth:`~django.test.SimpleTestCase.assertContains()` and
:meth:`~django.test.TransactionTestCase.assertQuerysetEqual()`.

In ``test_index_view_with_a_past_poll``, we create a poll and verify that it
appears in the list.
+1 −1

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1

File changed.

Preview size limit exceeded, changes collapsed.

Loading