Commit d63e5503 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reordered methods in database wrappers.

* Grouped related methods together -- with banner comments :/
* Described which methods are intended to be implemented in backends.
* Added docstrings.
* Used the same order in all wrappers.
parent c5a25c27
Loading
Loading
Loading
Loading
+170 −115
Original line number Diff line number Diff line
@@ -40,20 +40,24 @@ class BaseDatabaseWrapper(object):
        self.alias = alias
        self.use_debug_cursor = None

        # Transaction related attributes
        self.transaction_state = []
        # Savepoint management related attributes
        self.savepoint_state = 0

        # Transaction management related attributes
        self.transaction_state = []
        # Tracks if the connection is believed to be in transaction. This is
        # set somewhat aggressively, as the DBAPI doesn't make it easy to
        # deduce if the connection is in transaction or not.
        self._dirty = False
        self._thread_ident = thread.get_ident()
        self.allow_thread_sharing = allow_thread_sharing

        # Connection termination related attributes
        self.close_at = None
        self.errors_occurred = False

        # Thread-safety related attributes
        self.allow_thread_sharing = allow_thread_sharing
        self._thread_ident = thread.get_ident()

    def __eq__(self, other):
        return self.alias == other.alias

@@ -63,21 +67,26 @@ class BaseDatabaseWrapper(object):
    def __hash__(self):
        return hash(self.alias)

    def wrap_database_errors(self):
        return DatabaseErrorWrapper(self)
    ##### Backend-specific methods for creating connections and cursors #####

    def get_connection_params(self):
        """Returns a dict of parameters suitable for get_new_connection."""
        raise NotImplementedError

    def get_new_connection(self, conn_params):
        """Opens a connection to the database."""
        raise NotImplementedError

    def init_connection_state(self):
        """Initializes the database connection settings."""
        raise NotImplementedError

    def create_cursor(self):
        """Creates a cursor. Assumes that a connection is established."""
        raise NotImplementedError

    ##### Backend-specific wrappers for PEP-249 connection methods #####

    def _cursor(self):
        with self.wrap_database_errors():
            if self.connection is None:
@@ -107,20 +116,48 @@ class BaseDatabaseWrapper(object):
            with self.wrap_database_errors():
                return self.connection.close()

    def _enter_transaction_management(self, managed):
    ##### Generic wrappers for PEP-249 connection methods #####

    def cursor(self):
        """
        A hook for backend-specific changes required when entering manual
        transaction handling.
        Creates a cursor, opening a connection if necessary.
        """
        pass
        self.validate_thread_sharing()
        if (self.use_debug_cursor or
            (self.use_debug_cursor is None and settings.DEBUG)):
            cursor = self.make_debug_cursor(self._cursor())
        else:
            cursor = util.CursorWrapper(self._cursor(), self)
        return cursor

    def _leave_transaction_management(self, managed):
    def commit(self):
        """
        A hook for backend-specific changes required when leaving manual
        transaction handling. Will usually be implemented only when
        _enter_transaction_management() is also required.
        Does the commit itself and resets the dirty flag.
        """
        pass
        self.validate_thread_sharing()
        self._commit()
        self.set_clean()

    def rollback(self):
        """
        Does the rollback itself and resets the dirty flag.
        """
        self.validate_thread_sharing()
        self._rollback()
        self.set_clean()

    def close(self):
        """
        Closes the connection to the database.
        """
        self.validate_thread_sharing()
        try:
            self._close()
        finally:
            self.connection = None
        self.set_clean()

    ##### Backend-specific savepoint management methods #####

    def _savepoint(self, sid):
        if not self.features.uses_savepoints:
