Commit ebabd772 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Ensured a connection is established when checking the database version.

Fixed a test broken by 21765c0a. Refs #18135.
parent 9a3988ca
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -352,6 +352,18 @@ class BaseDatabaseWrapper(object):
    def make_debug_cursor(self, cursor):
        return util.CursorDebugWrapper(cursor, self)

    @contextmanager
    def temporary_connection(self):
        # Ensure a connection is established, and avoid leaving a dangling
        # connection, for operations outside of the request-response cycle.
        must_close = self.connection is None
        cursor = self.cursor()
        try:
            yield
        finally:
            cursor.close()
            if must_close:
                self.close()

class BaseDatabaseFeatures(object):
    allows_group_by_pk = False
+2 −1
Original line number Diff line number Diff line
@@ -453,6 +453,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):

    @cached_property
    def mysql_version(self):
        with self.temporary_connection():
            server_info = self.connection.get_server_info()
        match = server_version_re.match(server_info)
        if not match:
+3 −1
Original line number Diff line number Diff line
@@ -623,8 +623,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):

    @cached_property
    def oracle_version(self):
        with self.temporary_connection():
            version = self.connection.version
        try:
            return int(self.connection.version.split('.')[0])
            return int(version.split('.')[0])
        except ValueError:
            return None

+2 −1
Original line number Diff line number Diff line
@@ -152,6 +152,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):

    @cached_property
    def pg_version(self):
        with self.temporary_connection():
            return get_version(self.connection)

    def get_connection_params(self):
+1 −2
Original line number Diff line number Diff line
@@ -195,8 +195,7 @@ class DatabaseOperations(BaseDatabaseOperations):
        NotImplementedError if this is the database in use.
        """
        if aggregate.sql_function in ('STDDEV_POP', 'VAR_POP'):
            pg_version = self.connection.pg_version
            if pg_version >= 80200 and pg_version <= 80204:
            if 80200 <= self.connection.pg_version <= 80204:
                raise NotImplementedError('PostgreSQL 8.2 to 8.2.4 is known to have a faulty implementation of %s. Please upgrade your version of PostgreSQL.' % aggregate.sql_function)

    def max_name_length(self):