mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.
Distance lookups use the Distance function for decreased code redundancy.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							c7d58c6f43
						
					
				
				
					commit
					38af496b98
				
			| @@ -1,3 +1,6 @@ | ||||
| from django.contrib.gis.db.models.functions import Distance | ||||
|  | ||||
|  | ||||
| class BaseSpatialOperations: | ||||
|     # Quick booleans for the type of this spatial backend, and | ||||
|     # an attribute for the spatial database version tuple (if applicable) | ||||
| @@ -113,3 +116,5 @@ class BaseSpatialOperations: | ||||
|  | ||||
|     def spatial_ref_sys(self): | ||||
|         raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method') | ||||
|  | ||||
|     distance_expr_for_lookup = staticmethod(Distance) | ||||
|   | ||||
| @@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator): | ||||
|     sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'" | ||||
|  | ||||
|  | ||||
| class SDODistance(SpatialOperator): | ||||
|     sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE | ||||
|  | ||||
|  | ||||
| class SDODWithin(SpatialOperator): | ||||
|     sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'" | ||||
|  | ||||
| @@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         'relate': SDORelate(),  # Oracle uses a different syntax, e.g., 'mask=inside+touch' | ||||
|         'touches': SDOOperator(func='SDO_TOUCH'), | ||||
|         'within': SDOOperator(func='SDO_INSIDE'), | ||||
|         'distance_gt': SDODistance(op='>'), | ||||
|         'distance_gte': SDODistance(op='>='), | ||||
|         'distance_lt': SDODistance(op='<'), | ||||
|         'distance_lte': SDODistance(op='<='), | ||||
|         'dwithin': SDODWithin(), | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import ( | ||||
|     BaseSpatialOperations, | ||||
| ) | ||||
| from django.contrib.gis.db.backends.utils import SpatialOperator | ||||
| from django.contrib.gis.db.models import GeometryField, RasterField | ||||
| from django.contrib.gis.gdal import GDALRaster | ||||
| from django.contrib.gis.measure import Distance | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db.backends.postgresql.operations import DatabaseOperations | ||||
| from django.db.models import Func, Value | ||||
| from django.db.utils import ProgrammingError | ||||
| from django.utils.functional import cached_property | ||||
| from django.utils.version import get_version_tuple | ||||
| @@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator): | ||||
|         return template_params | ||||
|  | ||||
|  | ||||
| class PostGISDistanceOperator(PostGISOperator): | ||||
|     sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' | ||||
| class ST_Polygon(Func): | ||||
|     function = 'ST_Polygon' | ||||
|  | ||||
|     def as_sql(self, connection, lookup, template_params, sql_params): | ||||
|         if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): | ||||
|             template_params = self.check_raster(lookup, template_params) | ||||
|             sql_template = self.sql_template | ||||
|             if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid': | ||||
|                 template_params.update({ | ||||
|                     'op': self.op, | ||||
|                     'func': connection.ops.spatial_function_name('DistanceSpheroid'), | ||||
|                 }) | ||||
|                 sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s' | ||||
|                 # Using DistanceSpheroid requires the spheroid of the field as | ||||
|                 # a parameter. | ||||
|                 sql_params.insert(1, lookup.lhs.output_field.spheroid(connection)) | ||||
|             else: | ||||
|                 template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')}) | ||||
|             return sql_template % template_params, sql_params | ||||
|         return super().as_sql(connection, lookup, template_params, sql_params) | ||||
|     def __init__(self, expr): | ||||
|         super().__init__(expr) | ||||
|         expr = self.source_expressions[0] | ||||
|         if isinstance(expr, Value) and not expr._output_field_or_none: | ||||
|             self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid)) | ||||
|  | ||||
|     @cached_property | ||||
|     def output_field(self): | ||||
|         return GeometryField(srid=self.source_expressions[0].field.srid) | ||||
|  | ||||
|  | ||||
| class PostGISOperations(BaseSpatialOperations, DatabaseOperations): | ||||
| @@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL), | ||||
|         'within': PostGISOperator(func='ST_Within', raster=BILATERAL), | ||||
|         'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL), | ||||
|         'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True), | ||||
|         'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True), | ||||
|         'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True), | ||||
|         'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True), | ||||
|     } | ||||
|  | ||||
|     unsupported_functions = set() | ||||
| @@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|     def parse_raster(self, value): | ||||
|         """Convert a PostGIS HEX String into a dict readable by GDALRaster.""" | ||||
|         return from_pgraster(value) | ||||
|  | ||||
|     def distance_expr_for_lookup(self, lhs, rhs, **kwargs): | ||||
|         return super().distance_expr_for_lookup( | ||||
|             self._normalize_distance_lookup_arg(lhs), | ||||
|             self._normalize_distance_lookup_arg(rhs), | ||||
|             **kwargs | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _normalize_distance_lookup_arg(arg): | ||||
|         is_raster = ( | ||||
|             arg.field.geom_type == 'RASTER' | ||||
|             if hasattr(arg, 'field') else | ||||
|             isinstance(arg, GDALRaster) | ||||
|         ) | ||||
|         return ST_Polygon(arg) if is_raster else arg | ||||
|   | ||||
| @@ -17,20 +17,6 @@ from django.utils.functional import cached_property | ||||
| from django.utils.version import get_version_tuple | ||||
|  | ||||
|  | ||||
| class SpatiaLiteDistanceOperator(SpatialOperator): | ||||
|     def as_sql(self, connection, lookup, template_params, sql_params): | ||||
|         if lookup.lhs.output_field.geodetic(connection): | ||||
|             # SpatiaLite returns NULL instead of zero on geodetic coordinates | ||||
|             sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s' | ||||
|             template_params.update({ | ||||
|                 'op': self.op, | ||||
|                 'func': connection.ops.spatial_function_name('Distance'), | ||||
|             }) | ||||
|             sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid') | ||||
|             return sql_template % template_params, sql_params | ||||
|         return super().as_sql(connection, lookup, template_params, sql_params) | ||||
|  | ||||
|  | ||||
| class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|     name = 'spatialite' | ||||
|     spatialite = True | ||||
| @@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): | ||||
|         'exact': SpatialOperator(func='Equals'), | ||||
|         # Distance predicates | ||||
|         'dwithin': SpatialOperator(func='PtDistWithin'), | ||||
|         'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'), | ||||
|         'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='), | ||||
|         'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'), | ||||
|         'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='), | ||||
|     } | ||||
|  | ||||
|     disallowed_aggregates = (aggregates.Extent3D,) | ||||
|   | ||||
| @@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup): | ||||
|         if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid': | ||||
|             self.process_band_indices() | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         params = [connection.ops.Adapter(self.rhs)] | ||||
|         # Getting the distance parameter in the units of the field. | ||||
|     def process_distance(self, compiler, connection): | ||||
|         dist_param = self.rhs_params[0] | ||||
|         if hasattr(dist_param, 'resolve_expression'): | ||||
|             dist_param = dist_param.resolve_expression(compiler.query) | ||||
|             sql, expr_params = compiler.compile(dist_param) | ||||
|             self.template_params['value'] = sql | ||||
|             params.extend(expr_params) | ||||
|         else: | ||||
|             params += connection.ops.get_distance( | ||||
|                 self.lhs.output_field, self.rhs_params, | ||||
|                 self.lookup_name, | ||||
|             ) | ||||
|         rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler) | ||||
|         return (rhs, params) | ||||
|         return ( | ||||
|             compiler.compile(dist_param.resolve_expression(compiler.query)) | ||||
|             if hasattr(dist_param, 'resolve_expression') else | ||||
|             ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name)) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| @@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase): | ||||
|     lookup_name = 'dwithin' | ||||
|     sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)' | ||||
|  | ||||
|     def process_rhs(self, compiler, connection): | ||||
|         dist_sql, dist_params = self.process_distance(compiler, connection) | ||||
|         self.template_params['value'] = dist_sql | ||||
|         rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler) | ||||
|         return rhs, [connection.ops.Adapter(self.rhs)] + dist_params | ||||
|  | ||||
|  | ||||
| class DistanceLookupFromFunction(DistanceLookupBase): | ||||
|     def as_sql(self, compiler, connection): | ||||
|         spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None | ||||
|         distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid) | ||||
|         sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query)) | ||||
|         dist_sql, dist_params = self.process_distance(compiler, connection) | ||||
|         return ( | ||||
|             '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql}, | ||||
|             params + dist_params, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| class DistanceGTLookup(DistanceLookupBase): | ||||
| class DistanceGTLookup(DistanceLookupFromFunction): | ||||
|     lookup_name = 'distance_gt' | ||||
|     op = '>' | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| class DistanceGTELookup(DistanceLookupBase): | ||||
| class DistanceGTELookup(DistanceLookupFromFunction): | ||||
|     lookup_name = 'distance_gte' | ||||
|     op = '>=' | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| class DistanceLTLookup(DistanceLookupBase): | ||||
| class DistanceLTLookup(DistanceLookupFromFunction): | ||||
|     lookup_name = 'distance_lt' | ||||
|     op = '<' | ||||
|  | ||||
|  | ||||
| @BaseSpatialField.register_lookup | ||||
| class DistanceLTELookup(DistanceLookupBase): | ||||
| class DistanceLTELookup(DistanceLookupFromFunction): | ||||
|     lookup_name = 'distance_lte' | ||||
|     op = '<=' | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| from django.contrib.gis.db.models.functions import ( | ||||
|     Area, Distance, Length, Perimeter, Transform, | ||||
|     Area, Distance, Intersection, Length, Perimeter, Transform, | ||||
| ) | ||||
| from django.contrib.gis.geos import GEOSGeometry, LineString, Point | ||||
| from django.contrib.gis.measure import D  # alias for Distance | ||||
| @@ -206,6 +206,13 @@ class DistanceTest(TestCase): | ||||
|             ).order_by('name') | ||||
|             self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne']) | ||||
|  | ||||
|         # With a complex geometry expression | ||||
|         self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0))) | ||||
|         self.assertEqual( | ||||
|             SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(), | ||||
|             SouthTexasCity.objects.count(), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| ''' | ||||
| ============================= | ||||
|   | ||||
		Reference in New Issue
	
	Block a user