@@ -137,15 +174,65 @@ class BaseDatabaseWrapper(object):
            return
        self.cursor().execute(self.ops.savepoint_commit_sql(sid))

    def abort(self):
    ##### Generic savepoint management methods #####

    def savepoint(self):
        """
        Roll back any ongoing transaction and clean the transaction state
        stack.
        Creates a savepoint (if supported and required by the backend) inside the
        current transaction. Returns an identifier for the savepoint that will be
        used for the subsequent rollback or commit.
        """
        if self._dirty:
            self.rollback()
        while self.transaction_state:
            self.leave_transaction_management()
        thread_ident = thread.get_ident()

        self.savepoint_state += 1

        tid = str(thread_ident).replace('-', '')
        sid = "s%s_x%d" % (tid, self.savepoint_state)
        self._savepoint(sid)
        return sid

    def savepoint_rollback(self, sid):
        """
        Rolls back the most recent savepoint (if one exists). Does nothing if
        savepoints are not supported.
        """
        self.validate_thread_sharing()
        if self.savepoint_state:
            self._savepoint_rollback(sid)

    def savepoint_commit(self, sid):
        """
        Commits the most recent savepoint (if one exists). Does nothing if
        savepoints are not supported.
        """
        self.validate_thread_sharing()
        if self.savepoint_state:
            self._savepoint_commit(sid)

    def clean_savepoints(self):
        """
        Resets the counter used to generate unique savepoint ids in this thread.
        """
        self.savepoint_state = 0

    ##### Backend-specific transaction management methods #####

    def _enter_transaction_management(self, managed):
        """
        A hook for backend-specific changes required when entering manual
        transaction handling.
        """
        pass

    def _leave_transaction_management(self, managed):
        """
        A hook for backend-specific changes required when leaving manual
        transaction handling. Will usually be implemented only when
        _enter_transaction_management() is also required.
        """
        pass

    ##### Generic transaction management methods #####

    def enter_transaction_management(self, managed=True):
        """
@@ -185,20 +272,15 @@ class BaseDatabaseWrapper(object):
            raise TransactionManagementError(
                "Transaction managed block ended with pending COMMIT/ROLLBACK")

    def validate_thread_sharing(self):
    def abort(self):
        """
        Validates that the connection isn't accessed by another thread than the
        one which originally created it, unless the connection was explicitly
        authorized to be shared between threads (via the `allow_thread_sharing`
        property). Raises an exception if the validation fails.
        Roll back any ongoing transaction and clean the transaction state
        stack.
        """
        if (not self.allow_thread_sharing
            and self._thread_ident != thread.get_ident()):
                raise DatabaseError("DatabaseWrapper objects created in a "
                    "thread can only be used in that same thread. The object "
                    "with alias '%s' was created in thread id %s and this is "
                    "thread id %s."
                    % (self.alias, self._thread_ident, thread.get_ident()))
        if self._dirty:
            self.rollback()
        while self.transaction_state:
            self.leave_transaction_management()

    def is_dirty(self):
        """
@@ -224,12 +306,6 @@ class BaseDatabaseWrapper(object):
        self._dirty = False
        self.clean_savepoints()

    def clean_savepoints(self):
        """
        Resets the counter used to generate unique savepoint ids in this thread.
        """
        self.savepoint_state = 0

    def is_managed(self):
        """
        Checks whether the transaction manager is in manual or in auto state.
@@ -275,57 +351,13 @@ class BaseDatabaseWrapper(object):
        else:
            self.set_dirty()

    def commit(self):
        """
        Does the commit itself and resets the dirty flag.
        """
        self.validate_thread_sharing()
        self._commit()
        self.set_clean()

    def rollback(self):
        """
        This function does the rollback itself and resets the dirty flag.
        """
        self.validate_thread_sharing()
        self._rollback()
        self.set_clean()

    def savepoint(self):
        """
        Creates a savepoint (if supported and required by the backend) inside the
        current transaction. Returns an identifier for the savepoint that will be
        used for the subsequent rollback or commit.
        """
        thread_ident = thread.get_ident()

        self.savepoint_state += 1

        tid = str(thread_ident).replace('-', '')
        sid = "s%s_x%d" % (tid, self.savepoint_state)
        self._savepoint(sid)
        return sid

    def savepoint_rollback(self, sid):
        """
        Rolls back the most recent savepoint (if one exists). Does nothing if
        savepoints are not supported.
        """
        self.validate_thread_sharing()
        if self.savepoint_state:
            self._savepoint_rollback(sid)

    def savepoint_commit(self, sid):
        """
        Commits the most recent savepoint (if one exists). Does nothing if
        savepoints are not supported.
        """
        self.validate_thread_sharing()
        if self.savepoint_state:
            self._savepoint_commit(sid)
    ##### Foreign key constraints checks handling #####

    @contextmanager
    def constraint_checks_disabled(self):
        """
        Context manager that disables foreign key constraint checking.
        """
        disabled = self.disable_constraint_checking()
        try:
            yield
@@ -335,33 +367,40 @@ class BaseDatabaseWrapper(object):

    def disable_constraint_checking(self):
        """
        Backends can implement as needed to temporarily disable foreign key constraint
        checking.
        Backends can implement as needed to temporarily disable foreign key
        constraint checking.
        """
        pass

    def enable_constraint_checking(self):
        """
        Backends can implement as needed to re-enable foreign key constraint checking.
        Backends can implement as needed to re-enable foreign key constraint
        checking.
        """
        pass

    def check_constraints(self, table_names=None):
        """
        Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS
        ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered.
        Backends can override this method if they can apply constraint
        checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
        IntegrityError if any invalid foreign key references are encountered.
        """
        pass

    def close(self):
        self.validate_thread_sharing()
        try:
            self._close()
        finally:
            self.connection = None
        self.set_clean()
    ##### Connection termination handling #####

    def is_usable(self):
        """
        Tests if the database connection is usable.
        This function may assume that self.connection is not None.
        """
        raise NotImplementedError

    def close_if_unusable_or_obsolete(self):
        """
        Closes the current connection if unrecoverable errors have occurred,
        or if it outlived its maximum age.
        """
        if self.connection is not None:
            if self.errors_occurred:
                if self.is_usable():
@@ -373,30 +412,45 @@ class BaseDatabaseWrapper(object):
                self.close()
                return

    def is_usable(self):
        """
        Test if the database connection is usable.
    ##### Thread safety handling #####

        This function may assume that self.connection is not None.
    def validate_thread_sharing(self):
        """
        raise NotImplementedError
        Validates that the connection isn't accessed by another thread than the
        one which originally created it, unless the connection was explicitly
        authorized to be shared between threads (via the `allow_thread_sharing`
        property). Raises an exception if the validation fails.
        """
        if not (self.allow_thread_sharing
                or self._thread_ident == thread.get_ident()):
            raise DatabaseError("DatabaseWrapper objects created in a "
                "thread can only be used in that same thread. The object "
                "with alias '%s' was created in thread id %s and this is "
                "thread id %s."
                % (self.alias, self._thread_ident, thread.get_ident()))

    def cursor(self):
        self.validate_thread_sharing()
        if (self.use_debug_cursor or
            (self.use_debug_cursor is None and settings.DEBUG)):
            cursor = self.make_debug_cursor(self._cursor())
        else:
            cursor = util.CursorWrapper(self._cursor(), self)
        return cursor
    ##### Miscellaneous #####

    def wrap_database_errors(self):
        """
        Context manager and decorator that re-throws backend-specific database
        exceptions using Django's common wrappers.
        """
        return DatabaseErrorWrapper(self)

    def make_debug_cursor(self, cursor):
        """
        Creates a cursor that logs all queries in self.queries.
        """
        return util.CursorDebugWrapper(cursor, self)

    @contextmanager
    def temporary_connection(self):
        # Ensure a connection is established, and avoid leaving a dangling
        # connection, for operations outside of the request-response cycle.
        """
        Context manager that ensures that a connection is established, and
        if it opened one, closes it to avoid leaving a dangling connection.
        This is useful for operations outside of the request-response cycle.
        """
        must_close = self.connection is None
        cursor = self.cursor()
        try:
