Loading django/db/backends/postgresql_psycopg2/schema.py +54 −1 Original line number Diff line number Diff line Loading @@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): pass sql_create_sequence = "CREATE SEQUENCE %(sequence)s" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" def _alter_column_type_sql(self, table, column, type): """ Makes ALTER TYPE with SERIAL make sense. """ if type.lower() == "serial": sequence_name = "%s_%s_seq" % (table, column) return ( ( self.sql_alter_column_type % { "column": self.quote_name(column), "type": "integer", }, [], ), [ ( self.sql_delete_sequence % { "sequence": sequence_name, }, [], ), ( self.sql_create_sequence % { "sequence": sequence_name, }, [], ), ( self.sql_alter_column % { "table": table, "changes": self.sql_alter_column_default % { "column": column, "default": "nextval('%s')" % sequence_name, } }, [], ), ( self.sql_set_sequence_max % { "table": table, "column": column, "sequence": sequence_name, }, [], ), ], ) else: return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type) django/db/backends/schema.py +55 −10 Original line number Diff line number Diff line Loading @@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object): "name": fk_name, } ) # Drop incoming FK constraints if we're a primary key and things are going # to change. if old_field.primary_key and new_field.primary_key and old_type != new_type: for rel in new_field.model._meta.get_all_related_objects(): rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True) for fk_name in rel_fk_names: self.execute( self.sql_delete_fk % { "table": self.quote_name(rel.model._meta.db_table), "name": fk_name, } ) # Change check constraints? if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: constraint_names = self._constraint_names(model, [old_field.column], check=True) Loading @@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object): }) # Next, start accumulating actions to do actions = [] post_actions = [] # Type change? if old_type != new_type: actions.append(( self.sql_alter_column_type % { "column": self.quote_name(new_field.column), "type": new_type, }, [], )) fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type) actions.append(fragment) post_actions.extend(other_actions) # Default change? old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) Loading Loading @@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object): }, params, ) if post_actions: for sql, params in post_actions: self.execute(sql, params) # Added a unique? if not old_field.unique and new_field.unique: self.execute( Loading @@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object): # referring to us. rels_to_update = [] if old_field.primary_key and new_field.primary_key and old_type != new_type: rels_to_update.extend(model._meta.get_all_related_objects()) rels_to_update.extend(new_field.model._meta.get_all_related_objects()) # Changed to become primary key? # Note that we don't detect unsetting of a PK, as we assume another field # will always come along and replace it. Loading Loading @@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object): } ) # Update all referencing columns rels_to_update.extend(model._meta.get_all_related_objects()) # Handle out type alters on the other end of rels from the PK stuff above rels_to_update.extend(new_field.model._meta.get_all_related_objects()) # Handle our type alters on the other end of rels from the PK stuff above for rel in rels_to_update: rel_db_params = rel.field.db_parameters(connection=self.connection) rel_type = rel_db_params['type'] Loading @@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object): "to_column": self.quote_name(new_field.rel.get_related_field().column), } ) # Rebuild FKs that pointed to us if we previously had to drop them if old_field.primary_key and new_field.primary_key and old_type != new_type: for rel in new_field.model._meta.get_all_related_objects(): self.execute( self.sql_create_fk % { "table": self.quote_name(rel.model._meta.db_table), "name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"), "column": self.quote_name(rel.field.column), "to_table": self.quote_name(model._meta.db_table), "to_column": self.quote_name(new_field.column), } ) # Does it have check constraints we need to add? if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: self.execute( Loading @@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object): if self.connection.features.connection_persists_old_columns: self.connection.close() def _alter_column_type_sql(self, table, column, type): """ Hook to specialise column type alteration for different backends, for cases when a creation type is different to an alteration type (e.g. SERIAL in PostgreSQL, PostGIS fields). Should return two things; an SQL fragment of (sql, params) to insert into an ALTER TABLE statement, and a list of extra (sql, params) tuples to run once the field is altered. """ return ( ( self.sql_alter_column_type % { "column": self.quote_name(column), "type": type, }, [], ), [], ) def _alter_many_to_many(self, model, old_field, new_field, strict): """ Alters M2Ms to repoint their to= endpoints. Loading django/db/migrations/operations/fields.py +7 −5 Original line number Diff line number Diff line Loading @@ -24,9 +24,10 @@ class AddField(Operation): state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) Loading Loading @@ -73,9 +74,10 @@ class RemoveField(Operation): schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def describe(self): return "Remove field %s from %s" % (self.name, self.model_name) Loading Loading @@ -107,7 +109,7 @@ class AlterField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0], ) Loading Loading @@ -153,7 +155,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.new_name)[0], ) Loading @@ -163,7 +165,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.old_name)[0], ) Loading tests/migrations/test_operations.py +56 −1 Original line number Diff line number Diff line import unittest from django.db import connection, models, migrations, router from django.db.models.fields import NOT_PROVIDED from django.db.transaction import atomic Loading @@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase): both forwards and backwards. """ def set_up_test_model(self, app_label, second_model=False): def set_up_test_model(self, app_label, second_model=False, related_model=False): """ Creates a test model state and database table. """ Loading @@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase): )] if second_model: operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) if related_model: operations.append(migrations.CreateModel( "Rider", [ ("id", models.AutoField(primary_key=True)), ("pony", models.ForeignKey("Pony")), ], )) project_state = ProjectState() for operation in operations: operation.state_forwards(app_label, project_state) Loading Loading @@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase): operation.database_backwards("test_alfl", editor, new_state, project_state) self.assertColumnNotNull("test_alfl_pony", "pink") def test_alter_field_pk(self): """ Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness) """ project_state = self.set_up_test_model("test_alflpk") # Test the state alteration operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True)) new_state = project_state.clone() operation.state_forwards("test_alflpk", new_state) self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField) # Test the database alteration with connection.schema_editor() as editor: operation.database_forwards("test_alflpk", editor, project_state, new_state) # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_alflpk", editor, new_state, project_state) @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support") def test_alter_field_pk_fk(self): """ Tests the AlterField operation on primary keys changes any FKs pointing to it. """ project_state = self.set_up_test_model("test_alflpkfk", related_model=True) # Test the state alteration operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True)) new_state = project_state.clone() operation.state_forwards("test_alflpkfk", new_state) self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) # Test the database alteration id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) with connection.schema_editor() as editor: operation.database_forwards("test_alflpkfk", editor, project_state, new_state) id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_alflpkfk", editor, new_state, project_state) id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) def test_rename_field(self): """ Tests the RenameField operation. Loading tests/schema/tests.py +1 −0 Original line number Diff line number Diff line Loading @@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase): # Alter to change the PK new_field = SlugField(primary_key=True) new_field.set_attributes_from_name("slug") new_field.model = Tag with connection.schema_editor() as editor: editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.alter_field( Loading Loading
django/db/backends/postgresql_psycopg2/schema.py +54 −1 Original line number Diff line number Diff line Loading @@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): pass sql_create_sequence = "CREATE SEQUENCE %(sequence)s" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" def _alter_column_type_sql(self, table, column, type): """ Makes ALTER TYPE with SERIAL make sense. """ if type.lower() == "serial": sequence_name = "%s_%s_seq" % (table, column) return ( ( self.sql_alter_column_type % { "column": self.quote_name(column), "type": "integer", }, [], ), [ ( self.sql_delete_sequence % { "sequence": sequence_name, }, [], ), ( self.sql_create_sequence % { "sequence": sequence_name, }, [], ), ( self.sql_alter_column % { "table": table, "changes": self.sql_alter_column_default % { "column": column, "default": "nextval('%s')" % sequence_name, } }, [], ), ( self.sql_set_sequence_max % { "table": table, "column": column, "sequence": sequence_name, }, [], ), ], ) else: return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type)
django/db/backends/schema.py +55 −10 Original line number Diff line number Diff line Loading @@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object): "name": fk_name, } ) # Drop incoming FK constraints if we're a primary key and things are going # to change. if old_field.primary_key and new_field.primary_key and old_type != new_type: for rel in new_field.model._meta.get_all_related_objects(): rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True) for fk_name in rel_fk_names: self.execute( self.sql_delete_fk % { "table": self.quote_name(rel.model._meta.db_table), "name": fk_name, } ) # Change check constraints? if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: constraint_names = self._constraint_names(model, [old_field.column], check=True) Loading @@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object): }) # Next, start accumulating actions to do actions = [] post_actions = [] # Type change? if old_type != new_type: actions.append(( self.sql_alter_column_type % { "column": self.quote_name(new_field.column), "type": new_type, }, [], )) fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type) actions.append(fragment) post_actions.extend(other_actions) # Default change? old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) Loading Loading @@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object): }, params, ) if post_actions: for sql, params in post_actions: self.execute(sql, params) # Added a unique? if not old_field.unique and new_field.unique: self.execute( Loading @@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object): # referring to us. rels_to_update = [] if old_field.primary_key and new_field.primary_key and old_type != new_type: rels_to_update.extend(model._meta.get_all_related_objects()) rels_to_update.extend(new_field.model._meta.get_all_related_objects()) # Changed to become primary key? # Note that we don't detect unsetting of a PK, as we assume another field # will always come along and replace it. Loading Loading @@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object): } ) # Update all referencing columns rels_to_update.extend(model._meta.get_all_related_objects()) # Handle out type alters on the other end of rels from the PK stuff above rels_to_update.extend(new_field.model._meta.get_all_related_objects()) # Handle our type alters on the other end of rels from the PK stuff above for rel in rels_to_update: rel_db_params = rel.field.db_parameters(connection=self.connection) rel_type = rel_db_params['type'] Loading @@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object): "to_column": self.quote_name(new_field.rel.get_related_field().column), } ) # Rebuild FKs that pointed to us if we previously had to drop them if old_field.primary_key and new_field.primary_key and old_type != new_type: for rel in new_field.model._meta.get_all_related_objects(): self.execute( self.sql_create_fk % { "table": self.quote_name(rel.model._meta.db_table), "name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"), "column": self.quote_name(rel.field.column), "to_table": self.quote_name(model._meta.db_table), "to_column": self.quote_name(new_field.column), } ) # Does it have check constraints we need to add? if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: self.execute( Loading @@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object): if self.connection.features.connection_persists_old_columns: self.connection.close() def _alter_column_type_sql(self, table, column, type): """ Hook to specialise column type alteration for different backends, for cases when a creation type is different to an alteration type (e.g. SERIAL in PostgreSQL, PostGIS fields). Should return two things; an SQL fragment of (sql, params) to insert into an ALTER TABLE statement, and a list of extra (sql, params) tuples to run once the field is altered. """ return ( ( self.sql_alter_column_type % { "column": self.quote_name(column), "type": type, }, [], ), [], ) def _alter_many_to_many(self, model, old_field, new_field, strict): """ Alters M2Ms to repoint their to= endpoints. Loading
django/db/migrations/operations/fields.py +7 −5 Original line number Diff line number Diff line Loading @@ -24,9 +24,10 @@ class AddField(Operation): state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) Loading Loading @@ -73,9 +74,10 @@ class RemoveField(Operation): schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def describe(self): return "Remove field %s from %s" % (self.name, self.model_name) Loading Loading @@ -107,7 +109,7 @@ class AlterField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0], ) Loading Loading @@ -153,7 +155,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.new_name)[0], ) Loading @@ -163,7 +165,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( to_model, from_model, from_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.old_name)[0], ) Loading
tests/migrations/test_operations.py +56 −1 Original line number Diff line number Diff line import unittest from django.db import connection, models, migrations, router from django.db.models.fields import NOT_PROVIDED from django.db.transaction import atomic Loading @@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase): both forwards and backwards. """ def set_up_test_model(self, app_label, second_model=False): def set_up_test_model(self, app_label, second_model=False, related_model=False): """ Creates a test model state and database table. """ Loading @@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase): )] if second_model: operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) if related_model: operations.append(migrations.CreateModel( "Rider", [ ("id", models.AutoField(primary_key=True)), ("pony", models.ForeignKey("Pony")), ], )) project_state = ProjectState() for operation in operations: operation.state_forwards(app_label, project_state) Loading Loading @@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase): operation.database_backwards("test_alfl", editor, new_state, project_state) self.assertColumnNotNull("test_alfl_pony", "pink") def test_alter_field_pk(self): """ Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness) """ project_state = self.set_up_test_model("test_alflpk") # Test the state alteration operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True)) new_state = project_state.clone() operation.state_forwards("test_alflpk", new_state) self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField) # Test the database alteration with connection.schema_editor() as editor: operation.database_forwards("test_alflpk", editor, project_state, new_state) # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_alflpk", editor, new_state, project_state) @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support") def test_alter_field_pk_fk(self): """ Tests the AlterField operation on primary keys changes any FKs pointing to it. """ project_state = self.set_up_test_model("test_alflpkfk", related_model=True) # Test the state alteration operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True)) new_state = project_state.clone() operation.state_forwards("test_alflpkfk", new_state) self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) # Test the database alteration id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) with connection.schema_editor() as editor: operation.database_forwards("test_alflpkfk", editor, project_state, new_state) id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) # And test reversal with connection.schema_editor() as editor: operation.database_backwards("test_alflpkfk", editor, new_state, project_state) id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] self.assertEqual(id_type, fk_type) def test_rename_field(self): """ Tests the RenameField operation. Loading
tests/schema/tests.py +1 −0 Original line number Diff line number Diff line Loading @@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase): # Alter to change the PK new_field = SlugField(primary_key=True) new_field.set_attributes_from_name("slug") new_field.model = Tag with connection.schema_editor() as editor: editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.alter_field( Loading