1
0
mirror of https://github.com/django/django.git synced 2025-10-26 23:26:08 +00:00

Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.

Reverted test for refs #27603 in favor of using FuncTestMixin.
This commit is contained in:
Sergey Fedoseev
2017-09-11 20:56:39 +05:00
committed by Tim Graham
parent 99e65d6488
commit 3905cfa1a5
12 changed files with 176 additions and 86 deletions

View File

@@ -1,5 +1,8 @@
from django.contrib.gis.db.models import GeometryField
from django.contrib.gis.db.models.functions import Distance
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.utils.functional import cached_property
@@ -135,3 +138,24 @@ class BaseSpatialOperations:
'Subclasses of BaseSpatialOperations must provide a '
'get_geometry_converter() method.'
)
def get_area_att_for_field(self, field):
if field.geodetic(self.connection):
if self.connection.features.supports_area_geodetic:
return 'sq_m'
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else:
units_name = field.units_name(self.connection)
if units_name:
return AreaMeasure.unit_attname(units_name)
def get_distance_att_for_field(self, field):
dist_att = None
if field.geodetic(self.connection):
if self.connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
units = field.units_name(self.connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att

View File

@@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
geom.srid = srid
return geom
return converter
def get_area_att_for_field(self, field):
return 'sq_m'

View File

@@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
def converter(value, expression, connection):
return None if value is None else GEOSGeometryBase(read(value), geom_class)
return converter
def get_area_att_for_field(self, field):
return 'sq_m'

View File

@@ -3,9 +3,6 @@ from decimal import Decimal
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.core.exceptions import FieldError
from django.db.models import (
BooleanField, FloatField, IntegerField, TextField, Transform,
@@ -121,29 +118,16 @@ class OracleToleranceMixin:
class Area(OracleToleranceMixin, GeoFunc):
output_field_class = AreaField
arity = 1
def as_sql(self, compiler, connection, **extra_context):
if connection.ops.geography:
self.output_field.area_att = 'sq_m'
else:
# Getting the area units of the geographic field.
if self.geo_field.geodetic(connection):
if connection.features.supports_area_geodetic:
self.output_field.area_att = 'sq_m'
else:
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else:
units_name = self.geo_field.units_name(connection)
if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super().as_sql(compiler, connection, **extra_context)
@cached_property
def output_field(self):
return AreaField(self.geo_field)
def as_oracle(self, compiler, connection):
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
return super().as_oracle(compiler, connection)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection):
@@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
class DistanceResultMixin:
output_field_class = DistanceField
@cached_property
def output_field(self):
return DistanceField(self.geo_field)
def source_is_geography(self):
return self.geo_field.geography and self.geo_field.srid == 4326
def distance_att(self, connection):
dist_att = None
if self.geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
units = self.geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att
def as_sql(self, compiler, connection, **extra_context):
clone = self.copy()
clone.output_field.distance_att = self.distance_att(connection)
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos = (0, 1)
@@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def __init__(self, expr1, expr2, spheroid=None, **extra):
expressions = [expr1, expr2]
if spheroid is not None:
self.spheroid = spheroid
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
super().__init__(*expressions, **extra)
def as_postgresql(self, compiler, connection):
clone = self.copy()
function = None
expr2 = self.source_expressions[1]
expr2 = clone.source_expressions[1]
geography = self.source_is_geography()
if expr2.output_field.geography != geography:
if isinstance(expr2, Value):
expr2.output_field.geography = geography
else:
self.source_expressions[1] = Cast(
clone.source_expressions[1] = Cast(
expr2,
GeometryField(srid=expr2.output_field.srid, geography=geography),
)
@@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
function = connection.ops.spatial_function_name('DistanceSpheroid')
# Replace boolean param by the real spheroid of the base field
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
function = connection.ops.spatial_function_name('DistanceSphere')
return super().as_sql(compiler, connection, function=function)
def as_oracle(self, compiler, connection):
if self.spheroid:
self.source_expressions.pop(2)
return super().as_oracle(compiler, connection)
return super(Distance, clone).as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection, **extra_context):
if self.spheroid:
self.source_expressions.pop(2)
if self.geo_field.geodetic(connection):
# SpatiaLite returns NULL instead of zero on geodetic coordinates
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
@@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
clone = self.copy()
function = None
if self.source_is_geography():
self.source_expressions.append(Value(self.spheroid))
clone.source_expressions.append(Value(self.spheroid))
elif self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
function = connection.ops.spatial_function_name('LengthSpheroid')
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2:
function = connection.ops.length3d
return super().as_sql(compiler, connection, function=function)
return super(Length, clone).as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection):
function = None
@@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc):
class Translate(Scale):
def as_sqlite(self, compiler, connection):
clone = self.copy()
if len(self.source_expressions) < 4:
# Always provide the z parameter for ST_Translate
self.source_expressions.append(Value(0))
return super().as_sqlite(compiler, connection)
clone.source_expressions.append(Value(0))
return super(Translate, clone).as_sqlite(compiler, connection)
class Union(OracleToleranceMixin, GeomOutputGeoFunc):

View File

@@ -10,9 +10,9 @@ from django.db import models
class AreaField(models.FloatField):
"Wrapper for Area values."
def __init__(self, area_att=None):
def __init__(self, geo_field):
super().__init__()
self.area_att = area_att
self.geo_field = geo_field
def get_prep_value(self, value):
if not isinstance(value, Area):
@@ -20,19 +20,21 @@ class AreaField(models.FloatField):
return value
def get_db_prep_value(self, value, connection, prepared=False):
if value is None or not self.area_att:
return value
return getattr(value, self.area_att)
if value is None:
return
area_att = connection.ops.get_area_att_for_field(self.geo_field)
return getattr(value, area_att) if area_att else value
def from_db_value(self, value, expression, connection):
if value is None:
return
# If the database returns a Decimal, convert it to a float as expected
# by the Python geometric objects.
if isinstance(value, Decimal):
value = float(value)
# If the units are known, convert value into area measure.
if value is not None and self.area_att:
value = Area(**{self.area_att: value})
return value
area_att = connection.ops.get_area_att_for_field(self.geo_field)
return Area(**{area_att: value}) if area_att else value
def get_internal_type(self):
return 'AreaField'
@@ -40,9 +42,9 @@ class AreaField(models.FloatField):
class DistanceField(models.FloatField):
"Wrapper for Distance values."
def __init__(self, distance_att=None):
def __init__(self, geo_field):
super().__init__()
self.distance_att = distance_att
self.geo_field = geo_field
def get_prep_value(self, value):
if isinstance(value, Distance):
@@ -52,14 +54,16 @@ class DistanceField(models.FloatField):
def get_db_prep_value(self, value, connection, prepared=False):
if not isinstance(value, Distance):
return value
if not self.distance_att:
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
if not distance_att:
raise ValueError('Distance measure is supplied, but units are unknown for result.')
return getattr(value, self.distance_att)
return getattr(value, distance_att)
def from_db_value(self, value, expression, connection):
if value is None or not self.distance_att:
return value
return Distance(**{self.distance_att: value})
if value is None:
return
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
return Distance(**{distance_att: value}) if distance_att else value
def get_internal_type(self):
return 'DistanceField'