Commit 51fed81e authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

[1.7.x] Reorganized backends tests.

This reduces the number of explicit vendor checks.

Backport of d6672631 from master
parent e244e456
Loading
Loading
Loading
Loading
+146 −185
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ from . import models


class DummyBackendTest(TestCase):

    def test_no_databases(self):
        """
        Test that empty DATABASES setting default to the dummy backend.
@@ -42,18 +43,15 @@ class DummyBackendTest(TestCase):
            'django.db.backends.dummy')


class OracleChecks(unittest.TestCase):
@unittest.skipUnless(connection.vendor == 'oracle', "Test only for Oracle")
class OracleTests(unittest.TestCase):

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle quote_name semantics")
    def test_quote_name(self):
        # Check that '%' chars are escaped for query execution.
        name = '"SOME%NAME"'
        quoted_name = connection.ops.quote_name(name)
        self.assertEqual(quoted_name % (), name)

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle cursor semantics")
    def test_dbms_session(self):
        # If the backend is Oracle, test that we can call a standard
        # stored procedure through our cursor wrapper.
@@ -63,8 +61,6 @@ class OracleChecks(unittest.TestCase):
            cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
                            [convert_unicode('_django_testing!')])

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle cursor semantics")
    def test_cursor_var(self):
        # If the backend is Oracle, test that we can pass cursor variables
        # as query parameters.
@@ -75,8 +71,6 @@ class OracleChecks(unittest.TestCase):
            cursor.execute("BEGIN %s := 'X'; END; ", [var])
            self.assertEqual(var.getvalue(), 'X')

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle cursor semantics")
    def test_long_string(self):
        # If the backend is Oracle, test that we can save a text longer
        # than 4000 chars and read it properly
@@ -89,8 +83,6 @@ class OracleChecks(unittest.TestCase):
            self.assertEqual(long_str, row[0].read())
            cursor.execute('DROP TABLE ltext')

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle connection semantics")
    def test_client_encoding(self):
        # If the backend is Oracle, test that the client encoding is set
        # correctly.  This was broken under Cygwin prior to r14781.
@@ -98,8 +90,6 @@ class OracleChecks(unittest.TestCase):
        self.assertEqual(connection.connection.encoding, "UTF-8")
        self.assertEqual(connection.connection.nencoding, "UTF-8")

    @unittest.skipUnless(connection.vendor == 'oracle',
                         "No need to check Oracle connection semantics")
    def test_order_of_nls_parameters(self):
        # an 'almost right' datetime should work with configured
        # NLS parameters as per #18465.
@@ -111,11 +101,11 @@ class OracleChecks(unittest.TestCase):
            self.assertEqual(cursor.fetchone()[0], 1)


@unittest.skipUnless(connection.vendor == 'sqlite', "Test only for SQLite")
class SQLiteTests(TestCase):

    longMessage = True

    @unittest.skipUnless(connection.vendor == 'sqlite',
                        "Test valid only for SQLite")
    def test_autoincrement(self):
        """
        Check that auto_increment fields are created with the AUTOINCREMENT
@@ -129,10 +119,147 @@ class SQLiteTests(TestCase):
            match.group(1), "Wrong SQL used to create an auto-increment "
            "column on SQLite")

    def test_aggregation(self):
        """
        #19360: Raise NotImplementedError when aggregating on date/time fields.
        """
        for aggregate in (Sum, Avg, Variance, StdDev):
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('time'))
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('date'))
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('last_modified'))


    def test_convert_values_to_handle_null_value(self):
        convert_values = DatabaseOperations(connection).convert_values
        self.assertIsNone(convert_values(None, AutoField(primary_key=True)))
        self.assertIsNone(convert_values(None, DateField()))
        self.assertIsNone(convert_values(None, DateTimeField()))
        self.assertIsNone(convert_values(None, DecimalField()))
        self.assertIsNone(convert_values(None, IntegerField()))
        self.assertIsNone(convert_values(None, TimeField()))


