mirror of
https://github.com/django/django.git
synced 2025-08-10 03:49:11 +00:00
Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.
Reverted test for refs #27603 in favor of using FuncTestMixin.
This commit is contained in:
parent
99e65d6488
commit
3905cfa1a5
@ -1,5 +1,8 @@
|
|||||||
from django.contrib.gis.db.models import GeometryField
|
from django.contrib.gis.db.models import GeometryField
|
||||||
from django.contrib.gis.db.models.functions import Distance
|
from django.contrib.gis.db.models.functions import Distance
|
||||||
|
from django.contrib.gis.measure import (
|
||||||
|
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||||
|
)
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
@ -135,3 +138,24 @@ class BaseSpatialOperations:
|
|||||||
'Subclasses of BaseSpatialOperations must provide a '
|
'Subclasses of BaseSpatialOperations must provide a '
|
||||||
'get_geometry_converter() method.'
|
'get_geometry_converter() method.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_area_att_for_field(self, field):
|
||||||
|
if field.geodetic(self.connection):
|
||||||
|
if self.connection.features.supports_area_geodetic:
|
||||||
|
return 'sq_m'
|
||||||
|
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||||
|
else:
|
||||||
|
units_name = field.units_name(self.connection)
|
||||||
|
if units_name:
|
||||||
|
return AreaMeasure.unit_attname(units_name)
|
||||||
|
|
||||||
|
def get_distance_att_for_field(self, field):
|
||||||
|
dist_att = None
|
||||||
|
if field.geodetic(self.connection):
|
||||||
|
if self.connection.features.supports_distance_geodetic:
|
||||||
|
dist_att = 'm'
|
||||||
|
else:
|
||||||
|
units = field.units_name(self.connection)
|
||||||
|
if units:
|
||||||
|
dist_att = DistanceMeasure.unit_attname(units)
|
||||||
|
return dist_att
|
||||||
|
@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
|||||||
geom.srid = srid
|
geom.srid = srid
|
||||||
return geom
|
return geom
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
|
def get_area_att_for_field(self, field):
|
||||||
|
return 'sq_m'
|
||||||
|
@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||||||
def converter(value, expression, connection):
|
def converter(value, expression, connection):
|
||||||
return None if value is None else GEOSGeometryBase(read(value), geom_class)
|
return None if value is None else GEOSGeometryBase(read(value), geom_class)
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
|
def get_area_att_for_field(self, field):
|
||||||
|
return 'sq_m'
|
||||||
|
@ -3,9 +3,6 @@ from decimal import Decimal
|
|||||||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
||||||
from django.contrib.gis.geometry.backend import Geometry
|
from django.contrib.gis.geometry.backend import Geometry
|
||||||
from django.contrib.gis.measure import (
|
|
||||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
|
||||||
)
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
BooleanField, FloatField, IntegerField, TextField, Transform,
|
BooleanField, FloatField, IntegerField, TextField, Transform,
|
||||||
@ -121,29 +118,16 @@ class OracleToleranceMixin:
|
|||||||
|
|
||||||
|
|
||||||
class Area(OracleToleranceMixin, GeoFunc):
|
class Area(OracleToleranceMixin, GeoFunc):
|
||||||
output_field_class = AreaField
|
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
@cached_property
|
||||||
if connection.ops.geography:
|
def output_field(self):
|
||||||
self.output_field.area_att = 'sq_m'
|
return AreaField(self.geo_field)
|
||||||
else:
|
|
||||||
# Getting the area units of the geographic field.
|
|
||||||
if self.geo_field.geodetic(connection):
|
|
||||||
if connection.features.supports_area_geodetic:
|
|
||||||
self.output_field.area_att = 'sq_m'
|
|
||||||
else:
|
|
||||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
|
||||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
|
||||||
else:
|
|
||||||
units_name = self.geo_field.units_name(connection)
|
|
||||||
if units_name:
|
|
||||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
|
if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
|
||||||
return super().as_oracle(compiler, connection)
|
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||||
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if self.geo_field.geodetic(connection):
|
if self.geo_field.geodetic(connection):
|
||||||
@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
|||||||
|
|
||||||
|
|
||||||
class DistanceResultMixin:
|
class DistanceResultMixin:
|
||||||
output_field_class = DistanceField
|
@cached_property
|
||||||
|
def output_field(self):
|
||||||
|
return DistanceField(self.geo_field)
|
||||||
|
|
||||||
def source_is_geography(self):
|
def source_is_geography(self):
|
||||||
return self.geo_field.geography and self.geo_field.srid == 4326
|
return self.geo_field.geography and self.geo_field.srid == 4326
|
||||||
|
|
||||||
def distance_att(self, connection):
|
|
||||||
dist_att = None
|
|
||||||
if self.geo_field.geodetic(connection):
|
|
||||||
if connection.features.supports_distance_geodetic:
|
|
||||||
dist_att = 'm'
|
|
||||||
else:
|
|
||||||
units = self.geo_field.units_name(connection)
|
|
||||||
if units:
|
|
||||||
dist_att = DistanceMeasure.unit_attname(units)
|
|
||||||
return dist_att
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
|
||||||
clone = self.copy()
|
|
||||||
clone.output_field.distance_att = self.distance_att(connection)
|
|
||||||
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
|
|
||||||
|
|
||||||
|
|
||||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
geom_param_pos = (0, 1)
|
geom_param_pos = (0, 1)
|
||||||
@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||||||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||||
expressions = [expr1, expr2]
|
expressions = [expr1, expr2]
|
||||||
if spheroid is not None:
|
if spheroid is not None:
|
||||||
self.spheroid = spheroid
|
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
|
||||||
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
|
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
|
clone = self.copy()
|
||||||
function = None
|
function = None
|
||||||
expr2 = self.source_expressions[1]
|
expr2 = clone.source_expressions[1]
|
||||||
geography = self.source_is_geography()
|
geography = self.source_is_geography()
|
||||||
if expr2.output_field.geography != geography:
|
if expr2.output_field.geography != geography:
|
||||||
if isinstance(expr2, Value):
|
if isinstance(expr2, Value):
|
||||||
expr2.output_field.geography = geography
|
expr2.output_field.geography = geography
|
||||||
else:
|
else:
|
||||||
self.source_expressions[1] = Cast(
|
clone.source_expressions[1] = Cast(
|
||||||
expr2,
|
expr2,
|
||||||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||||
)
|
)
|
||||||
@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||||||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||||
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
||||||
# Replace boolean param by the real spheroid of the base field
|
# Replace boolean param by the real spheroid of the base field
|
||||||
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
|
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||||
else:
|
else:
|
||||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super(Distance, clone).as_sql(compiler, connection, function=function)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
|
||||||
if self.spheroid:
|
|
||||||
self.source_expressions.pop(2)
|
|
||||||
return super().as_oracle(compiler, connection)
|
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if self.spheroid:
|
|
||||||
self.source_expressions.pop(2)
|
|
||||||
if self.geo_field.geodetic(connection):
|
if self.geo_field.geodetic(connection):
|
||||||
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
||||||
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
|
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
|
||||||
@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
|
clone = self.copy()
|
||||||
function = None
|
function = None
|
||||||
if self.source_is_geography():
|
if self.source_is_geography():
|
||||||
self.source_expressions.append(Value(self.spheroid))
|
clone.source_expressions.append(Value(self.spheroid))
|
||||||
elif self.geo_field.geodetic(connection):
|
elif self.geo_field.geodetic(connection):
|
||||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||||
function = connection.ops.spatial_function_name('LengthSpheroid')
|
function = connection.ops.spatial_function_name('LengthSpheroid')
|
||||||
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||||
else:
|
else:
|
||||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||||
if dim > 2:
|
if dim > 2:
|
||||||
function = connection.ops.length3d
|
function = connection.ops.length3d
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super(Length, clone).as_sql(compiler, connection, function=function)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc):
|
|||||||
|
|
||||||
class Translate(Scale):
|
class Translate(Scale):
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
|
clone = self.copy()
|
||||||
if len(self.source_expressions) < 4:
|
if len(self.source_expressions) < 4:
|
||||||
# Always provide the z parameter for ST_Translate
|
# Always provide the z parameter for ST_Translate
|
||||||
self.source_expressions.append(Value(0))
|
clone.source_expressions.append(Value(0))
|
||||||
return super().as_sqlite(compiler, connection)
|
return super(Translate, clone).as_sqlite(compiler, connection)
|
||||||
|
|
||||||
|
|
||||||
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
@ -10,9 +10,9 @@ from django.db import models
|
|||||||
|
|
||||||
class AreaField(models.FloatField):
|
class AreaField(models.FloatField):
|
||||||
"Wrapper for Area values."
|
"Wrapper for Area values."
|
||||||
def __init__(self, area_att=None):
|
def __init__(self, geo_field):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.area_att = area_att
|
self.geo_field = geo_field
|
||||||
|
|
||||||
def get_prep_value(self, value):
|
def get_prep_value(self, value):
|
||||||
if not isinstance(value, Area):
|
if not isinstance(value, Area):
|
||||||
@ -20,19 +20,21 @@ class AreaField(models.FloatField):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def get_db_prep_value(self, value, connection, prepared=False):
|
def get_db_prep_value(self, value, connection, prepared=False):
|
||||||
if value is None or not self.area_att:
|
if value is None:
|
||||||
return value
|
return
|
||||||
return getattr(value, self.area_att)
|
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||||
|
return getattr(value, area_att) if area_att else value
|
||||||
|
|
||||||
def from_db_value(self, value, expression, connection):
|
def from_db_value(self, value, expression, connection):
|
||||||
|
if value is None:
|
||||||
|
return
|
||||||
# If the database returns a Decimal, convert it to a float as expected
|
# If the database returns a Decimal, convert it to a float as expected
|
||||||
# by the Python geometric objects.
|
# by the Python geometric objects.
|
||||||
if isinstance(value, Decimal):
|
if isinstance(value, Decimal):
|
||||||
value = float(value)
|
value = float(value)
|
||||||
# If the units are known, convert value into area measure.
|
# If the units are known, convert value into area measure.
|
||||||
if value is not None and self.area_att:
|
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||||
value = Area(**{self.area_att: value})
|
return Area(**{area_att: value}) if area_att else value
|
||||||
return value
|
|
||||||
|
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return 'AreaField'
|
return 'AreaField'
|
||||||
@ -40,9 +42,9 @@ class AreaField(models.FloatField):
|
|||||||
|
|
||||||
class DistanceField(models.FloatField):
|
class DistanceField(models.FloatField):
|
||||||
"Wrapper for Distance values."
|
"Wrapper for Distance values."
|
||||||
def __init__(self, distance_att=None):
|
def __init__(self, geo_field):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.distance_att = distance_att
|
self.geo_field = geo_field
|
||||||
|
|
||||||
def get_prep_value(self, value):
|
def get_prep_value(self, value):
|
||||||
if isinstance(value, Distance):
|
if isinstance(value, Distance):
|
||||||
@ -52,14 +54,16 @@ class DistanceField(models.FloatField):
|
|||||||
def get_db_prep_value(self, value, connection, prepared=False):
|
def get_db_prep_value(self, value, connection, prepared=False):
|
||||||
if not isinstance(value, Distance):
|
if not isinstance(value, Distance):
|
||||||
return value
|
return value
|
||||||
if not self.distance_att:
|
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||||
|
if not distance_att:
|
||||||
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
||||||
return getattr(value, self.distance_att)
|
return getattr(value, distance_att)
|
||||||
|
|
||||||
def from_db_value(self, value, expression, connection):
|
def from_db_value(self, value, expression, connection):
|
||||||
if value is None or not self.distance_att:
|
if value is None:
|
||||||
return value
|
return
|
||||||
return Distance(**{self.distance_att: value})
|
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||||
|
return Distance(**{distance_att: value}) if distance_att else value
|
||||||
|
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return 'DistanceField'
|
return 'DistanceField'
|
||||||
|
@ -9,7 +9,9 @@ from django.db import connection
|
|||||||
from django.db.models import F, Q
|
from django.db.models import F, Q
|
||||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
|
|
||||||
from ..utils import mysql, no_oracle, oracle, postgis, spatialite
|
from ..utils import (
|
||||||
|
FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite,
|
||||||
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt,
|
AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt,
|
||||||
SouthTexasInterstate, SouthTexasZipcode,
|
SouthTexasInterstate, SouthTexasZipcode,
|
||||||
@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-(
|
|||||||
''' # NOQA
|
''' # NOQA
|
||||||
|
|
||||||
|
|
||||||
class DistanceFunctionsTests(TestCase):
|
class DistanceFunctionsTests(FuncTestMixin, TestCase):
|
||||||
fixtures = ['initial']
|
fixtures = ['initial']
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_Area_function")
|
@skipUnlessDBFeature("has_Area_function")
|
||||||
|
@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import (
|
|||||||
from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon
|
from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
|
from ..utils import FuncTestMixin
|
||||||
from .models import (
|
from .models import (
|
||||||
City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D,
|
City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D,
|
||||||
MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D,
|
MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D,
|
||||||
@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_3d_functions")
|
@skipUnlessDBFeature("supports_3d_functions")
|
||||||
class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase):
|
class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase):
|
||||||
def test_kml(self):
|
def test_kml(self):
|
||||||
"""
|
"""
|
||||||
Test KML() function with Z values.
|
Test KML() function with Z values.
|
||||||
|
@ -12,11 +12,11 @@ from django.db import connection
|
|||||||
from django.db.models import Sum
|
from django.db.models import Sum
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
from ..utils import mysql, oracle, postgis, spatialite
|
from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite
|
||||||
from .models import City, Country, CountryWebMercator, State, Track
|
from .models import City, Country, CountryWebMercator, State, Track
|
||||||
|
|
||||||
|
|
||||||
class GISFunctionsTests(TestCase):
|
class GISFunctionsTests(FuncTestMixin, TestCase):
|
||||||
"""
|
"""
|
||||||
Testing functions from django/contrib/gis/db/models/functions.py.
|
Testing functions from django/contrib/gis/db/models/functions.py.
|
||||||
Area/Distance/Length/Perimeter are tested in distapp/tests.
|
Area/Distance/Length/Perimeter are tested in distapp/tests.
|
||||||
@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase):
|
|||||||
City.objects.annotate(kml=functions.AsKML('name'))
|
City.objects.annotate(kml=functions.AsKML('name'))
|
||||||
|
|
||||||
# Ensuring the KML is as expected.
|
# Ensuring the KML is as expected.
|
||||||
qs = City.objects.annotate(kml=functions.AsKML('point', precision=9))
|
ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo')
|
||||||
ptown = qs.get(name='Pueblo')
|
|
||||||
self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml)
|
self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml)
|
||||||
# Same result if the queryset is evaluated again.
|
|
||||||
self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml)
|
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_AsSVG_function")
|
@skipUnlessDBFeature("has_AsSVG_function")
|
||||||
def test_assvg(self):
|
def test_assvg(self):
|
||||||
|
@ -11,7 +11,7 @@ from django.db import connection
|
|||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast
|
||||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
|
|
||||||
from ..utils import oracle, postgis, spatialite
|
from ..utils import FuncTestMixin, oracle, postgis, spatialite
|
||||||
from .models import City, County, Zipcode
|
from .models import City, County, Zipcode
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class GeographyTest(TestCase):
|
|||||||
self.assertEqual(state, c.state)
|
self.assertEqual(state, c.state)
|
||||||
|
|
||||||
|
|
||||||
class GeographyFunctionTests(TestCase):
|
class GeographyFunctionTests(FuncTestMixin, TestCase):
|
||||||
fixtures = ['initial']
|
fixtures = ['initial']
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_extent_aggr")
|
@skipUnlessDBFeature("supports_extent_aggr")
|
||||||
|
@ -7,9 +7,9 @@ from django.test import SimpleTestCase
|
|||||||
class FieldsTests(SimpleTestCase):
|
class FieldsTests(SimpleTestCase):
|
||||||
|
|
||||||
def test_area_field_deepcopy(self):
|
def test_area_field_deepcopy(self):
|
||||||
field = AreaField()
|
field = AreaField(None)
|
||||||
self.assertEqual(copy.deepcopy(field), field)
|
self.assertEqual(copy.deepcopy(field), field)
|
||||||
|
|
||||||
def test_distance_field_deepcopy(self):
|
def test_distance_field_deepcopy(self):
|
||||||
field = DistanceField()
|
field = DistanceField(None)
|
||||||
self.assertEqual(copy.deepcopy(field), field)
|
self.assertEqual(copy.deepcopy(field), field)
|
||||||
|
52
tests/gis_tests/test_gis_tests_utils.py
Normal file
52
tests/gis_tests/test_gis_tests_utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from django.db import connection, models
|
||||||
|
from django.db.models.expressions import Func
|
||||||
|
from django.test import SimpleTestCase
|
||||||
|
|
||||||
|
from .utils import FuncTestMixin
|
||||||
|
|
||||||
|
|
||||||
|
def test_mutation(raises=True):
|
||||||
|
def wrapper(mutation_func):
|
||||||
|
def test(test_case_instance, *args, **kwargs):
|
||||||
|
class TestFunc(Func):
|
||||||
|
output_field = models.IntegerField()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.attribute = 'initial'
|
||||||
|
super().__init__('initial', ['initial'])
|
||||||
|
|
||||||
|
def as_sql(self, *args, **kwargs):
|
||||||
|
mutation_func(self)
|
||||||
|
return '', ()
|
||||||
|
|
||||||
|
if raises:
|
||||||
|
msg = 'TestFunc Func was mutated during compilation.'
|
||||||
|
with test_case_instance.assertRaisesMessage(AssertionError, msg):
|
||||||
|
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
|
||||||
|
else:
|
||||||
|
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
|
||||||
|
|
||||||
|
return test
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
|
||||||
|
@test_mutation()
|
||||||
|
def test_mutated_attribute(func):
|
||||||
|
func.attribute = 'mutated'
|
||||||
|
|
||||||
|
@test_mutation()
|
||||||
|
def test_mutated_expressions(func):
|
||||||
|
func.source_expressions.clear()
|
||||||
|
|
||||||
|
@test_mutation()
|
||||||
|
def test_mutated_expression(func):
|
||||||
|
func.source_expressions[0].name = 'mutated'
|
||||||
|
|
||||||
|
@test_mutation()
|
||||||
|
def test_mutated_expression_deep(func):
|
||||||
|
func.source_expressions[1].value[0] = 'mutated'
|
||||||
|
|
||||||
|
@test_mutation(raises=False)
|
||||||
|
def test_not_mutated(func):
|
||||||
|
pass
|
@ -1,8 +1,11 @@
|
|||||||
|
import copy
|
||||||
import unittest
|
import unittest
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import DEFAULT_DB_ALIAS, connection
|
from django.db import DEFAULT_DB_ALIAS, connection
|
||||||
|
from django.db.models.expressions import Func
|
||||||
|
|
||||||
|
|
||||||
def skipUnlessGISLookup(*gis_lookups):
|
def skipUnlessGISLookup(*gis_lookups):
|
||||||
@ -56,3 +59,39 @@ elif spatialite:
|
|||||||
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys
|
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys
|
||||||
else:
|
else:
|
||||||
SpatialRefSys = None
|
SpatialRefSys = None
|
||||||
|
|
||||||
|
|
||||||
|
class FuncTestMixin:
|
||||||
|
"""Assert that Func expressions aren't mutated during their as_sql()."""
|
||||||
|
def setUp(self):
|
||||||
|
def as_sql_wrapper(original_as_sql):
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
func = original_as_sql.__self__
|
||||||
|
# Resolve output_field before as_sql() so touching it in
|
||||||
|
# as_sql() won't change __dict__.
|
||||||
|
func.output_field
|
||||||
|
__dict__original = copy.deepcopy(func.__dict__)
|
||||||
|
result = original_as_sql(*args, **kwargs)
|
||||||
|
msg = '%s Func was mutated during compilation.' % func.__class__.__name__
|
||||||
|
self.assertEqual(func.__dict__, __dict__original, msg)
|
||||||
|
return result
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
if name != vendor_impl:
|
||||||
|
return __getattribute__original(self, name)
|
||||||
|
try:
|
||||||
|
as_sql = __getattribute__original(self, vendor_impl)
|
||||||
|
except AttributeError:
|
||||||
|
as_sql = __getattribute__original(self, 'as_sql')
|
||||||
|
return as_sql_wrapper(as_sql)
|
||||||
|
|
||||||
|
vendor_impl = 'as_' + connection.vendor
|
||||||
|
__getattribute__original = Func.__getattribute__
|
||||||
|
self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__)
|
||||||
|
self.func_patcher.start()
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
self.func_patcher.stop()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user