Commit 50328f0a authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19861 -- Transaction ._dirty flag improvement

There were a couple of errors in ._dirty flag handling:
  * It started as None, but was never reset to None.
  * The _dirty flag was sometimes used to indicate if the connection
    was inside transaction management, but this was not done
    consistently. This also meant the flag had three separate values.
  * The None value had a special meaning, causing for example inability
    to commit() on new connection unless enter/leave tx management was
    done.
  * The _dirty was tracking "connection in transaction" state, but only
    in managed transactions.
  * Some tests never reset the transaction state of the used connection.
  * And some additional less important changes.

This commit has some potential for regressions, but as the above list
shows, the current situation isn't perfect either.
parent 21089416
Loading
Loading
Loading
Loading
+18 −25
Original line number Diff line number Diff line
@@ -41,7 +41,10 @@ class BaseDatabaseWrapper(object):
        # Transaction related attributes
        self.transaction_state = []
        self.savepoint_state = 0
        self._dirty = None
        # Tracks if the connection is believed to be in transaction. This is
        # set somewhat aggressively, as the DBAPI doesn't make it easy to
        # deduce if the connection is in transaction or not.
        self._dirty = False
        self._thread_ident = thread.get_ident()
        self.allow_thread_sharing = allow_thread_sharing

@@ -118,8 +121,7 @@ class BaseDatabaseWrapper(object):
        stack.
        """
        if self._dirty:
            self._rollback()
            self._dirty = False
            self.rollback()
        while self.transaction_state:
            self.leave_transaction_management()

@@ -137,9 +139,6 @@ class BaseDatabaseWrapper(object):
            self.transaction_state.append(self.transaction_state[-1])
        else:
            self.transaction_state.append(settings.TRANSACTIONS_MANAGED)

        if self._dirty is None:
            self._dirty = False
        self._enter_transaction_management(managed)

    def leave_transaction_management(self):
@@ -153,14 +152,16 @@ class BaseDatabaseWrapper(object):
        else:
            raise TransactionManagementError(
                "This code isn't under transaction management")
        # The _leave_transaction_management hook can change the dirty flag,
        # so memoize it.
        dirty = self._dirty
        # We will pass the next status (after leaving the previous state
        # behind) to subclass hook.
        self._leave_transaction_management(self.is_managed())
        if self._dirty:
        if dirty:
            self.rollback()
            raise TransactionManagementError(
                "Transaction managed block ended with pending COMMIT/ROLLBACK")
        self._dirty = False

    def validate_thread_sharing(self):
        """
@@ -190,11 +191,7 @@ class BaseDatabaseWrapper(object):
        to decide in a managed block of code to decide whether there are open
        changes waiting for commit.
        """
        if self._dirty is not None:
        self._dirty = True
        else:
            raise TransactionManagementError("This code isn't under transaction "
                "management")

    def set_clean(self):
        """
@@ -202,10 +199,7 @@ class BaseDatabaseWrapper(object):
        to decide in a managed block of code to decide whether a commit or rollback
        should happen.
        """
        if self._dirty is not None:
        self._dirty = False
        else:
            raise TransactionManagementError("This code isn't under transaction management")
        self.clean_savepoints()

    def clean_savepoints(self):
@@ -233,8 +227,7 @@ class BaseDatabaseWrapper(object):
        if top:
            top[-1] = flag
            if not flag and self.is_dirty():
                self._commit()
                self.set_clean()
                self.commit()
        else:
            raise TransactionManagementError("This code isn't under transaction "
                "management")
@@ -245,7 +238,7 @@ class BaseDatabaseWrapper(object):
        """
        self.validate_thread_sharing()
        if not self.is_managed():
            self._commit()
            self.commit()
            self.clean_savepoints()
        else:
            self.set_dirty()
@@ -256,7 +249,7 @@ class BaseDatabaseWrapper(object):
        """
        self.validate_thread_sharing()
        if not self.is_managed():
            self._rollback()
            self.rollback()
        else:
            self.set_dirty()

@@ -343,6 +336,7 @@ class BaseDatabaseWrapper(object):
        if self.connection is not None:
            self.connection.close()
            self.connection = None
        self.set_clean()

    def cursor(self):
        self.validate_thread_sharing()
@@ -485,14 +479,13 @@ class BaseDatabaseFeatures(object):
            self.connection.managed(True)
            cursor = self.connection.cursor()
            cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
            self.connection._commit()
            self.connection.commit()
            cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
            self.connection._rollback()
            self.connection.rollback()
            cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
            count, = cursor.fetchone()
            cursor.execute('DROP TABLE ROLLBACK_TEST')
            self.connection._commit()
            self.connection._dirty = False
            self.connection.commit()
        finally:
            self.connection.leave_transaction_management()
        return count == 0
+1 −1
Original line number Diff line number Diff line
@@ -385,8 +385,8 @@ class BaseDatabaseCreation(object):
        # Create the test database and connect to it. We need to autocommit
        # if the database supports it because PostgreSQL doesn't allow
        # CREATE/DROP DATABASE statements within transactions.
        cursor = self.connection.cursor()
        self._prepare_for_test_db_ddl()
        cursor = self.connection.cursor()
        try:
            cursor.execute(
                "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
+9 −0
Original line number Diff line number Diff line
@@ -149,6 +149,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                exc_info=sys.exc_info()
            )
            raise
        finally:
            self.set_clean()

    @cached_property
    def pg_version(self):
@@ -233,10 +235,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
        try:
            if self.connection is not None:
                self.connection.set_isolation_level(level)
            if level == psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT:
                self.set_clean()
        finally:
            self.isolation_level = level
            self.features.uses_savepoints = bool(level)

    def set_dirty(self):
        if ((self.transaction_state and self.transaction_state[-1]) or
                not self.features.uses_autocommit):
            super(DatabaseWrapper, self).set_dirty()

    def _commit(self):
        if self.connection is not None:
            try:
+2 −0
Original line number Diff line number Diff line
@@ -82,6 +82,8 @@ class DatabaseCreation(BaseDatabaseCreation):

    def _prepare_for_test_db_ddl(self):
        """Rollback and close the active transaction."""
        # Make sure there is an open connection.
        self.connection.cursor()
        self.connection.connection.rollback()
        self.connection.connection.set_isolation_level(
                psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+3 −7
Original line number Diff line number Diff line
@@ -19,13 +19,9 @@ class CursorWrapper(object):
        self.cursor = cursor
        self.db = db

    def set_dirty(self):
        if self.db.is_managed():
            self.db.set_dirty()

    def __getattr__(self, attr):
        if attr in ('execute', 'executemany', 'callproc'):
            self.set_dirty()
            self.db.set_dirty()
        return getattr(self.cursor, attr)

    def __iter__(self):
@@ -35,7 +31,7 @@ class CursorWrapper(object):
class CursorDebugWrapper(CursorWrapper):

    def execute(self, sql, params=()):
        self.set_dirty()
        self.db.set_dirty()
        start = time()
        try:
            return self.cursor.execute(sql, params)
@@ -52,7 +48,7 @@ class CursorDebugWrapper(CursorWrapper):
            )

    def executemany(self, sql, param_list):
        self.set_dirty()
        self.db.set_dirty()
        start = time()
        try:
            return self.cursor.executemany(sql, param_list)
Loading