Commit be87f0b0 authored by Karen Tracey's avatar Karen Tracey
Browse files

Fixed #3615: Added support for loading fixtures with forward references on...

Fixed #3615: Added support for loading fixtures with forward references on database backends (such as MySQL/InnoDB) that do not support deferred constraint checking. Many thanks to jsdalton for coming up with a clever solution to this long-standing issue, and to jacob, ramiro, graham_king, and russellm for review/testing. (Apologies if I missed anyone else who helped here.)

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16590 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent e3c89346
Loading
Loading
Loading
Loading
+18 −6
Original line number Diff line number Diff line
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement

import sys
import os
import gzip
@@ -166,12 +170,20 @@ class Command(BaseCommand):
                                    (format, fixture_name, humanize(fixture_dir)))
                            try:
                                objects = serializers.deserialize(format, fixture, using=using)

                                with connection.constraint_checks_disabled():
                                    for obj in objects:
                                        objects_in_fixture += 1
                                        if router.allow_syncdb(using, obj.object.__class__):
                                            loaded_objects_in_fixture += 1
                                            models.add(obj.object.__class__)
                                            obj.save(using=using)

                                # Since we disabled constraint checks, we must manually check for
                                # any invalid keys that might have been added
                                table_names = [model._meta.db_table for model in models]
                                connection.check_constraints(table_names=table_names)

                                loaded_object_count += loaded_objects_in_fixture
                                fixture_object_count += objects_in_fixture
                                label_found = True
+43 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ try:
except ImportError:
    import dummy_thread as thread
from threading import local
from contextlib import contextmanager

from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
@@ -238,6 +239,35 @@ class BaseDatabaseWrapper(local):
        if self.savepoint_state:
            self._savepoint_commit(sid)

    @contextmanager
    def constraint_checks_disabled(self):
        disabled = self.disable_constraint_checking()
        try:
            yield
        finally:
            if disabled:
                self.enable_constraint_checking()

    def disable_constraint_checking(self):
        """
        Backends can implement as needed to temporarily disable foreign key constraint
        checking.
        """
        pass

    def enable_constraint_checking(self):
        """
        Backends can implement as needed to re-enable foreign key constraint checking.
        """
        pass

    def check_constraints(self, table_names=None):
        """
        Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS
        ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered.
        """
        pass

    def close(self):
        if self.connection is not None:
            self.connection.close()
@@ -869,6 +899,19 @@ class BaseDatabaseIntrospection(object):

        return sequence_list

    def get_key_columns(self, cursor, table_name):
        """
        Backends can override this to return a list of (column_name, referenced_table_name,
        referenced_column_name) for all key columns in given table.
        """
        raise NotImplementedError

    def get_primary_key_column(self, cursor, table_name):
        """
        Backends can override this to return the column name of the primary key for the given table.
        """
        raise NotImplementedError

class BaseDatabaseClient(object):
    """
    This class encapsulates all backend-specific methods for opening a
+1 −0
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
    get_table_description = complain
    get_relations = complain
    get_indexes = complain
    get_key_columns = complain

class DatabaseWrapper(BaseDatabaseWrapper):
    operators = {}
+49 −0
Original line number Diff line number Diff line
@@ -349,3 +349,52 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info())
            self.server_version = tuple([int(x) for x in m.groups()])
        return self.server_version

    def disable_constraint_checking(self):
        """
        Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
        to indicate constraint checks need to be re-enabled.
        """
        self.cursor().execute('SET foreign_key_checks=0')
        return True

    def enable_constraint_checking(self):
        """
        Re-enable foreign key checks after they have been disabled.
        """
        self.cursor().execute('SET foreign_key_checks=1')

    def check_constraints(self, table_names=None):
        """
        Checks each table name in table-names for rows with invalid foreign key references. This method is
        intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to
        determine if rows with invalid references were entered while constraint checks were off.

        Raises an IntegrityError on the first invalid foreign key reference encountered (if any) and provides
        detailed information about the invalid reference in the error message.

        Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS
        ALL IMMEDIATE")
        """
        cursor = self.cursor()
        if table_names is None:
            table_names = self.introspection.get_table_list(cursor)
        for table_name in table_names:
            primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
            if not primary_key_column_name:
                continue
            key_columns = self.introspection.get_key_columns(cursor, table_name)
            for column_name, referenced_table_name, referenced_column_name in key_columns:
                cursor.execute("""
                    SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
                    LEFT JOIN `%s` as REFERRED
                    ON (REFERRING.`%s` = REFERRED.`%s`)
                    WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL"""
                    % (primary_key_column_name, column_name, table_name, referenced_table_name,
                    column_name, referenced_column_name, column_name, referenced_column_name))
                for bad_row in cursor.fetchall():
                    raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid "
                        "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
                        % (table_name, bad_row[0],
                        table_name, column_name, bad_row[1],
                        referenced_table_name, referenced_column_name))
+24 −10
Original line number Diff line number Diff line
@@ -51,10 +51,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
        representing all relationships to the given table. Indexes are 0-based.
        """
        my_field_dict = self._name_to_index(cursor, table_name)
        constraints = []
        constraints = self.get_key_columns(cursor, table_name)
        relations = {}
        for my_fieldname, other_table, other_field in constraints:
            other_field_index = self._name_to_index(cursor, other_table)[other_field]
            my_field_index = my_field_dict[my_fieldname]
            relations[my_field_index] = (other_field_index, other_table)
        return relations

    def get_key_columns(self, cursor, table_name):
        """
        Returns a list of (column_name, referenced_table_name, referenced_column_name) for all
        key columns in given table.
        """
        key_columns = []
        try:
            # This should work for MySQL 5.0.
            cursor.execute("""
                SELECT column_name, referenced_table_name, referenced_column_name
                FROM information_schema.key_column_usage
@@ -62,7 +73,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                    AND table_schema = DATABASE()
                    AND referenced_table_name IS NOT NULL
                    AND referenced_column_name IS NOT NULL""", [table_name])
            constraints.extend(cursor.fetchall())
            key_columns.extend(cursor.fetchall())
        except (ProgrammingError, OperationalError):
            # Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
            # Go through all constraints and save the equal matches.
@@ -74,14 +85,17 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
                    if match == None:
                        break
                    pos = match.end()
                    constraints.append(match.groups())
                    key_columns.append(match.groups())
        return key_columns

        for my_fieldname, other_table, other_field in constraints:
            other_field_index = self._name_to_index(cursor, other_table)[other_field]
            my_field_index = my_field_dict[my_fieldname]
            relations[my_field_index] = (other_field_index, other_table)

        return relations
    def get_primary_key_column(self, cursor, table_name):
        """
        Returns the name of the primary key column for the given table
        """
        for column in self.get_indexes(cursor, table_name).iteritems():
            if column[1]['primary_key']:
                return column[0]
        return None

    def get_indexes(self, cursor, table_name):
        """
Loading