mirror of
				https://github.com/django/django.git
				synced 2025-10-30 09:06:13 +00:00 
			
		
		
		
	Fixed #25605 -- Made GIS DB functions accept geometric expressions, not only values, in all positions.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							e487ffd3f0
						
					
				
				
					commit
					bde86ce9ae
				
			| @@ -1,8 +1,6 @@ | ||||
| from decimal import Decimal | ||||
|  | ||||
| from django.contrib.gis.db.models.fields import ( | ||||
|     BaseSpatialField, GeometryField, RasterField, | ||||
| ) | ||||
| from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField | ||||
| from django.contrib.gis.db.models.sql import AreaField | ||||
| from django.contrib.gis.geometry.backend import Geometry | ||||
| from django.contrib.gis.measure import ( | ||||
| @@ -13,6 +11,7 @@ from django.db.models import ( | ||||
|     BooleanField, FloatField, IntegerField, TextField, Transform, | ||||
| ) | ||||
| from django.db.models.expressions import Func, Value | ||||
| from django.db.models.functions import Cast | ||||
|  | ||||
| NUMERIC_TYPES = (int, float, Decimal) | ||||
|  | ||||
| @@ -20,26 +19,37 @@ NUMERIC_TYPES = (int, float, Decimal) | ||||
| class GeoFuncMixin: | ||||
|     function = None | ||||
|     output_field_class = None | ||||
|     geom_param_pos = 0 | ||||
|     geom_param_pos = (0,) | ||||
|  | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if 'output_field' not in extra and self.output_field_class: | ||||
|             extra['output_field'] = self.output_field_class() | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|         # Ensure that value expressions are geometric. | ||||
|         for pos in self.geom_param_pos: | ||||
|             expr = self.source_expressions[pos] | ||||
|             if not isinstance(expr, Value): | ||||
|                 continue | ||||
|             try: | ||||
|                 output_field = expr.output_field | ||||
|             except FieldError: | ||||
|                 output_field = None | ||||
|             geom = expr.value | ||||
|             if not isinstance(geom, Geometry) or output_field and not isinstance(output_field, GeometryField): | ||||
|                 raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1)) | ||||
|             if not geom.srid and not output_field: | ||||
|                 raise ValueError("SRID is required for all geometries.") | ||||
|             if not output_field: | ||||
|                 self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid)) | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self.__class__.__name__ | ||||
|  | ||||
|     @property | ||||
|     def srid(self): | ||||
|         expr = self.source_expressions[self.geom_param_pos] | ||||
|         if hasattr(expr, 'srid'): | ||||
|             return expr.srid | ||||
|         try: | ||||
|             return expr.field.srid | ||||
|         except (AttributeError, FieldError): | ||||
|             return None | ||||
|         return self.source_expressions[self.geom_param_pos[0]].field.srid | ||||
|  | ||||
|     @property | ||||
|     def geo_field(self): | ||||
| @@ -48,19 +58,28 @@ class GeoFuncMixin: | ||||
|     def as_sql(self, compiler, connection, function=None, **extra_context): | ||||
|         if not self.function and not function: | ||||
|             function = connection.ops.spatial_function_name(self.name) | ||||
|         if any(isinstance(field, RasterField) for field in self.get_source_fields()): | ||||
|             raise TypeError("Geometry functions not supported for raster fields.") | ||||
|         return super().as_sql(compiler, connection, function=function, **extra_context) | ||||
|  | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         res = super().resolve_expression(*args, **kwargs) | ||||
|         base_srid = res.srid | ||||
|         if not base_srid: | ||||
|             raise TypeError("Geometry functions can only operate on geometric content.") | ||||
|  | ||||
|         for pos, expr in enumerate(res.source_expressions[1:], start=1): | ||||
|             if isinstance(expr, GeomValue) and expr.srid != base_srid: | ||||
|                 # Automatic SRID conversion so objects are comparable | ||||
|         # Ensure that expressions are geometric. | ||||
|         source_fields = res.get_source_fields() | ||||
|         for pos in self.geom_param_pos: | ||||
|             field = source_fields[pos] | ||||
|             if not isinstance(field, GeometryField): | ||||
|                 raise TypeError( | ||||
|                     "%s function requires a GeometryField in position %s, got %s." % ( | ||||
|                         self.name, pos + 1, type(field).__name__, | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|         base_srid = res.srid | ||||
|         for pos in self.geom_param_pos[1:]: | ||||
|             expr = res.source_expressions[pos] | ||||
|             expr_srid = expr.output_field.srid | ||||
|             if expr_srid != base_srid: | ||||
|                 # Automatic SRID conversion so objects are comparable. | ||||
|                 res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs) | ||||
|         return res | ||||
|  | ||||
| @@ -78,34 +97,16 @@ class GeoFunc(GeoFuncMixin, Func): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class GeomValue(Value): | ||||
|     geography = False | ||||
| class GeomOutputGeoFunc(GeoFunc): | ||||
|     def __init__(self, *expressions, **extra): | ||||
|         if 'output_field' not in extra: | ||||
|             extra['output_field'] = GeometryField() | ||||
|         super(GeomOutputGeoFunc, self).__init__(*expressions, **extra) | ||||
|  | ||||
|     @property | ||||
|     def srid(self): | ||||
|         return self.value.srid | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         return '%s(%%s, %s)' % (connection.ops.from_text, self.srid), [connection.ops.Adapter(self.value)] | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         return '%s(%%s)' % (connection.ops.from_text), [connection.ops.Adapter(self.value)] | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         if self.geography: | ||||
|             self.value = connection.ops.Adapter(self.value, geography=self.geography) | ||||
|         else: | ||||
|             self.value = connection.ops.Adapter(self.value) | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class GeoFuncWithGeoParam(GeoFunc): | ||||
|     def __init__(self, expression, geom, *expressions, **extra): | ||||
|         if not isinstance(geom, Geometry): | ||||
|             raise TypeError("Please provide a geometry object.") | ||||
|         if not hasattr(geom, 'srid') or not geom.srid: | ||||
|             raise ValueError("Please provide a geometry attribute with a defined SRID.") | ||||
|         super().__init__(expression, GeomValue(geom), *expressions, **extra) | ||||
|     def resolve_expression(self, *args, **kwargs): | ||||
|         res = super().resolve_expression(*args, **kwargs) | ||||
|         res.output_field.srid = res.srid | ||||
|         return res | ||||
|  | ||||
|  | ||||
| class SQLiteDecimalToFloatMixin: | ||||
| @@ -181,7 +182,7 @@ class AsGeoJSON(GeoFunc): | ||||
|  | ||||
|  | ||||
| class AsGML(GeoFunc): | ||||
|     geom_param_pos = 1 | ||||
|     geom_param_pos = (1,) | ||||
|     output_field_class = TextField | ||||
|  | ||||
|     def __init__(self, expression, version=2, precision=8, **extra): | ||||
| @@ -230,12 +231,13 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc): | ||||
|         return super(BoundingCircle, clone).as_oracle(compiler, connection) | ||||
|  | ||||
|  | ||||
| class Centroid(OracleToleranceMixin, GeoFunc): | ||||
| class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|  | ||||
| class Difference(OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
| class Difference(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 2 | ||||
|     geom_param_pos = (0, 1) | ||||
|  | ||||
|  | ||||
| class DistanceResultMixin: | ||||
| @@ -259,7 +261,8 @@ class DistanceResultMixin: | ||||
|         return value | ||||
|  | ||||
|  | ||||
| class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
| class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|     geom_param_pos = (0, 1) | ||||
|     output_field_class = FloatField | ||||
|     spheroid = None | ||||
|  | ||||
| @@ -273,13 +276,18 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         function = None | ||||
|         geo_field = GeometryField(srid=self.srid)  # Fake field to get SRID info | ||||
|         if self.source_is_geography(): | ||||
|             # Set parameters as geography if base field is geography | ||||
|             for pos, expr in enumerate( | ||||
|                     self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1): | ||||
|                 if isinstance(expr, GeomValue): | ||||
|                     expr.geography = True | ||||
|         elif geo_field.geodetic(connection): | ||||
|         expr2 = self.source_expressions[1] | ||||
|         geography = self.source_is_geography() | ||||
|         if expr2.output_field.geography != geography: | ||||
|             if isinstance(expr2, Value): | ||||
|                 expr2.output_field.geography = geography | ||||
|             else: | ||||
|                 self.source_expressions[1] = Cast( | ||||
|                     expr2, | ||||
|                     GeometryField(srid=expr2.output_field.srid, geography=geography), | ||||
|                 ) | ||||
|  | ||||
|         if not geography and geo_field.geodetic(connection): | ||||
|             # Geometry fields with geodetic (lon/lat) coordinates need special distance functions | ||||
|             if self.spheroid: | ||||
|                 # DistanceSpheroid is more accurate and resource intensive than DistanceSphere | ||||
| @@ -305,11 +313,11 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Envelope(GeoFunc): | ||||
| class Envelope(GeomOutputGeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|  | ||||
| class ForceRHR(GeoFunc): | ||||
| class ForceRHR(GeomOutputGeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|  | ||||
| @@ -323,8 +331,9 @@ class GeoHash(GeoFunc): | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|  | ||||
| class Intersection(OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
| class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 2 | ||||
|     geom_param_pos = (0, 1) | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| @@ -392,7 +401,7 @@ class NumPoints(GeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if self.source_expressions[self.geom_param_pos].output_field.geom_type != 'LINESTRING': | ||||
|         if self.source_expressions[self.geom_param_pos[0]].output_field.geom_type != 'LINESTRING': | ||||
|             if not connection.features.supports_num_points_poly: | ||||
|                 raise TypeError('NumPoints can only operate on LineString content on this database.') | ||||
|         return super().as_sql(compiler, connection) | ||||
| @@ -419,7 +428,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|         return super().as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class PointOnSurface(OracleToleranceMixin, GeoFunc): | ||||
| class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|  | ||||
| @@ -427,7 +436,7 @@ class Reverse(GeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|  | ||||
| class Scale(SQLiteDecimalToFloatMixin, GeoFunc): | ||||
| class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): | ||||
|     def __init__(self, expression, x, y, z=0.0, **extra): | ||||
|         expressions = [ | ||||
|             expression, | ||||
| @@ -439,7 +448,7 @@ class Scale(SQLiteDecimalToFloatMixin, GeoFunc): | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|  | ||||
| class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc): | ||||
| class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): | ||||
|     def __init__(self, expression, *args, **extra): | ||||
|         nargs = len(args) | ||||
|         expressions = [expression] | ||||
| @@ -460,11 +469,12 @@ class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc): | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|  | ||||
| class SymDifference(OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
| class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 2 | ||||
|     geom_param_pos = (0, 1) | ||||
|  | ||||
|  | ||||
| class Transform(GeoFunc): | ||||
| class Transform(GeomOutputGeoFunc): | ||||
|     def __init__(self, expression, srid, **extra): | ||||
|         expressions = [ | ||||
|             expression, | ||||
| @@ -477,7 +487,7 @@ class Transform(GeoFunc): | ||||
|     @property | ||||
|     def srid(self): | ||||
|         # Make srid the resulting srid of the transformation | ||||
|         return self.source_expressions[self.geom_param_pos + 1].value | ||||
|         return self.source_expressions[1].value | ||||
|  | ||||
|  | ||||
| class Translate(Scale): | ||||
| @@ -488,5 +498,6 @@ class Translate(Scale): | ||||
|         return super().as_sqlite(compiler, connection) | ||||
|  | ||||
|  | ||||
| class Union(OracleToleranceMixin, GeoFuncWithGeoParam): | ||||
| class Union(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|     arity = 2 | ||||
|     geom_param_pos = (0, 1) | ||||
|   | ||||
| @@ -76,6 +76,8 @@ class Lookup: | ||||
|  | ||||
|     def process_lhs(self, compiler, connection, lhs=None): | ||||
|         lhs = lhs or self.lhs | ||||
|         if hasattr(lhs, 'resolve_expression'): | ||||
|             lhs = lhs.resolve_expression(compiler.query) | ||||
|         return compiler.compile(lhs) | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|   | ||||
| @@ -429,6 +429,9 @@ class DistanceFunctionsTests(TestCase): | ||||
|         self.assertTrue( | ||||
|             SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists() | ||||
|         ) | ||||
|         # Length with an explicit geometry value. | ||||
|         qs = Interstate.objects.annotate(length=Length(i10.path)) | ||||
|         self.assertAlmostEqual(qs.first().length.m, len_m2, 2) | ||||
|  | ||||
|     @skipUnlessDBFeature("has_Perimeter_function") | ||||
|     def test_perimeter(self): | ||||
|   | ||||
| @@ -2,7 +2,9 @@ import re | ||||
| from decimal import Decimal | ||||
|  | ||||
| from django.contrib.gis.db.models import functions | ||||
| from django.contrib.gis.geos import LineString, Point, Polygon, fromstr | ||||
| from django.contrib.gis.geos import ( | ||||
|     GEOSGeometry, LineString, Point, Polygon, fromstr, | ||||
| ) | ||||
| from django.contrib.gis.measure import Area | ||||
| from django.db import connection | ||||
| from django.db.models import Sum | ||||
| @@ -494,7 +496,48 @@ class GISFunctionsTests(TestCase): | ||||
|  | ||||
|     @skipUnlessDBFeature("has_Union_function") | ||||
|     def test_union(self): | ||||
|         """Union with all combinations of geometries/geometry fields.""" | ||||
|         geom = Point(-95.363151, 29.763374, srid=4326) | ||||
|         ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas') | ||||
|  | ||||
|         union = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas').union | ||||
|         expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326) | ||||
|         self.assertTrue(expected.equals(ptown.union)) | ||||
|         self.assertTrue(expected.equals(union)) | ||||
|  | ||||
|         union = City.objects.annotate(union=functions.Union(geom, 'point')).get(name='Dallas').union | ||||
|         self.assertTrue(expected.equals(union)) | ||||
|  | ||||
|         union = City.objects.annotate(union=functions.Union('point', 'point')).get(name='Dallas').union | ||||
|         expected = GEOSGeometry('POINT(-96.801611 32.782057)', srid=4326) | ||||
|         self.assertTrue(expected.equals(union)) | ||||
|  | ||||
|         union = City.objects.annotate(union=functions.Union(geom, geom)).get(name='Dallas').union | ||||
|         self.assertTrue(geom.equals(union)) | ||||
|  | ||||
|     @skipUnlessDBFeature("has_Union_function", "has_Transform_function") | ||||
|     def test_union_mixed_srid(self): | ||||
|         """The result SRID depends on the order of parameters.""" | ||||
|         geom = Point(61.42915, 55.15402, srid=4326) | ||||
|         geom_3857 = geom.transform(3857, clone=True) | ||||
|         tol = 0.001 | ||||
|  | ||||
|         for city in City.objects.annotate(union=functions.Union('point', geom_3857)): | ||||
|             expected = city.point | geom | ||||
|             self.assertTrue(city.union.equals_exact(expected, tol)) | ||||
|             self.assertEqual(city.union.srid, 4326) | ||||
|  | ||||
|         for city in City.objects.annotate(union=functions.Union(geom_3857, 'point')): | ||||
|             expected = geom_3857 | city.point.transform(3857, clone=True) | ||||
|             self.assertTrue(expected.equals_exact(city.union, tol)) | ||||
|             self.assertEqual(city.union.srid, 3857) | ||||
|  | ||||
|     def test_argument_validation(self): | ||||
|         with self.assertRaisesMessage(ValueError, 'SRID is required for all geometries.'): | ||||
|             City.objects.annotate(geo=functions.GeoFunc(Point(1, 1))) | ||||
|  | ||||
|         msg = 'GeoFunc function requires a GeometryField in position 1, got CharField.' | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             City.objects.annotate(geo=functions.GeoFunc('name')) | ||||
|  | ||||
|         msg = 'GeoFunc function requires a geometric argument in position 1.' | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             City.objects.annotate(union=functions.GeoFunc(1, 'point')).get(name='Dallas') | ||||
|   | ||||
| @@ -120,9 +120,19 @@ class GeographyFunctionTests(TestCase): | ||||
|         else: | ||||
|             ref_dists = [0, 4891.20, 8071.64, 9123.95] | ||||
|         htown = City.objects.get(name='Houston') | ||||
|         qs = Zipcode.objects.annotate(distance=Distance('poly', htown.point)) | ||||
|         qs = Zipcode.objects.annotate( | ||||
|             distance=Distance('poly', htown.point), | ||||
|             distance2=Distance(htown.point, 'poly'), | ||||
|         ) | ||||
|         for z, ref in zip(qs, ref_dists): | ||||
|             self.assertAlmostEqual(z.distance.m, ref, 2) | ||||
|  | ||||
|         if postgis: | ||||
|             # PostGIS casts geography to geometry when distance2 is calculated. | ||||
|             ref_dists = [0, 4899.68, 8081.30, 9115.15] | ||||
|         for z, ref in zip(qs, ref_dists): | ||||
|             self.assertAlmostEqual(z.distance2.m, ref, 2) | ||||
|  | ||||
|         if not spatialite: | ||||
|             # Distance function combined with a lookup. | ||||
|             hzip = Zipcode.objects.get(code='77002') | ||||
|   | ||||
| @@ -271,7 +271,7 @@ class RasterFieldTest(TransactionTestCase): | ||||
|  | ||||
|     def test_isvalid_lookup_with_raster_error(self): | ||||
|         qs = RasterModel.objects.filter(rast__isvalid=True) | ||||
|         msg = 'Geometry functions not supported for raster fields.' | ||||
|         msg = 'IsValid function requires a GeometryField in position 1, got RasterField.' | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             qs.count() | ||||
|  | ||||
| @@ -336,11 +336,11 @@ class RasterFieldTest(TransactionTestCase): | ||||
|         """ | ||||
|         point = GEOSGeometry("SRID=3086;POINT (-697024.9213808845 683729.1705516104)") | ||||
|         rast = GDALRaster(json.loads(JSON_RASTER)) | ||||
|         msg = "Please provide a geometry object." | ||||
|         msg = "Distance function requires a geometric argument in position 2." | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             RasterModel.objects.annotate(distance_from_point=Distance("geom", rast)) | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", rast)) | ||||
|         msg = "Geometry functions not supported for raster fields." | ||||
|         msg = "Distance function requires a GeometryField in position 1, got RasterField." | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", point)).count() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user