mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #26788 -- Fixed QuerySet.update() crash when updating a geometry to another one.
This commit is contained in:
parent
64264c9a19
commit
e7afef13f5
@ -71,7 +71,29 @@ class BaseSpatialOperations:
|
||||
stored procedure call to the transformation function of the spatial
|
||||
backend.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
|
||||
def transform_value(value, field):
|
||||
return (
|
||||
not (value is None or value.srid == field.srid) and
|
||||
self.connection.features.supports_transform
|
||||
)
|
||||
|
||||
if hasattr(value, 'as_sql'):
|
||||
return (
|
||||
'%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid)
|
||||
if transform_value(value.output_field, f)
|
||||
else '%s'
|
||||
)
|
||||
if transform_value(value, f):
|
||||
# Add Transform() to the SQL placeholder.
|
||||
return '%s(%s(%%s,%s), %s)' % (
|
||||
self.spatial_function_name('Transform'),
|
||||
self.from_text, value.srid, f.srid,
|
||||
)
|
||||
elif self.connection.features.has_spatialrefsys_table:
|
||||
return '%s(%%s,%s)' % (self.from_text, f.srid)
|
||||
else:
|
||||
# For backwards compatibility on MySQL (#27464).
|
||||
return '%s(%%s)' % self.from_text
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
if isinstance(expression, self.disallowed_aggregates):
|
||||
|
@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
def geo_db_type(self, f):
|
||||
return f.geom_type
|
||||
|
||||
def get_geom_placeholder(self, f, value, compiler):
|
||||
"""
|
||||
The placeholder here has to include MySQL's WKT constructor. Because
|
||||
MySQL does not support spatial transformations, there is no need to
|
||||
modify the placeholder based on the contents of the given value.
|
||||
"""
|
||||
if hasattr(value, 'as_sql'):
|
||||
placeholder, _ = compiler.compile(value)
|
||||
else:
|
||||
placeholder = '%s(%%s)' % self.from_text
|
||||
return placeholder
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection:
|
||||
|
@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
return [dist_param]
|
||||
|
||||
def get_geom_placeholder(self, f, value, compiler):
|
||||
"""
|
||||
Provide a proper substitution value for Geometries that are not in the
|
||||
SRID of the field. Specifically, this routine will substitute in the
|
||||
SDO_CS.TRANSFORM() function call.
|
||||
"""
|
||||
tranform_func = self.spatial_function_name('Transform')
|
||||
|
||||
if value is None:
|
||||
return 'NULL'
|
||||
|
||||
def transform_value(val, srid):
|
||||
return val.srid != srid
|
||||
|
||||
if hasattr(value, 'as_sql'):
|
||||
if transform_value(value, f.srid):
|
||||
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
|
||||
else:
|
||||
placeholder = '%s'
|
||||
# No geometry value used for F expression, substitute in
|
||||
# the column name instead.
|
||||
sql, _ = compiler.compile(value)
|
||||
return placeholder % sql
|
||||
else:
|
||||
if transform_value(value, f.srid):
|
||||
return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid)
|
||||
else:
|
||||
return 'SDO_GEOMETRY(%%s, %s)' % f.srid
|
||||
return super().get_geom_placeholder(f, value, compiler)
|
||||
|
||||
def spatial_aggregate_name(self, agg_name):
|
||||
"""
|
||||
|
@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
substitute in the ST_Transform() function call.
|
||||
"""
|
||||
tranform_func = self.spatial_function_name('Transform')
|
||||
if hasattr(value, 'as_sql'):
|
||||
if value.field.srid == f.srid:
|
||||
placeholder = '%s'
|
||||
else:
|
||||
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
|
||||
return placeholder
|
||||
|
||||
# Get the srid for this object
|
||||
if value is None:
|
||||
@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
else:
|
||||
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
|
||||
|
||||
if hasattr(value, 'as_sql'):
|
||||
# If this is an F expression, then we don't really want
|
||||
# a placeholder and instead substitute in the column
|
||||
# of the expression.
|
||||
sql, _ = compiler.compile(value)
|
||||
placeholder = placeholder % sql
|
||||
|
||||
return placeholder
|
||||
|
||||
def _get_postgis_func(self, func):
|
||||
|
@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
dist_param = value
|
||||
return [dist_param]
|
||||
|
||||
def get_geom_placeholder(self, f, value, compiler):
|
||||
"""
|
||||
Provide a proper substitution value for Geometries that are not in the
|
||||
SRID of the field. Specifically, this routine will substitute in the
|
||||
Transform() and GeomFromText() function call(s).
|
||||
"""
|
||||
tranform_func = self.spatial_function_name('Transform')
|
||||
|
||||
def transform_value(value, srid):
|
||||
return not (value is None or value.srid == srid)
|
||||
if hasattr(value, 'as_sql'):
|
||||
if transform_value(value, f.srid):
|
||||
placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
|
||||
else:
|
||||
placeholder = '%s'
|
||||
# No geometry value used for F expression, substitute in
|
||||
# the column name instead.
|
||||
sql, _ = compiler.compile(value)
|
||||
return placeholder % sql
|
||||
else:
|
||||
if transform_value(value, f.srid):
|
||||
# Adding Transform() to the SQL placeholder.
|
||||
return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid)
|
||||
else:
|
||||
return '%s(%%s,%s)' % (self.from_text, f.srid)
|
||||
|
||||
def _get_spatialite_func(self, func):
|
||||
"""
|
||||
Helper routine for calling SpatiaLite functions and returning
|
||||
|
@ -68,6 +68,8 @@ class GISLookup(Lookup):
|
||||
if not hasattr(geo_fld, 'srid'):
|
||||
raise ValueError('No geographic field found in expression.')
|
||||
self.rhs.srid = geo_fld.srid
|
||||
sql, _ = compiler.compile(geom)
|
||||
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
|
||||
elif isinstance(self.rhs, Expression):
|
||||
raise ValueError('Complex expressions not supported for spatial fields.')
|
||||
elif isinstance(self.rhs, (list, tuple)):
|
||||
|
@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
return value
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
return '_binary %s' if value is not None else '%s'
|
||||
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
|
@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||
name = field.column
|
||||
if hasattr(val, 'as_sql'):
|
||||
sql, params = self.compile(val)
|
||||
values.append('%s = %s' % (qn(name), sql))
|
||||
values.append('%s = %s' % (qn(name), placeholder % sql))
|
||||
update_params.extend(params)
|
||||
elif val is not None:
|
||||
values.append('%s = %s' % (qn(name), placeholder))
|
||||
|
@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField):
|
||||
class NonConcreteModel(NamedModel):
|
||||
non_concrete = NonConcreteField()
|
||||
point = models.PointField(geography=True)
|
||||
|
||||
|
||||
class ManyPointModel(NamedModel):
|
||||
point1 = models.PointField()
|
||||
point2 = models.PointField()
|
||||
point3 = models.PointField(srid=3857)
|
||||
|
@ -1,11 +1,12 @@
|
||||
from unittest import skipUnless
|
||||
|
||||
from django.contrib.gis.db.models import GeometryField, Value, functions
|
||||
from django.contrib.gis.db.models import F, GeometryField, Value, functions
|
||||
from django.contrib.gis.geos import Point, Polygon
|
||||
from django.db import connection
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
|
||||
from ..utils import postgis
|
||||
from .models import City
|
||||
from .models import City, ManyPointModel
|
||||
|
||||
|
||||
@skipUnlessDBFeature('gis_enabled')
|
||||
@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase):
|
||||
p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1)))
|
||||
area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a
|
||||
self.assertAlmostEqual(area.sq_km, 12305.1, 0)
|
||||
|
||||
def test_update_from_other_field(self):
|
||||
p1 = Point(1, 1, srid=4326)
|
||||
p2 = Point(2, 2, srid=4326)
|
||||
obj = ManyPointModel.objects.create(
|
||||
point1=p1,
|
||||
point2=p2,
|
||||
point3=p2.transform(3857, clone=True),
|
||||
)
|
||||
# Updating a point to a point of the same SRID.
|
||||
ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1'))
|
||||
obj.refresh_from_db()
|
||||
self.assertEqual(obj.point2, p1)
|
||||
# Updating a point to a point with a different SRID.
|
||||
if connection.features.supports_transform:
|
||||
ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1'))
|
||||
obj.refresh_from_db()
|
||||
self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1))
|
||||
|
||||
@skipUnlessDBFeature('has_Translate_function')
|
||||
def test_update_with_expression(self):
|
||||
city = City.objects.create(point=Point(1, 1, srid=4326))
|
||||
City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1))
|
||||
city.refresh_from_db()
|
||||
self.assertEqual(city.point, Point(2, 2, srid=4326))
|
||||
|
Loading…
Reference in New Issue
Block a user