Commit 0586c061 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Cloned databases for running tests in parallel.

parent cd9fcd4e
Loading
Loading
Loading
Loading
+48 −4
Original line number Diff line number Diff line
@@ -190,13 +190,56 @@ class BaseDatabaseCreation(object):

        return test_database_name

    def destroy_test_db(self, old_database_name, verbosity=1, keepdb=False):
    def clone_test_db(self, number, verbosity=1, autoclobber=False, keepdb=False):
        """
        Clone a test database.
        """
        source_database_name = self.connection.settings_dict['NAME']

        if verbosity >= 1:
            test_db_repr = ''
            action = 'Cloning test database'
            if verbosity >= 2:
                test_db_repr = " ('%s')" % source_database_name
            if keepdb:
                action = 'Using existing clone'
            print("%s for alias '%s'%s..." % (action, self.connection.alias, test_db_repr))

        # We could skip this call if keepdb is True, but we instead
        # give it the keepdb param. See create_test_db for details.
        self._clone_test_db(number, verbosity, keepdb)

    def get_test_db_clone_settings(self, number):
        """
        Return a modified connection settings dict for the n-th clone of a DB.
        """
        # When this function is called, the test database has been created
        # already and its name has been copied to settings_dict['NAME'] so
        # we don't need to call _get_test_db_name.
        orig_settings_dict = self.connection.settings_dict
        new_settings_dict = orig_settings_dict.copy()
        new_settings_dict['NAME'] = '{}_{}'.format(orig_settings_dict['NAME'], number)
        return new_settings_dict

    def _clone_test_db(self, number, verbosity, keepdb=False):
        """
        Internal implementation - duplicate the test db tables.
        """
        raise NotImplementedError(
            "The database backend doesn't support cloning databases. "
            "Disable the option to run tests in parallel processes.")

    def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, number=None):
        """
        Destroy a test database, prompting the user for confirmation if the
        database already exists.
        """
        self.connection.close()
        if number is None:
            test_database_name = self.connection.settings_dict['NAME']
        else:
            test_database_name = self.get_test_db_clone_settings(number)['NAME']

        if verbosity >= 1:
            test_db_repr = ''
            action = 'Destroying'
@@ -213,6 +256,7 @@ class BaseDatabaseCreation(object):
            self._destroy_test_db(test_database_name, verbosity)

        # Restore the original database name
        if old_database_name is not None:
            settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
            self.connection.settings_dict["NAME"] = old_database_name

+36 −0
Original line number Diff line number Diff line
import subprocess
import sys

from django.db.backends.base.creation import BaseDatabaseCreation

from .client import DatabaseClient


class DatabaseCreation(BaseDatabaseCreation):

@@ -11,3 +16,34 @@ class DatabaseCreation(BaseDatabaseCreation):
        if test_settings['COLLATION']:
            suffix.append('COLLATE %s' % test_settings['COLLATION'])
        return ' '.join(suffix)

    def _clone_test_db(self, number, verbosity, keepdb=False):
        qn = self.connection.ops.quote_name
        source_database_name = self.connection.settings_dict['NAME']
        target_database_name = self.get_test_db_clone_settings(number)['NAME']

        with self._nodb_connection.cursor() as cursor:
            try:
                cursor.execute("CREATE DATABASE %s" % qn(target_database_name))
            except Exception as e:
                if keepdb:
                    return
                try:
                    if verbosity >= 1:
                        print("Destroying old test database '%s'..." % self.connection.alias)
                    cursor.execute("DROP DATABASE %s" % qn(target_database_name))
                    cursor.execute("CREATE DATABASE %s" % qn(target_database_name))
                except Exception as e:
                    sys.stderr.write("Got an error recreating the test database: %s\n" % e)
                    sys.exit(2)

        dump_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
        dump_cmd[0] = 'mysqldump'
        dump_cmd[-1] = source_database_name
        load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict)
        load_cmd[-1] = target_database_name

        dump_proc = subprocess.Popen(dump_cmd, stdout=subprocess.PIPE)
        load_proc = subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.PIPE)
        dump_proc.stdout.close()    # allow dump_proc to receive a SIGPIPE if load_proc exits.
        load_proc.communicate()
+28 −0
Original line number Diff line number Diff line
import sys

from django.db.backends.base.creation import BaseDatabaseCreation


@@ -11,3 +13,29 @@ class DatabaseCreation(BaseDatabaseCreation):
        if test_settings['CHARSET']:
            return "WITH ENCODING '%s'" % test_settings['CHARSET']
        return ''

    def _clone_test_db(self, number, verbosity, keepdb=False):
        # CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
        # to the template database.
        self.connection.close()

        qn = self.connection.ops.quote_name
        source_database_name = self.connection.settings_dict['NAME']
        target_database_name = self.get_test_db_clone_settings(number)['NAME']

        with self._nodb_connection.cursor() as cursor:
            try:
                cursor.execute("CREATE DATABASE %s WITH TEMPLATE %s" % (
                    qn(target_database_name), qn(source_database_name)))
            except Exception as e:
                if keepdb:
                    return
                try:
                    if verbosity >= 1:
                        print("Destroying old test database '%s'..." % self.connection.alias)
                    cursor.execute("DROP DATABASE %s" % qn(target_database_name))
                    cursor.execute("CREATE DATABASE %s WITH TEMPLATE %s" % (
                        qn(target_database_name), qn(source_database_name)))
                except Exception as e:
                    sys.stderr.write("Got an error cloning the test database: %s\n" % e)
                    sys.exit(2)