@@ -406,6 +460,7 @@ class BaseDatabaseWrapper(object):
            if must_close:
                self.close()


class BaseDatabaseFeatures(object):
    allows_group_by_pk = False
    # True if django.db.backend.utils.typecast_timestamp is used on values
+7 −7
Original line number Diff line number Diff line
@@ -48,19 +48,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
    # implementations. Anything that tries to actually
    # do something raises complain; anything that tries
    # to rollback or undo something raises ignore.
    _cursor = complain
    _commit = complain
    _rollback = ignore
    enter_transaction_management = complain
    leave_transaction_management = ignore
    _close = ignore
    _savepoint = ignore
    _savepoint_commit = complain
    _savepoint_rollback = ignore
    _enter_transaction_management = complain
    _leave_transaction_management = ignore
    set_dirty = complain
    set_clean = complain
    commit_unless_managed = complain
    rollback_unless_managed = ignore
    savepoint = ignore
    savepoint_commit = complain
    savepoint_rollback = ignore
    close = ignore
    cursor = complain

    def __init__(self, *args, **kwargs):
        super(DatabaseWrapper, self).__init__(*args, **kwargs)
+17 −17
Original line number Diff line number Diff line
@@ -439,29 +439,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        cursor = self.connection.cursor()
        return CursorWrapper(cursor)

    def is_usable(self):
        try:
            self.connection.ping()
        except DatabaseError:
            return False
        else:
            return True

    def _rollback(self):
        try:
            BaseDatabaseWrapper._rollback(self)
        except Database.NotSupportedError:
            pass

    @cached_property
    def mysql_version(self):
        with self.temporary_connection():
            server_info = self.connection.get_server_info()
        match = server_version_re.match(server_info)
        if not match:
            raise Exception('Unable to determine MySQL version from version string %r' % server_info)
        return tuple([int(x) for x in match.groups()])

    def disable_constraint_checking(self):
        """
        Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
@@ -510,3 +493,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                        % (table_name, bad_row[0],
                        table_name, column_name, bad_row[1],
                        referenced_table_name, referenced_column_name))

    def is_usable(self):
        try:
            self.connection.ping()
        except DatabaseError:
            return False
        else:
            return True

    @cached_property
    def mysql_version(self):
        with self.temporary_connection():
            server_info = self.connection.get_server_info()
        match = server_version_re.match(server_info)
        if not match:
            raise Exception('Unable to determine MySQL version from version string %r' % server_info)
        return tuple([int(x) for x in match.groups()])
+26 −26
Original line number Diff line number Diff line
@@ -515,14 +515,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        self.introspection = DatabaseIntrospection(self)
        self.validation = BaseDatabaseValidation(self)

    def check_constraints(self, table_names=None):
        """
        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
        are returned to deferred.
        """
        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')

    def _connect_string(self):
        settings_dict = self.settings_dict
        if not settings_dict['HOST'].strip():
@@ -536,9 +528,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        return "%s/%s@%s" % (settings_dict['USER'],
                             settings_dict['PASSWORD'], dsn)

    def create_cursor(self):
        return FormatStylePlaceholderCursor(self.connection)

    def get_connection_params(self):
        conn_params = self.settings_dict['OPTIONS'].copy()
        if 'use_returning_into' in conn_params:
@@ -598,21 +587,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
            # stmtcachesize is available only in 4.3.2 and up.
            pass

    def is_usable(self):
        try:
            if hasattr(self.connection, 'ping'):    # Oracle 10g R2 and higher
                self.connection.ping()
            else:
                # Use a cx_Oracle cursor directly, bypassing Django's utilities.
                self.connection.cursor().execute("SELECT 1 FROM DUAL")
        except DatabaseError:
            return False
        else:
            return True

    # Oracle doesn't support savepoint commits.  Ignore them.
    def _savepoint_commit(self, sid):
        pass
    def create_cursor(self):
        return FormatStylePlaceholderCursor(self.connection)

    def _commit(self):
        if self.connection is not None:
@@ -632,6 +608,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                    six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
                raise

    # Oracle doesn't support savepoint commits.  Ignore them.
    def _savepoint_commit(self, sid):
        pass

    def check_constraints(self, table_names=None):
        """
        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
        are returned to deferred.
        """
        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')

    def is_usable(self):
        try:
            if hasattr(self.connection, 'ping'):    # Oracle 10g R2 and higher
                self.connection.ping()
            else:
                # Use a cx_Oracle cursor directly, bypassing Django's utilities.
                self.connection.cursor().execute("SELECT 1 FROM DUAL")
        except DatabaseError:
            return False
        else:
            return True

    @cached_property
    def oracle_version(self):
        with self.temporary_connection():
+41 −41
Original line number Diff line number Diff line
@@ -91,40 +91,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        self.introspection = DatabaseIntrospection(self)
        self.validation = BaseDatabaseValidation(self)

    def check_constraints(self, table_names=None):
        """
        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
        are returned to deferred.
        """
        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')

    def close(self):
        self.validate_thread_sharing()
        if self.connection is None:
            return

        try:
            self.connection.close()
            self.connection = None
        except Database.Error:
            # In some cases (database restart, network connection lost etc...)
            # the connection to the database is lost without giving Django a
            # notification. If we don't set self.connection to None, the error
            # will occur a every request.
            self.connection = None
            logger.warning('psycopg2 error while closing the connection.',
                exc_info=sys.exc_info()
            )
            raise
        finally:
            self.set_clean()

    @cached_property
    def pg_version(self):
        with self.temporary_connection():
            return get_version(self.connection)

    def get_connection_params(self):
        settings_dict = self.settings_dict
        if not settings_dict['NAME']:
@@ -177,14 +143,26 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
        return cursor

    def is_usable(self):
    def close(self):
        self.validate_thread_sharing()
        if self.connection is None:
            return

        try:
            # Use a psycopg cursor directly, bypassing Django's utilities.
            self.connection.cursor().execute("SELECT 1")
        except DatabaseError:
            return False
        else:
            return True
            self.connection.close()
            self.connection = None
        except Database.Error:
            # In some cases (database restart, network connection lost etc...)
            # the connection to the database is lost without giving Django a
            # notification. If we don't set self.connection to None, the error
            # will occur a every request.
            self.connection = None
            logger.warning('psycopg2 error while closing the connection.',
                exc_info=sys.exc_info()
            )
            raise
        finally:
            self.set_clean()

    def _enter_transaction_management(self, managed):
        """
@@ -222,3 +200,25 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        if ((self.transaction_state and self.transaction_state[-1]) or
                not self.features.uses_autocommit):
            super(DatabaseWrapper, self).set_dirty()

    def check_constraints(self, table_names=None):
        """
        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
        are returned to deferred.
        """
        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')

    def is_usable(self):
        try:
            # Use a psycopg cursor directly, bypassing Django's utilities.
            self.connection.cursor().execute("SELECT 1")
        except DatabaseError:
            return False
        else:
            return True

    @cached_property
    def pg_version(self):
        with self.temporary_connection():
            return get_version(self.connection)
Loading