Commit 8c12d51e authored by Andrew Godwin's avatar Andrew Godwin
Browse files

Fixed #22487: Optional rollback emulation for migrated apps

parent 8721adcb
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -578,6 +578,10 @@ DEFAULT_EXCEPTION_REPORTER_FILTER = 'django.views.debug.SafeExceptionReporterFil
# The name of the class to use to run the test suite
TEST_RUNNER = 'django.test.runner.DiscoverRunner'

# Apps that don't need to be serialized at test database creation time
# (only apps with migrations are to start with)
TEST_NON_SERIALIZED_APPS = []

############
# FIXTURES #
############
+4 −5
Original line number Diff line number Diff line
@@ -22,10 +22,9 @@ class Command(NoArgsCommand):
        make_option('--no-initial-data', action='store_false', dest='load_initial_data', default=True,
            help='Tells Django not to load any initial data after database synchronization.'),
    )
    help = ('Returns the database to the state it was in immediately after '
           'migrate was first executed. This means that all data will be removed '
           'from the database, any post-migration handlers will be '
           're-executed, and the initial_data fixture will be re-installed.')
    help = ('Removes ALL DATA from the database, including data added during '
           'migrations. Unmigrated apps will also have their initial_data '
           'fixture reloaded. Does not achieve a "fresh install" state.')

    def handle_noargs(self, **options):
        database = options.get('database')
@@ -54,7 +53,7 @@ class Command(NoArgsCommand):
        if interactive:
            confirm = input("""You have requested a flush of the database.
This will IRREVERSIBLY DESTROY all data currently in the %r database,
and return each table to a fresh state.
and return each table to an empty state.
Are you sure you want to do this?

    Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['NAME'])
+1 −1
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ class Command(BaseCommand):
        addrport = options.get('addrport')

        # Create a test database.
        db_name = connection.creation.create_test_db(verbosity=verbosity, autoclobber=not interactive)
        db_name = connection.creation.create_test_db(verbosity=verbosity, autoclobber=not interactive, serialize=False)

        # Import the fixture data into the test database.
        call_command('loaddata', *fixture_labels, **{'verbosity': verbosity})
+63 −14
Original line number Diff line number Diff line
@@ -7,6 +7,11 @@ from django.db.utils import load_backend
from django.utils.encoding import force_bytes
from django.utils.functional import cached_property
from django.utils.six.moves import input
from django.utils.six import StringIO
from django.core.management.commands.dumpdata import sort_dependencies
from django.db import router
from django.apps import apps
from django.core import serializers

from .utils import truncate_name

@@ -332,7 +337,7 @@ class BaseDatabaseCreation(object):
            ";",
        ]

    def create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
    def create_test_db(self, verbosity=1, autoclobber=False, keepdb=False, serialize=True):
        """
        Creates a test database, prompting the user for confirmation if the
        database already exists. Returns the name of the test database created.
@@ -364,25 +369,31 @@ class BaseDatabaseCreation(object):
        settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
        self.connection.settings_dict["NAME"] = test_database_name

        # Report migrate messages at one level lower than that requested.
        # We report migrate messages at one level lower than that requested.
        # This ensures we don't get flooded with messages during testing
        # (unless you really ask to be flooded)
        call_command('migrate',
        # (unless you really ask to be flooded).
        call_command(
            'migrate',
            verbosity=max(verbosity - 1, 0),
            interactive=False,
            database=self.connection.alias,
            load_initial_data=False,
            test_database=True)

        # We need to then do a flush to ensure that any data installed by
        # custom SQL has been removed. The only test data should come from
        # test fixtures, or autogenerated from post_migrate triggers.
        # This has the side effect of loading initial data (which was
        # intentionally skipped in the syncdb).
        call_command('flush',
            test_database=True,
        )

        # We then serialize the current state of the database into a string
        # and store it on the connection. This slightly horrific process is so people
        # who are testing on databases without transactions or who are using
        # a TransactionTestCase still get a clean database on every test run.
        if serialize:
            self.connection._test_serialized_contents = self.serialize_db_to_string()

        # Finally, we flush the database to clean
        call_command(
            'flush',
            verbosity=max(verbosity - 1, 0),
            interactive=False,
            database=self.connection.alias)
            database=self.connection.alias
        )

        call_command('createcachetable', database=self.connection.alias)

@@ -391,6 +402,44 @@ class BaseDatabaseCreation(object):

        return test_database_name

    def serialize_db_to_string(self):
        """
        Serializes all data in the database into a JSON string.
        Designed only for test runner usage; will not handle large
        amounts of data.
        """
        # Build list of all apps to serialize
        from django.db.migrations.loader import MigrationLoader
        loader = MigrationLoader(self.connection)
        app_list = []
        for app_config in apps.get_app_configs():
            if (
                app_config.models_module is not None and
                app_config.label in loader.migrated_apps and
                app_config.name not in settings.TEST_NON_SERIALIZED_APPS
            ):
                app_list.append((app_config, None))
        # Make a function to iteratively return every object
        def get_objects():
            for model in sort_dependencies(app_list):
                if not model._meta.proxy and router.allow_migrate(self.connection.alias, model):
                    queryset = model._default_manager.using(self.connection.alias).order_by(model._meta.pk.name)
                    for obj in queryset.iterator():
                        yield obj
        # Serialise to a string
        out = StringIO()
        serializers.serialize("json", get_objects(), indent=None, stream=out)
        return out.getvalue()

    def deserialize_db_from_string(self, data):
        """
        Reloads the database with data from a string generated by
        the serialize_db_to_string method.
        """
        data = StringIO(data)
        for obj in serializers.deserialize("json", data, using=self.connection.alias):
            obj.save()

    def _get_test_db_name(self):
        """
        Internal implementation - returns the name of the test DB that will be
+5 −1
Original line number Diff line number Diff line
@@ -298,7 +298,11 @@ def setup_databases(verbosity, interactive, keepdb=False, **kwargs):
            connection = connections[alias]
            if test_db_name is None:
                test_db_name = connection.creation.create_test_db(
                    verbosity, autoclobber=not interactive, keepdb=keepdb)
                    verbosity,
                    autoclobber=not interactive,
                    keepdb=keepdb,
                    serialize=connection.settings_dict.get("TEST_SERIALIZE", True),
                )
                destroy = True
            else:
                connection.settings_dict['NAME'] = test_db_name
Loading