mirror of
https://github.com/django/django.git
synced 2025-10-26 23:26:08 +00:00
Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.
Reverted test for refs #27603 in favor of using FuncTestMixin.
This commit is contained in:
committed by
Tim Graham
parent
99e65d6488
commit
3905cfa1a5
@@ -1,5 +1,8 @@
|
||||
from django.contrib.gis.db.models import GeometryField
|
||||
from django.contrib.gis.db.models.functions import Distance
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
@@ -135,3 +138,24 @@ class BaseSpatialOperations:
|
||||
'Subclasses of BaseSpatialOperations must provide a '
|
||||
'get_geometry_converter() method.'
|
||||
)
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
if field.geodetic(self.connection):
|
||||
if self.connection.features.supports_area_geodetic:
|
||||
return 'sq_m'
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
else:
|
||||
units_name = field.units_name(self.connection)
|
||||
if units_name:
|
||||
return AreaMeasure.unit_attname(units_name)
|
||||
|
||||
def get_distance_att_for_field(self, field):
|
||||
dist_att = None
|
||||
if field.geodetic(self.connection):
|
||||
if self.connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
units = field.units_name(self.connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
||||
@@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
geom.srid = srid
|
||||
return geom
|
||||
return converter
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
return 'sq_m'
|
||||
|
||||
@@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
def converter(value, expression, connection):
|
||||
return None if value is None else GEOSGeometryBase(read(value), geom_class)
|
||||
return converter
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
return 'sq_m'
|
||||
|
||||
@@ -3,9 +3,6 @@ from decimal import Decimal
|
||||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
)
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models import (
|
||||
BooleanField, FloatField, IntegerField, TextField, Transform,
|
||||
@@ -121,29 +118,16 @@ class OracleToleranceMixin:
|
||||
|
||||
|
||||
class Area(OracleToleranceMixin, GeoFunc):
|
||||
output_field_class = AreaField
|
||||
arity = 1
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if connection.ops.geography:
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# Getting the area units of the geographic field.
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_area_geodetic:
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
else:
|
||||
units_name = self.geo_field.units_name(connection)
|
||||
if units_name:
|
||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return AreaField(self.geo_field)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
|
||||
return super().as_oracle(compiler, connection)
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if self.geo_field.geodetic(connection):
|
||||
@@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
|
||||
|
||||
class DistanceResultMixin:
|
||||
output_field_class = DistanceField
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return DistanceField(self.geo_field)
|
||||
|
||||
def source_is_geography(self):
|
||||
return self.geo_field.geography and self.geo_field.srid == 4326
|
||||
|
||||
def distance_att(self, connection):
|
||||
dist_att = None
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
units = self.geo_field.units_name(connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
clone = self.copy()
|
||||
clone.output_field.distance_att = self.distance_att(connection)
|
||||
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
geom_param_pos = (0, 1)
|
||||
@@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||
expressions = [expr1, expr2]
|
||||
if spheroid is not None:
|
||||
self.spheroid = spheroid
|
||||
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
|
||||
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
function = None
|
||||
expr2 = self.source_expressions[1]
|
||||
expr2 = clone.source_expressions[1]
|
||||
geography = self.source_is_geography()
|
||||
if expr2.output_field.geography != geography:
|
||||
if isinstance(expr2, Value):
|
||||
expr2.output_field.geography = geography
|
||||
else:
|
||||
self.source_expressions[1] = Cast(
|
||||
clone.source_expressions[1] = Cast(
|
||||
expr2,
|
||||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||
)
|
||||
@@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
||||
# Replace boolean param by the real spheroid of the base field
|
||||
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
|
||||
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
else:
|
||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
if self.spheroid:
|
||||
self.source_expressions.pop(2)
|
||||
return super().as_oracle(compiler, connection)
|
||||
return super(Distance, clone).as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if self.spheroid:
|
||||
self.source_expressions.pop(2)
|
||||
if self.geo_field.geodetic(connection):
|
||||
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
||||
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
|
||||
@@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
function = None
|
||||
if self.source_is_geography():
|
||||
self.source_expressions.append(Value(self.spheroid))
|
||||
clone.source_expressions.append(Value(self.spheroid))
|
||||
elif self.geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||
function = connection.ops.spatial_function_name('LengthSpheroid')
|
||||
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
else:
|
||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||
if dim > 2:
|
||||
function = connection.ops.length3d
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
return super(Length, clone).as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
function = None
|
||||
@@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc):
|
||||
|
||||
class Translate(Scale):
|
||||
def as_sqlite(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
if len(self.source_expressions) < 4:
|
||||
# Always provide the z parameter for ST_Translate
|
||||
self.source_expressions.append(Value(0))
|
||||
return super().as_sqlite(compiler, connection)
|
||||
clone.source_expressions.append(Value(0))
|
||||
return super(Translate, clone).as_sqlite(compiler, connection)
|
||||
|
||||
|
||||
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
|
||||
@@ -10,9 +10,9 @@ from django.db import models
|
||||
|
||||
class AreaField(models.FloatField):
|
||||
"Wrapper for Area values."
|
||||
def __init__(self, area_att=None):
|
||||
def __init__(self, geo_field):
|
||||
super().__init__()
|
||||
self.area_att = area_att
|
||||
self.geo_field = geo_field
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if not isinstance(value, Area):
|
||||
@@ -20,19 +20,21 @@ class AreaField(models.FloatField):
|
||||
return value
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None or not self.area_att:
|
||||
return value
|
||||
return getattr(value, self.area_att)
|
||||
if value is None:
|
||||
return
|
||||
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||
return getattr(value, area_att) if area_att else value
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return
|
||||
# If the database returns a Decimal, convert it to a float as expected
|
||||
# by the Python geometric objects.
|
||||
if isinstance(value, Decimal):
|
||||
value = float(value)
|
||||
# If the units are known, convert value into area measure.
|
||||
if value is not None and self.area_att:
|
||||
value = Area(**{self.area_att: value})
|
||||
return value
|
||||
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||
return Area(**{area_att: value}) if area_att else value
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'AreaField'
|
||||
@@ -40,9 +42,9 @@ class AreaField(models.FloatField):
|
||||
|
||||
class DistanceField(models.FloatField):
|
||||
"Wrapper for Distance values."
|
||||
def __init__(self, distance_att=None):
|
||||
def __init__(self, geo_field):
|
||||
super().__init__()
|
||||
self.distance_att = distance_att
|
||||
self.geo_field = geo_field
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if isinstance(value, Distance):
|
||||
@@ -52,14 +54,16 @@ class DistanceField(models.FloatField):
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not isinstance(value, Distance):
|
||||
return value
|
||||
if not self.distance_att:
|
||||
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||
if not distance_att:
|
||||
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
||||
return getattr(value, self.distance_att)
|
||||
return getattr(value, distance_att)
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None or not self.distance_att:
|
||||
return value
|
||||
return Distance(**{self.distance_att: value})
|
||||
if value is None:
|
||||
return
|
||||
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||
return Distance(**{distance_att: value}) if distance_att else value
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'DistanceField'
|
||||
|
||||
Reference in New Issue
Block a user