mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #24214 -- Added GIS functions to replace geoqueryset's methods
Thanks Simon Charette and Tim Graham for the reviews.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from django.contrib.gis.db.models import aggregates
|
||||
@@ -59,11 +60,11 @@ class BaseSpatialFeatures(object):
|
||||
# `has_<name>_method` (defined in __init__) which accesses connection.ops
|
||||
# to determine GIS method availability.
|
||||
geoqueryset_methods = (
|
||||
'area', 'centroid', 'difference', 'distance', 'distance_spheroid',
|
||||
'envelope', 'force_rhr', 'geohash', 'gml', 'intersection', 'kml',
|
||||
'length', 'num_geom', 'perimeter', 'point_on_surface', 'reverse',
|
||||
'scale', 'snap_to_grid', 'svg', 'sym_difference', 'transform',
|
||||
'translate', 'union', 'unionagg',
|
||||
'area', 'bounding_circle', 'centroid', 'difference', 'distance',
|
||||
'distance_spheroid', 'envelope', 'force_rhr', 'geohash', 'gml',
|
||||
'intersection', 'kml', 'length', 'mem_size', 'num_geom', 'num_points',
|
||||
'perimeter', 'point_on_surface', 'reverse', 'scale', 'snap_to_grid',
|
||||
'svg', 'sym_difference', 'transform', 'translate', 'union', 'unionagg',
|
||||
)
|
||||
|
||||
# Specifies whether the Collect and Extent aggregates are supported by the database
|
||||
@@ -86,5 +87,13 @@ class BaseSpatialFeatures(object):
|
||||
setattr(self.__class__, 'has_%s_method' % method,
|
||||
property(partial(BaseSpatialFeatures.has_ops_method, method=method)))
|
||||
|
||||
def __getattr__(self, name):
|
||||
m = re.match(r'has_(\w*)_function$', name)
|
||||
if m:
|
||||
func_name = m.group(1)
|
||||
if func_name not in self.connection.ops.unsupported_functions:
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_ops_method(self, method):
|
||||
return getattr(self.connection.ops, method, False)
|
||||
|
||||
@@ -22,6 +22,7 @@ class BaseSpatialOperations(object):
|
||||
geometry = False
|
||||
|
||||
area = False
|
||||
bounding_circle = False
|
||||
centroid = False
|
||||
difference = False
|
||||
distance = False
|
||||
@@ -30,7 +31,6 @@ class BaseSpatialOperations(object):
|
||||
envelope = False
|
||||
force_rhr = False
|
||||
mem_size = False
|
||||
bounding_circle = False
|
||||
num_geom = False
|
||||
num_points = False
|
||||
perimeter = False
|
||||
@@ -48,6 +48,22 @@ class BaseSpatialOperations(object):
|
||||
# Aggregates
|
||||
disallowed_aggregates = ()
|
||||
|
||||
geom_func_prefix = ''
|
||||
|
||||
# Mapping between Django function names and backend names, when names do not
|
||||
# match; used in spatial_function_name().
|
||||
function_names = {}
|
||||
|
||||
# Blacklist/set of known unsupported functions of the backend
|
||||
unsupported_functions = {
|
||||
'Area', 'AsGeoHash', 'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG',
|
||||
'BoundingCircle', 'Centroid', 'Difference', 'Distance', 'Envelope',
|
||||
'ForceRHR', 'Intersection', 'Length', 'MemSize', 'NumGeometries',
|
||||
'NumPoints', 'Perimeter', 'PointOnSurface', 'Reverse', 'Scale',
|
||||
'SnapToGrid', 'SymDifference', 'Transform', 'Translate',
|
||||
'Union',
|
||||
}
|
||||
|
||||
# Serialization
|
||||
geohash = False
|
||||
geojson = False
|
||||
@@ -108,9 +124,14 @@ class BaseSpatialOperations(object):
|
||||
def spatial_aggregate_name(self, agg_name):
|
||||
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
|
||||
|
||||
def spatial_function_name(self, func_name):
|
||||
if func_name in self.unsupported_functions:
|
||||
raise NotImplementedError("This backend doesn't support the %s function." % func_name)
|
||||
return self.function_names.get(func_name, self.geom_func_prefix + func_name)
|
||||
|
||||
# Routines for getting the OGC-compliant models.
|
||||
def geometry_columns(self):
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide geometry_columns() method')
|
||||
raise NotImplementedError('Subclasses of BaseSpatialOperations must provide a geometry_columns() method.')
|
||||
|
||||
def spatial_ref_sys(self):
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
|
||||
|
||||
@@ -8,12 +8,13 @@ from psycopg2.extensions import ISQLQuote
|
||||
|
||||
|
||||
class PostGISAdapter(object):
|
||||
def __init__(self, geom):
|
||||
def __init__(self, geom, geography=False):
|
||||
"Initializes on the geometry."
|
||||
# Getting the WKB (in string form, to allow easy pickling of
|
||||
# the adaptor) and the SRID from the geometry.
|
||||
self.ewkb = bytes(geom.ewkb)
|
||||
self.srid = geom.srid
|
||||
self.geography = geography
|
||||
self._adapter = Binary(self.ewkb)
|
||||
|
||||
def __conform__(self, proto):
|
||||
@@ -44,4 +45,7 @@ class PostGISAdapter(object):
|
||||
def getquoted(self):
|
||||
"Returns a properly quoted string for use in PostgreSQL/PostGIS."
|
||||
# psycopg will figure out whether to use E'\\000' or '\000'
|
||||
return str('ST_GeomFromEWKB(%s)' % self._adapter.getquoted().decode())
|
||||
return str('%s(%s)' % (
|
||||
'ST_GeogFromWKB' if self.geography else 'ST_GeomFromEWKB',
|
||||
self._adapter.getquoted().decode())
|
||||
)
|
||||
|
||||
@@ -88,6 +88,13 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
|
||||
}
|
||||
|
||||
unsupported_functions = set()
|
||||
function_names = {
|
||||
'BoundingCircle': 'ST_MinimumBoundingCircle',
|
||||
'MemSize': 'ST_Mem_Size',
|
||||
'NumPoints': 'ST_NPoints',
|
||||
}
|
||||
|
||||
def __init__(self, connection):
|
||||
super(PostGISOperations, self).__init__(connection)
|
||||
|
||||
|
||||
351
django/contrib/gis/db/models/functions.py
Normal file
351
django/contrib/gis/db/models/functions.py
Normal file
@@ -0,0 +1,351 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.gis.db.models.fields import GeometryField
|
||||
from django.contrib.gis.db.models.sql import AreaField
|
||||
from django.contrib.gis.geos.geometry import GEOSGeometry
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
)
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models import FloatField, IntegerField, TextField
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.utils import six
|
||||
|
||||
NUMERIC_TYPES = six.integer_types + (float, Decimal)
|
||||
|
||||
|
||||
class GeoFunc(Func):
|
||||
function = None
|
||||
output_field_class = None
|
||||
geom_param_pos = 0
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if 'output_field' not in extra and self.output_field_class:
|
||||
extra['output_field'] = self.output_field_class()
|
||||
super(GeoFunc, self).__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
expr = self.source_expressions[self.geom_param_pos]
|
||||
if hasattr(expr, 'srid'):
|
||||
return expr.srid
|
||||
try:
|
||||
return expr.field.srid
|
||||
except (AttributeError, FieldError):
|
||||
return None
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.function is None:
|
||||
self.function = connection.ops.spatial_function_name(self.name)
|
||||
return super(GeoFunc, self).as_sql(compiler, connection)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
res = super(GeoFunc, self).resolve_expression(*args, **kwargs)
|
||||
base_srid = res.srid
|
||||
if not base_srid:
|
||||
raise TypeError("Geometry functions can only operate on geometric content.")
|
||||
|
||||
for pos, expr in enumerate(res.source_expressions[1:], start=1):
|
||||
if isinstance(expr, GeomValue) and expr.srid != base_srid:
|
||||
# Automatic SRID conversion so objects are comparable
|
||||
res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
|
||||
return res
|
||||
|
||||
def _handle_param(self, value, param_name='', check_types=None):
|
||||
if not hasattr(value, 'resolve_expression'):
|
||||
if check_types and not isinstance(value, check_types):
|
||||
raise TypeError(
|
||||
"The %s parameter has the wrong type: should be %s." % (
|
||||
param_name, str(check_types))
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class GeomValue(Value):
|
||||
geography = False
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
return self.value.srid
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.geography:
|
||||
self.value = connection.ops.Adapter(self.value, geography=self.geography)
|
||||
else:
|
||||
self.value = connection.ops.Adapter(self.value)
|
||||
return super(GeomValue, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class GeoFuncWithGeoParam(GeoFunc):
|
||||
def __init__(self, expression, geom, *expressions, **extra):
|
||||
if not hasattr(geom, 'srid'):
|
||||
# Try to interpret it as a geometry input
|
||||
try:
|
||||
geom = GEOSGeometry(geom)
|
||||
except Exception:
|
||||
raise ValueError("This function requires a geometric parameter.")
|
||||
if not geom.srid:
|
||||
raise ValueError("Please provide a geometry attribute with a defined SRID.")
|
||||
geom = GeomValue(geom)
|
||||
super(GeoFuncWithGeoParam, self).__init__(expression, geom, *expressions, **extra)
|
||||
|
||||
|
||||
class Area(GeoFunc):
|
||||
def as_sql(self, compiler, connection):
|
||||
if connection.ops.oracle:
|
||||
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
|
||||
else:
|
||||
if connection.ops.geography:
|
||||
# Geography fields support area calculation, returns square meters.
|
||||
self.output_field = AreaField('sq_m')
|
||||
elif not self.output_field.geodetic(connection):
|
||||
# Getting the area units of the geographic field.
|
||||
self.output_field = AreaField(
|
||||
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
|
||||
else:
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
return super(Area, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class AsGeoJSON(GeoFunc):
|
||||
output_field_class = TextField
|
||||
|
||||
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
|
||||
expressions = [expression]
|
||||
if precision is not None:
|
||||
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
|
||||
options = 0
|
||||
if crs and bbox:
|
||||
options = 3
|
||||
elif bbox:
|
||||
options = 1
|
||||
elif crs:
|
||||
options = 2
|
||||
if options:
|
||||
expressions.append(options)
|
||||
super(AsGeoJSON, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class AsGML(GeoFunc):
|
||||
geom_param_pos = 1
|
||||
output_field_class = TextField
|
||||
|
||||
def __init__(self, expression, version=2, precision=8, **extra):
|
||||
expressions = [version, expression]
|
||||
if precision is not None:
|
||||
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
|
||||
super(AsGML, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class AsKML(AsGML):
|
||||
pass
|
||||
|
||||
|
||||
class AsSVG(GeoFunc):
|
||||
output_field_class = TextField
|
||||
|
||||
def __init__(self, expression, relative=False, precision=8, **extra):
|
||||
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
|
||||
expressions = [
|
||||
expression,
|
||||
relative,
|
||||
self._handle_param(precision, 'precision', six.integer_types),
|
||||
]
|
||||
super(AsSVG, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class BoundingCircle(GeoFunc):
|
||||
def __init__(self, expression, num_seg=48, **extra):
|
||||
super(BoundingCircle, self).__init__(*[expression, num_seg], **extra)
|
||||
|
||||
|
||||
class Centroid(GeoFunc):
|
||||
pass
|
||||
|
||||
|
||||
class Difference(GeoFuncWithGeoParam):
|
||||
pass
|
||||
|
||||
|
||||
class DistanceResultMixin(object):
|
||||
def convert_value(self, value, expression, connection, context):
|
||||
if value is None:
|
||||
return None
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection):
|
||||
dist_att = 'm'
|
||||
else:
|
||||
dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
|
||||
return DistanceMeasure(**{dist_att: value})
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
|
||||
output_field_class = FloatField
|
||||
spheroid = None
|
||||
|
||||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||
expressions = [expr1, expr2]
|
||||
if spheroid is not None:
|
||||
self.spheroid = spheroid
|
||||
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
|
||||
super(Distance, self).__init__(*expressions, **extra)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
src_field = self.get_source_fields()[0]
|
||||
geography = src_field.geography and self.srid == 4326
|
||||
if geography:
|
||||
# Set parameters as geography if base field is geography
|
||||
for pos, expr in enumerate(
|
||||
self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1):
|
||||
if isinstance(expr, GeomValue):
|
||||
expr.geography = True
|
||||
elif geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
|
||||
if self.spheroid:
|
||||
self.function = 'ST_Distance_Spheroid' # More accurate, resource intensive
|
||||
# Replace boolean param by the real spheroid of the base field
|
||||
self.source_expressions[2] = Value(geo_field._spheroid)
|
||||
else:
|
||||
self.function = 'ST_Distance_Sphere'
|
||||
return super(Distance, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class Envelope(GeoFunc):
|
||||
pass
|
||||
|
||||
|
||||
class ForceRHR(GeoFunc):
|
||||
pass
|
||||
|
||||
|
||||
class GeoHash(GeoFunc):
|
||||
output_field_class = TextField
|
||||
|
||||
def __init__(self, expression, precision=None, **extra):
|
||||
expressions = [expression]
|
||||
if precision is not None:
|
||||
expressions.append(self._handle_param(precision, 'precision', six.integer_types))
|
||||
super(GeoHash, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class Intersection(GeoFuncWithGeoParam):
|
||||
pass
|
||||
|
||||
|
||||
class Length(DistanceResultMixin, GeoFunc):
|
||||
output_field_class = FloatField
|
||||
|
||||
def __init__(self, expr1, spheroid=True, **extra):
|
||||
self.spheroid = spheroid
|
||||
super(Length, self).__init__(expr1, **extra)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
src_field = self.get_source_fields()[0]
|
||||
geography = src_field.geography and self.srid == 4326
|
||||
if geography:
|
||||
self.source_expressions.append(Value(self.spheroid))
|
||||
elif geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||
self.function = 'ST_Length_Spheroid'
|
||||
self.source_expressions.append(Value(geo_field._spheroid))
|
||||
else:
|
||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||
if dim > 2:
|
||||
self.function = connection.ops.length3d
|
||||
return super(Length, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class MemSize(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
|
||||
|
||||
class NumGeometries(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
|
||||
|
||||
class NumPoints(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
|
||||
|
||||
class Perimeter(DistanceResultMixin, GeoFunc):
|
||||
output_field_class = FloatField
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
dim = min(f.dim for f in self.get_source_fields())
|
||||
if dim > 2:
|
||||
self.function = connection.ops.perimeter3d
|
||||
return super(Perimeter, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class PointOnSurface(GeoFunc):
|
||||
pass
|
||||
|
||||
|
||||
class Reverse(GeoFunc):
|
||||
pass
|
||||
|
||||
|
||||
class Scale(GeoFunc):
|
||||
def __init__(self, expression, x, y, z=0.0, **extra):
|
||||
expressions = [
|
||||
expression,
|
||||
self._handle_param(x, 'x', NUMERIC_TYPES),
|
||||
self._handle_param(y, 'y', NUMERIC_TYPES),
|
||||
]
|
||||
if z != 0.0:
|
||||
expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES))
|
||||
super(Scale, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class SnapToGrid(GeoFunc):
|
||||
def __init__(self, expression, *args, **extra):
|
||||
nargs = len(args)
|
||||
expressions = [expression]
|
||||
if nargs in (1, 2):
|
||||
expressions.extend(
|
||||
[self._handle_param(arg, '', NUMERIC_TYPES) for arg in args]
|
||||
)
|
||||
elif nargs == 4:
|
||||
# Reverse origin and size param ordering
|
||||
expressions.extend(
|
||||
[self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]]
|
||||
)
|
||||
expressions.extend(
|
||||
[self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]]
|
||||
)
|
||||
else:
|
||||
raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.')
|
||||
super(SnapToGrid, self).__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class SymDifference(GeoFuncWithGeoParam):
|
||||
pass
|
||||
|
||||
|
||||
class Transform(GeoFunc):
|
||||
def __init__(self, expression, srid, **extra):
|
||||
expressions = [
|
||||
expression,
|
||||
self._handle_param(srid, 'srid', six.integer_types),
|
||||
]
|
||||
super(Transform, self).__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
# Make srid the resulting srid of the transformation
|
||||
return self.source_expressions[self.geom_param_pos + 1].value
|
||||
|
||||
|
||||
class Translate(Scale):
|
||||
pass
|
||||
|
||||
|
||||
class Union(GeoFuncWithGeoParam):
|
||||
pass
|
||||
Reference in New Issue
Block a user