1
0
mirror of https://github.com/django/django.git synced 2025-03-13 19:00:45 +00:00

Fixed #26788 -- Fixed QuerySet.update() crash when updating a geometry to another one.

This commit is contained in:
Sergey Fedoseev 2017-04-10 22:26:26 +05:00 committed by Tim Graham
parent 64264c9a19
commit e7afef13f5
10 changed files with 68 additions and 75 deletions

View File

@ -71,7 +71,29 @@ class BaseSpatialOperations:
stored procedure call to the transformation function of the spatial stored procedure call to the transformation function of the spatial
backend. backend.
""" """
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') def transform_value(value, field):
return (
not (value is None or value.srid == field.srid) and
self.connection.features.supports_transform
)
if hasattr(value, 'as_sql'):
return (
'%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid)
if transform_value(value.output_field, f)
else '%s'
)
if transform_value(value, f):
# Add Transform() to the SQL placeholder.
return '%s(%s(%%s,%s), %s)' % (
self.spatial_function_name('Transform'),
self.from_text, value.srid, f.srid,
)
elif self.connection.features.has_spatialrefsys_table:
return '%s(%%s,%s)' % (self.from_text, f.srid)
else:
# For backwards compatibility on MySQL (#27464).
return '%s(%%s)' % self.from_text
def check_expression_support(self, expression): def check_expression_support(self, expression):
if isinstance(expression, self.disallowed_aggregates): if isinstance(expression, self.disallowed_aggregates):

View File

@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
def geo_db_type(self, f): def geo_db_type(self, f):
return f.geom_type return f.geom_type
def get_geom_placeholder(self, f, value, compiler):
"""
The placeholder here has to include MySQL's WKT constructor. Because
MySQL does not support spatial transformations, there is no need to
modify the placeholder based on the contents of the given value.
"""
if hasattr(value, 'as_sql'):
placeholder, _ = compiler.compile(value)
else:
placeholder = '%s(%%s)' % self.from_text
return placeholder
def get_db_converters(self, expression): def get_db_converters(self, expression):
converters = super().get_db_converters(expression) converters = super().get_db_converters(expression)
if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection: if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection:

View File

@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value, compiler): def get_geom_placeholder(self, f, value, compiler):
"""
Provide a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the
SDO_CS.TRANSFORM() function call.
"""
tranform_func = self.spatial_function_name('Transform')
if value is None: if value is None:
return 'NULL' return 'NULL'
return super().get_geom_placeholder(f, value, compiler)
def transform_value(val, srid):
return val.srid != srid
if hasattr(value, 'as_sql'):
if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
else:
placeholder = '%s'
# No geometry value used for F expression, substitute in
# the column name instead.
sql, _ = compiler.compile(value)
return placeholder % sql
else:
if transform_value(value, f.srid):
return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid)
else:
return 'SDO_GEOMETRY(%%s, %s)' % f.srid
def spatial_aggregate_name(self, agg_name): def spatial_aggregate_name(self, agg_name):
""" """

View File

@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
substitute in the ST_Transform() function call. substitute in the ST_Transform() function call.
""" """
tranform_func = self.spatial_function_name('Transform') tranform_func = self.spatial_function_name('Transform')
if hasattr(value, 'as_sql'):
if value.field.srid == f.srid:
placeholder = '%s'
else:
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
return placeholder
# Get the srid for this object # Get the srid for this object
if value is None: if value is None:
@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
else: else:
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
if hasattr(value, 'as_sql'):
# If this is an F expression, then we don't really want
# a placeholder and instead substitute in the column
# of the expression.
sql, _ = compiler.compile(value)
placeholder = placeholder % sql
return placeholder return placeholder
def _get_postgis_func(self, func): def _get_postgis_func(self, func):

View File

