From e7afef13f594eb667f2709c0ef7bca98452ab32b Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Mon, 10 Apr 2017 22:26:26 +0500 Subject: [PATCH] Fixed #26788 -- Fixed QuerySet.update() crash when updating a geometry to another one. --- .../gis/db/backends/base/operations.py | 24 ++++++++++++++- .../gis/db/backends/mysql/operations.py | 12 -------- .../gis/db/backends/oracle/operations.py | 26 +--------------- .../gis/db/backends/postgis/operations.py | 13 ++++---- .../gis/db/backends/spatialite/operations.py | 26 ---------------- django/contrib/gis/db/models/lookups.py | 2 ++ django/db/backends/mysql/operations.py | 2 +- django/db/models/sql/compiler.py | 2 +- tests/gis_tests/geoapp/models.py | 6 ++++ tests/gis_tests/geoapp/test_expressions.py | 30 +++++++++++++++++-- 10 files changed, 68 insertions(+), 75 deletions(-) diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index 0732d517c6..1b314a9d2b 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -71,7 +71,29 @@ class BaseSpatialOperations: stored procedure call to the transformation function of the spatial backend. """ - raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') + def transform_value(value, field): + return ( + not (value is None or value.srid == field.srid) and + self.connection.features.supports_transform + ) + + if hasattr(value, 'as_sql'): + return ( + '%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid) + if transform_value(value.output_field, f) + else '%s' + ) + if transform_value(value, f): + # Add Transform() to the SQL placeholder. + return '%s(%s(%%s,%s), %s)' % ( + self.spatial_function_name('Transform'), + self.from_text, value.srid, f.srid, + ) + elif self.connection.features.has_spatialrefsys_table: + return '%s(%%s,%s)' % (self.from_text, f.srid) + else: + # For backwards compatibility on MySQL (#27464). + return '%s(%%s)' % self.from_text def check_expression_support(self, expression): if isinstance(expression, self.disallowed_aggregates): diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 0ea9f8f274..1c850b39d1 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): def geo_db_type(self, f): return f.geom_type - def get_geom_placeholder(self, f, value, compiler): - """ - The placeholder here has to include MySQL's WKT constructor. Because - MySQL does not support spatial transformations, there is no need to - modify the placeholder based on the contents of the given value. - """ - if hasattr(value, 'as_sql'): - placeholder, _ = compiler.compile(value) - else: - placeholder = '%s(%%s)' % self.from_text - return placeholder - def get_db_converters(self, expression): converters = super().get_db_converters(expression) if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection: diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 926209ac40..50f5bed8b9 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): return [dist_param] def get_geom_placeholder(self, f, value, compiler): - """ - Provide a proper substitution value for Geometries that are not in the - SRID of the field. Specifically, this routine will substitute in the - SDO_CS.TRANSFORM() function call. - """ - tranform_func = self.spatial_function_name('Transform') - if value is None: return 'NULL' - - def transform_value(val, srid): - return val.srid != srid - - if hasattr(value, 'as_sql'): - if transform_value(value, f.srid): - placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) - else: - placeholder = '%s' - # No geometry value used for F expression, substitute in - # the column name instead. - sql, _ = compiler.compile(value) - return placeholder % sql - else: - if transform_value(value, f.srid): - return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid) - else: - return 'SDO_GEOMETRY(%%s, %s)' % f.srid + return super().get_geom_placeholder(f, value, compiler) def spatial_aggregate_name(self, agg_name): """ diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 49662236c0..71302d9eb7 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): substitute in the ST_Transform() function call. """ tranform_func = self.spatial_function_name('Transform') + if hasattr(value, 'as_sql'): + if value.field.srid == f.srid: + placeholder = '%s' + else: + placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) + return placeholder # Get the srid for this object if value is None: @@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): else: placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) - if hasattr(value, 'as_sql'): - # If this is an F expression, then we don't really want - # a placeholder and instead substitute in the column - # of the expression. - sql, _ = compiler.compile(value) - placeholder = placeholder % sql - return placeholder def _get_postgis_func(self, func): diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 42b079bbfa..087a879d07 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): dist_param = value return [dist_param] - def get_geom_placeholder(self, f, value, compiler): - """ - Provide a proper substitution value for Geometries that are not in the - SRID of the field. Specifically, this routine will substitute in the - Transform() and GeomFromText() function call(s). - """ - tranform_func = self.spatial_function_name('Transform') - - def transform_value(value, srid): - return not (value is None or value.srid == srid) - if hasattr(value, 'as_sql'): - if transform_value(value, f.srid): - placeholder = '%s(%%s, %s)' % (tranform_func, f.srid) - else: - placeholder = '%s' - # No geometry value used for F expression, substitute in - # the column name instead. - sql, _ = compiler.compile(value) - return placeholder % sql - else: - if transform_value(value, f.srid): - # Adding Transform() to the SQL placeholder. - return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid) - else: - return '%s(%%s,%s)' % (self.from_text, f.srid) - def _get_spatialite_func(self, func): """ Helper routine for calling SpatiaLite functions and returning diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 2024df45b0..c34e3391e1 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -68,6 +68,8 @@ class GISLookup(Lookup): if not hasattr(geo_fld, 'srid'): raise ValueError('No geographic field found in expression.') self.rhs.srid = geo_fld.srid + sql, _ = compiler.compile(geom) + return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, [] elif isinstance(self.rhs, Expression): raise ValueError('Complex expressions not supported for spatial fields.') elif isinstance(self.rhs, (list, tuple)): diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index b47136df26..c1d0451a54 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations): return value def binary_placeholder_sql(self, value): - return '_binary %s' if value is not None else '%s' + return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s' def subtract_temporals(self, internal_type, lhs, rhs): lhs_sql, lhs_params = lhs diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index f32d106def..14a727e998 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler): name = field.column if hasattr(val, 'as_sql'): sql, params = self.compile(val) - values.append('%s = %s' % (qn(name), sql)) + values.append('%s = %s' % (qn(name), placeholder % sql)) update_params.extend(params) elif val is not None: values.append('%s = %s' % (qn(name), placeholder)) diff --git a/tests/gis_tests/geoapp/models.py b/tests/gis_tests/geoapp/models.py index 363f3deaf0..b555165e56 100644 --- a/tests/gis_tests/geoapp/models.py +++ b/tests/gis_tests/geoapp/models.py @@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField): class NonConcreteModel(NamedModel): non_concrete = NonConcreteField() point = models.PointField(geography=True) + + +class ManyPointModel(NamedModel): + point1 = models.PointField() + point2 = models.PointField() + point3 = models.PointField(srid=3857) diff --git a/tests/gis_tests/geoapp/test_expressions.py b/tests/gis_tests/geoapp/test_expressions.py index c18d07f0e8..72f9a37dc4 100644 --- a/tests/gis_tests/geoapp/test_expressions.py +++ b/tests/gis_tests/geoapp/test_expressions.py @@ -1,11 +1,12 @@ from unittest import skipUnless -from django.contrib.gis.db.models import GeometryField, Value, functions +from django.contrib.gis.db.models import F, GeometryField, Value, functions from django.contrib.gis.geos import Point, Polygon +from django.db import connection from django.test import TestCase, skipUnlessDBFeature from ..utils import postgis -from .models import City +from .models import City, ManyPointModel @skipUnlessDBFeature('gis_enabled') @@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase): p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1))) area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a self.assertAlmostEqual(area.sq_km, 12305.1, 0) + + def test_update_from_other_field(self): + p1 = Point(1, 1, srid=4326) + p2 = Point(2, 2, srid=4326) + obj = ManyPointModel.objects.create( + point1=p1, + point2=p2, + point3=p2.transform(3857, clone=True), + ) + # Updating a point to a point of the same SRID. + ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1')) + obj.refresh_from_db() + self.assertEqual(obj.point2, p1) + # Updating a point to a point with a different SRID. + if connection.features.supports_transform: + ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1')) + obj.refresh_from_db() + self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1)) + + @skipUnlessDBFeature('has_Translate_function') + def test_update_with_expression(self): + city = City.objects.create(point=Point(1, 1, srid=4326)) + City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1)) + city.refresh_from_db() + self.assertEqual(city.point, Point(2, 2, srid=4326))