From bde86ce9ae17ee52aa5be9b74b64422f5219530d Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Sat, 1 Apr 2017 18:47:49 +0500 Subject: [PATCH] Fixed #25605 -- Made GIS DB functions accept geometric expressions, not only values, in all positions. --- django/contrib/gis/db/models/functions.py | 147 ++++++++++-------- django/db/models/lookups.py | 2 + tests/gis_tests/distapp/tests.py | 3 + tests/gis_tests/geoapp/test_functions.py | 49 +++++- tests/gis_tests/geogapp/tests.py | 12 +- tests/gis_tests/rasterapp/test_rasterfield.py | 6 +- 6 files changed, 144 insertions(+), 75 deletions(-) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 36d1644ddd..dcd09472e3 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -1,8 +1,6 @@ from decimal import Decimal -from django.contrib.gis.db.models.fields import ( - BaseSpatialField, GeometryField, RasterField, -) +from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField from django.contrib.gis.db.models.sql import AreaField from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import ( @@ -13,6 +11,7 @@ from django.db.models import ( BooleanField, FloatField, IntegerField, TextField, Transform, ) from django.db.models.expressions import Func, Value +from django.db.models.functions import Cast NUMERIC_TYPES = (int, float, Decimal) @@ -20,26 +19,37 @@ NUMERIC_TYPES = (int, float, Decimal) class GeoFuncMixin: function = None output_field_class = None - geom_param_pos = 0 + geom_param_pos = (0,) def __init__(self, *expressions, **extra): if 'output_field' not in extra and self.output_field_class: extra['output_field'] = self.output_field_class() super().__init__(*expressions, **extra) + # Ensure that value expressions are geometric. + for pos in self.geom_param_pos: + expr = self.source_expressions[pos] + if not isinstance(expr, Value): + continue + try: + output_field = expr.output_field + except FieldError: + output_field = None + geom = expr.value + if not isinstance(geom, Geometry) or output_field and not isinstance(output_field, GeometryField): + raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1)) + if not geom.srid and not output_field: + raise ValueError("SRID is required for all geometries.") + if not output_field: + self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid)) + @property def name(self): return self.__class__.__name__ @property def srid(self): - expr = self.source_expressions[self.geom_param_pos] - if hasattr(expr, 'srid'): - return expr.srid - try: - return expr.field.srid - except (AttributeError, FieldError): - return None + return self.source_expressions[self.geom_param_pos[0]].field.srid @property def geo_field(self): @@ -48,19 +58,28 @@ class GeoFuncMixin: def as_sql(self, compiler, connection, function=None, **extra_context): if not self.function and not function: function = connection.ops.spatial_function_name(self.name) - if any(isinstance(field, RasterField) for field in self.get_source_fields()): - raise TypeError("Geometry functions not supported for raster fields.") return super().as_sql(compiler, connection, function=function, **extra_context) def resolve_expression(self, *args, **kwargs): res = super().resolve_expression(*args, **kwargs) - base_srid = res.srid - if not base_srid: - raise TypeError("Geometry functions can only operate on geometric content.") - for pos, expr in enumerate(res.source_expressions[1:], start=1): - if isinstance(expr, GeomValue) and expr.srid != base_srid: - # Automatic SRID conversion so objects are comparable + # Ensure that expressions are geometric. + source_fields = res.get_source_fields() + for pos in self.geom_param_pos: + field = source_fields[pos] + if not isinstance(field, GeometryField): + raise TypeError( + "%s function requires a GeometryField in position %s, got %s." % ( + self.name, pos + 1, type(field).__name__, + ) + ) + + base_srid = res.srid + for pos in self.geom_param_pos[1:]: + expr = res.source_expressions[pos] + expr_srid = expr.output_field.srid + if expr_srid != base_srid: + # Automatic SRID conversion so objects are comparable. res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs) return res @@ -78,34 +97,16 @@ class GeoFunc(GeoFuncMixin, Func): pass -class GeomValue(Value): - geography = False +class GeomOutputGeoFunc(GeoFunc): + def __init__(self, *expressions, **extra): + if 'output_field' not in extra: + extra['output_field'] = GeometryField() + super(GeomOutputGeoFunc, self).__init__(*expressions, **extra) - @property - def srid(self): - return self.value.srid - - def as_sql(self, compiler, connection): - return '%s(%%s, %s)' % (connection.ops.from_text, self.srid), [connection.ops.Adapter(self.value)] - - def as_mysql(self, compiler, connection): - return '%s(%%s)' % (connection.ops.from_text), [connection.ops.Adapter(self.value)] - - def as_postgresql(self, compiler, connection): - if self.geography: - self.value = connection.ops.Adapter(self.value, geography=self.geography) - else: - self.value = connection.ops.Adapter(self.value) - return super().as_sql(compiler, connection) - - -class GeoFuncWithGeoParam(GeoFunc): - def __init__(self, expression, geom, *expressions, **extra): - if not isinstance(geom, Geometry): - raise TypeError("Please provide a geometry object.") - if not hasattr(geom, 'srid') or not geom.srid: - raise ValueError("Please provide a geometry attribute with a defined SRID.") - super().__init__(expression, GeomValue(geom), *expressions, **extra) + def resolve_expression(self, *args, **kwargs): + res = super().resolve_expression(*args, **kwargs) + res.output_field.srid = res.srid + return res class SQLiteDecimalToFloatMixin: @@ -181,7 +182,7 @@ class AsGeoJSON(GeoFunc): class AsGML(GeoFunc): - geom_param_pos = 1 + geom_param_pos = (1,) output_field_class = TextField def __init__(self, expression, version=2, precision=8, **extra): @@ -230,12 +231,13 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc): return super(BoundingCircle, clone).as_oracle(compiler, connection) -class Centroid(OracleToleranceMixin, GeoFunc): +class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): arity = 1 -class Difference(OracleToleranceMixin, GeoFuncWithGeoParam): +class Difference(OracleToleranceMixin, GeomOutputGeoFunc): arity = 2 + geom_param_pos = (0, 1) class DistanceResultMixin: @@ -259,7 +261,8 @@ class DistanceResultMixin: return value -class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): +class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): + geom_param_pos = (0, 1) output_field_class = FloatField spheroid = None @@ -273,13 +276,18 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): def as_postgresql(self, compiler, connection): function = None geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info - if self.source_is_geography(): - # Set parameters as geography if base field is geography - for pos, expr in enumerate( - self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1): - if isinstance(expr, GeomValue): - expr.geography = True - elif geo_field.geodetic(connection): + expr2 = self.source_expressions[1] + geography = self.source_is_geography() + if expr2.output_field.geography != geography: + if isinstance(expr2, Value): + expr2.output_field.geography = geography + else: + self.source_expressions[1] = Cast( + expr2, + GeometryField(srid=expr2.output_field.srid, geography=geography), + ) + + if not geography and geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need special distance functions if self.spheroid: # DistanceSpheroid is more accurate and resource intensive than DistanceSphere @@ -305,11 +313,11 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): return super().as_sql(compiler, connection, **extra_context) -class Envelope(GeoFunc): +class Envelope(GeomOutputGeoFunc): arity = 1 -class ForceRHR(GeoFunc): +class ForceRHR(GeomOutputGeoFunc): arity = 1 @@ -323,8 +331,9 @@ class GeoHash(GeoFunc): super().__init__(*expressions, **extra) -class Intersection(OracleToleranceMixin, GeoFuncWithGeoParam): +class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): arity = 2 + geom_param_pos = (0, 1) @BaseSpatialField.register_lookup @@ -392,7 +401,7 @@ class NumPoints(GeoFunc): arity = 1 def as_sql(self, compiler, connection): - if self.source_expressions[self.geom_param_pos].output_field.geom_type != 'LINESTRING': + if self.source_expressions[self.geom_param_pos[0]].output_field.geom_type != 'LINESTRING': if not connection.features.supports_num_points_poly: raise TypeError('NumPoints can only operate on LineString content on this database.') return super().as_sql(compiler, connection) @@ -419,7 +428,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): return super().as_sql(compiler, connection) -class PointOnSurface(OracleToleranceMixin, GeoFunc): +class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): arity = 1 @@ -427,7 +436,7 @@ class Reverse(GeoFunc): arity = 1 -class Scale(SQLiteDecimalToFloatMixin, GeoFunc): +class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): def __init__(self, expression, x, y, z=0.0, **extra): expressions = [ expression, @@ -439,7 +448,7 @@ class Scale(SQLiteDecimalToFloatMixin, GeoFunc): super().__init__(*expressions, **extra) -class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc): +class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): def __init__(self, expression, *args, **extra): nargs = len(args) expressions = [expression] @@ -460,11 +469,12 @@ class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc): super().__init__(*expressions, **extra) -class SymDifference(OracleToleranceMixin, GeoFuncWithGeoParam): +class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc): arity = 2 + geom_param_pos = (0, 1) -class Transform(GeoFunc): +class Transform(GeomOutputGeoFunc): def __init__(self, expression, srid, **extra): expressions = [ expression, @@ -477,7 +487,7 @@ class Transform(GeoFunc): @property def srid(self): # Make srid the resulting srid of the transformation - return self.source_expressions[self.geom_param_pos + 1].value + return self.source_expressions[1].value class Translate(Scale): @@ -488,5 +498,6 @@ class Translate(Scale): return super().as_sqlite(compiler, connection) -class Union(OracleToleranceMixin, GeoFuncWithGeoParam): +class Union(OracleToleranceMixin, GeomOutputGeoFunc): arity = 2 + geom_param_pos = (0, 1) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d96c4468f5..c37fcabba4 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -76,6 +76,8 @@ class Lookup: def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs + if hasattr(lhs, 'resolve_expression'): + lhs = lhs.resolve_expression(compiler.query) return compiler.compile(lhs) def process_rhs(self, compiler, connection): diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index de00bafd70..31987d9ce8 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -429,6 +429,9 @@ class DistanceFunctionsTests(TestCase): self.assertTrue( SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists() ) + # Length with an explicit geometry value. + qs = Interstate.objects.annotate(length=Length(i10.path)) + self.assertAlmostEqual(qs.first().length.m, len_m2, 2) @skipUnlessDBFeature("has_Perimeter_function") def test_perimeter(self): diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index 21b0c283e4..34deb88f4a 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -2,7 +2,9 @@ import re from decimal import Decimal from django.contrib.gis.db.models import functions -from django.contrib.gis.geos import LineString, Point, Polygon, fromstr +from django.contrib.gis.geos import ( + GEOSGeometry, LineString, Point, Polygon, fromstr, +) from django.contrib.gis.measure import Area from django.db import connection from django.db.models import Sum @@ -494,7 +496,48 @@ class GISFunctionsTests(TestCase): @skipUnlessDBFeature("has_Union_function") def test_union(self): + """Union with all combinations of geometries/geometry fields.""" geom = Point(-95.363151, 29.763374, srid=4326) - ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas') + + union = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas').union expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326) - self.assertTrue(expected.equals(ptown.union)) + self.assertTrue(expected.equals(union)) + + union = City.objects.annotate(union=functions.Union(geom, 'point')).get(name='Dallas').union + self.assertTrue(expected.equals(union)) + + union = City.objects.annotate(union=functions.Union('point', 'point')).get(name='Dallas').union + expected = GEOSGeometry('POINT(-96.801611 32.782057)', srid=4326) + self.assertTrue(expected.equals(union)) + + union = City.objects.annotate(union=functions.Union(geom, geom)).get(name='Dallas').union + self.assertTrue(geom.equals(union)) + + @skipUnlessDBFeature("has_Union_function", "has_Transform_function") + def test_union_mixed_srid(self): + """The result SRID depends on the order of parameters.""" + geom = Point(61.42915, 55.15402, srid=4326) + geom_3857 = geom.transform(3857, clone=True) + tol = 0.001 + + for city in City.objects.annotate(union=functions.Union('point', geom_3857)): + expected = city.point | geom + self.assertTrue(city.union.equals_exact(expected, tol)) + self.assertEqual(city.union.srid, 4326) + + for city in City.objects.annotate(union=functions.Union(geom_3857, 'point')): + expected = geom_3857 | city.point.transform(3857, clone=True) + self.assertTrue(expected.equals_exact(city.union, tol)) + self.assertEqual(city.union.srid, 3857) + + def test_argument_validation(self): + with self.assertRaisesMessage(ValueError, 'SRID is required for all geometries.'): + City.objects.annotate(geo=functions.GeoFunc(Point(1, 1))) + + msg = 'GeoFunc function requires a GeometryField in position 1, got CharField.' + with self.assertRaisesMessage(TypeError, msg): + City.objects.annotate(geo=functions.GeoFunc('name')) + + msg = 'GeoFunc function requires a geometric argument in position 1.' + with self.assertRaisesMessage(TypeError, msg): + City.objects.annotate(union=functions.GeoFunc(1, 'point')).get(name='Dallas') diff --git a/tests/gis_tests/geogapp/tests.py b/tests/gis_tests/geogapp/tests.py index 551fda4b83..3450847504 100644 --- a/tests/gis_tests/geogapp/tests.py +++ b/tests/gis_tests/geogapp/tests.py @@ -120,9 +120,19 @@ class GeographyFunctionTests(TestCase): else: ref_dists = [0, 4891.20, 8071.64, 9123.95] htown = City.objects.get(name='Houston') - qs = Zipcode.objects.annotate(distance=Distance('poly', htown.point)) + qs = Zipcode.objects.annotate( + distance=Distance('poly', htown.point), + distance2=Distance(htown.point, 'poly'), + ) for z, ref in zip(qs, ref_dists): self.assertAlmostEqual(z.distance.m, ref, 2) + + if postgis: + # PostGIS casts geography to geometry when distance2 is calculated. + ref_dists = [0, 4899.68, 8081.30, 9115.15] + for z, ref in zip(qs, ref_dists): + self.assertAlmostEqual(z.distance2.m, ref, 2) + if not spatialite: # Distance function combined with a lookup. hzip = Zipcode.objects.get(code='77002') diff --git a/tests/gis_tests/rasterapp/test_rasterfield.py b/tests/gis_tests/rasterapp/test_rasterfield.py index 1e4ffbfbd7..68df5dc0aa 100644 --- a/tests/gis_tests/rasterapp/test_rasterfield.py +++ b/tests/gis_tests/rasterapp/test_rasterfield.py @@ -271,7 +271,7 @@ class RasterFieldTest(TransactionTestCase): def test_isvalid_lookup_with_raster_error(self): qs = RasterModel.objects.filter(rast__isvalid=True) - msg = 'Geometry functions not supported for raster fields.' + msg = 'IsValid function requires a GeometryField in position 1, got RasterField.' with self.assertRaisesMessage(TypeError, msg): qs.count() @@ -336,11 +336,11 @@ class RasterFieldTest(TransactionTestCase): """ point = GEOSGeometry("SRID=3086;POINT (-697024.9213808845 683729.1705516104)") rast = GDALRaster(json.loads(JSON_RASTER)) - msg = "Please provide a geometry object." + msg = "Distance function requires a geometric argument in position 2." with self.assertRaisesMessage(TypeError, msg): RasterModel.objects.annotate(distance_from_point=Distance("geom", rast)) with self.assertRaisesMessage(TypeError, msg): RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", rast)) - msg = "Geometry functions not supported for raster fields." + msg = "Distance function requires a GeometryField in position 1, got RasterField." with self.assertRaisesMessage(TypeError, msg): RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", point)).count()