Commit 48c4ea41 authored by Claude Paroz's avatar Claude Paroz
Browse files

Used migration framework in GIS test tearDown

parent 3c06b2f2
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor):
    sql_drop_geometry_column = "SELECT DropGeometryColumn(%(table)s, %(column)s)"
    sql_alter_geometry_column_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL"
    sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s USING %(index_type)s (%(column)s %(ops)s)"
    sql_clear_geometry_columns = "DELETE FROM geometry_columns WHERE f_table_name = %(table)s"

    def __init__(self, *args, **kwargs):
        super(PostGISSchemaEditor, self).__init__(*args, **kwargs)
@@ -46,7 +47,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor):
                    self.sql_alter_geometry_column_not_null % {
                        "table": self.quote_name(model._meta.db_table),
                        "column": self.quote_name(field.column),
                    },
                    }
                )

        if field.spatial_index:
@@ -83,6 +84,12 @@ class PostGISSchemaEditor(DatabaseSchemaEditor):
            self.execute(sql)
        self.geometry_sql = []

    def delete_model(self, model):
        super(PostGISSchemaEditor, self).delete_model(model)
        self.execute(self.sql_clear_geometry_columns % {
            "table": self.geo_quote_name(model._meta.db_table),
        })

    def add_field(self, model, field):
        super(PostGISSchemaEditor, self).add_field(model, field)
        # Create geometry columns
+3 −10
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ from django.contrib.gis.tests.utils import HAS_SPATIAL_DB
from django.db import connection, migrations, models
from django.db.migrations.migration import Migration
from django.db.migrations.state import ProjectState
from django.db.utils import DatabaseError
from django.test import TransactionTestCase

if HAS_SPATIAL_DB:
@@ -24,15 +23,7 @@ class OperationTests(TransactionTestCase):

    def tearDown(self):
        # Delete table after testing
        with connection.cursor() as cursor:
            try:
                cursor.execute("DROP TABLE %s" % connection.ops.quote_name("gis_neighborhood"))
            except DatabaseError:
                pass
            else:
                if HAS_GEOMETRY_COLUMNS:
                    cursor.execute("DELETE FROM geometry_columns WHERE %s = %%s" % (
                        GeometryColumns.table_name_col(),), ["gis_neighborhood"])
        self.apply_operations('gis', self.current_state, [migrations.DeleteModel("Neighborhood")])
        super(OperationTests, self).tearDown()

    def get_table_description(self, table):
@@ -84,6 +75,7 @@ class OperationTests(TransactionTestCase):
                GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(),
                2
            )
        self.current_state = new_state

    def test_remove_gis_field(self):
        """
@@ -103,3 +95,4 @@ class OperationTests(TransactionTestCase):
                GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(),
                0
            )
        self.current_state = new_state