+34 −0
Original line number Diff line number Diff line
import os
import shutil
import sys

from django.core.exceptions import ImproperlyConfigured
@@ -47,6 +48,39 @@ class DatabaseCreation(BaseDatabaseCreation):
                    sys.exit(1)
        return test_database_name

    def get_test_db_clone_settings(self, number):
        orig_settings_dict = self.connection.settings_dict
        source_database_name = orig_settings_dict['NAME']
        if self.connection.is_in_memory_db(source_database_name):
            return orig_settings_dict
        else:
            new_settings_dict = orig_settings_dict.copy()
            root, ext = os.path.splitext(orig_settings_dict['NAME'])
            new_settings_dict['NAME'] = '{}_{}.{}'.format(root, number, ext)
            return new_settings_dict

    def _clone_test_db(self, number, verbosity, keepdb=False):
        source_database_name = self.connection.settings_dict['NAME']
        target_database_name = self.get_test_db_clone_settings(number)['NAME']
        # Forking automatically makes a copy of an in-memory database.
        if not self.connection.is_in_memory_db(source_database_name):
            # Erase the old test database
            if os.access(target_database_name, os.F_OK):
                if keepdb:
                    return
                if verbosity >= 1:
                    print("Destroying old test database '%s'..." % target_database_name)
                try:
                    os.remove(target_database_name)
                except Exception as e:
                    sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
                    sys.exit(2)
            try:
                shutil.copy(source_database_name, target_database_name)
            except Exception as e:
                sys.stderr.write("Got an error cloning the test database: %s\n" % e)
                sys.exit(2)

    def _destroy_test_db(self, test_database_name, verbosity):
        if test_database_name and not self.connection.is_in_memory_db(test_database_name):
            # Remove the SQLite database file
+48 −5
Original line number Diff line number Diff line
import collections
import ctypes
import itertools
import logging
import multiprocessing
@@ -158,12 +159,36 @@ def default_test_processes():
        return multiprocessing.cpu_count()


_worker_id = 0


def _init_worker(counter):
    """
    Switch to databases dedicated to this worker.

    This helper lives at module-level because of the multiprocessing module's
    requirements.
    """

    global _worker_id

    with counter.get_lock():
        counter.value += 1
        _worker_id = counter.value

    for alias in connections:
        connection = connections[alias]
        settings_dict = connection.creation.get_test_db_clone_settings(_worker_id)
        connection.settings_dict = settings_dict
        connection.close()


def _run_subsuite(args):
    """
    Run a suite of tests with a RemoteTestRunner and return a RemoteTestResult.

    This helper lives at module-level and its arguments are wrapped in a tuple
    because of idiosyncrasies of Python's multiprocessing module.
    because of the multiprocessing module's requirements.
    """
    subsuite_index, subsuite, failfast = args
    runner = RemoteTestRunner(failfast=failfast)
@@ -211,7 +236,11 @@ class ParallelTestSuite(unittest.TestSuite):
        if tblib is not None:
            tblib.pickling_support.install()

        pool = multiprocessing.Pool(processes=self.processes)
        counter = multiprocessing.Value(ctypes.c_int, 0)
        pool = multiprocessing.Pool(
            processes=self.processes,
            initializer=_init_worker,
            initargs=[counter])
        args = [
            (index, subsuite, self.failfast)
            for index, subsuite in enumerate(self.subsuites)
@@ -368,7 +397,7 @@ class DiscoverRunner(object):
    def setup_databases(self, **kwargs):
        return setup_databases(
            self.verbosity, self.interactive, self.keepdb, self.debug_sql,
            **kwargs
            self.parallel, **kwargs
        )

    def get_resultclass(self):
@@ -388,6 +417,13 @@ class DiscoverRunner(object):
        """
        for connection, old_name, destroy in old_config:
            if destroy:
                if self.parallel > 1:
                    for index in range(self.parallel):
                        connection.creation.destroy_test_db(
                            number=index + 1,
                            verbosity=self.verbosity,
                            keepdb=self.keepdb,
                        )
                connection.creation.destroy_test_db(old_name, self.verbosity, self.keepdb)

    def teardown_test_environment(self, **kwargs):
@@ -581,7 +617,7 @@ def get_unique_databases():
    return test_databases


def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, **kwargs):
def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, parallel=0, **kwargs):
    """
    Creates the test databases.
    """
@@ -599,11 +635,18 @@ def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, **kwa
            if first_alias is None:
                first_alias = alias
                connection.creation.create_test_db(
                    verbosity,
                    verbosity=verbosity,
                    autoclobber=not interactive,
                    keepdb=keepdb,
                    serialize=connection.settings_dict.get("TEST", {}).get("SERIALIZE", True),
                )
                if parallel > 1:
                    for index in range(parallel):
                        connection.creation.clone_test_db(
                            number=index + 1,
                            verbosity=verbosity,
                            keepdb=keepdb,
                        )
            # Configure all other connections as mirrors of the first one
            else:
                connections[alias].creation.set_as_test_mirror(
Loading