Commit 29628e0b authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Factored out common code in database backends.

parent 64d0f89a
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from contextlib import contextmanager

from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.db.backends.signals import connection_created
from django.db.backends import util
from django.db.transaction import TransactionManagementError
from django.utils.functional import cached_property
@@ -52,6 +53,17 @@ class BaseDatabaseWrapper(object):

    __hash__ = object.__hash__

    def _valid_connection(self):
        return self.connection is not None

    def _cursor(self):
        if not self._valid_connection():
            conn_params = self.get_connection_params()
            self.connection = self.get_new_connection(conn_params)
            self.init_connection_state()
            connection_created.send(sender=self.__class__, connection=self)
        return self.create_cursor()

    def _commit(self):
        if self.connection is not None:
            return self.connection.commit()
+1 −9
Original line number Diff line number Diff line
@@ -33,19 +33,16 @@ from MySQLdb.constants import FIELD_TYPE, CLIENT
from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation
from django.utils.encoding import force_str
from django.utils.functional import cached_property
from django.utils.safestring import SafeBytes, SafeText
from django.utils import six
from django.utils import timezone

# Raise exceptions for database warnings if DEBUG is on
from django.conf import settings
if settings.DEBUG:
    warnings.filterwarnings("error", category=Database.Warning)

@@ -454,12 +451,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        cursor.execute('SET SQL_AUTO_IS_NULL = 0')
        cursor.close()

    def _cursor(self):
        if not self._valid_connection():
            conn_params = self.get_connection_params()
            self.connection = self.get_new_connection(conn_params)
            self.init_connection_state()
            connection_created.send(sender=self.__class__, connection=self)
    def create_cursor(self):
        cursor = self.connection.cursor()
        return CursorWrapper(cursor)

+4 −16
Original line number Diff line number Diff line
@@ -48,7 +48,6 @@ except ImportError as e:
from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
@@ -521,9 +520,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')

    def _valid_connection(self):
        return self.connection is not None

    def _connect_string(self):
        settings_dict = self.settings_dict
        if not settings_dict['HOST'].strip():
@@ -537,8 +533,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        return "%s/%s@%s" % (settings_dict['USER'],
                             settings_dict['PASSWORD'], dsn)

    def create_cursor(self, conn):
        return FormatStylePlaceholderCursor(conn)
    def create_cursor(self):
        return FormatStylePlaceholderCursor(self.connection)

    def get_connection_params(self):
        conn_params = self.settings_dict['OPTIONS'].copy()
@@ -551,7 +547,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        return Database.connect(conn_string, **conn_params)

    def init_connection_state(self):
        cursor = self.create_cursor(self.connection)
        cursor = self.create_cursor()
        # Set the territory first. The territory overrides NLS_DATE_FORMAT
        # and NLS_TIMESTAMP_FORMAT to the territory default. When all of
        # these are set in single statement it isn't clear what is supposed
@@ -572,7 +568,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
            # This check is performed only once per DatabaseWrapper
            # instance per thread, since subsequent connections will use
            # the same settings.
            cursor = self.create_cursor(self.connection)
            cursor = self.create_cursor()
            try:
                cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
                               % self._standard_operators['contains'],
@@ -602,14 +598,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
            # stmtcachesize is available only in 4.3.2 and up.
            pass

    def _cursor(self):
        if not self._valid_connection():
            conn_params = self.get_connection_params()
            self.connection = self.get_new_connection(conn_params)
            self.init_connection_state()
            connection_created.send(sender=self.__class__, connection=self)
        return self.create_cursor(self.connection)

    # Oracle doesn't support savepoint commits.  Ignore them.
    def _savepoint_commit(self, sid):
        pass
+1 −7
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@ import sys

from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations
from django.db.backends.postgresql_psycopg2.client import DatabaseClient
from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation
@@ -205,12 +204,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        self.connection.set_isolation_level(self.isolation_level)
        self._get_pg_version()

    def _cursor(self):
        if self.connection is None:
            conn_params = self.get_connection_params()
            self.connection = self.get_new_connection(conn_params)
            self.init_connection_state()
            connection_created.send(sender=self.__class__, connection=self)
    def create_cursor(self):
        cursor = self.connection.cursor()
        cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
        return CursorWrapper(cursor)
+1 −8
Original line number Diff line number Diff line
@@ -14,7 +14,6 @@ import sys

from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
@@ -344,13 +343,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
    def init_connection_state(self):
        pass

    def _cursor(self):
        if self.connection is None:
            conn_params = self.get_connection_params()
            self.connection = self.get_new_connection(conn_params)
            self.init_connection_state()
            connection_created.send(sender=self.__class__, connection=self)

    def create_cursor(self):
        return self.connection.cursor(factory=SQLiteCursorWrapper)

    def check_constraints(self, table_names=None):