@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class PostgreSQLTests(TestCase):

    def assert_parses(self, version_string, version):
        self.assertEqual(pg_version._parse_version(version_string), version)

    def test_parsing(self):
        """Test PostgreSQL version parsing from `SELECT version()` output"""
        self.assert_parses("PostgreSQL 8.3 beta4", 80300)
        self.assert_parses("PostgreSQL 8.3", 80300)
        self.assert_parses("EnterpriseDB 8.3", 80300)
        self.assert_parses("PostgreSQL 8.3.6", 80306)
        self.assert_parses("PostgreSQL 8.4beta1", 80400)
        self.assert_parses("PostgreSQL 8.3.1 on i386-apple-darwin9.2.2, compiled by GCC i686-apple-darwin9-gcc-4.0.1 (GCC) 4.0.1 (Apple Inc. build 5478)", 80301)

    def test_version_detection(self):
        """Test PostgreSQL version detection"""

        # Helper mocks
        class CursorMock(object):
            "Very simple mock of DB-API cursor"
            def execute(self, arg):
                pass

            def fetchone(self):
                return ["PostgreSQL 8.3"]

            def __enter__(self):
                return self

            def __exit__(self, type, value, traceback):
                pass

        class OlderConnectionMock(object):
            "Mock of psycopg2 (< 2.0.12) connection"
            def cursor(self):
                return CursorMock()

        # psycopg2 < 2.0.12 code path
        conn = OlderConnectionMock()
        self.assertEqual(pg_version.get_version(conn), 80300)

    def test_connect_and_rollback(self):
        """
        PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
        transaction is rolled back (#17062).
        """
        databases = copy.deepcopy(settings.DATABASES)
        new_connections = ConnectionHandler(databases)
        new_connection = new_connections[DEFAULT_DB_ALIAS]
        try:
            # Ensure the database default time zone is different than
            # the time zone in new_connection.settings_dict. We can
            # get the default time zone by reset & show.
            cursor = new_connection.cursor()
            cursor.execute("RESET TIMEZONE")
            cursor.execute("SHOW TIMEZONE")
            db_default_tz = cursor.fetchone()[0]
            new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
            new_connection.close()

            # Fetch a new connection with the new_tz as default
            # time zone, run a query and rollback.
            new_connection.settings_dict['TIME_ZONE'] = new_tz
            new_connection.enter_transaction_management()
            cursor = new_connection.cursor()
            new_connection.rollback()

            # Now let's see if the rollback rolled back the SET TIME ZONE.
            cursor.execute("SHOW TIMEZONE")
            tz = cursor.fetchone()[0]
            self.assertEqual(new_tz, tz)
        finally:
            new_connection.close()

    def test_connect_non_autocommit(self):
        """
        The connection wrapper shouldn't believe that autocommit is enabled
        after setting the time zone when AUTOCOMMIT is False (#21452).
        """
        databases = copy.deepcopy(settings.DATABASES)
        databases[DEFAULT_DB_ALIAS]['AUTOCOMMIT'] = False
        new_connections = ConnectionHandler(databases)
        new_connection = new_connections[DEFAULT_DB_ALIAS]
        try:
            # Open a database connection.
            new_connection.cursor()
            self.assertFalse(new_connection.get_autocommit())
        finally:
            new_connection.close()

    def _select(self, val):
        with connection.cursor() as cursor:
            cursor.execute("SELECT %s", (val,))
            return cursor.fetchone()[0]

    def test_select_ascii_array(self):
        a = ["awef"]
        b = self._select(a)
        self.assertEqual(a[0], b[0])

    def test_select_unicode_array(self):
        a = ["ᄲawef"]
        b = self._select(a)
        self.assertEqual(a[0], b[0])

    def test_lookup_cast(self):
        from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations

        do = DatabaseOperations(connection=None)
        for lookup in ('iexact', 'contains', 'icontains', 'startswith',
                       'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
            self.assertIn('::text', do.lookup_cast(lookup))


@unittest.skipUnless(connection.vendor == 'mysql', "Test only for MySQL")
class MySQLTests(TestCase):
    @unittest.skipUnless(connection.vendor == 'mysql',
                        "Test valid only for MySQL")

    def test_autoincrement(self):
        """
        Check that auto_increment fields are reset correctly by sql_flush().
@@ -226,6 +353,7 @@ class LastExecutedQueryTest(TestCase):


class ParameterHandlingTest(TestCase):

    def test_bad_parameter_count(self):
        "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
        cursor = connection.cursor()
@@ -286,6 +414,7 @@ class LongNameTest(TestCase):


class SequenceResetTest(TestCase):

    def test_generic_relation(self):
        "Sequence names are correct when resetting generic relations (Ref #13941)"
        # Create an object with a manually specified PK
@@ -303,105 +432,6 @@ class SequenceResetTest(TestCase):
        self.assertTrue(obj.pk > 10)


class PostgresVersionTest(TestCase):
    def assert_parses(self, version_string, version):
        self.assertEqual(pg_version._parse_version(version_string), version)

    def test_parsing(self):
        """Test PostgreSQL version parsing from `SELECT version()` output"""
        self.assert_parses("PostgreSQL 8.3 beta4", 80300)
        self.assert_parses("PostgreSQL 8.3", 80300)
        self.assert_parses("EnterpriseDB 8.3", 80300)
        self.assert_parses("PostgreSQL 8.3.6", 80306)
        self.assert_parses("PostgreSQL 8.4beta1", 80400)
        self.assert_parses("PostgreSQL 8.3.1 on i386-apple-darwin9.2.2, compiled by GCC i686-apple-darwin9-gcc-4.0.1 (GCC) 4.0.1 (Apple Inc. build 5478)", 80301)

    def test_version_detection(self):
        """Test PostgreSQL version detection"""

        # Helper mocks
        class CursorMock(object):
            "Very simple mock of DB-API cursor"
            def execute(self, arg):
                pass

            def fetchone(self):
                return ["PostgreSQL 8.3"]

            def __enter__(self):
                return self

            def __exit__(self, type, value, traceback):
                pass

        class OlderConnectionMock(object):
            "Mock of psycopg2 (< 2.0.12) connection"
            def cursor(self):
                return CursorMock()

        # psycopg2 < 2.0.12 code path
        conn = OlderConnectionMock()
        self.assertEqual(pg_version.get_version(conn), 80300)


class PostgresNewConnectionTests(TestCase):

    @unittest.skipUnless(
        connection.vendor == 'postgresql',
        "This test applies only to PostgreSQL")
    def test_connect_and_rollback(self):
        """
        PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
        transaction is rolled back (#17062).
        """
        databases = copy.deepcopy(settings.DATABASES)
        new_connections = ConnectionHandler(databases)
        new_connection = new_connections[DEFAULT_DB_ALIAS]
        try:
            # Ensure the database default time zone is different than
            # the time zone in new_connection.settings_dict. We can
            # get the default time zone by reset & show.
            cursor = new_connection.cursor()
            cursor.execute("RESET TIMEZONE")
            cursor.execute("SHOW TIMEZONE")
            db_default_tz = cursor.fetchone()[0]
            new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
            new_connection.close()

            # Fetch a new connection with the new_tz as default
            # time zone, run a query and rollback.
            new_connection.settings_dict['TIME_ZONE'] = new_tz
            new_connection.enter_transaction_management()
            cursor = new_connection.cursor()
            new_connection.rollback()

            # Now let's see if the rollback rolled back the SET TIME ZONE.
            cursor.execute("SHOW TIMEZONE")
            tz = cursor.fetchone()[0]
            self.assertEqual(new_tz, tz)
        finally:
            new_connection.close()

    @unittest.skipUnless(
        connection.vendor == 'postgresql',
        "This test applies only to PostgreSQL")
    def test_connect_non_autocommit(self):
        """
        The connection wrapper shouldn't believe that autocommit is enabled
        after setting the time zone when AUTOCOMMIT is False (#21452).
        """
        databases = copy.deepcopy(settings.DATABASES)
        databases[DEFAULT_DB_ALIAS]['AUTOCOMMIT'] = False
        new_connections = ConnectionHandler(databases)
        new_connection = new_connections[DEFAULT_DB_ALIAS]
        try:
            # Open a database connection.
            new_connection.cursor()
            self.assertFalse(new_connection.get_autocommit())
        finally:
            new_connection.close()


# This test needs to run outside of a transaction, otherwise closing the
# connection would implicitly rollback and cause problems during teardown.
class ConnectionCreatedSignalTest(TransactionTestCase):
@@ -464,54 +494,6 @@ class EscapingChecksDebug(EscapingChecks):
    pass


class SqliteAggregationTests(TestCase):
    """
    #19360: Raise NotImplementedError when aggregating on date/time fields.
    """
    @unittest.skipUnless(connection.vendor == 'sqlite',
                         "No need to check SQLite aggregation semantics")
    def test_aggregation(self):
        for aggregate in (Sum, Avg, Variance, StdDev):
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('time'))
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('date'))
            self.assertRaises(NotImplementedError,
                models.Item.objects.all().aggregate, aggregate('last_modified'))


class SqliteChecks(TestCase):

    @unittest.skipUnless(connection.vendor == 'sqlite',
                         "No need to do SQLite checks")
    def test_convert_values_to_handle_null_value(self):
        database_operations = DatabaseOperations(connection)
        self.assertEqual(
            None,
            database_operations.convert_values(None, AutoField(primary_key=True))
        )
        self.assertEqual(
            None,
            database_operations.convert_values(None, DateField())
        )
        self.assertEqual(
            None,
            database_operations.convert_values(None, DateTimeField())
        )
        self.assertEqual(
            None,
            database_operations.convert_values(None, DecimalField())
        )
        self.assertEqual(
            None,
            database_operations.convert_values(None, IntegerField())
        )
        self.assertEqual(
            None,
            database_operations.convert_values(None, TimeField())
        )


class BackendTestCase(TestCase):

    def create_squares_with_executemany(self, args):
@@ -1036,24 +1018,3 @@ class BackendUtilTests(TestCase):
              '0.1')
        equal('0.1234567890', 12, 0,
              '0')


@unittest.skipUnless(
    connection.vendor == 'postgresql',
    "This test applies only to PostgreSQL")
class UnicodeArrayTestCase(TestCase):

    def select(self, val):
        cursor = connection.cursor()
        cursor.execute("select %s", (val,))
        return cursor.fetchone()[0]

    def test_select_ascii_array(self):
        a = ["awef"]
        b = self.select(a)
        self.assertEqual(a[0], b[0])

    def test_select_unicode_array(self):
        a = ["ᄲawef"]
        b = self.select(a)
        self.assertEqual(a[0], b[0])