Commit 220993bc authored by Malcolm Tredinnick's avatar Malcolm Tredinnick
Browse files

Added savepoint support to the transaction code.

This is a no-op for most databases. Only necessary on PostgreSQL so that we can
do things which will possibly intentionally raise an IntegrityError and not
have to rollback the entire transaction. Not supported for PostgreSQL versions
prior to 8.0, so should be used sparingly in internal Django code.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@8314 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent e73bf2bd
Loading
Loading
Loading
Loading
+43 −8
Original line number Diff line number Diff line
@@ -31,6 +31,21 @@ class BaseDatabaseWrapper(local):
        if self.connection is not None:
            return self.connection.rollback()

    def _savepoint(self, sid):
        if not self.features.uses_savepoints:
            return
        self.connection.cursor().execute(self.ops.savepoint_create_sql(sid))

    def _savepoint_rollback(self, sid):
        if not self.features.uses_savepoints:
            return
        self.connection.cursor().execute(self.ops.savepoint_rollback_sql(sid))

    def _savepoint_commit(self, sid):
        if not self.features.uses_savepoints:
            return
        self.connection.cursor().execute(self.ops.savepoint_commit_sql(sid))

    def close(self):
        if self.connection is not None:
            self.connection.close()
@@ -55,6 +70,7 @@ class BaseDatabaseFeatures(object):
    update_can_self_select = True
    interprets_empty_strings_as_nulls = False
    can_use_chunked_reads = True
    uses_savepoints = False

class BaseDatabaseOperations(object):
    """
@@ -226,6 +242,26 @@ class BaseDatabaseOperations(object):
        """
        raise NotImplementedError

    def savepoint_create_sql(self, sid):
        """
        Returns the SQL for starting a new savepoint. Only required if the
        "uses_savepoints" feature is True. The "sid" parameter is a string
        for the savepoint id.
        """
        raise NotImplementedError

    def savepoint_commit_sql(self, sid):
        """
        Returns the SQL for committing the given savepoint.
        """
        raise NotImplementedError

    def savepoint_rollback_sql(self, sid):
        """
        Returns the SQL for rolling back the given savepoint.
        """
        raise NotImplementedError

    def sql_flush(self, style, tables, sequences):
        """
        Returns a list of SQL statements required to remove all data from
@@ -394,7 +430,6 @@ class BaseDatabaseIntrospection(object):

        return sequence_list

        
class BaseDatabaseClient(object):
    """
    This class encapsualtes all backend-specific methods for opening a
+5 −2
Original line number Diff line number Diff line
@@ -63,6 +63,9 @@ class UnicodeCursorWrapper(object):
    def __iter__(self):
        return iter(self.cursor)

class DatabaseFeatures(BaseDatabaseFeatures):
    uses_savepoints = True

class DatabaseWrapper(BaseDatabaseWrapper):
    operators = {
        'exact': '= %s',
@@ -84,7 +87,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
    def __init__(self, *args, **kwargs):
        super(DatabaseWrapper, self).__init__(*args, **kwargs)

        self.features = BaseDatabaseFeatures()
        self.features = DatabaseFeatures()
        self.ops = DatabaseOperations()
        self.client = DatabaseClient()
        self.creation = DatabaseCreation(self)
+10 −0
Original line number Diff line number Diff line
@@ -124,3 +124,13 @@ class DatabaseOperations(BaseDatabaseOperations):
                    style.SQL_KEYWORD('FROM'),
                    style.SQL_TABLE(qn(f.m2m_db_table()))))
        return output

    def savepoint_create_sql(self, sid):
        return "SAVEPOINT %s" % sid

    def savepoint_commit_sql(self, sid):
        return "RELEASE SAVEPOINT %s" % sid

    def savepoint_rollback_sql(self, sid):
        return "ROLLBACK TO SAVEPOINT %s" % sid
+1 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ psycopg2.extensions.register_adapter(SafeUnicode, psycopg2.extensions.QuotedStri

class DatabaseFeatures(BaseDatabaseFeatures):
    needs_datetime_string_cast = False
    uses_savepoints = True

class DatabaseOperations(PostgresqlDatabaseOperations):
    def last_executed_query(self, cursor, sql, params):
+33 −2
Original line number Diff line number Diff line
@@ -30,9 +30,10 @@ class TransactionManagementError(Exception):
    """
    pass

# The state is a dictionary of lists. The key to the dict is the current
# The states are dictionaries of lists. The key to the dict is the current
# thread and the list is handled as a stack of values.
state = {}
savepoint_state = {}

# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
@@ -164,6 +165,36 @@ def rollback():
    connection._rollback()
    set_clean()

def savepoint():
    """
    Creates a savepoint (if supported and required by the backend) inside the
    current transaction. Returns an identifier for the savepoint that will be
    used for the subsequent rollback or commit.
    """
    thread_ident = thread.get_ident()
    if thread_ident in savepoint_state:
        savepoint_state[thread_ident].append(None)
    else:
        savepoint_state[thread_ident] = [None]
    tid = str(thread_ident).replace('-', '')
    sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
    connection._savepoint(sid)
    return sid

def savepoint_rollback(sid):
    """
    Rolls back the most recent savepoint (if one exists). Does nothing if
    savepoints are not supported.
    """
    connection._savepoint_rollback(sid)

def savepoint_commit(sid):
    """
    Commits the most recent savepoint (if one exists). Does nothing if
    savepoints are not supported.
    """
    connection._savepoint_commit(sid)

##############
# DECORATORS #
##############
Loading