@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
dist_param = value dist_param = value
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value, compiler):
"""
Provide a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the
Transform() and GeomFromText() function call(s).
"""
tranform_func = self.spatial_function_name('Transform')
def transform_value(value, srid):
return not (value is None or value.srid == srid)
if hasattr(value, 'as_sql'):
if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
else:
placeholder = '%s'
# No geometry value used for F expression, substitute in
# the column name instead.
sql, _ = compiler.compile(value)
return placeholder % sql
else:
if transform_value(value, f.srid):
# Adding Transform() to the SQL placeholder.
return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid)
else:
return '%s(%%s,%s)' % (self.from_text, f.srid)
def _get_spatialite_func(self, func): def _get_spatialite_func(self, func):
""" """
Helper routine for calling SpatiaLite functions and returning Helper routine for calling SpatiaLite functions and returning

View File

@ -68,6 +68,8 @@ class GISLookup(Lookup):
if not hasattr(geo_fld, 'srid'): if not hasattr(geo_fld, 'srid'):
raise ValueError('No geographic field found in expression.') raise ValueError('No geographic field found in expression.')
self.rhs.srid = geo_fld.srid self.rhs.srid = geo_fld.srid
sql, _ = compiler.compile(geom)
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
elif isinstance(self.rhs, Expression): elif isinstance(self.rhs, Expression):
raise ValueError('Complex expressions not supported for spatial fields.') raise ValueError('Complex expressions not supported for spatial fields.')
elif isinstance(self.rhs, (list, tuple)): elif isinstance(self.rhs, (list, tuple)):

View File

@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return value return value
def binary_placeholder_sql(self, value): def binary_placeholder_sql(self, value):
return '_binary %s' if value is not None else '%s' return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
def subtract_temporals(self, internal_type, lhs, rhs): def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs lhs_sql, lhs_params = lhs

View File

@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler):
name = field.column name = field.column
if hasattr(val, 'as_sql'): if hasattr(val, 'as_sql'):
sql, params = self.compile(val) sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql)) values.append('%s = %s' % (qn(name), placeholder % sql))
update_params.extend(params) update_params.extend(params)
elif val is not None: elif val is not None:
values.append('%s = %s' % (qn(name), placeholder)) values.append('%s = %s' % (qn(name), placeholder))

View File

@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField):
class NonConcreteModel(NamedModel): class NonConcreteModel(NamedModel):
non_concrete = NonConcreteField() non_concrete = NonConcreteField()
point = models.PointField(geography=True) point = models.PointField(geography=True)
class ManyPointModel(NamedModel):
point1 = models.PointField()
point2 = models.PointField()
point3 = models.PointField(srid=3857)

View File

@ -1,11 +1,12 @@
from unittest import skipUnless from unittest import skipUnless
from django.contrib.gis.db.models import GeometryField, Value, functions from django.contrib.gis.db.models import F, GeometryField, Value, functions
from django.contrib.gis.geos import Point, Polygon from django.contrib.gis.geos import Point, Polygon
from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from ..utils import postgis from ..utils import postgis
from .models import City from .models import City, ManyPointModel
@skipUnlessDBFeature('gis_enabled') @skipUnlessDBFeature('gis_enabled')
@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase):
p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1))) p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1)))
area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a
self.assertAlmostEqual(area.sq_km, 12305.1, 0) self.assertAlmostEqual(area.sq_km, 12305.1, 0)
def test_update_from_other_field(self):
p1 = Point(1, 1, srid=4326)
p2 = Point(2, 2, srid=4326)
obj = ManyPointModel.objects.create(
point1=p1,
point2=p2,
point3=p2.transform(3857, clone=True),
)
# Updating a point to a point of the same SRID.
ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1'))
obj.refresh_from_db()
self.assertEqual(obj.point2, p1)
# Updating a point to a point with a different SRID.
if connection.features.supports_transform:
ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1'))
obj.refresh_from_db()
self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1))
@skipUnlessDBFeature('has_Translate_function')
def test_update_with_expression(self):
city = City.objects.create(point=Point(1, 1, srid=4326))
City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1))
city.refresh_from_db()
self.assertEqual(city.point, Point(2, 2, srid=4326))