1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Fixed #26788 -- Fixed QuerySet.update() crash when updating a geometry to another one.

This commit is contained in:
Sergey Fedoseev 2017-04-10 22:26:26 +05:00 committed by Tim Graham
parent 64264c9a19
commit e7afef13f5
10 changed files with 68 additions and 75 deletions

View File

@ -71,7 +71,29 @@ class BaseSpatialOperations:
stored procedure call to the transformation function of the spatial
backend.
"""
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def transform_value(value, field):
return (
not (value is None or value.srid == field.srid) and
self.connection.features.supports_transform
)
if hasattr(value, 'as_sql'):
return (
'%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid)
if transform_value(value.output_field, f)
else '%s'
)
if transform_value(value, f):
# Add Transform() to the SQL placeholder.
return '%s(%s(%%s,%s), %s)' % (
self.spatial_function_name('Transform'),
self.from_text, value.srid, f.srid,
)
elif self.connection.features.has_spatialrefsys_table:
return '%s(%%s,%s)' % (self.from_text, f.srid)
else:
# For backwards compatibility on MySQL (#27464).
return '%s(%%s)' % self.from_text
def check_expression_support(self, expression):
if isinstance(expression, self.disallowed_aggregates):

View File

@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
def geo_db_type(self, f):
return f.geom_type
def get_geom_placeholder(self, f, value, compiler):
"""
The placeholder here has to include MySQL's WKT constructor. Because
MySQL does not support spatial transformations, there is no need to
modify the placeholder based on the contents of the given value.
"""
if hasattr(value, 'as_sql'):
placeholder, _ = compiler.compile(value)
else:
placeholder = '%s(%%s)' % self.from_text
return placeholder
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection:

View File

@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
return [dist_param]
def get_geom_placeholder(self, f, value, compiler):
"""
Provide a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the
SDO_CS.TRANSFORM() function call.
"""
tranform_func = self.spatial_function_name('Transform')
if value is None:
return 'NULL'
def transform_value(val, srid):
return val.srid != srid
if hasattr(value, 'as_sql'):
if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
else:
placeholder = '%s'
# No geometry value used for F expression, substitute in
# the column name instead.
sql, _ = compiler.compile(value)
return placeholder % sql
else:
if transform_value(value, f.srid):
return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid)
else:
return 'SDO_GEOMETRY(%%s, %s)' % f.srid
return super().get_geom_placeholder(f, value, compiler)
def spatial_aggregate_name(self, agg_name):
"""

View File

@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
substitute in the ST_Transform() function call.
"""
tranform_func = self.spatial_function_name('Transform')
if hasattr(value, 'as_sql'):
if value.field.srid == f.srid:
placeholder = '%s'
else:
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
return placeholder
# Get the srid for this object
if value is None:
@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
else:
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
if hasattr(value, 'as_sql'):
# If this is an F expression, then we don't really want
# a placeholder and instead substitute in the column
# of the expression.
sql, _ = compiler.compile(value)
placeholder = placeholder % sql
return placeholder
def _get_postgis_func(self, func):

View File

@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
dist_param = value
return [dist_param]
def get_geom_placeholder(self, f, value, compiler):
"""
Provide a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the
Transform() and GeomFromText() function call(s).
"""
tranform_func = self.spatial_function_name('Transform')
def transform_value(value, srid):
return not (value is None or value.srid == srid)
if hasattr(value, 'as_sql'):
if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
else:
placeholder = '%s'
# No geometry value used for F expression, substitute in
# the column name instead.
sql, _ = compiler.compile(value)
return placeholder % sql
else:
if transform_value(value, f.srid):
# Adding Transform() to the SQL placeholder.
return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid)
else:
return '%s(%%s,%s)' % (self.from_text, f.srid)
def _get_spatialite_func(self, func):
"""
Helper routine for calling SpatiaLite functions and returning

View File

@ -68,6 +68,8 @@ class GISLookup(Lookup):
if not hasattr(geo_fld, 'srid'):
raise ValueError('No geographic field found in expression.')
self.rhs.srid = geo_fld.srid
sql, _ = compiler.compile(geom)
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
elif isinstance(self.rhs, Expression):
raise ValueError('Complex expressions not supported for spatial fields.')
elif isinstance(self.rhs, (list, tuple)):

View File

@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return value
def binary_placeholder_sql(self, value):
return '_binary %s' if value is not None else '%s'
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs

View File

@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler):
name = field.column
if hasattr(val, 'as_sql'):
sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
values.append('%s = %s' % (qn(name), placeholder % sql))
update_params.extend(params)
elif val is not None:
values.append('%s = %s' % (qn(name), placeholder))

View File

@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField):
class NonConcreteModel(NamedModel):
non_concrete = NonConcreteField()
point = models.PointField(geography=True)
class ManyPointModel(NamedModel):
point1 = models.PointField()
point2 = models.PointField()
point3 = models.PointField(srid=3857)

View File

@ -1,11 +1,12 @@
from unittest import skipUnless
from django.contrib.gis.db.models import GeometryField, Value, functions
from django.contrib.gis.db.models import F, GeometryField, Value, functions
from django.contrib.gis.geos import Point, Polygon
from django.db import connection
from django.test import TestCase, skipUnlessDBFeature
from ..utils import postgis
from .models import City
from .models import City, ManyPointModel
@skipUnlessDBFeature('gis_enabled')
@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase):
p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1)))
area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a
self.assertAlmostEqual(area.sq_km, 12305.1, 0)
def test_update_from_other_field(self):
p1 = Point(1, 1, srid=4326)
p2 = Point(2, 2, srid=4326)
obj = ManyPointModel.objects.create(
point1=p1,
point2=p2,
point3=p2.transform(3857, clone=True),
)
# Updating a point to a point of the same SRID.
ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1'))
obj.refresh_from_db()
self.assertEqual(obj.point2, p1)
# Updating a point to a point with a different SRID.
if connection.features.supports_transform:
ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1'))
obj.refresh_from_db()
self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1))
@skipUnlessDBFeature('has_Translate_function')
def test_update_with_expression(self):
city = City.objects.create(point=Point(1, 1, srid=4326))
City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1))
city.refresh_from_db()
self.assertEqual(city.point, Point(2, 2, srid=4326))