diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index a9a90253ec..24bd45c1da 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -152,7 +152,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_constraints(self, cursor, table_name): """ - Retrieves any constraints or keys (unique, pk, fk, check, index) across one or more columns. + Retrieve any constraints or keys (unique, pk, fk, check, index) across + one or more columns. Also retrieve the definition of expression-based + indexes. """ constraints = {} # Loop over the key table, collecting things as constraints. The column @@ -191,15 +193,20 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None, "check": kind == "c", "index": False, + "definition": None, } # Now get indexes cursor.execute(""" SELECT indexname, array_agg(attname), indisunique, indisprimary, - array_agg(ordering), amname + array_agg(ordering), amname, exprdef FROM ( SELECT c2.relname as indexname, idx.*, attr.attname, am.amname, + CASE + WHEN idx.indexprs IS NOT NULL THEN + pg_get_indexdef(idx.indexrelid) + END AS exprdef, CASE WHEN am.amcanorder THEN CASE (option & 1) @@ -217,18 +224,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key WHERE c.relname = %s ) s2 - GROUP BY indexname, indisunique, indisprimary, amname; + GROUP BY indexname, indisunique, indisprimary, amname, exprdef; """, [table_name]) - for index, columns, unique, primary, orders, type_ in cursor.fetchall(): + for index, columns, unique, primary, orders, type_, definition in cursor.fetchall(): if index not in constraints: constraints[index] = { - "columns": columns, - "orders": orders, + "columns": columns if columns != [None] else [], + "orders": orders if orders != [None] else [], "primary_key": primary, "unique": unique, "foreign_key": None, "check": False, "index": True, "type": type_, + "definition": definition, } return constraints diff --git a/tests/gis_tests/gis_migrations/test_operations.py b/tests/gis_tests/gis_migrations/test_operations.py index 84735d8fbd..ff1a621c6c 100644 --- a/tests/gis_tests/gis_migrations/test_operations.py +++ b/tests/gis_tests/gis_migrations/test_operations.py @@ -67,10 +67,17 @@ class OperationTests(TransactionTestCase): expected_count ) - def assertSpatialIndexExists(self, table, column): + def assertSpatialIndexExists(self, table, column, raster=False): with connection.cursor() as cursor: constraints = connection.introspection.get_constraints(cursor, table) - self.assertIn([column], [c['columns'] for c in constraints.values()]) + if raster: + self.assertTrue(any( + 'st_convexhull(%s)' % column in c['definition'] + for c in constraints.values() + if c['definition'] is not None + )) + else: + self.assertIn([column], [c['columns'] for c in constraints.values()]) def alter_gis_model(self, migration_class, model_name, field_name, blank=False, field_class=None): @@ -111,7 +118,7 @@ class OperationTests(TransactionTestCase): # Test spatial indices when available if self.has_spatial_indexes: - self.assertSpatialIndexExists('gis_neighborhood', 'heatmap') + self.assertSpatialIndexExists('gis_neighborhood', 'heatmap', raster=True) @skipIfDBFeature('supports_raster') def test_create_raster_model_on_db_without_raster_support(self): @@ -159,7 +166,7 @@ class OperationTests(TransactionTestCase): # Test spatial indices when available if self.has_spatial_indexes: - self.assertSpatialIndexExists('gis_neighborhood', 'heatmap') + self.assertSpatialIndexExists('gis_neighborhood', 'heatmap', raster=True) def test_remove_geom_field(self): """ @@ -189,7 +196,7 @@ class OperationTests(TransactionTestCase): self.assertSpatialIndexExists('gis_neighborhood', 'geom') if connection.features.supports_raster: - self.assertSpatialIndexExists('gis_neighborhood', 'rast') + self.assertSpatialIndexExists('gis_neighborhood', 'rast', raster=True) @property def has_spatial_indexes(self):