Commit dd9c6bc3 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Introduced getters for connection.autocommit and .needs_rollback.

They ensure that the attributes aren't accessed in conditions where they
don't contain a valid value.

Fixed #20666.
parent 2c406818
Loading
Loading
Loading
Loading
+23 −11
Original line number Diff line number Diff line
@@ -204,7 +204,7 @@ class BaseDatabaseWrapper(object):

    def _savepoint_allowed(self):
        # Savepoints cannot be created outside a transaction
        return self.features.uses_savepoints and not self.autocommit
        return self.features.uses_savepoints and not self.get_autocommit()

    ##### Generic savepoint management methods #####

@@ -279,15 +279,13 @@ class BaseDatabaseWrapper(object):
        """
        self.validate_no_atomic_block()

        self.ensure_connection()

        self.transaction_state.append(managed)

        if not managed and self.is_dirty() and not forced:
            self.commit()
            self.set_clean()

        if managed == self.autocommit:
        if managed == self.get_autocommit():
            self.set_autocommit(not managed)

    def leave_transaction_management(self):
@@ -298,8 +296,6 @@ class BaseDatabaseWrapper(object):
        """
        self.validate_no_atomic_block()

        self.ensure_connection()

        if self.transaction_state:
            del self.transaction_state[-1]
        else:
@@ -313,14 +309,21 @@ class BaseDatabaseWrapper(object):

        if self._dirty:
            self.rollback()
            if managed == self.autocommit:
            if managed == self.get_autocommit():
                self.set_autocommit(not managed)
            raise TransactionManagementError(
                "Transaction managed block ended with pending COMMIT/ROLLBACK")

        if managed == self.autocommit:
        if managed == self.get_autocommit():
            self.set_autocommit(not managed)

    def get_autocommit(self):
        """
        Check the autocommit state.
        """
        self.ensure_connection()
        return self.autocommit

    def set_autocommit(self, autocommit):
        """
        Enable or disable autocommit.
@@ -330,13 +333,22 @@ class BaseDatabaseWrapper(object):
        self._set_autocommit(autocommit)
        self.autocommit = autocommit

    def get_rollback(self):
        """
        Get the "needs rollback" flag -- for *advanced use* only.
        """
        if not self.in_atomic_block:
            raise TransactionManagementError(
                "The rollback flag doesn't work outside of an 'atomic' block.")
        return self.needs_rollback

    def set_rollback(self, rollback):
        """
        Set or unset the "needs rollback" flag -- for *advanced use* only.
        """
        if not self.in_atomic_block:
            raise TransactionManagementError(
                "needs_rollback doesn't work outside of an 'atomic' block.")
                "The rollback flag doesn't work outside of an 'atomic' block.")
        self.needs_rollback = rollback

    def validate_no_atomic_block(self):
@@ -370,7 +382,7 @@ class BaseDatabaseWrapper(object):
        to decide in a managed block of code to decide whether there are open
        changes waiting for commit.
        """
        if not self.autocommit:
        if not self.get_autocommit():
            self._dirty = True

    def set_clean(self):
@@ -436,7 +448,7 @@ class BaseDatabaseWrapper(object):
        if self.connection is not None:
            # If the application didn't restore the original autocommit setting,
            # don't take chances, drop the connection.
            if self.autocommit != self.settings_dict['AUTOCOMMIT']:
            if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
                self.close()
                return

+4 −8
Original line number Diff line number Diff line
@@ -123,7 +123,7 @@ def get_autocommit(using=None):
    """
    Get the autocommit status of the connection.
    """
    return get_connection(using).autocommit
    return get_connection(using).get_autocommit()

def set_autocommit(autocommit, using=None):
    """
@@ -175,7 +175,7 @@ def get_rollback(using=None):
    """
    Gets the "needs rollback" flag -- for *advanced use* only.
    """
    return get_connection(using).needs_rollback
    return get_connection(using).get_rollback()

def set_rollback(rollback, using=None):
    """
@@ -229,15 +229,11 @@ class Atomic(object):
    def __enter__(self):
        connection = get_connection(self.using)

        # Ensure we have a connection to the database before testing
        # autocommit status.
        connection.ensure_connection()

        if not connection.in_atomic_block:
            # Reset state when entering an outermost atomic block.
            connection.commit_on_exit = True
            connection.needs_rollback = False
            if not connection.autocommit:
            if not connection.get_autocommit():
                # Some database adapters (namely sqlite3) don't handle
                # transactions and savepoints properly when autocommit is off.
                # Turning autocommit back on isn't an option; it would trigger
@@ -500,7 +496,7 @@ def commit_on_success_unless_managed(using=None, savepoint=False):
    legacy behavior.
    """
    connection = get_connection(using)
    if connection.autocommit or connection.in_atomic_block:
    if connection.get_autocommit() or connection.in_atomic_block:
        return atomic(using, savepoint)
    else:
        def entering(using):