Loading django/contrib/gis/db/backends/postgis/schema.py +27 −22 Original line number Diff line number Diff line Loading @@ -10,36 +10,27 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): sql_alter_geometry_column_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL" sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s USING %(index_type)s (%(column)s %(ops)s)" def __init__(self, *args, **kwargs): super(PostGISSchemaEditor, self).__init__(*args, **kwargs) self.geometry_sql = [] def geo_quote_name(self, name): return self.connection.ops.geo_quote_name(name) def create_model(self, model): def column_sql(self, model, field, include_default=False): from django.contrib.gis.db.models.fields import GeometryField # Do model creation first super(PostGISSchemaEditor, self).create_model(model) # Now add any spatial field SQL sqls = [] for field in model._meta.local_fields: if isinstance(field, GeometryField): sqls.extend(self.spatial_field_sql(model, field)) for sql in sqls: self.execute(sql) def spatial_field_sql(self, model, field): """ Takes a GeometryField and returns a list of SQL to execute to create its spatial indexes. """ output = [] if not isinstance(field, GeometryField): return super(PostGISSchemaEditor, self).column_sql(model, field, include_default) if field.geography or self.connection.ops.geometry: # Geography and Geometry (PostGIS 2.0+) columns are # created normally. pass column_sql = super(PostGISSchemaEditor, self).column_sql(model, field, include_default) else: column_sql = None, None # Geometry columns are created by the `AddGeometryColumn` # stored procedure. output.append( self.geometry_sql.append( self.sql_add_geometry_column % { "table": self.geo_quote_name(model._meta.db_table), "column": self.geo_quote_name(field.column), Loading @@ -48,8 +39,9 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "dim": field.dim, } ) if not field.null: output.append( self.geometry_sql.append( self.sql_alter_geometry_column_not_null % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(field.column), Loading @@ -72,7 +64,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): index_ops = '' else: index_ops = self.geom_index_ops output.append( self.geometry_sql.append( self.sql_add_spatial_index % { "index": self.quote_name('%s_%s_id' % (model._meta.db_table, field.column)), "table": self.quote_name(model._meta.db_table), Loading @@ -81,5 +73,18 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "ops": index_ops, } ) return column_sql return output def create_model(self, model): super(PostGISSchemaEditor, self).create_model(model) # Create geometry columns for sql in self.geometry_sql: self.execute(sql) self.geometry_sql = [] def add_field(self, model, field): super(PostGISSchemaEditor, self).add_field(model, field) # Create geometry columns for sql in self.geometry_sql: self.execute(sql) self.geometry_sql = [] django/contrib/gis/tests/gis_migrations/test_operations.py 0 → 100644 +67 −0 Original line number Diff line number Diff line from __future__ import unicode_literals from unittest import skipUnless from django.contrib.gis.tests.utils import HAS_SPATIAL_DB from django.db import connection, migrations, models from django.db.migrations.migration import Migration from django.db.migrations.state import ProjectState from django.test import TransactionTestCase if HAS_SPATIAL_DB: from django.contrib.gis.db.models import fields @skipUnless(HAS_SPATIAL_DB, "Spatial db is required.") class OperationTests(TransactionTestCase): available_apps = ["django.contrib.gis.tests.gis_migrations"] def get_table_description(self, table): with connection.cursor() as cursor: return connection.introspection.get_table_description(cursor, table) def assertColumnExists(self, table, column): self.assertIn(column, [c.name for c in self.get_table_description(table)]) def apply_operations(self, app_label, project_state, operations): migration = Migration('name', app_label) migration.operations = operations with connection.schema_editor() as editor: return migration.apply(project_state, editor) def set_up_test_model(self): operations = [migrations.CreateModel( "Neighborhood", [ ("id", models.AutoField(primary_key=True)), ('name', models.CharField(max_length=100, unique=True)), ('geom', fields.MultiPolygonField(srid=4326, null=True)), ], )] return self.apply_operations('gis', ProjectState(), operations) def test_add_gis_field(self): """ Tests the AddField operation with a GIS-enabled column. """ project_state = self.set_up_test_model() operation = migrations.AddField( "Neighborhood", "path", fields.LineStringField(srid=4326, null=True, blank=True), ) new_state = project_state.clone() operation.state_forwards("gis", new_state) with connection.schema_editor() as editor: operation.database_forwards("gis", editor, project_state, new_state) self.assertColumnExists("gis_neighborhood", "path") # Test GeometryColumns when available try: from django.contrib.gis.models import GeometryColumns except ImportError: return self.assertEqual( GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), 2 ) Loading
django/contrib/gis/db/backends/postgis/schema.py +27 −22 Original line number Diff line number Diff line Loading @@ -10,36 +10,27 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): sql_alter_geometry_column_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL" sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s USING %(index_type)s (%(column)s %(ops)s)" def __init__(self, *args, **kwargs): super(PostGISSchemaEditor, self).__init__(*args, **kwargs) self.geometry_sql = [] def geo_quote_name(self, name): return self.connection.ops.geo_quote_name(name) def create_model(self, model): def column_sql(self, model, field, include_default=False): from django.contrib.gis.db.models.fields import GeometryField # Do model creation first super(PostGISSchemaEditor, self).create_model(model) # Now add any spatial field SQL sqls = [] for field in model._meta.local_fields: if isinstance(field, GeometryField): sqls.extend(self.spatial_field_sql(model, field)) for sql in sqls: self.execute(sql) def spatial_field_sql(self, model, field): """ Takes a GeometryField and returns a list of SQL to execute to create its spatial indexes. """ output = [] if not isinstance(field, GeometryField): return super(PostGISSchemaEditor, self).column_sql(model, field, include_default) if field.geography or self.connection.ops.geometry: # Geography and Geometry (PostGIS 2.0+) columns are # created normally. pass column_sql = super(PostGISSchemaEditor, self).column_sql(model, field, include_default) else: column_sql = None, None # Geometry columns are created by the `AddGeometryColumn` # stored procedure. output.append( self.geometry_sql.append( self.sql_add_geometry_column % { "table": self.geo_quote_name(model._meta.db_table), "column": self.geo_quote_name(field.column), Loading @@ -48,8 +39,9 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "dim": field.dim, } ) if not field.null: output.append( self.geometry_sql.append( self.sql_alter_geometry_column_not_null % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(field.column), Loading @@ -72,7 +64,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): index_ops = '' else: index_ops = self.geom_index_ops output.append( self.geometry_sql.append( self.sql_add_spatial_index % { "index": self.quote_name('%s_%s_id' % (model._meta.db_table, field.column)), "table": self.quote_name(model._meta.db_table), Loading @@ -81,5 +73,18 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "ops": index_ops, } ) return column_sql return output def create_model(self, model): super(PostGISSchemaEditor, self).create_model(model) # Create geometry columns for sql in self.geometry_sql: self.execute(sql) self.geometry_sql = [] def add_field(self, model, field): super(PostGISSchemaEditor, self).add_field(model, field) # Create geometry columns for sql in self.geometry_sql: self.execute(sql) self.geometry_sql = []
django/contrib/gis/tests/gis_migrations/test_operations.py 0 → 100644 +67 −0 Original line number Diff line number Diff line from __future__ import unicode_literals from unittest import skipUnless from django.contrib.gis.tests.utils import HAS_SPATIAL_DB from django.db import connection, migrations, models from django.db.migrations.migration import Migration from django.db.migrations.state import ProjectState from django.test import TransactionTestCase if HAS_SPATIAL_DB: from django.contrib.gis.db.models import fields @skipUnless(HAS_SPATIAL_DB, "Spatial db is required.") class OperationTests(TransactionTestCase): available_apps = ["django.contrib.gis.tests.gis_migrations"] def get_table_description(self, table): with connection.cursor() as cursor: return connection.introspection.get_table_description(cursor, table) def assertColumnExists(self, table, column): self.assertIn(column, [c.name for c in self.get_table_description(table)]) def apply_operations(self, app_label, project_state, operations): migration = Migration('name', app_label) migration.operations = operations with connection.schema_editor() as editor: return migration.apply(project_state, editor) def set_up_test_model(self): operations = [migrations.CreateModel( "Neighborhood", [ ("id", models.AutoField(primary_key=True)), ('name', models.CharField(max_length=100, unique=True)), ('geom', fields.MultiPolygonField(srid=4326, null=True)), ], )] return self.apply_operations('gis', ProjectState(), operations) def test_add_gis_field(self): """ Tests the AddField operation with a GIS-enabled column. """ project_state = self.set_up_test_model() operation = migrations.AddField( "Neighborhood", "path", fields.LineStringField(srid=4326, null=True, blank=True), ) new_state = project_state.clone() operation.state_forwards("gis", new_state) with connection.schema_editor() as editor: operation.database_forwards("gis", editor, project_state, new_state) self.assertColumnExists("gis_neighborhood", "path") # Test GeometryColumns when available try: from django.contrib.gis.models import GeometryColumns except ImportError: return self.assertEqual( GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), 2 )